画混淆矩阵sklearn

画混淆矩阵sklearn”””画混淆矩阵,需要(真实标签,预测标签,标签列表)y_test,y_pred,display_labels混淆矩阵用:sklearn库中的confusion_matrix混淆矩阵画图用:sklearn库中的ConfusionMatrixDisplaymatplotlib库中的pyplot这里用iris数据集做例子,SVM做分类器。”””importmatplotlib.pyplot

""" 画混淆矩阵,需要(真实标签,预测标签,标签列表) y_test, y_pred, display_labels 混淆矩阵用: sklearn库中的confusion_matrix 混淆矩阵画图用: sklearn库中的ConfusionMatrixDisplay matplotlib库中的pyplot 这里用iris数据集做例子,SVM做分类器。 """

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.datasets import load_iris

# 加载鸢尾花数据集,即Iris数据集,(训练集,测试集,标签名称)
X = load_iris().data
y = load_iris().target
labels = load_iris().target_names

# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# 创建一个SVM分类器
clf = SVC(random_state=0)

# 训练分类器(classifier, 简称clf)
clf.fit(X_train, y_train)

# 预测分类结果
y_pred = clf.predict(X_test)

# 你可以打印一下预测结果和分类结果
print("y_test: ", y_test)
print("y_pred: ", y_pred)

# 得到混淆矩阵(confusion matrix,简称cm)
# confusion_matrix 需要的参数:y_true(真实标签),y_pred(预测标签)
cm = confusion_matrix(y_true=y_test, y_pred=y_pred)

# 打印混淆矩阵
print("Confusion Matrix: ")
print(cm)

# 画出混淆矩阵
# ConfusionMatrixDisplay 需要的参数: confusion_matrix(混淆矩阵), display_labels(标签名称列表)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot()
plt.show()

得到的 “输出” 和 “混淆矩阵” 如下所示:
output:

y_test:  [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0 1]
y_pred:  [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0 2]
Confusion Matrix: 
[[13  0  0]
 [ 0 15  1]
 [ 0  0  9]]

混淆矩阵图片:
在这里插入图片描述
这个图片看起来还不错,有那个味道了,但是我们看到主对角线颜色貌似不太一样,这是因为我们没有归一化(normalized),因为混淆矩阵在分类领域主要是希望所有的数量都集中在主对角线上,颜色最好是相似的,要不然有点迷惑。

这里我们再设置混淆矩阵为归一化格式,然后看什么效果。在confusion_matrix函数中加入了normalize选项,’true’代表按照真实标签归一化,’pred’按照预测标签归一化,’all’对所有值归一化。

# 得到混淆矩阵(confusion matrix,简称cm)
# confusion_matrix 需要的参数:y_true(真实标签),y_pred(预测标签),normalize(归一化,'true', 'pred', 'all')
cm = confusion_matrix(y_true=y_test, y_pred=y_pred, normalize='true')

output:

Confusion Matrix: 
[[1.     0.     0.    ]
 [0.     0.9375 0.0625]
 [0.     0.     1.    ]]

混淆矩阵如下:
在这里插入图片描述
在对比原来的图是不是好多了?
第1类setosa和第3类virginica分类都正确,召回率为100%。
第2行第3列中,有0.062(即6.2%)的versicolor类的鸢尾花分成了virginica类。94%的分类正确了。

Note:另外,disp.plot()函数内还可以加其他参数,如cmap,意思是colormap,有很多种类型。

supported values are 'Accent', 'Accent_r', 'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', 'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', 'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', 'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', 'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', 'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', 'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', 'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', 'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', 'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cividis', 'cividis_r', 'cool', 'cool_r', 'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'crest', 'crest_r', 'cubehelix', 'cubehelix_r', 'flag', 'flag_r', 'flare', 'flare_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', 'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', 'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', 'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'icefire', 'icefire_r', 'inferno', 'inferno_r', 'jet', 'jet_r', 'magma', 'magma_r', 'mako', 'mako_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', 'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', 'rainbow_r', 'rocket', 'rocket_r', 'seismic', 'seismic_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', 'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', 'terrain_r', 'turbo', 'turbo_r', 'twilight', 'twilight_r', 'twilight_shifted', 'twilight_shifted_r', 'viridis', 'viridis_r', 'vlag', 'vlag_r', 'winter', 'winter_r'

详见https://matplotlib.org/stable/gallery/color/colormap_reference.html


参考地址

  1. https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html#sklearn.metrics.confusion_matrix
  2. https://scikit-learn.org/stable/modules/generated/sklearn.metrics.ConfusionMatrixDisplay.html#sklearn.metrics.ConfusionMatrixDisplay

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/38276.html

(0)
编程小号编程小号

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注