Hard Negatie Mining与Online Hard Example Mining(OHEM)都属于难例挖掘,它是解决目标检测老大难问题的常用办法,运用于R-CNN,fast R-CNN,faster rcnn等two-stage模型与SSD等(有anchor的)one-stage模型训练时的训练方法。
OHEM和难负例挖掘名字上的不同。
- Hard Negative Mining只注意难负例
- OHEM 则注意所有难例,不论正负(Loss大的例子)
难例挖掘的思想可以解决很多样本不平衡/简单样本过多的问题,比如说分类网络,将hard sample 补充到数据集里,重新丢进网络当中,就好像给网络准备一个错题集,哪里不会点哪里。
难例挖掘与非极大值抑制 NMS 一样,都是为了解决目标检测老大难问题(样本不平衡+低召回率)及其带来的副作用。
根据每个RoIs的loss的大小来决定哪些是难样例, 哪些是简单样例, 通过这种方法, 可以更高效的训练网络, 并且可以使得网络获得更小的训练loss
Pytorch实现
def ohem_loss( batch_size, cls_pred, cls_target, loc_pred, loc_target, smooth_l1_sigma=1.0 ): """ Arguments: batch_size (int): number of sampled rois for bbox head training loc_pred (FloatTensor): [R, 4], location of positive rois loc_target (FloatTensor): [R, 4], location of positive rois pos_mask (FloatTensor): [R], binary mask for sampled positive rois cls_pred (FloatTensor): [R, C] cls_target (LongTensor): [R] Returns: cls_loss, loc_loss (FloatTensor) """ ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1) ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target, sigma=smooth_l1_sigma, reduce=False) #这里先暂存下正常的分类loss和回归loss loss = ohem_cls_loss + ohem_loc_loss #然后对分类和回归loss求和 sorted_ohem_loss, idx = torch.sort(loss, descending=True) #再对loss进行降序排列 keep_num = min(sorted_ohem_loss.size()[0], batch_size) #得到需要保留的loss数量 if keep_num < sorted_ohem_loss.size()[0]: #这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留 keep_idx_cuda = idx[:keep_num] #保留到需要keep的数目 ohem_cls_loss = ohem_cls_loss[keep_idx_cuda] ohem_loc_loss = ohem_loc_loss[keep_idx_cuda] #分类和回归保留相同的数目 cls_loss = ohem_cls_loss.sum() / keep_num loc_loss = ohem_loc_loss.sum() / keep_num #然后分别对分类和回归loss求均值 return cls_loss, loc_loss
今天的文章OHEM(Online Hard Example Mining)在线难例挖掘(在线困难样例挖掘) & HNM (目标检测)分享到此就结束了,感谢您的阅读。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/10222.html