Pytorch DataLoader

less than 1 minute read

Published:

pytroch的数据集模块Dataset,支持iterable-style datasetsmap-style datasets两种类型的数据,详见官网链接:pytorch dataset

  • Map-style datasets用dataset[idx]来获取数据,读取的是硬盘中的idx-th图像数据和对应的label
  • Iterable-style datasets则是使用的iter(dataset)这种语法,用于读取database中的流,比如一个远程的服务器(or logs generated in real time?)

在调用的时候直接使用:

from torch.utils.data import Dataset
class Mydataset(Dataset):

这种语法就可以继承自Dataset类,对类内的函数__len__方法(提供dataset的大小)以及__getitem__方法(提供整形索引)进行改写,就可以实现特定的功能,对数据进行索引

dataset = myDataset(csv_file,root_dir)
dataset[0]

在查看一些代码的时候发现并没有调用dataset[idx]这种语法,而是直接使用pytorch自带的DataLoader将dataset作为传参传入,语法:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)

跑了程序之后感觉用这种方法加载dataset的时候dataset[idx]中的idx整数索引是随机的,官网上的解释是: >DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

参考链接:pytorch DataLoader

先速记一下,TBC…