【目标检测】LLA: Loss-aware label assignment for dense pedestrian detection【标签分配】
创始人
2024-02-19 12:53:26
0

总结

本文提出了一种用于行人目标检测的标签分配策略,具体来说,主要有以下几步流程。

  1. 构建代价矩阵。通过网络的前向传播得到网络的输出,CclsC^{cls}Ccls, CregC^{reg}Creg,构建代价矩阵C=Ccls+λ∗CregC=C^{cls}+\lambda*C^{reg}C=Ccls+λ∗Creg
  2. 选取代价矩阵中的前TOP K个候选框(即 loss比较小的),作为正样本,其他的为负样本。
  3. 为了加速收敛,强制正样本候选区域在gt框内。

本文的作者和YOLOX是同一个作者,YOLOX的标签分配策略,可以看做在本文上面进行了稍微的更改。

更多的细节

  1. TOP K,超参的敏感性
    作者通过做实验发现,TOP K在一定范围内是不敏感的在这里插入图片描述
  2. 代价矩阵中各部分消融实验研究
    在这里插入图片描述
  3. 可视化结果
    在这里插入图片描述

代码

参考连接

def get_lla_assignments_and_losses(self, shifts, targets, box_cls, box_delta, box_iou):gt_classes = []box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls]box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta]box_iou = [permute_to_N_HWA_K(x, 1) for x in box_iou]box_cls = torch.cat(box_cls, dim=1)box_delta = torch.cat(box_delta, dim=1)box_iou = torch.cat(box_iou, dim=1)losses_cls = []losses_box_reg = []losses_iou = []num_fg = 0for shifts_per_image, targets_per_image, box_cls_per_image, \box_delta_per_image, box_iou_per_image in zip(shifts, targets, box_cls, box_delta, box_iou):shifts_over_all = torch.cat(shifts_per_image, dim=0)gt_boxes = targets_per_image.gt_boxesgt_classes = targets_per_image.gt_classesdeltas = self.shift2box_transform.get_deltas(shifts_over_all, gt_boxes.tensor.unsqueeze(1))is_in_boxes = deltas.min(dim=-1).values > 0.01shape = (len(targets_per_image), len(shifts_over_all), -1)box_cls_per_image_unexpanded = box_cls_per_imagebox_delta_per_image_unexpanded = box_delta_per_imagebox_cls_per_image = box_cls_per_image.unsqueeze(0).expand(shape)gt_cls_per_image = F.one_hot(torch.max(gt_classes, torch.zeros_like(gt_classes)), self.num_classes).float().unsqueeze(1).expand(shape)with torch.no_grad():loss_cls = sigmoid_focal_loss_jit(box_cls_per_image,gt_cls_per_image,alpha=self.focal_loss_alpha,gamma=self.focal_loss_gamma).sum(dim=-1)loss_cls_bg = sigmoid_focal_loss_jit(box_cls_per_image_unexpanded,torch.zeros_like(box_cls_per_image_unexpanded),alpha=self.focal_loss_alpha,gamma=self.focal_loss_gamma).sum(dim=-1)box_delta_per_image = box_delta_per_image.unsqueeze(0).expand(shape)gt_delta_per_image = self.shift2box_transform.get_deltas(shifts_over_all, gt_boxes.tensor.unsqueeze(1))loss_delta = iou_loss(box_delta_per_image,gt_delta_per_image,box_mode="ltrb",loss_type='iou')ious = get_ious(box_delta_per_image,gt_delta_per_image,box_mode="ltrb",loss_type='iou')loss = loss_cls + self.reg_cost * loss_delta + 1e3 * (1 - is_in_boxes.float())loss = torch.cat([loss, loss_cls_bg.unsqueeze(0)], dim=0)num_gt = loss.shape[0] - 1num_anchor = loss.shape[1]# Topkmatching_matrix = torch.zeros_like(loss)_, topk_idx = torch.topk(loss[:-1], k=self.topk, dim=1, largest=False)matching_matrix[torch.arange(num_gt).unsqueeze(1).repeat(1,self.topk).view(-1), topk_idx.view(-1)] = 1.# make sure one anchor with one gtanchor_matched_gt = matching_matrix.sum(0)if (anchor_matched_gt > 1).sum() > 0:loss_min, loss_argmin = torch.min(loss[:-1, anchor_matched_gt > 1], dim=0)matching_matrix[:, anchor_matched_gt > 1] *= 0.matching_matrix[loss_argmin, anchor_matched_gt > 1] = 1.anchor_matched_gt = matching_matrix.sum(0)num_fg += matching_matrix.sum()matching_matrix[-1] = 1. - anchor_matched_gt  # assignment for Backgroundassigned_gt_inds = torch.argmax(matching_matrix, dim=0)gt_cls_per_image_bg = gt_cls_per_image.new_zeros((gt_cls_per_image.size(1), gt_cls_per_image.size(2))).unsqueeze(0)gt_cls_per_image_with_bg = torch.cat([gt_cls_per_image, gt_cls_per_image_bg], dim=0)cls_target_per_image = gt_cls_per_image_with_bg[assigned_gt_inds, torch.arange(num_anchor)]# Dealing with Crowdhuman ignore labelgt_classes_ = torch.cat([gt_classes, gt_classes.new_zeros(1)])anchor_cls_labels = gt_classes_[assigned_gt_inds]valid_flag = anchor_cls_labels >= 0pos_mask = assigned_gt_inds != len(targets_per_image)  # get foreground maskvalid_fg = pos_mask & valid_flagassigned_fg_inds = assigned_gt_inds[valid_fg]range_fg = torch.arange(num_anchor)[valid_fg]ious_fg = ious[assigned_fg_inds, range_fg]anchor_loss_cls = sigmoid_focal_loss_jit(box_cls_per_image_unexpanded[valid_flag],cls_target_per_image[valid_flag],alpha=self.focal_loss_alpha,gamma=self.focal_loss_gamma).sum(dim=-1)delta_target = gt_delta_per_image[assigned_fg_inds, range_fg]anchor_loss_delta = 2. * iou_loss(box_delta_per_image_unexpanded[valid_fg],delta_target,box_mode="ltrb",loss_type=self.iou_loss_type)anchor_loss_iou = 0.5 * F.binary_cross_entropy_with_logits(box_iou_per_image.squeeze(1)[valid_fg],ious_fg,reduction='none')losses_cls.append(anchor_loss_cls.sum())losses_box_reg.append(anchor_loss_delta.sum())losses_iou.append(anchor_loss_iou.sum())if self.norm_sync:dist.all_reduce(num_fg)num_fg = num_fg.float() / dist.get_world_size()return {'loss_cls': torch.stack(losses_cls).sum() / num_fg,'loss_box_reg': torch.stack(losses_box_reg).sum() / num_fg,'loss_iou': torch.stack(losses_iou).sum() / num_fg}

相关内容

热门资讯

喜欢穿一身黑的男生性格(喜欢穿... 今天百科达人给各位分享喜欢穿一身黑的男生性格的知识,其中也会对喜欢穿一身黑衣服的男人人好相处吗进行解...
发春是什么意思(思春和发春是什... 本篇文章极速百科给大家谈谈发春是什么意思,以及思春和发春是什么意思对应的知识点,希望对各位有所帮助,...
网络用语zl是什么意思(zl是... 今天给各位分享网络用语zl是什么意思的知识,其中也会对zl是啥意思是什么网络用语进行解释,如果能碰巧...
为什么酷狗音乐自己唱的歌不能下... 本篇文章极速百科小编给大家谈谈为什么酷狗音乐自己唱的歌不能下载到本地?,以及为什么酷狗下载的歌曲不是...
家里可以做假山养金鱼吗(假山能... 今天百科达人给各位分享家里可以做假山养金鱼吗的知识,其中也会对假山能放鱼缸里吗进行解释,如果能碰巧解...
华为下载未安装的文件去哪找(华... 今天百科达人给各位分享华为下载未安装的文件去哪找的知识,其中也会对华为下载未安装的文件去哪找到进行解...
四分五裂是什么生肖什么动物(四... 本篇文章极速百科小编给大家谈谈四分五裂是什么生肖什么动物,以及四分五裂打一生肖是什么对应的知识点,希...
怎么往应用助手里添加应用(应用... 今天百科达人给各位分享怎么往应用助手里添加应用的知识,其中也会对应用助手怎么添加微信进行解释,如果能...
苏州离哪个飞机场近(苏州离哪个... 本篇文章极速百科小编给大家谈谈苏州离哪个飞机场近,以及苏州离哪个飞机场近点对应的知识点,希望对各位有...
客厅放八骏马摆件可以吗(家里摆... 今天给各位分享客厅放八骏马摆件可以吗的知识,其中也会对家里摆八骏马摆件好吗进行解释,如果能碰巧解决你...