交叉熵损失公式_生产损失一般包括什么

交叉熵损失公式_生产损失一般包括什么在机器学习中(特别是分类模型),模型训练时,通常都是使用交叉熵(Cross-Entropy)作为损失进行最小化:CE(p,q)=−∑i=1Cpilog(qi)CE(p,q)=−∑i=1Cpi

在机器学习中(特别是分类模型),模型训练时,通常都是使用交叉熵(Cross-Entropy)作为损失进行最小化:

C E ( p , q ) = − ∑ i = 1 C p i l o g ( q i ) CE(p,q)=- \sum_{i=1}^{C} p_i log(q_i) CE(p,q)=i=1Cpilog(qi)
其中 C C C代表类别数。 p i p_i pi为真实, q i q_i qi为预测。

我们以MNIST多分类为例,通常Label会编码为One-Hot,最后一层输出会使用Softmax函数进行概率化输出,如下表所示:

Sample True Predicted
这里写图片描述 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0] [0.1, 0.6, 0.3, 0, 0, 0, 0, 0, 0, 0]
这里写图片描述 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0] [0, 0.3, 0.2, 0, 0.5, 0, 0, 0, 0, 0]
这里写图片描述 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0] [0.6, 0.3, 0, 0, 0, 0.1, 0, 0, 0, 0]

对于第一个样本,交叉熵损失为:
− l n ( 0.6 ) ≈ 0.51 -ln(0.6) \approx 0.51 ln(0.6)0.51

对于第二个样本,交叉熵损失为:
− l n ( 0.5 ) ≈ 0.69 -ln(0.5) \approx 0.69 ln(0.5)0.69

对于第三个样本,交叉熵损失为:
− l n ( 0.1 ) ≈ 2.30 -ln(0.1) \approx 2.30 ln(0.1)2.30

平均交叉熵损失为:
− ( l n ( 0.6 ) + l n ( 0.5 ) + l n ( 0.1 ) ) 3 ≈ 1.17 -\frac{(ln(0.6)+ln(0.5)+ln(0.1))}{3} \approx 1.17 3(ln(0.6)+ln(0.5)+ln(0.1))1.17

从上面的计算可以知道,预测越准,损失越小。

Scikit-learn中提供了交叉熵损失的计算方法:

from sklearn.metrics import log_loss true = ['1', '4', '5'] pred=[[0.1, 0.6, 0.3, 0, 0, 0, 0, 0, 0, 0], [0, 0.3, 0.2, 0, 0.5, 0, 0, 0, 0, 0], [0.6, 0.3, 0, 0, 0, 0.1, 0, 0, 0, 0]] labels=['0','1','2','3','4','5','6','7','8','9'] log_loss(true, pred, labels) Out: 1.00008 

为什么训练时采取交叉熵损失,而不用均方误差(Mean Squared Error, MSE)呢?

Why You Should Use Cross-Entropy Error Instead Of Classification Error Or Mean Squared Error For Neural Network Classifier Training -> 翻译版

今天的文章
交叉熵损失公式_生产损失一般包括什么分享到此就结束了,感谢您的阅读。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:http://bianchenghao.cn/89703.html

(0)
编程小号编程小号

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注