Python 绘制混淆矩阵

发布时间:2024-01-29 15:00

这篇的文章的好多代码都源自于博客,我只是把他们重新整合,然后变成了我需要的漂亮的,适合放在论文中图片代码。 

参考链接:

https://blog.csdn.net/weixin_38314865/article/details/88989506

https://www.cnblogs.com/ZHANG576433951/p/11233159.html

https://blog.csdn.net/qq_37851620/article/details/100642566?utm_source=app&app_version=4.7.1

https://blog.csdn.net/Poul_henry/article/details/88294297

https://mathpretty.com/10675.html

import numpy as np
import itertools
import matplotlib.pyplot as plt


# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title=\'Confusion matrix\', cmap=plt.cm.Blues):
    \"\"\"
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    Input
    - cm : 计算出的混淆矩阵的值
    - classes : 混淆矩阵中每一行每一列对应的列
    - normalize : True:显示百分比, False:显示个数
    \"\"\"
    if normalize:
        matrix = cm
        cm = cm.astype(\'float\') / cm.sum(axis=1)[:, np.newaxis]
        print(\"Normalized confusion matrix\")
    else:
        print(\'Confusion matrix, without normalization\')
    plt.figure()
    # 设置输出的图片大小
    figsize = 8, 6
    figure, ax = plt.subplots(figsize=figsize)
    plt.imshow(cm, interpolation=\'nearest\', cmap=cmap)
    # 设置title的大小以及title的字体
    font_title= {\'family\': \'Times New Roman\',
                  \'weight\': \'normal\',
                  \'size\': 15,
                  }
    plt.title(title,fontdict=font_title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45,)
    plt.yticks(tick_marks, classes)
    # 设置坐标刻度值的大小以及刻度值的字体
    plt.tick_params(labelsize=15)
    labels = ax.get_xticklabels() + ax.get_yticklabels()
    print (labels)
    [label.set_fontname(\'Times New Roman\') for label in labels]
    if normalize:
        fm_int = \'d\'
        fm_float = \'.3%\'
        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, format(cm[i, j], fm_float),
                     horizontalalignment=\"center\", verticalalignment=\'bottom\',family = \"Times New Roman\", weight = \"normal\",size = 15,
                     color=\"white\" if cm[i, j] > thresh else \"black\")
            plt.text(j, i, format(matrix[i, j], fm_int),
                     horizontalalignment=\"center\", verticalalignment=\'top\',family = \"Times New Roman\", weight = \"normal\",size = 15,
                     color=\"white\" if cm[i, j] > thresh else \"black\")
    else:
        fm_int = \'d\'
        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, format(cm[i, j], fm_int),
                     horizontalalignment=\"center\", verticalalignment=\'bottom\',
                     color=\"white\" if cm[i, j] > thresh else \"black\")

    plt.tight_layout()
    # 设置横纵坐标的名称以及对应字体格式
    # font_lable = {\'family\': \'Times New Roman\',
    #               \'weight\': \'normal\',
    #               \'size\': 15,
    #               }
    # plt.ylabel(\'True label\', font_lable)
    # plt.xlabel(\'Predicted label\', font_lable)

    plt.savefig(\'confusion_matrix.eps\', dpi=600, format=\'eps\')
    plt.savefig(\'confusion_matrix.png\', dpi=600, format=\'png\')


cnf_matrix = np.array([[109653, 2, 0, 1, 0],
                       [0, 104180, 2, 0, 0],
                       [1, 0, 110422, 1, 0],
                       [9, 1, 1, 104380, 0],
                       [13, 0, 0, 3, 767875]])
attack_types = [\'Normal\', \'DoS\', \'Probe\', \'R2L\', \'U2R\']

plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title=\'Normalized confusion matrix\')

 

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号