浅谈sklearn中predict与predict_proba区别
作者:GitzLiu 时间:2023-11-08 03:53:45
predict_proba 返回的是一个 n 行 k 列的数组,列是标签(有排序), 第 i 行 第 j 列上的数值是模型预测 第 i 个预测样本为某个标签的概率,并且每一行的概率和为1。
predict 直接返回的是预测 的标签。
具体见下面示例:
# conding :utf-8
from sklearn.linear_model import LogisticRegression
import numpy as np
x_train = np.array([[1,2,3],
[1,3,4],
[2,1,2],
[4,5,6],
[3,5,3],
[1,7,2]])
y_train = np.array([3, 3, 3, 2, 2, 2])
x_test = np.array([[2,2,2],
[3,2,6],
[1,7,4]])
clf = LogisticRegression()
clf.fit(x_train, y_train)
# 返回预测标签
print(clf.predict(x_test))
# 返回预测属于某标签的概率
print(clf.predict_proba(x_test))
# [2 3 2]
#
# [[0.56651809 0.43348191]
# [0.15598162 0.84401838]
# [0.86852502 0.13147498]]
# 分析结果:
# 标签是 2,3 共两个,所以predict_proba返回的为2列,且是排序的(第一列为标签2,第二列为标签3),
# 返回矩阵的行数是测试样本个数 因此为3行
# 预测[2,2,2]的标签是2的概率为0.56651809,3的概率为0.43348191
#
# 预测[3,2,6]的标签是2的概率为0.15598162,3的概率为0.84401838
#
# 预测[1,7,4]的标签是2的概率为0.86852502,3的概率为0.13147498
补充知识:sklearn中predict与predict_proba的识别结果不一致
今天训练了好久的决策树模型在测试的时候发现个bug,使用predict得到的结果居然不是predict_proba中最大数值的索引!因为脚本中需要模型的置信度,所以希望拿到predict_proba的类别概率。
经过胡乱分析发现predict_proba得到的维度比总类别数少了几个,经过测试发现就是这个造成的,即训练集中有部分类别样本数为0。这个问题比较隐蔽,记录一下方便天涯沦落人绕坑。
Tip:在sklearn的train_test_split中有一个参数可以强制测试集和训练集的数据分布一致,也就不会导致缺类别的问题。
来源:https://blog.csdn.net/GitzLiu/article/details/81952431
标签:sklearn,predict,proba
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
Python利用arcpy模块实现栅格的创建与拼接
2021-10-07 22:39:37
![](https://img.aspxhome.com/file/2023/8/81798_0s.png)
Python 加密与解密小结
2021-04-28 00:35:47
js“树”读取xml数据源码
2007-08-04 19:42:00
Kali Linux 2022.1安装和相关配置教程(图文详解)
2023-11-10 15:54:50
![](https://img.aspxhome.com/file/2023/5/99365_0s.jpg)
Linux下Python脚本自启动和定时启动的详细步骤
2022-08-13 20:51:22
![](https://img.aspxhome.com/file/2023/3/83983_0s.png)
vsCode安装使用教程和插件安装方法
2024-04-30 09:55:49
![](https://img.aspxhome.com/file/2023/7/132237_0s.jpg)
浅谈python中对于json写入txt文件的编码问题
2022-01-28 05:08:58
大数据量,海量数据处理方法总结
2024-01-12 21:59:38
Python函数和模块的使用详情
2023-10-11 13:51:20
解读golang plugin热更新尝试
2024-05-22 10:09:28
Golang pipe在不同场景下远程交互
2024-05-09 09:45:58
canvas实现手机端用来上传用户头像的代码
2023-09-16 02:30:54
使用Python编写提取日志中的中文的脚本的方法
2023-12-14 16:04:44
![](https://img.aspxhome.com/file/2023/9/81849_0s.png)
Python中的defaultdict模块和namedtuple模块的简单入门指南
2022-01-21 07:10:20
Vue 服务端渲染SSR示例详解
2024-05-28 15:50:39
![](https://img.aspxhome.com/file/2023/1/123161_0s.jpg)
Spring Boot如何解决Mysql断连问题
2024-01-14 23:52:42
![](https://img.aspxhome.com/file/2023/1/97981_0s.png)
vue动态菜单、动态路由加载以及刷新踩坑实战
2024-05-05 09:25:27
![](https://img.aspxhome.com/file/2023/0/128890_0s.png)
用SELECT... INTO OUTFILE语句导出MySQL数据的教程
2024-01-13 19:50:52
CentOS中使用virtualenv搭建python3环境
2022-08-30 07:28:43
PyCharm 2020.2下配置Anaconda环境的方法步骤
2022-10-08 14:25:00
![](https://img.aspxhome.com/file/2023/3/117383_0s.png)