pytorch分类模型绘制混淆矩阵以及可视化详解

作者:王延凯的博客 时间:2023-01-17 17:35:43 

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
# 使用torch.no_grad()可以显著降低测试用例的GPU占用
   with torch.no_grad():
       for step, (imgs, targets) in enumerate(test_loader):
           # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
           # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
           targets = targets.squeeze()  # [50,1] ----->  [50]

# 将变量转为gpu
           targets = targets.cuda()
           imgs = imgs.cuda()
           # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())

out = model(imgs)
           #记录混淆矩阵参数
           conf_matrix = confusion_matrix(out, targets, conf_matrix)
           conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
   preds = torch.argmax(preds, 1)
   for p, t in zip(preds, labels):
       conf_matrix[p, t] += 1
   return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
print(conf_matrix)

# 获取每种Emotion的识别准确率
print("每种情感总个数:",per_kinds)
print("每种情感预测正确的个数:",corrects)
print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
   for y in range(Emotion_kinds):
       # 注意这里的matrix[y, x]不是matrix[x, y]
       info = int(conf_matrix[y, x])
       plt.text(x, y, info,
                verticalalignment='center',
                horizontalalignment='center',
                color="white" if info > thresh else "black")

plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

来源:https://blog.csdn.net/weixin_38468077/article/details/121671139

标签:pytorch,混淆矩阵,可视化
0
投稿

猜你喜欢

  • nodejs+mysql实现用户相关的增删改查的详细操作

    2024-01-26 10:34:45
  • Python实现用手机监控远程控制电脑的方法

    2021-06-22 07:57:49
  • ORACLE中的的HINT详解

    2024-01-26 23:29:53
  • Python3.10的一些新特性原理分析

    2023-06-17 06:35:10
  • asp如何使用Office Chart 9.0 制作图表?

    2010-06-05 12:41:00
  • 是在客户端确认还是在服务器端确认?

    2010-07-14 21:05:00
  • Python 利用base64库 解码本地txt文本字符串

    2021-01-13 14:33:58
  • JavaScript图片放大镜效果

    2009-10-19 22:15:00
  • 什么是python类属性

    2021-07-31 20:27:16
  • python getpass实现密文实例详解

    2021-06-25 20:29:17
  • ASP 隐藏下载地址及防盗链代码

    2011-02-26 11:17:00
  • 玩转CSS3色彩[译]

    2010-01-13 13:02:00
  • 快速解决pandas.read_csv()乱码的问题

    2023-07-10 21:14:47
  • 使用OpenCV校准鱼眼镜头的方法

    2022-04-02 01:58:48
  • MYSQL主从数据库同步备份配置的方法

    2024-01-23 15:03:43
  • openfiledialog读取txt写入数据库示例

    2024-01-16 02:03:35
  • Python多进程编程技术实例分析

    2022-07-23 18:02:49
  • 用户体验的另一种认识

    2007-10-25 12:36:00
  • 使用php操作xml教程

    2023-06-14 03:10:45
  • python如何往列表头部和尾部添加元素

    2021-12-17 07:05:17
  • asp之家 网络编程 m.aspxhome.com