pytorch sampler对数据进行采样的实现

作者:蓝鲸123 时间:2023-02-09 20:05:40 

PyTorch中还单独提供了一个sampler模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。

构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。

下面举例说明。


from dataSet import *
dataset = DogCat('data/dogcat/', transform=transform)

from torch.utils.data import DataLoader
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]

print(weights)

from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
               num_samples=9,\
               replacement=True)
dataloader = DataLoader(dataset,
           batch_size=3,
           sampler=sampler)
for datas, labels in dataloader:
 print(labels.tolist())

输出:


[2, 2, 1, 1, 2, 1, 1, 2]
[1, 1, 0]
[1, 0, 0]
[0, 0, 1]

github 地址:

https://github.com/WebLearning17/CommonTool

来源:https://blog.csdn.net/TH_NUM/article/details/80877772

标签:pytorch,sampler,数据,采样
0
投稿

猜你喜欢

  • ORACLE中段的HEADER_BLOCK示例详析

    2024-01-26 02:35:09
  • pandas的qcut()方法详解

    2022-07-23 03:36:21
  • Python项目管理Git常用命令详图讲解

    2021-01-24 13:41:38
  • 如何使用Python Matplotlib绘制条形图

    2023-09-21 04:41:46
  • python 5个顶级异步框架推荐

    2021-12-23 06:21:47
  • mysql安装图解总结

    2024-01-15 04:12:21
  • MySQL 使用DQL命令查询数据的实现方法

    2024-01-16 18:53:13
  • 原生js实现下拉菜单

    2024-04-28 09:43:04
  • Python的Flask框架中web表单的教程

    2023-05-17 06:11:06
  • 浅谈python中拼接路径os.path.join斜杠的问题

    2023-08-21 23:41:23
  • 请注意页面head区域的编码是不是规范

    2008-08-06 13:14:00
  • Python中的 pass 占位语句

    2023-02-21 20:45:12
  • python import 上级目录的导入

    2021-09-13 00:54:29
  • Python中的self用法详解

    2023-08-22 15:34:19
  • python实现对任意大小图片均匀切割的示例

    2022-05-07 06:17:05
  • python 文件读写和数据清洗

    2021-02-10 22:58:43
  • 在Pycharm中项目解释器与环境变量的设置方法

    2023-12-31 02:49:53
  • MySQL Explain命令用于查看执行效果

    2009-02-27 15:30:00
  • 解决SpringBoot启动过后不能访问jsp页面的问题(超详细)

    2023-06-13 19:43:31
  • 如何使用python3获取当前路径及os.path.dirname的使用

    2023-07-22 06:29:37
  • asp之家 网络编程 m.aspxhome.com