一文弄懂Pytorch的DataLoader,Dataset,Sampler之间的关系

阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6

很多文章都是从 D a t a s e t Dataset Dataset等对象自下网上进行介绍的但是对于初学者而言其实这并不好理解因为有时候会不自觉的陷入到一些细枝末节中去而不能把握重点所以本文将自上而下的对 P y t o r c h Pytorch Pytorch数据读取方法进行介绍。

自上而下理解三者关系

首先我们看一下 D a t a L o a d e r . n e x t DataLoader.next DataLoader.next的源代码长什么样为方便理解我只选取了num_works为0的情况num_works)简单理解都是能够并行化读取数据

 def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
            return batch

在阅读上面代码时候,我们可以假设,我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取的数据就只需要对应index即可,即上面代码中的 i n d i c e s indices indices,而选取index的方式有多种:有按顺序的,也有乱序的,所以这个工作需要 S a m p l e r Sampler Sampler来完成,现在你不需要具体的细节,后面会介绍,只需要了解 D a t a L o a d e r DataLoader DataLoader S a m p l e r Sampler Sampler在这里产生关系.
那么 D a t a s e t Dataset Dataset D a t a L o a d e r DataLoader DataLoader在什么时候产生关系呢?没错就是下面一行,我们已经拿到了 i n d i c e s indices indices,那么下一步,我们只需要根据 i n d i c e s indices indices对数据进行读取即可.

在下面 i f if if语句的作用都是,如果 p i n m e m o r y = T r u e , pin_memory=True, pinmemory=True,,那么 P y t o r c h Pytorch Pytorch会采用一系列操作把数据拷贝到GPU中,总之为了加速.

综上,可以了解DataLoader Sampler和Dataset三者关系如下:
在这里插入图片描述
在阅读后文中,始终需要将上面的关系记在心里,这样能帮助你更好的理解

Sampler

参数传递

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None)

要更加细致的理解 S a m p l e r Sampler Sampler原理,我们需要先阅读以下 D a t a L o a d e r DataLoader DataLoader的源代码 如下:
可以看到初始化参数有两种 S a m p l e r Sampler Sampler : Sampler和batch_sampler
都默认为None,前者作用是生成一系列 i n d e x index index,而batch_sampler则是将sampler生成indices打包分组,得到一个又一个batch的index,例如,下面所示示例:
Batchsampler将 S e q u e n t i a l S a m p l e r SequentialSampler SequentialSampler,生成的index按照指定的batchsize分组.

>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

pyTorch已经实现的sampler有以下几种

  • SequentialSampler

  • RandomSampler

  • WeightedSampler

  • SubsetRandomSampler
    需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读理解源码更深刻的理解,这里只做总结:

  • 源码

    • 如果自定义batch_sampler,那么这些参数都必须使用默认值:batch_size Shuffle sampler drop_last.
    • 如果自定义了sampler :那么shuffle需要设置为false
    • 如果sampler和batch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况
      • 若shuffle = True时,则sampler=RandomSampler(dataset)
      • 若shuffle = False时,则sampler=SequentialSampler(dataset)

如果定义sampler和BatchSampler

仔细查看源代码可以发现,所有采样器其实都继承同一个父类,即 S a m p l e r Sampler Sampler,其代码定义如下:

class Sampler(object):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.
    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """
 
    def __init__(self, data_source):
        pass
 
    def __iter__(self):
        raise NotImplementedError
		
    def __len__(self):
        return len(self.data_source)

所以,你要做好的都是定义好__iter__(self) 函数,不过要注意的是该函数的返回值需要是可迭代的,例如 S e q u e n t i a l S a m p l e r SequentialSampler SequentialSampler返回的是:
iter(range(len(self.data_source)))
另外 B a t c h S a m p l e r BatchSampler BatchSampler与其他 S a m p l e r Sampler Sampler的主要区别是其需要将 S a m p l e r Sampler Sampler作为参数进行打包,进而每次迭代返回以batch size为大小的index列表,也就是说后面读取数据的过程中都是 b a t c h s a m p l e r batch sampler batchsampler.

Dataset

定义如下

class Dataset(object):
	def __init__(self):
		...
		
	def __getitem__(self, index):
		return ...
	
	def __len__(self):
		return ...

上面三个方法最基本的,其中__getitem__是最主要的方法,其规定了如何读取数据,但是其又不同于一般的方法,因为它是 p y t h o n b u i l t − i n python built-in pythonbuiltin方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问,加入你定义好一个dataset,那么可以直接通过dataset[0]来访问第一个数据,在之前,我一值没弄清__getitem__是什么作用,所以一值不知道该怎么进入这个函数进行调试,现在如果你想对__getitem__方法进行调试,可以写一个for循环遍历dataset来进行调试,而不用构建dataloader等一大堆东西啦,建议学会使用ipdb这个库非常实用以后有时间再写一篇ipdb的使用教程。另外其实我们通过最前面的Dataloader的__next__函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了我直接复制了一遍如下

class DataLoader(object): 
    ... 
     
    def __next__(self): 
        if self.num_workers == 0:   
            indices = next(self.sample_iter)  
            batch = self.collate_fn([self.dataset[i] for i in indices]) # this line 
            if self.pin_memory: 
                batch = _utils.pin_memory.pin_memory_batch(batch) 
            return batch

我们仔细可以发现,前面有一个self.collate_fn方法,这个是干嘛用的呢?在介绍前,我们需要知道每个参数的含义:

  • indices: 表示每一个iterationsampler返回的indices即一个batch size大小的索引列表
  • self.dataset[i] 这里对第i个数据进行读取操作.
    一般来说:self.dataset[i]=(img, label)

我们不难猜出,collate_fn的作用就是将一个batch的数据进行合并的操作,默认的是collate_fn是将img和label分别合并成 i m g s imgs imgs l a b e l s labels labels,所以,如果你的__getitem__方法只是返回img,label.那么你可以使用默认的collate_fn方法,但是如果你每次读取的数据有img, box, label等等那么你就需要自定义collate_fn来将对应的数据合并成一个batch数据这样方便后续的训练步骤。

自己理解在这里插入图片描述

DataLoader Dataset和Sampler之间的关系

  • Sampler产生对数据进行采样
  • Dataset:产生数据
  • DataLoader将数据迭代产生batch_size数据格式.

总结

会自己看源代码,根据源代码了解,这里只是做总结
慢慢的将各种数据之间的关系都搞明白,全部都将其搞透彻.

阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6