在pytorch 中计算精度、回归率、F1 score等指标的实例

作者:Link2Link 时间:2022-08-10 06:28:18 

pytorch中训练完网络后,需要对学习的结果进行测试。官网上例程用的方法统统都是正确率,使用的是torch.eq()这个函数。

但是为了更精细的评价结果,我们还需要计算其他各个指标。在把官网API翻了一遍之后发现并没有用于计算TP,TN,FP,FN的函数。。。

在动了无数歪脑筋之后,心想pytorch完全支持numpy,那能不能直接进行判断,试了一下果然可以,上代码:


# TP predict 和 label 同时为1
TP += ((pred_choice == 1) & (target.data == 1)).cpu().sum()
# TN predict 和 label 同时为0
TN += ((pred_choice == 0) & (target.data == 0)).cpu().sum()
# FN predict 0 label 1
FN += ((pred_choice == 0) & (target.data == 1)).cpu().sum()
# FP predict 1 label 0
FP += ((pred_choice == 1) & (target.data == 0)).cpu().sum()

p = TP / (TP + FP)
r = TP / (TP + FN)
F1 = 2 * r * p / (r + p)
acc = (TP + TN) / (TP + TN + FP + FN

这样就能看到各个指标了。

因为target是Variable所以需要用target.data取到对应的tensor,又因为是在gpu上算的,需要用 .cpu() 移到cpu上。

因为这是一个batch的统计,所以需要用+=累计出整个epoch的统计。当然,在epoch开始之前需要清零

来源:https://blog.csdn.net/qq_15602569/article/details/79565402

标签:pytorch,精度,回归率,F1,score
0
投稿

猜你喜欢

  • Python学习笔记之读取文件、OS模块、异常处理、with as语法示例

    2023-03-20 21:54:58
  • python matplotlib画图时坐标轴重叠显示不全和图片保存时不完整的问题解决

    2023-12-11 03:42:42
  • 使用OpenCV对运动员的姿势进行检测功能实现

    2022-06-08 03:23:14
  • python检测服务器是否正常

    2022-06-18 05:10:19
  • MySQL权限详解

    2011-02-16 12:20:00
  • 用Python创建声明性迷你语言的教程

    2023-08-10 04:49:42
  • Pytorch中膨胀卷积的用法详解

    2023-03-26 12:03:25
  • 详解Python3的TFTP文件传输

    2023-06-01 22:29:17
  • django的403/404/500错误自定义页面的配置方式

    2023-01-19 06:44:40
  • python下的opencv画矩形和文字注释的实现方法

    2022-12-26 22:27:17
  • asp.net C#实现解压缩文件的方法

    2023-07-14 10:34:01
  • Python基于域相关实现图像增强的方法教程

    2023-08-24 15:30:22
  • 一文讲清base64编码原理

    2023-04-10 23:51:48
  • 详解Python中命令行参数argparse的常用命令

    2022-06-06 15:59:30
  • vscode使用nuget包管理工具

    2023-10-30 13:37:55
  • python实现循环语句1到100累和

    2023-05-15 15:39:38
  • 常见SQL Server 2000漏洞及其相关利用

    2007-10-01 14:45:00
  • MySQL 数据库的监控方式小结

    2024-01-14 19:07:14
  • 浅谈mysql的中文乱码问题

    2024-01-21 10:31:16
  • Go Plugins插件的实现方式

    2023-10-15 01:51:20
  • asp之家 网络编程 m.aspxhome.com