mse pytorch_torch和pytorch

mse pytorch_torch和pytorch在 pytorch 中 经常使用 nn MSELoss 作为损失函数 例如 loss nn MSELoss input torch randn 3 5 requires grad True target torch randn 3 5 error loss input target error backward 这个地方有一个巨坑 就是一定要小心 input 和 target 的位置

在pytorch中,经常使用nn.MSELoss作为损失函数,例如

loss=nn.MSELoss()
input=torch.randn(3,5,requires_grad=True)
target=torch.randn(3,5)
error=loss(input,target)
error.backward()

这个地方有一个巨坑,就是一定要小心input和target的位置,说的更具体一些,target一定需要是一个不能被训练更新的、requires_grad=False的值,否则会报错!!!

另外,关于MSELoss的设定

若设定loss=torch.nn.MSELoss(reduction=’mean’),最终输出值是(target-input)每个元素数字平方和除以width x height,也就是在batch和特征维度上都做了平均。如果只想在batch上做平均,则可以写成这个样子:

#需要注意的是,这里的input和target是mini-batch的形式
loss=torch.nn.MSELoss(reduction='sum')
loss=loss(input,target)/target.size(0)
编程小号
上一篇 2025-02-07 10:57
下一篇 2025-03-09 12:46

相关推荐

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/hz/131574.html