pytorch中dataloader 的sampler 参数详解

作者:mingqian_chu 时间:2023-09-16 21:00:13 

1. dataloader() 初始化函数

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
                worker_init_fn=None, multiprocessing_context=None):

其中几个常用的参数:

  • dataset 数据集,map-style and iterable-style 可以用index取值的对象、

  • batch_size 大小

  • shuffle 取batch是否随机取, 默认为False

  • sampler 定义取batch的方法,是一个迭代器, 每次生成一个key 用于读取dataset中的值

  • batch_sampler 也是一个迭代器, 每次生次一个batch_size的key

  • num_workers 参与工作的线程数collate_fn 对取出的batch进行处理

  • drop_last 对最后不足batchsize的数据的处理方法

下面看两段取自DataLoader中的__init__代码, 帮助我们理解几个常用参数之间的关系

2. shuffle 与sample之间的关系

当我们sampler有输入时,shuffle的值就没有意义,

if sampler is None:  # give default samplers
   if self._dataset_kind == _DatasetKind.Iterable:
       # See NOTE [ Custom Samplers and IterableDataset ]
       sampler = _InfiniteConstantSampler()
   else:  # map-style
       if shuffle:
           sampler = RandomSampler(dataset)
       else:
           sampler = SequentialSampler(dataset)

当dataset类型是map style时, shuffle其实就是改变sampler的取值

  • shuffle为默认值 False时,sampler是SequentialSampler,就是按顺序取样,

  • shuffle为True时,sampler是RandomSampler, 就是按随机取样

3. sample 的定义方法

3.1 sampler 参数的使用

sampler 是用来定义取batch方法的一个函数或者类,返回的是一个迭代器。

我们可以看下自带的RandomSampler类中最重要的iter函数

def __iter__(self):
       n = len(self.data_source)
       # dataset的长度, 按顺序索引
       if self.replacement:# 对应的replace参数
           return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
       return iter(torch.randperm(n).tolist())

可以看出,其实就是生成索引,然后随机的取值, 然后再迭代。

其实还有一些细节需要注意理解:

比如__len__函数,包括DataLoader的len和sample的len, 两者区别, 这部分代码比较简单,可以自行阅读,其实参考着RandomSampler写也不会出现问题。
比如,迭代器和生成器的使用, 以及区别

if batch_size is not None and batch_sampler is None:
       # auto_collation without custom batch_sampler
       batch_sampler = BatchSampler(sampler, batch_size, drop_last)

self.sampler = sampler
   self.batch_sampler = batch_sampler

BatchSampler的生成过程:

# 略去类的初始化
   def __iter__(self):
       batch = []
       for idx in self.sampler:
           batch.append(idx)
           if len(batch) == self.batch_size:
               yield batch
               batch = []
       if len(batch) > 0 and not self.drop_last:
           yield batch

就是按batch_size从sampler中读取索引, 并形成生成器返回。

以上可以看出, batch_sampler和sampler, batch_size, drop_last之间的关系

  • 如果batch_sampler没有定义的话且batch_size有定义, 会根据sampler, batch_size, drop_last生成一个batch_sampler

  • 自带的注释中对batch_sampler有一句话: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.

  • 意思就是b

  • atch_sampler 与这些参数冲突 ,即 如果你定义了batch_sampler, 其他参数都不需要有

4. batch 生成过程

每个batch都是由迭代器产生的:

# DataLoader中iter的部分
   def __iter__(self):
       if self.num_workers == 0:
           return _SingleProcessDataLoaderIter(self)
       else:
           return _MultiProcessingDataLoaderIter(self)

# 再看调用的另一个类
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
   def __init__(self, loader):
       super(_SingleProcessDataLoaderIter, self).__init__(loader)
       assert self._timeout == 0
       assert self._num_workers == 0

self._dataset_fetcher = _DatasetKind.create_fetcher(
           self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

def __next__(self):
       index = self._next_index()  
       data = self._dataset_fetcher.fetch(index)  
       if self._pin_memory:
           data = _utils.pin_memory.pin_memory(data)
       return data

来源:https://blog.csdn.net/chumingqian/article/details/126625724

标签:pytorch,dataloader,sampler
0
投稿

猜你喜欢

  • python使用xlrd模块读写Excel文件的方法

    2022-02-14 16:54:55
  • tensorflow的ckpt及pb模型持久化方式及转化详解

    2022-12-10 17:32:08
  • 阿里系的中国雅虎新首页浅谈

    2008-07-16 12:19:00
  • Python基于paramunittest模块实现excl参数化

    2023-12-27 00:29:02
  • 浅谈numpy数组的几种排序方式

    2022-04-24 12:48:15
  • Python pyecharts案例超市4年数据可视化分析

    2021-04-09 21:10:29
  • javascript ImgBox透明遮罩层背景图片展示

    2024-02-27 04:51:07
  • Nodejs实现短信验证码功能

    2024-05-08 09:37:32
  • 实操Python爬取觅知网素材图片示例

    2021-12-12 21:19:59
  • 在脚本中单独使用django的ORM模型详解

    2021-03-09 05:17:26
  • python timestamp和datetime之间转换详解

    2021-02-07 11:17:51
  • SQL Server的FileStream和FileTable深入剖析

    2023-07-17 01:17:24
  • Python socket实现的文件下载器功能示例

    2021-03-12 22:43:19
  • JS轮播图实现简单代码

    2024-04-28 09:38:41
  • mysql的sql语句特殊处理语句总结(必看)

    2024-01-17 02:10:05
  • python实现微信机器人: 登录微信、消息接收、自动回复功能

    2023-05-30 05:42:53
  • CentOS 7安装Mysql并设置开机自启动的方法

    2024-01-27 05:32:47
  • python 画条形图(柱状图)实例

    2021-12-06 19:09:26
  • MySQL 数据库锁的实现

    2024-01-13 01:30:29
  • python二维键值数组生成转json的例子

    2021-10-09 07:49:14
  • asp之家 网络编程 m.aspxhome.com