本文提出了一种用于行人目标检测的标签分配策略,具体来说,主要有以下几步流程。
本文的作者和YOLOX是同一个作者,YOLOX的标签分配策略,可以看做在本文上面进行了稍微的更改。
参考连接
def get_lla_assignments_and_losses(self, shifts, targets, box_cls, box_delta, box_iou):gt_classes = []box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls]box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta]box_iou = [permute_to_N_HWA_K(x, 1) for x in box_iou]box_cls = torch.cat(box_cls, dim=1)box_delta = torch.cat(box_delta, dim=1)box_iou = torch.cat(box_iou, dim=1)losses_cls = []losses_box_reg = []losses_iou = []num_fg = 0for shifts_per_image, targets_per_image, box_cls_per_image, \box_delta_per_image, box_iou_per_image in zip(shifts, targets, box_cls, box_delta, box_iou):shifts_over_all = torch.cat(shifts_per_image, dim=0)gt_boxes = targets_per_image.gt_boxesgt_classes = targets_per_image.gt_classesdeltas = self.shift2box_transform.get_deltas(shifts_over_all, gt_boxes.tensor.unsqueeze(1))is_in_boxes = deltas.min(dim=-1).values > 0.01shape = (len(targets_per_image), len(shifts_over_all), -1)box_cls_per_image_unexpanded = box_cls_per_imagebox_delta_per_image_unexpanded = box_delta_per_imagebox_cls_per_image = box_cls_per_image.unsqueeze(0).expand(shape)gt_cls_per_image = F.one_hot(torch.max(gt_classes, torch.zeros_like(gt_classes)), self.num_classes).float().unsqueeze(1).expand(shape)with torch.no_grad():loss_cls = sigmoid_focal_loss_jit(box_cls_per_image,gt_cls_per_image,alpha=self.focal_loss_alpha,gamma=self.focal_loss_gamma).sum(dim=-1)loss_cls_bg = sigmoid_focal_loss_jit(box_cls_per_image_unexpanded,torch.zeros_like(box_cls_per_image_unexpanded),alpha=self.focal_loss_alpha,gamma=self.focal_loss_gamma).sum(dim=-1)box_delta_per_image = box_delta_per_image.unsqueeze(0).expand(shape)gt_delta_per_image = self.shift2box_transform.get_deltas(shifts_over_all, gt_boxes.tensor.unsqueeze(1))loss_delta = iou_loss(box_delta_per_image,gt_delta_per_image,box_mode="ltrb",loss_type='iou')ious = get_ious(box_delta_per_image,gt_delta_per_image,box_mode="ltrb",loss_type='iou')loss = loss_cls + self.reg_cost * loss_delta + 1e3 * (1 - is_in_boxes.float())loss = torch.cat([loss, loss_cls_bg.unsqueeze(0)], dim=0)num_gt = loss.shape[0] - 1num_anchor = loss.shape[1]# Topkmatching_matrix = torch.zeros_like(loss)_, topk_idx = torch.topk(loss[:-1], k=self.topk, dim=1, largest=False)matching_matrix[torch.arange(num_gt).unsqueeze(1).repeat(1,self.topk).view(-1), topk_idx.view(-1)] = 1.# make sure one anchor with one gtanchor_matched_gt = matching_matrix.sum(0)if (anchor_matched_gt > 1).sum() > 0:loss_min, loss_argmin = torch.min(loss[:-1, anchor_matched_gt > 1], dim=0)matching_matrix[:, anchor_matched_gt > 1] *= 0.matching_matrix[loss_argmin, anchor_matched_gt > 1] = 1.anchor_matched_gt = matching_matrix.sum(0)num_fg += matching_matrix.sum()matching_matrix[-1] = 1. - anchor_matched_gt # assignment for Backgroundassigned_gt_inds = torch.argmax(matching_matrix, dim=0)gt_cls_per_image_bg = gt_cls_per_image.new_zeros((gt_cls_per_image.size(1), gt_cls_per_image.size(2))).unsqueeze(0)gt_cls_per_image_with_bg = torch.cat([gt_cls_per_image, gt_cls_per_image_bg], dim=0)cls_target_per_image = gt_cls_per_image_with_bg[assigned_gt_inds, torch.arange(num_anchor)]# Dealing with Crowdhuman ignore labelgt_classes_ = torch.cat([gt_classes, gt_classes.new_zeros(1)])anchor_cls_labels = gt_classes_[assigned_gt_inds]valid_flag = anchor_cls_labels >= 0pos_mask = assigned_gt_inds != len(targets_per_image) # get foreground maskvalid_fg = pos_mask & valid_flagassigned_fg_inds = assigned_gt_inds[valid_fg]range_fg = torch.arange(num_anchor)[valid_fg]ious_fg = ious[assigned_fg_inds, range_fg]anchor_loss_cls = sigmoid_focal_loss_jit(box_cls_per_image_unexpanded[valid_flag],cls_target_per_image[valid_flag],alpha=self.focal_loss_alpha,gamma=self.focal_loss_gamma).sum(dim=-1)delta_target = gt_delta_per_image[assigned_fg_inds, range_fg]anchor_loss_delta = 2. * iou_loss(box_delta_per_image_unexpanded[valid_fg],delta_target,box_mode="ltrb",loss_type=self.iou_loss_type)anchor_loss_iou = 0.5 * F.binary_cross_entropy_with_logits(box_iou_per_image.squeeze(1)[valid_fg],ious_fg,reduction='none')losses_cls.append(anchor_loss_cls.sum())losses_box_reg.append(anchor_loss_delta.sum())losses_iou.append(anchor_loss_iou.sum())if self.norm_sync:dist.all_reduce(num_fg)num_fg = num_fg.float() / dist.get_world_size()return {'loss_cls': torch.stack(losses_cls).sum() / num_fg,'loss_box_reg': torch.stack(losses_box_reg).sum() / num_fg,'loss_iou': torch.stack(losses_iou).sum() / num_fg}