发布时间:2024-01-29 15:00
这篇的文章的好多代码都源自于博客,我只是把他们重新整合,然后变成了我需要的漂亮的,适合放在论文中图片代码。
参考链接:
https://blog.csdn.net/weixin_38314865/article/details/88989506
https://www.cnblogs.com/ZHANG576433951/p/11233159.html
https://blog.csdn.net/qq_37851620/article/details/100642566?utm_source=app&app_version=4.7.1
https://blog.csdn.net/Poul_henry/article/details/88294297
https://mathpretty.com/10675.html
import numpy as np
import itertools
import matplotlib.pyplot as plt
# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title=\'Confusion matrix\', cmap=plt.cm.Blues):
\"\"\"
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
Input
- cm : 计算出的混淆矩阵的值
- classes : 混淆矩阵中每一行每一列对应的列
- normalize : True:显示百分比, False:显示个数
\"\"\"
if normalize:
matrix = cm
cm = cm.astype(\'float\') / cm.sum(axis=1)[:, np.newaxis]
print(\"Normalized confusion matrix\")
else:
print(\'Confusion matrix, without normalization\')
plt.figure()
# 设置输出的图片大小
figsize = 8, 6
figure, ax = plt.subplots(figsize=figsize)
plt.imshow(cm, interpolation=\'nearest\', cmap=cmap)
# 设置title的大小以及title的字体
font_title= {\'family\': \'Times New Roman\',
\'weight\': \'normal\',
\'size\': 15,
}
plt.title(title,fontdict=font_title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45,)
plt.yticks(tick_marks, classes)
# 设置坐标刻度值的大小以及刻度值的字体
plt.tick_params(labelsize=15)
labels = ax.get_xticklabels() + ax.get_yticklabels()
print (labels)
[label.set_fontname(\'Times New Roman\') for label in labels]
if normalize:
fm_int = \'d\'
fm_float = \'.3%\'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fm_float),
horizontalalignment=\"center\", verticalalignment=\'bottom\',family = \"Times New Roman\", weight = \"normal\",size = 15,
color=\"white\" if cm[i, j] > thresh else \"black\")
plt.text(j, i, format(matrix[i, j], fm_int),
horizontalalignment=\"center\", verticalalignment=\'top\',family = \"Times New Roman\", weight = \"normal\",size = 15,
color=\"white\" if cm[i, j] > thresh else \"black\")
else:
fm_int = \'d\'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fm_int),
horizontalalignment=\"center\", verticalalignment=\'bottom\',
color=\"white\" if cm[i, j] > thresh else \"black\")
plt.tight_layout()
# 设置横纵坐标的名称以及对应字体格式
# font_lable = {\'family\': \'Times New Roman\',
# \'weight\': \'normal\',
# \'size\': 15,
# }
# plt.ylabel(\'True label\', font_lable)
# plt.xlabel(\'Predicted label\', font_lable)
plt.savefig(\'confusion_matrix.eps\', dpi=600, format=\'eps\')
plt.savefig(\'confusion_matrix.png\', dpi=600, format=\'png\')
cnf_matrix = np.array([[109653, 2, 0, 1, 0],
[0, 104180, 2, 0, 0],
[1, 0, 110422, 1, 0],
[9, 1, 1, 104380, 0],
[13, 0, 0, 3, 767875]])
attack_types = [\'Normal\', \'DoS\', \'Probe\', \'R2L\', \'U2R\']
plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title=\'Normalized confusion matrix\')