在pytorch 中计算精度、回归率、F1 score等指标的实例
作者:Link2Link 时间:2022-08-10 06:28:18
pytorch中训练完网络后,需要对学习的结果进行测试。官网上例程用的方法统统都是正确率,使用的是torch.eq()这个函数。
但是为了更精细的评价结果,我们还需要计算其他各个指标。在把官网API翻了一遍之后发现并没有用于计算TP,TN,FP,FN的函数。。。
在动了无数歪脑筋之后,心想pytorch完全支持numpy,那能不能直接进行判断,试了一下果然可以,上代码:
# TP predict 和 label 同时为1
TP += ((pred_choice == 1) & (target.data == 1)).cpu().sum()
# TN predict 和 label 同时为0
TN += ((pred_choice == 0) & (target.data == 0)).cpu().sum()
# FN predict 0 label 1
FN += ((pred_choice == 0) & (target.data == 1)).cpu().sum()
# FP predict 1 label 0
FP += ((pred_choice == 1) & (target.data == 0)).cpu().sum()
p = TP / (TP + FP)
r = TP / (TP + FN)
F1 = 2 * r * p / (r + p)
acc = (TP + TN) / (TP + TN + FP + FN
这样就能看到各个指标了。
因为target是Variable所以需要用target.data取到对应的tensor,又因为是在gpu上算的,需要用 .cpu() 移到cpu上。
因为这是一个batch的统计,所以需要用+=累计出整个epoch的统计。当然,在epoch开始之前需要清零
来源:https://blog.csdn.net/qq_15602569/article/details/79565402
标签:pytorch,精度,回归率,F1,score
0
投稿
猜你喜欢
Python学习笔记之读取文件、OS模块、异常处理、with as语法示例
2023-03-20 21:54:58
python matplotlib画图时坐标轴重叠显示不全和图片保存时不完整的问题解决
2023-12-11 03:42:42
使用OpenCV对运动员的姿势进行检测功能实现
2022-06-08 03:23:14
python检测服务器是否正常
2022-06-18 05:10:19
MySQL权限详解
2011-02-16 12:20:00
用Python创建声明性迷你语言的教程
2023-08-10 04:49:42
Pytorch中膨胀卷积的用法详解
2023-03-26 12:03:25
详解Python3的TFTP文件传输
2023-06-01 22:29:17
django的403/404/500错误自定义页面的配置方式
2023-01-19 06:44:40
python下的opencv画矩形和文字注释的实现方法
2022-12-26 22:27:17
asp.net C#实现解压缩文件的方法
2023-07-14 10:34:01
Python基于域相关实现图像增强的方法教程
2023-08-24 15:30:22
一文讲清base64编码原理
2023-04-10 23:51:48
详解Python中命令行参数argparse的常用命令
2022-06-06 15:59:30
vscode使用nuget包管理工具
2023-10-30 13:37:55
python实现循环语句1到100累和
2023-05-15 15:39:38
常见SQL Server 2000漏洞及其相关利用
2007-10-01 14:45:00
MySQL 数据库的监控方式小结
2024-01-14 19:07:14
浅谈mysql的中文乱码问题
2024-01-21 10:31:16
Go Plugins插件的实现方式
2023-10-15 01:51:20