pytorch分类模型绘制混淆矩阵以及可视化详解
作者:王延凯的博客 时间:2023-01-17 17:35:43
Step 1. 获取混淆矩阵
#首先定义一个 分类数*分类数 的空混淆矩阵
conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
# 使用torch.no_grad()可以显著降低测试用例的GPU占用
with torch.no_grad():
for step, (imgs, targets) in enumerate(test_loader):
# imgs: torch.Size([50, 3, 200, 200]) torch.FloatTensor
# targets: torch.Size([50, 1]), torch.LongTensor 多了一维,所以我们要把其去掉
targets = targets.squeeze() # [50,1] -----> [50]
# 将变量转为gpu
targets = targets.cuda()
imgs = imgs.cuda()
# print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
out = model(imgs)
#记录混淆矩阵参数
conf_matrix = confusion_matrix(out, targets, conf_matrix)
conf_matrix=conf_matrix.cpu()
混淆矩阵的求取用到了confusion_matrix函数,其定义如下:
def confusion_matrix(preds, labels, conf_matrix):
preds = torch.argmax(preds, 1)
for p, t in zip(preds, labels):
conf_matrix[p, t] += 1
return conf_matrix
在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:
conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数
print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
print(conf_matrix)
# 获取每种Emotion的识别准确率
print("每种情感总个数:",per_kinds)
print("每种情感预测正确的个数:",corrects)
print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))
执行此步的输出结果如下所示:
Step 2. 混淆矩阵可视化
对上边求得的混淆矩阵可视化
# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签
# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)
# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
for y in range(Emotion_kinds):
# 注意这里的matrix[y, x]不是matrix[x, y]
info = int(conf_matrix[y, x])
plt.text(x, y, info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()
来源:https://blog.csdn.net/weixin_38468077/article/details/121671139
标签:pytorch,混淆矩阵,可视化
0
投稿
猜你喜欢
nodejs+mysql实现用户相关的增删改查的详细操作
2024-01-26 10:34:45
Python实现用手机监控远程控制电脑的方法
2021-06-22 07:57:49
ORACLE中的的HINT详解
2024-01-26 23:29:53
Python3.10的一些新特性原理分析
2023-06-17 06:35:10
asp如何使用Office Chart 9.0 制作图表?
2010-06-05 12:41:00
是在客户端确认还是在服务器端确认?
2010-07-14 21:05:00
Python 利用base64库 解码本地txt文本字符串
2021-01-13 14:33:58
JavaScript图片放大镜效果
2009-10-19 22:15:00
什么是python类属性
2021-07-31 20:27:16
python getpass实现密文实例详解
2021-06-25 20:29:17
ASP 隐藏下载地址及防盗链代码
2011-02-26 11:17:00
玩转CSS3色彩[译]
2010-01-13 13:02:00
快速解决pandas.read_csv()乱码的问题
2023-07-10 21:14:47
使用OpenCV校准鱼眼镜头的方法
2022-04-02 01:58:48
MYSQL主从数据库同步备份配置的方法
2024-01-23 15:03:43
openfiledialog读取txt写入数据库示例
2024-01-16 02:03:35
Python多进程编程技术实例分析
2022-07-23 18:02:49
用户体验的另一种认识
2007-10-25 12:36:00
使用php操作xml教程
2023-06-14 03:10:45
python如何往列表头部和尾部添加元素
2021-12-17 07:05:17