网上关于混淆矩阵的代码参差不齐,没找到可用的线程的代码,所以自己尝试写了下
1、混淆矩阵:Confusion Matrix
首先它长这样:
怎么看?
Confusion Matrix最广泛的应用应该是分类,比如图中是7分类的真实标签和预测标签的效果。
首先图中表明了纵轴是truth label,横轴是predicted label,那么对于第一行第一个0.60的含义是:本来是angry标签的图,我的模型正确分类成angry的比例是60%,也即是angry这一类模型分类正确的精度只有60%。同时模型将angry分类成了happy的图占比0.04%,其他的以此类推。
注意:因为本身是angry,模型预测成7种类的数量占比。所以每一行的和为100%。
同时对于fear标签,模型分类成fear的占比41%,分类成sad的占比为20%,我们可以认为模型不能很好区分fear和sad两种类别。
2、怎么画(新)?
这里直接给出代码,在下一节中直接使用即可
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
def draw_confusion_matrix(label_true, label_pred, label_name, normlize, title="Confusion Matrix", pdf_save_path=None, dpi=100):
""" @param label_true: 真实标签,比如[0,1,2,7,4,5,...] @param label_pred: 预测标签,比如[0,5,4,2,1,4,...] @param label_name: 标签名字,比如['cat','dog','flower',...] @param normlize: 是否设元素为百分比形式 @param title: 图标题 @param pdf_save_path: 是否保存,是则为保存路径pdf_save_path=xxx.png | xxx.pdf | ...等其他plt.savefig支持的保存格式 @param dpi: 保存到文件的分辨率,论文一般要求至少300dpi @return: example: draw_confusion_matrix(label_true=y_gt, label_pred=y_pred, label_name=["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"], normlize=True, title="Confusion Matrix on Fer2013", pdf_save_path="Confusion_Matrix_on_Fer2013.png", dpi=300) """
cm = confusion_matrix(label_true, label_pred)
if normlize:
row_sums = np.sum(cm, axis=1) # 计算每行的和
cm = cm / row_sums[:, np.newaxis] # 广播计算每个元素占比
plt.imshow(cm, cmap='Blues')
plt.title(title)
plt.xlabel("Predict label")
plt.ylabel("Truth label")
plt.yticks(range(label_name.__len__()), label_name)
plt.xticks(range(label_name.__len__()), label_name, rotation=45)
plt.tight_layout()
plt.colorbar()
for i in range(label_name.__len__()):
for j in range(label_name.__len__()):
color = (1, 1, 1) if i == j else (0, 0, 0) # 对角线字体白色,其他黑色
value = float(format('%.2f' % cm[i, j]))
plt.text(i, j, value, verticalalignment='center', horizontalalignment='center', color=color)
# plt.show()
if not pdf_save_path is None:
plt.savefig(pdf_save_path, bbox_inches='tight',dpi=dpi)
3、怎么用?
给出一个简单的实例:
labels_name=['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']
y_gt=[]
y_pred=[]
for index, (labels, imgs) in enumerate(test_loader):
labels_pd = model(imgs)
predict_np = np.argmax(labels_pd.cpu().detach().numpy(), axis=-1) # array([0,5,1,6,3,...],dtype=int64)
labels_np = labels.numpy() # array([0,5,0,6,2,...],dtype=int64)
y_pred.append(labels_np)
y_gt.append(labels_np)
draw_confusion_matrix(label_true=y_gt, # y_gt=[0,5,1,6,3,...]
label_pred=y_pred, # y_pred=[0,5,1,6,3,...]
label_name=["An", "Di", "Fe", "Ha", "Sa", "Su", "Ne"],
normlize=True,
title="Confusion Matrix on Fer2013",
pdf_save_path="Confusion_Matrix_on_Fer2013.jpg",
dpi=300)
- cpu().detach():从device上获取数据
- .numpy():将tensor类型转换为numpy类型
在我的模型上的结果:
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:http://bianchenghao.cn/38150.html