对python中数据集划分函数StratifiedShuffleSplit的使用详解

作者:冻鸡hhhh 时间:2022-12-04 23:10:28 

文章开始先讲下交叉验证,这个概念同样适用于这个划分函数

1.交叉验证(Cross-validation)

交叉验证是指在给定的建模样本中,拿出其中的大部分样本进行模型训练,生成模型,留小部分样本用刚建立的模型进行预测,并求这小部分样本的预测误差,记录它们的平方加和。这个过程一直进行,直到所有的样本都被预测了一次而且仅被预测一次,比较每组的预测误差,选取误差最小的那一组作为训练模型。

下图所示

对python中数据集划分函数StratifiedShuffleSplit的使用详解

2.StratifiedShuffleSplit函数的使用

官方文档

用法:


from sklearn.model_selection import StratifiedShuffleSplit
StratifiedShuffleSplit(n_splits=10,test_size=None,train_size=None, random_state=None)

2.1 参数说明

参数 n_splits是将训练数据分成train/test对的组数,可根据需要进行设置,默认为10

参数test_size和train_size是用来设置train/test对中train和test所占的比例。例如:

1.提供10个数据num进行训练和测试集划分

2.设置train_size=0.8 test_size=0.2

3.train_num=num*train_size=8 test_num=num*test_size=2

4.即10个数据,进行划分以后8个是训练数据,2个是测试数据

注*:train_num≥2,test_num≥2 ;test_size+train_size可以小于1*

参数 random_state控制是将样本随机打乱

2.2 函数作用描述

1.其产生指定数量的独立的train/test数据集划分数据集划分成n组。

2.首先将样本随机打乱,然后根据设置参数划分出train/test对。

3.其创建的每一组划分将保证每组类比比例相同。即第一组训练数据类别比例为2:1,则后面每组类别都满足这个比例

2.3 具体实现


from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4],
[1, 2],[3, 4], [1, 2], [3, 4]])#训练数据集8*2
y = np.array([0, 0, 1, 1,0,0,1,1])#类别数据集8*1

ss=StratifiedShuffleSplit(n_splits=5,test_size=0.25,train_size=0.75,random_state=0)#分成5组,测试比例为0.25,训练比例是0.75

for train_index, test_index in ss.split(X, y):
print("TRAIN:", train_index, "TEST:", test_index)#获得索引值
X_train, X_test = X[train_index], X[test_index]#训练集对应的值
y_train, y_test = y[train_index], y[test_index]#类别集对应的值

运行结果:

对python中数据集划分函数StratifiedShuffleSplit的使用详解

从结果看出,1.训练集是6个,测试集是2,与设置的所对应;2.五组中每组对应的类别比例相同

来源:https://blog.csdn.net/m0_38061927/article/details/76180541

标签:python,Stratified,Shuffle,Split
0
投稿

猜你喜欢

  • 详解JavaScript中的Object.is()与"==="运算符总结

    2024-04-22 12:50:25
  • 详解pycharm的python包opencv(cv2)无代码提示问题的解决

    2022-01-10 06:45:34
  • 使用Vue自定义指令实现Select组件

    2024-05-09 15:26:41
  • mssql查找备注(text,ntext)类型字段为空的方法

    2024-01-28 04:13:18
  • Python操作数据库之数据库编程接口

    2024-01-25 01:55:41
  • mysql 基础教程之库与表的详解

    2024-01-19 13:14:38
  • oracle 优化的一点体会

    2009-10-02 17:59:00
  • SQL--JOIN之完全用法

    2008-09-12 17:30:00
  • 分享python机器学习中应用所产生的聚类数据集方法

    2021-06-05 13:28:39
  • golang1.16新特性速览(推荐)

    2023-07-06 10:34:47
  • 什么是python类属性

    2021-07-31 20:27:16
  • MySql版本问题sql_mode=only_full_group_by的完美解决方案

    2024-01-18 16:08:14
  • vue使用iframe嵌入网页的示例代码

    2024-05-05 09:12:04
  • Python牛刀小试密码爆破

    2021-10-05 14:18:47
  • 8个js表单验证函数

    2007-10-28 19:19:00
  • Python使用pyautogui模块实现自动化鼠标和键盘操作示例

    2022-10-27 16:02:25
  • 微软的jQuery国际化插件

    2010-07-02 12:46:00
  • python实现删除文件与目录的方法

    2023-11-12 23:34:46
  • python实现画循环圆

    2023-01-17 02:18:28
  • django 微信网页授权认证api的步骤详解

    2021-09-15 03:58:40
  • asp之家 网络编程 m.aspxhome.com