pytorch中DataLoader()过程中遇到的一些问题

作者:hfw6310 时间:2022-01-17 18:36:11 

如下所示:

RuntimeError: stack expects each tensor to be equal size, but got [3, 60, 32] at entry 0 and [3, 54, 32] at entry 2


train_dataset = datasets.ImageFolder(
   traindir,
   transforms.Compose([
       transforms.Resize((224)) ###

原因是

transforms.Resize() 的参数设置问题,改为如下设置就可以了


train_dataset = datasets.ImageFolder(
   traindir,
   transforms.Compose([
       transforms.Resize((224,224)),

同理,val_dataset中也调整为transforms.Resize((224,224))。

补充:pytorch之dataloader深入剖析

- dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问;

- 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问;

- 也可以使用`for inputs, labels in dataloaders`进行可迭代对象的访问;

- 一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;

① DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存

② Queue的特点

当队列里面没有数据时: queue.get() 会阻塞, 阻塞的时候,其它进程/线程如果有queue.put() 操作,本线程/进程会被通知,然后就可以 get 成功。

当数据满了: queue.put() 会阻塞

③ DataLoader是一个高效,简洁,直观的网络输入数据结构,便于使用和扩展

输入数据PipeLine

pytorch 的数据加载到模型的操作顺序是这样的:

① 创建一个 Dataset 对象

② 创建一个 DataLoader 对象

③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练


dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for img, label in dataloader:
....

所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。

首先简单介绍一下DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

官方对DataLoader的说明是:“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”关于iterator和iterable的区别和概念请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。

1.DataLoader

先介绍一下DataLoader(object)的参数:

dataset(Dataset): 传入的数据集

batch_size(int, optional): 每个batch有多少个样本

shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序

sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数

pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…

如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each


worker subprocess with the worker id (an int in [0, num_workers - 1]) as
input, after seeding and before data loading. (default: None)

- 首先dataloader初始化时得到datasets的采样list


class DataLoader(object):
   r"""
   Data loader. Combines a dataset and a sampler, and provides
   single- or multi-process iterators over the dataset.
   Arguments:
       dataset (Dataset): dataset from which to load the data.
       batch_size (int, optional): how many samples per batch to load
           (default: 1).
       shuffle (bool, optional): set to ``True`` to have the data reshuffled
           at every epoch (default: False).
       sampler (Sampler, optional): defines the strategy to draw samples from
           the dataset. If specified, ``shuffle`` must be False.
       batch_sampler (Sampler, optional): like sampler, but returns a batch of
           indices at a time. Mutually exclusive with batch_size, shuffle,
           sampler, and drop_last.
       num_workers (int, optional): how many subprocesses to use for data
           loading. 0 means that the data will be loaded in the main process.
           (default: 0)
       collate_fn (callable, optional): merges a list of samples to form a mini-batch.
       pin_memory (bool, optional): If ``True``, the data loader will copy tensors
           into CUDA pinned memory before returning them.
       drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
           if the dataset size is not divisible by the batch size. If ``False`` and
           the size of dataset is not divisible by the batch size, then the last batch
           will be smaller. (default: False)
       timeout (numeric, optional): if positive, the timeout value for collecting a batch
           from workers. Should always be non-negative. (default: 0)
       worker_init_fn (callable, optional): If not None, this will be called on each
           worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
           input, after seeding and before data loading. (default: None)
   .. note:: By default, each worker will have its PyTorch seed set to
             ``base_seed + worker_id``, where ``base_seed`` is a long generated
             by main process using its RNG. However, seeds for other libraies
             may be duplicated upon initializing workers (w.g., NumPy), causing
             each worker to return identical random numbers. (See
             :ref:`dataloader-workers-random-seed` section in FAQ.) You may
             use ``torch.initial_seed()`` to access the PyTorch seed for each
             worker in :attr:`worker_init_fn`, and use it to set other seeds
             before data loading.
   .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
                unpicklable object, e.g., a lambda function.
   """
   __initialized = False
   def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
       self.dataset = dataset
       self.batch_size = batch_size
       self.num_workers = num_workers
       self.collate_fn = collate_fn
       self.pin_memory = pin_memory
       self.drop_last = drop_last
       self.timeout = timeout
       self.worker_init_fn = worker_init_fn
       if timeout < 0:
           raise ValueError('timeout option should be non-negative')
       if batch_sampler is not None:
           if batch_size > 1 or shuffle or sampler is not None or drop_last:
               raise ValueError('batch_sampler option is mutually exclusive '
                                'with batch_size, shuffle, sampler, and '
                                'drop_last')
           self.batch_size = None
           self.drop_last = None
       if sampler is not None and shuffle:
           raise ValueError('sampler option is mutually exclusive with '
                            'shuffle')
       if self.num_workers < 0:
           raise ValueError('num_workers option cannot be negative; '
                            'use num_workers=0 to disable multiprocessing.')
       if batch_sampler is None:
           if sampler is None:
               if shuffle:
                   sampler = RandomSampler(dataset)  //将list打乱
               else:
                   sampler = SequentialSampler(dataset)
           batch_sampler = BatchSampler(sampler, batch_size, drop_last)
       self.sampler = sampler
       self.batch_sampler = batch_sampler
       self.__initialized = True
   def __setattr__(self, attr, val):
       if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
           raise ValueError('{} attribute should not be set after {} is '
                            'initialized'.format(attr, self.__class__.__name__))
       super(DataLoader, self).__setattr__(attr, val)
   def __iter__(self):
       return _DataLoaderIter(self)
   def __len__(self):
       return len(self.batch_sampler)

其中:RandomSampler,BatchSampler已经得到了采用batch数据的index索引;yield batch机制已经在!!!


class RandomSampler(Sampler):
   r"""Samples elements randomly, without replacement.
   Arguments:
       data_source (Dataset): dataset to sample from
   """
   def __init__(self, data_source):
       self.data_source = data_source
   def __iter__(self):
       return iter(torch.randperm(len(self.data_source)).tolist())
   def __len__(self):
       return len(self.data_source)

class BatchSampler(Sampler):
   r"""Wraps another sampler to yield a mini-batch of indices.
   Args:
       sampler (Sampler): Base sampler.
       batch_size (int): Size of mini-batch.
       drop_last (bool): If ``True``, the sampler will drop the last batch if
           its size would be less than ``batch_size``
   Example:
       >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
       [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
       >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
       [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
   """
   def __init__(self, sampler, batch_size, drop_last):
       if not isinstance(sampler, Sampler):
           raise ValueError("sampler should be an instance of "
                            "torch.utils.data.Sampler, but got sampler={}"
                            .format(sampler))
       if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
               batch_size <= 0:
           raise ValueError("batch_size should be a positive integeral value, "
                            "but got batch_size={}".format(batch_size))
       if not isinstance(drop_last, bool):
           raise ValueError("drop_last should be a boolean value, but got "
                            "drop_last={}".format(drop_last))
       self.sampler = sampler
       self.batch_size = batch_size
       self.drop_last = drop_last
   def __iter__(self):
       batch = []
       for idx in self.sampler:
           batch.append(idx)
           if len(batch) == self.batch_size:
               yield batch
               batch = []
       if len(batch) > 0 and not self.drop_last:
           yield batch
   def __len__(self):
       if self.drop_last:
           return len(self.sampler) // self.batch_size
       else:
           return (len(self.sampler) + self.batch_size - 1) // self.batch_size

- 其中 _DataLoaderIter(self)输入为一个dataloader对象;如果num_workers=0很好理解,num_workers!=0引入多线程机制,加速数据加载过程;

- 没有多线程时:batch = self.collate_fn([self.dataset[i] for i in indices])进行将index转化为data数据,返回(image,label);self.dataset[i]会调用datasets对象的

__getitem__()方法

- 多线程下,会为每个线程创建一个索引队列index_queues;共享一个worker_result_queue数据队列!在_worker_loop方法中加载数据;


class _DataLoaderIter(object):
   r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
   def __init__(self, loader):
       self.dataset = loader.dataset
       self.collate_fn = loader.collate_fn
       self.batch_sampler = loader.batch_sampler
       self.num_workers = loader.num_workers
       self.pin_memory = loader.pin_memory and torch.cuda.is_available()
       self.timeout = loader.timeout
       self.done_event = threading.Event()
       self.sample_iter = iter(self.batch_sampler)
       base_seed = torch.LongTensor(1).random_().item()
       if self.num_workers > 0:
           self.worker_init_fn = loader.worker_init_fn
           self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
           self.worker_queue_idx = 0
           self.worker_result_queue = multiprocessing.SimpleQueue()
           self.batches_outstanding = 0
           self.worker_pids_set = False
           self.shutdown = False
           self.send_idx = 0
           self.rcvd_idx = 0
           self.reorder_dict = {}
           self.workers = [
               multiprocessing.Process(
                   target=_worker_loop,
                   args=(self.dataset, self.index_queues[i],
                         self.worker_result_queue, self.collate_fn, base_seed + i,
                         self.worker_init_fn, i))
               for i in range(self.num_workers)]
           if self.pin_memory or self.timeout > 0:
               self.data_queue = queue.Queue()
               if self.pin_memory:
                   maybe_device_id = torch.cuda.current_device()
               else:
                   # do not initialize cuda context if not necessary
                   maybe_device_id = None
               self.worker_manager_thread = threading.Thread(
                   target=_worker_manager_loop,
                   args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
                         maybe_device_id))
               self.worker_manager_thread.daemon = True
               self.worker_manager_thread.start()
           else:
               self.data_queue = self.worker_result_queue
           for w in self.workers:
               w.daemon = True  # ensure that the worker exits on process exit
               w.start()
           _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
           _set_SIGCHLD_handler()
           self.worker_pids_set = True
           # prime the prefetch loop
           for _ in range(2 * self.num_workers):
               self._put_indices()
   def __len__(self):
       return len(self.batch_sampler)
   def _get_batch(self):
       if self.timeout > 0:
           try:
               return self.data_queue.get(timeout=self.timeout)
           except queue.Empty:
               raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
       else:
           return self.data_queue.get()
   def __next__(self):
       if self.num_workers == 0:  # same-process loading
           indices = next(self.sample_iter)  # may raise StopIteration
           batch = self.collate_fn([self.dataset[i] for i in indices])
           if self.pin_memory:
               batch = pin_memory_batch(batch)
           return batch
       # check if the next sample has already been generated
       if self.rcvd_idx in self.reorder_dict:
           batch = self.reorder_dict.pop(self.rcvd_idx)
           return self._process_next_batch(batch)
       if self.batches_outstanding == 0:
           self._shutdown_workers()
           raise StopIteration
       while True:
           assert (not self.shutdown and self.batches_outstanding > 0)
           idx, batch = self._get_batch()
           self.batches_outstanding -= 1
           if idx != self.rcvd_idx:
               # store out-of-order samples
               self.reorder_dict[idx] = batch
               continue
           return self._process_next_batch(batch)
   next = __next__  # Python 2 compatibility
   def __iter__(self):
       return self
   def _put_indices(self):
       assert self.batches_outstanding < 2 * self.num_workers
       indices = next(self.sample_iter, None)
       if indices is None:
           return
       self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
       self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
       self.batches_outstanding += 1
       self.send_idx += 1
   def _process_next_batch(self, batch):
       self.rcvd_idx += 1
       self._put_indices()
       if isinstance(batch, ExceptionWrapper):
           raise batch.exc_type(batch.exc_msg)
       return batch

def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
   global _use_shared_memory
   _use_shared_memory = True
   # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
   # module's handlers are executed after Python returns from C low-level
   # handlers, likely when the same fatal signal happened again already.
   # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
   _set_worker_signal_handlers()
   torch.set_num_threads(1)
   random.seed(seed)
   torch.manual_seed(seed)
   if init_fn is not None:
       init_fn(worker_id)
   watchdog = ManagerWatchdog()
   while True:
       try:
           r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
       except queue.Empty:
           if watchdog.is_alive():
               continue
           else:
               break
       if r is None:
           break
       idx, batch_indices = r
       try:
           samples = collate_fn([dataset[i] for i in batch_indices])
       except Exception:
           data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
       else:
           data_queue.put((idx, samples))
           del samples

- 需要对队列操作,缓存数据,使得加载提速!

来源:https://blog.csdn.net/hfw6310/article/details/106992968

标签:pytorch,DataLoader
0
投稿

猜你喜欢

  • 关于对Java正则表达式"\\\\"的理解

    2023-06-24 07:23:02
  • 浅谈Python numpy创建空数组的问题

    2022-10-10 07:11:08
  • golang三元表达式的使用方法

    2023-08-28 14:34:09
  • python调用系统中应用程序的函数示例

    2021-01-18 11:06:32
  • pydantic-resolve嵌套数据结构生成LoaderDepend管理contextvars

    2023-01-12 22:21:05
  • go实现文件的创建、删除与读取示例代码

    2023-06-17 05:10:50
  • Golang中switch语句和select语句的用法教程

    2023-09-02 09:09:06
  • python对站点数据做EOF且做插值绘制填色图

    2023-03-05 03:30:56
  • python tkinter实现弹窗的输入输出

    2021-10-03 14:58:42
  • 如何优化网站图片以快速显示

    2008-04-05 10:09:00
  • pandas 根据列的值选取所有行的示例

    2023-10-13 16:19:38
  • python脚本后台执行方式

    2021-02-12 20:57:58
  • 分布式爬虫scrapy-redis的实战踩坑记录

    2022-03-02 02:13:24
  • Python操作多维数组输出和矩阵运算示例

    2022-11-30 03:44:18
  • Pytorch数据类型与转换(torch.tensor,torch.FloatTensor)

    2023-03-31 13:32:36
  • python 实现docx与doc文件的互相转换

    2022-01-19 06:45:58
  • ASP使用MYSQL数据库全攻略

    2009-11-08 18:27:00
  • ASP获取ACCESS数据库表名及结构的代码

    2011-04-15 10:50:00
  • python实现登录与注册系统

    2022-04-26 02:32:38
  • Web标准的web UI

    2008-01-02 12:34:00
  • asp之家 网络编程 m.aspxhome.com