发布时间:2023-09-02 09:00
\"\"\"
画混淆矩阵,需要(真实标签,预测标签,标签列表)
y_test, y_pred, display_labels
混淆矩阵用: sklearn库中的confusion_matrix
混淆矩阵画图用: sklearn库中的ConfusionMatrixDisplay
matplotlib库中的pyplot
这里用iris数据集做例子,SVM做分类器。
\"\"\"
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.datasets import load_iris
# 加载鸢尾花数据集,即Iris数据集,(训练集,测试集,标签名称)
X = load_iris().data
y = load_iris().target
labels = load_iris().target_names
# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# 创建一个SVM分类器
clf = SVC(random_state=0)
# 训练分类器(classifier, 简称clf)
clf.fit(X_train, y_train)
# 预测分类结果
y_pred = clf.predict(X_test)
# 你可以打印一下预测结果和分类结果
print(\"y_test: \", y_test)
print(\"y_pred: \", y_pred)
# 得到混淆矩阵(confusion matrix,简称cm)
# confusion_matrix 需要的参数:y_true(真实标签),y_pred(预测标签)
cm = confusion_matrix(y_true=y_test, y_pred=y_pred)
# 打印混淆矩阵
print(\"Confusion Matrix: \")
print(cm)
# 画出混淆矩阵
# ConfusionMatrixDisplay 需要的参数: confusion_matrix(混淆矩阵), display_labels(标签名称列表)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot()
plt.show()
得到的 “输出” 和 “混淆矩阵” 如下所示:
output:
y_test: [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0 1]
y_pred: [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0 2]
Confusion Matrix:
[[13 0 0]
[ 0 15 1]
[ 0 0 9]]
混淆矩阵图片:
这个图片看起来还不错,有那个味道了,但是我们看到主对角线颜色貌似不太一样,这是因为我们没有归一化(normalized),因为混淆矩阵在分类领域主要是希望所有的数量都集中在主对角线上,颜色最好是相似的,要不然有点迷惑。
这里我们再设置混淆矩阵为归一化格式,然后看什么效果。在confusion_matrix函数中加入了normalize选项,\'true’代表按照真实标签归一化,\'pred’按照预测标签归一化,\'all’对所有值归一化。
# 得到混淆矩阵(confusion matrix,简称cm)
# confusion_matrix 需要的参数:y_true(真实标签),y_pred(预测标签),normalize(归一化,\'true\', \'pred\', \'all\')
cm = confusion_matrix(y_true=y_test, y_pred=y_pred, normalize=\'true\')
output:
Confusion Matrix:
[[1. 0. 0. ]
[0. 0.9375 0.0625]
[0. 0. 1. ]]
混淆矩阵如下:
在对比原来的图是不是好多了?
第1类setosa和第3类virginica分类都正确,召回率为100%。
第2行第3列中,有0.062(即6.2%)的versicolor类的鸢尾花分成了virginica类。94%的分类正确了。
Note:另外,disp.plot()函数内还可以加其他参数,如cmap,意思是colormap,有很多种类型。
supported values are \'Accent\', \'Accent_r\', \'Blues\', \'Blues_r\', \'BrBG\', \'BrBG_r\', \'BuGn\', \'BuGn_r\', \'BuPu\', \'BuPu_r\', \'CMRmap\', \'CMRmap_r\', \'Dark2\', \'Dark2_r\', \'GnBu\', \'GnBu_r\', \'Greens\', \'Greens_r\', \'Greys\', \'Greys_r\', \'OrRd\', \'OrRd_r\', \'Oranges\', \'Oranges_r\', \'PRGn\', \'PRGn_r\', \'Paired\', \'Paired_r\', \'Pastel1\', \'Pastel1_r\', \'Pastel2\', \'Pastel2_r\', \'PiYG\', \'PiYG_r\', \'PuBu\', \'PuBuGn\', \'PuBuGn_r\', \'PuBu_r\', \'PuOr\', \'PuOr_r\', \'PuRd\', \'PuRd_r\', \'Purples\', \'Purples_r\', \'RdBu\', \'RdBu_r\', \'RdGy\', \'RdGy_r\', \'RdPu\', \'RdPu_r\', \'RdYlBu\', \'RdYlBu_r\', \'RdYlGn\', \'RdYlGn_r\', \'Reds\', \'Reds_r\', \'Set1\', \'Set1_r\', \'Set2\', \'Set2_r\', \'Set3\', \'Set3_r\', \'Spectral\', \'Spectral_r\', \'Wistia\', \'Wistia_r\', \'YlGn\', \'YlGnBu\', \'YlGnBu_r\', \'YlGn_r\', \'YlOrBr\', \'YlOrBr_r\', \'YlOrRd\', \'YlOrRd_r\', \'afmhot\', \'afmhot_r\', \'autumn\', \'autumn_r\', \'binary\', \'binary_r\', \'bone\', \'bone_r\', \'brg\', \'brg_r\', \'bwr\', \'bwr_r\', \'cividis\', \'cividis_r\', \'cool\', \'cool_r\', \'coolwarm\', \'coolwarm_r\', \'copper\', \'copper_r\', \'crest\', \'crest_r\', \'cubehelix\', \'cubehelix_r\', \'flag\', \'flag_r\', \'flare\', \'flare_r\', \'gist_earth\', \'gist_earth_r\', \'gist_gray\', \'gist_gray_r\', \'gist_heat\', \'gist_heat_r\', \'gist_ncar\', \'gist_ncar_r\', \'gist_rainbow\', \'gist_rainbow_r\', \'gist_stern\', \'gist_stern_r\', \'gist_yarg\', \'gist_yarg_r\', \'gnuplot\', \'gnuplot2\', \'gnuplot2_r\', \'gnuplot_r\', \'gray\', \'gray_r\', \'hot\', \'hot_r\', \'hsv\', \'hsv_r\', \'icefire\', \'icefire_r\', \'inferno\', \'inferno_r\', \'jet\', \'jet_r\', \'magma\', \'magma_r\', \'mako\', \'mako_r\', \'nipy_spectral\', \'nipy_spectral_r\', \'ocean\', \'ocean_r\', \'pink\', \'pink_r\', \'plasma\', \'plasma_r\', \'prism\', \'prism_r\', \'rainbow\', \'rainbow_r\', \'rocket\', \'rocket_r\', \'seismic\', \'seismic_r\', \'spring\', \'spring_r\', \'summer\', \'summer_r\', \'tab10\', \'tab10_r\', \'tab20\', \'tab20_r\', \'tab20b\', \'tab20b_r\', \'tab20c\', \'tab20c_r\', \'terrain\', \'terrain_r\', \'turbo\', \'turbo_r\', \'twilight\', \'twilight_r\', \'twilight_shifted\', \'twilight_shifted_r\', \'viridis\', \'viridis_r\', \'vlag\', \'vlag_r\', \'winter\', \'winter_r\'
详见https://matplotlib.org/stable/gallery/color/colormap_reference.html
参考地址: