DenseFusion系列代码全讲解目录:【DenseFusion系列目录】代码全讲解+可视化+计算评估指标_Panpanpan!的博客-CSDN博客
这些内容均为个人学习记录,欢迎大家提出错误一起讨论一起学习!
该部分是对refine网络部分的loss进行计算。代码位置在lib/loss_refiner.py
这里对代码的理解和loss.py的没有很多区别,有些一样的过程请参照loss.py
什么时候进行loss_refiner的计算呢?
如果refine过程没有开始,则进行的是PoseNet—loss—loss.backward()过程,如果开始了refine过程,则主干网络停止训练,改为eval模式,PoseRefineNet改为train模式,经过PoseNet—loss之后,将loss所输出的new_points输入到PoseRefineNet中进行训练,输出预测的姿态,然后就应该计算loss_refiner了。
train.py中对其的使用过程为:
from lib.loss_refiner import Loss_refine #第27行-首先import
criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list) #第109行-初始化
dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points) #第145行-进行forward过程求解loss_refine
首先来看Loss_refine类的定义:
from torch.nn.modules.loss import _Loss
class Loss_refine(_Loss):
def __init__(self, num_points_mesh, sym_list):
super(Loss_refine, self).__init__(True)
self.num_pt_mesh = num_points_mesh
self.sym_list = sym_list
def forward(self, pred_r, pred_t, target, model_points, idx, points):
return loss_calculation(pred_r, pred_t, target, model_points, idx, points, self.num_pt_mesh, self.sym_list)
Loss_refine类继承了torch.nn.modules.loss类,用于定义自己的损失函数。首先进行初始化,参数有mesh点数以及对称物体编号,在上述train.py的 criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list) 中实现初始化,定义criterion_refine。然后在训练过程中,使用criterion_refine()的时候调用forward函数,括号中的参数信息为:
- pred_r:refine预测的旋转R,大小为torch.Size([1, 4]),4是四元数表示,这里只有一个像素,为主干网络选取的置信度最大的像素的预测
- pred_t:预测的平移t,torch.Size([1, 3])
- target:目标点云,torch.Size([1, 500, 3]) ,为loss计算之后由target逆转而来的new_target
- model_points:模型第一帧的点云,torch.Size([1, 500, 3])
- idx:类别编号,torch.Size([1, 1])
- points:筛选的500个点云,torch.Size([1, 500, 3]),为loss计算之后由points逆转而来的new_points
forward函数中调用的计算loss_refine的函数loss_calculation,下面一行一行分析。
首先,传入的参数除了上述之外加了初始化的两个参数:
def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list):
knn = KNearestNeighbor(1)
pred_r = pred_r.view(1, 1, -1)
pred_t = pred_t.view(1, 1, -1)
bs, num_p, _ = pred_r.size()
num_input_points = len(points[0])
pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
第一行定义了KNN算法,为了处理对称物体,后续再详细介绍。然后将pred_r和pred_t转换成大小为torch.Size([1, 1, 4])和torch.Size([1, 1, 3])。然后获取pred_r的大小,bs为1,num_p为1,获取输入点云个数,linemod为500.然后对pred_r进行标准化,torch.norm(pred_r, dim=2) 用来对pred_r在最后一维上求L2范数。接着:
base = torch.cat(((1.0 - 2.0*(pred_r[:, :, 2]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1),\
(2.0*pred_r[:, :, 1]*pred_r[:, :, 2] - 2.0*pred_r[:, :, 0]*pred_r[:, :, 3]).view(bs, num_p, 1), \
(2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
(2.0*pred_r[:, :, 1]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 3]*pred_r[:, :, 0]).view(bs, num_p, 1), \
(1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1), \
(-2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
(-2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
(2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
(1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 2]**2)).view(bs, num_p, 1)), dim=2).contiguous().view(bs * num_p, 3, 3)
这一大段就是求旋转矩阵R,DenseFusion中使用的是四元数(常用的四元数、欧拉角等)来表示旋转矩阵,网络回归出的是4个数值,现在要把它们转换成原始的9个数值,公式如下:
上述求base的过程就是该公式的实现,base的大小为torch.Size([1, 3, 3]) 。
ori_base = base
base = base.contiguous().transpose(2, 1).contiguous()
model_points = model_points.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
target = target.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
ori_target = target
pred_t = pred_t.contiguous().view(bs * num_p, 1, 3)
ori_t = pred_t
pred = torch.add(torch.bmm(model_points, base), pred_t)
这里model_points和target实际没变,因为就一个像素,没有进行复制,最后,计算预测的模型pred,这里为什么没有加上points(loss.py里面加了points具体详见loss.py),我理解的是refine网络预测的就是绝对的平移,因为它的思想是逐点预测,但这里只有一个像素的预测结果,就没有必要加上points了。
if idx[0].item() in sym_list:
target = target[0].transpose(1, 0).contiguous().view(3, -1)
pred = pred.permute(2, 0, 1).contiguous().view(3, -1)
inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
target = torch.index_select(target, 1, inds.view(-1) - 1)
target = target.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()
pred = pred.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()
这里就是对对称物体计算ADD-S。
dis = torch.mean(torch.norm((pred - target), dim=2), dim=1)
对非对称物体计算ADD。也就是说,dis为该像素的loss,对于对称物体就是ADD-S的值,非对称物体就是ADD的值。
t = ori_t[0]
points = points.view(1, num_input_points, 3)
ori_base = ori_base[0].view(1, 3, 3).contiguous()
ori_t = t.repeat(bs * num_input_points, 1).contiguous().view(1, bs * num_input_points, 3)
new_points = torch.bmm((points - ori_t), ori_base).contiguous()
new_target = ori_target[0].view(1, num_point_mesh, 3).contiguous()
ori_t = t.repeat(num_point_mesh, 1).contiguous().view(1, num_point_mesh, 3)
new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()
# print('------------> ', dis.item(), idx[0].item())
del knn
return dis, new_points.detach(), new_target.detach()
根据新的旋转和平移对points和target进行逆转操作。输出dis、new_points和new_target。可以看出,逆转操作是一种不断纠正姿态的过程,这也是迭代自优化的主要思想,将上一过程的预测结果作为下一过程的输入。
今天的文章dice loss代码_crossentropy loss分享到此就结束了,感谢您的阅读。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/65179.html