torch.utils.data¶
PyTorch 数据加载实用程序的核心是 torch.utils.data.DataLoader
类。它表示数据集上的 Python 可迭代对象,并支持
这些选项由 DataLoader
的构造函数参数配置,其签名为
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
以下各节将详细描述这些选项的效果和用法。
数据集类型¶
DataLoader
构造函数最重要的参数是 dataset
,它指示要从中加载数据的数据集对象。PyTorch 支持两种不同类型的数据集
映射式数据集¶
映射式数据集是实现 __getitem__()
和 __len__()
协议的数据集,它表示从(可能是非整数值)索引/键到数据样本的映射。
例如,当使用 dataset[idx]
访问此类数据集时,它可以从磁盘上的文件夹中读取第 idx
幅图像及其对应的标签。
有关更多详细信息,请参阅 Dataset
。
可迭代式数据集¶
可迭代式数据集是 IterableDataset
子类的实例,它实现了 __iter__()
协议,并表示数据样本上的可迭代对象。这种类型的数据集特别适合随机读取代价高昂甚至不可能的情况,以及批大小取决于获取的数据的情况。
例如,当调用 iter(dataset)
时,此类数据集可以返回从数据库、远程服务器甚至实时生成的日志中读取的数据流。
有关更多详细信息,请参阅 IterableDataset
。
注意
当将 IterableDataset
与 多进程数据加载 一起使用时。同一数据集对象会在每个工作进程中复制,因此必须对副本进行不同的配置以避免数据重复。有关如何实现此目的,请参阅 IterableDataset
文档。
数据加载顺序和 Sampler
¶
对于 可迭代式数据集,数据加载顺序完全由用户定义的可迭代对象控制。这使得更容易实现块读取和动态批大小(例如,每次产生一个批处理样本)。
本节的其余部分涉及 映射式数据集 的情况。torch.utils.data.Sampler
类用于指定数据加载中使用的索引/键的序列。它们表示数据集索引上的可迭代对象。例如,在随机梯度下降 (SGD) 的常见情况下,Sampler
可以随机排列索引列表并每次产生一个,或者为小批量 SGD 产生少量索引。
根据传递给DataLoader
的shuffle
参数,将自动构建顺序或随机采样器。或者,用户可以使用sampler
参数指定一个自定义的Sampler
对象,该对象每次都会生成下一个要获取的索引/键。
可以将一个自定义的Sampler
(每次生成一批索引列表)作为batch_sampler
参数传递。还可以通过batch_size
和drop_last
参数启用自动批处理。有关这方面的更多详细信息,请参阅下一节。
注意
sampler
和batch_sampler
都不兼容可迭代式数据集,因为此类数据集没有键或索引的概念。
加载批处理和非批处理数据¶
DataLoader
支持通过参数batch_size
、drop_last
、batch_sampler
和collate_fn
(具有默认函数)将单个获取的数据样本自动合并成批次。
自动批处理(默认)¶
这是最常见的情况,对应于获取数据的小批量并将其合并成批处理样本,即包含一个维度作为批处理维度(通常是第一个)的张量。
当batch_size
(默认值为1
)不为None
时,数据加载器会生成批处理样本而不是单个样本。batch_size
和drop_last
参数用于指定数据加载器如何获取数据集键的批次。对于映射式数据集,用户还可以指定batch_sampler
,它每次都会生成一个键列表。
注意
batch_size
和drop_last
参数实际上用于从sampler
构建batch_sampler
。对于映射式数据集,sampler
要么由用户提供,要么根据shuffle
参数构建。对于可迭代式数据集,sampler
是一个虚拟的无限采样器。有关采样器的更多详细信息,请参阅此部分。
使用采样器中的索引获取样本列表后,将作为collate_fn
参数传递的函数用于将样本列表合并成批次。
在这种情况下,从映射式数据集加载大致等效于
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
从可迭代式数据集加载大致等效于
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
可以使用自定义的collate_fn
来自定义合并,例如,将顺序数据填充到批次的最大长度。有关collate_fn
的更多信息,请参阅此部分。
禁用自动批处理¶
在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者只加载单个样本。例如,直接加载批处理数据(例如,从数据库批量读取或读取连续的内存块)可能更便宜,或者批处理大小取决于数据,或者程序设计为处理单个样本。在这些情况下,最好不要使用自动批处理(其中collate_fn
用于合并样本),而是让数据加载器直接返回dataset
对象的每个成员。
当batch_size
和batch_sampler
都为None
(batch_sampler
的默认值已为None
)时,自动批处理将被禁用。从dataset
获取的每个样本都将使用作为collate_fn
参数传递的函数进行处理。
**当禁用自动批处理时**,默认的collate_fn
只会将NumPy数组转换为PyTorch张量,并保持其他所有内容不变。
在这种情况下,从映射式数据集加载大致等效于
for index in sampler:
yield collate_fn(dataset[index])
从可迭代式数据集加载大致等效于
for data in iter(dataset):
yield collate_fn(data)
有关collate_fn
的更多信息,请参阅此部分。
使用collate_fn
¶
启用或禁用自动批处理时,collate_fn
的使用略有不同。
**当禁用自动批处理时**,collate_fn
将使用每个单独的数据样本调用,并且输出将从数据加载器迭代器生成。在这种情况下,默认的collate_fn
只会将NumPy数组转换为PyTorch张量。
**当启用自动批处理时**,collate_fn
每次都会使用数据样本列表调用。它应该将输入样本合并成批次,以便从数据加载器迭代器生成。本节的其余部分描述了默认collate_fn
(default_collate()
)的行为。
例如,如果每个数据样本包含一个3通道图像和一个整数类标签,即数据集的每个元素都返回一个元组(image, class_index)
,则默认的collate_fn
会将此类元组列表合并成一个元组,该元组包含批处理图像张量和批处理类标签张量。特别是,默认的collate_fn
具有以下属性
它始终在前面添加一个新的维度作为批处理维度。
它会自动将NumPy数组和Python数值转换为PyTorch张量。
它保留数据结构,例如,如果每个样本都是一个字典,它会输出一个具有相同键集但批处理张量作为值的字典(如果值不能转换为张量,则为列表)。对于
list
、tuple
、namedtuple
等也是如此。
用户可以使用自定义的collate_fn
来实现自定义批处理,例如,沿着除第一个维度以外的其他维度合并,填充不同长度的序列,或添加对自定义数据类型的支持。
如果您遇到DataLoader
的输出具有与预期不同的维度或类型的状况,则可能需要检查您的collate_fn
。
单进程和多进程数据加载¶
DataLoader
默认使用单进程数据加载。
在Python进程内,全局解释器锁 (GIL)阻止了跨线程真正完全并行化Python代码。为了避免数据加载阻塞计算代码,PyTorch 提供了一个简单的切换方式,只需将参数num_workers
设置为正整数即可执行多进程数据加载。
单进程数据加载(默认)¶
在此模式下,数据获取在初始化DataLoader
的同一进程中完成。因此,数据加载可能会阻塞计算。但是,当用于在进程之间共享数据的资源(例如,共享内存、文件描述符)有限,或者整个数据集很小并且可以完全加载到内存中时,可能更喜欢此模式。此外,单进程加载通常显示更易读的错误跟踪,因此对于调试很有用。
多进程数据加载¶
将参数num_workers
设置为正整数将开启多进程数据加载,并使用指定的加载器工作进程数量。
警告
经过多次迭代后,加载器工作进程将消耗与父进程相同的 CPU 内存,用于父进程中所有从工作进程访问的 Python 对象。如果数据集包含大量数据(例如,您在数据集构建时加载了非常大的文件名列表)和/或您使用了大量工作进程(总内存使用量为number of workers * size of parent process
),这可能会成为问题。最简单的解决方法是用非引用计数表示(例如 Pandas、Numpy 或 PyArrow 对象)替换 Python 对象。查看issue #13246,了解为什么会发生这种情况以及如何解决这些问题的示例代码。
在此模式下,每次创建DataLoader
的迭代器时(例如,当您调用enumerate(dataloader)
时),都会创建num_workers
个工作进程。此时,dataset
、collate_fn
和worker_init_fn
将传递给每个工作进程,用于初始化和获取数据。这意味着数据集访问及其内部 IO、转换(包括collate_fn
)在工作进程中运行。
torch.utils.data.get_worker_info()
在工作进程中返回各种有用的信息(包括工作进程 ID、数据集副本、初始种子等),并在主进程中返回None
。用户可以在数据集代码和/或worker_init_fn
中使用此函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。例如,这在对数据集进行分片时特别有用。
对于映射式数据集,主进程使用sampler
生成索引并将其发送到工作进程。因此,任何随机洗牌操作都在主进程中完成,该进程通过分配要加载的索引来指导加载。
对于可迭代式数据集,由于每个工作进程都获得dataset
对象的副本,因此简单的多进程加载通常会导致数据重复。使用torch.utils.data.get_worker_info()
和/或worker_init_fn
,用户可以独立配置每个副本。(请参阅IterableDataset
文档以了解如何实现此目的。)出于类似原因,在多进程加载中,drop_last
参数会删除每个工作进程的可迭代式数据集副本的最后一个非完整批次。
一旦达到迭代结束或迭代器被垃圾回收,工作进程就会关闭。
警告
通常不建议在多进程加载中返回 CUDA 张量,因为在多处理中使用 CUDA 和共享 CUDA 张量存在许多细微差别(请参阅CUDA in multiprocessing)。相反,我们建议使用自动内存固定(即,设置pin_memory=True
),这可以实现快速的数据传输到支持 CUDA 的 GPU。
平台特定行为¶
由于工作进程依赖于 Python multiprocessing
,因此工作进程启动行为在 Windows 上与 Unix 上不同。
在 Unix 上,
fork()
是默认的multiprocessing
启动方法。使用fork()
,子工作进程通常可以通过克隆的地址空间直接访问dataset
和 Python 参数函数。在 Windows 或 MacOS 上,
spawn()
是默认的multiprocessing
启动方法。使用spawn()
,将启动另一个解释器,该解释器运行您的主脚本,然后是接收dataset
、collate_fn
和其他参数的内部工作进程函数,这些参数通过pickle
序列化。
这种单独的序列化意味着您应该采取两个步骤来确保在使用多进程数据加载时与 Windows 兼容
将大部分主脚本代码包装在
if __name__ == '__main__':
块中,以确保在启动每个工作进程时不会再次运行它(很可能产生错误)。您可以将您的数据集和DataLoader
实例创建逻辑放在这里,因为它不需要在工作进程中重新执行。确保任何自定义
collate_fn
、worker_init_fn
或dataset
代码都声明为顶级定义,位于__main__
检查之外。这确保了它们在工作进程中可用。(这是必要的,因为函数仅作为引用进行腌制,而不是bytecode
。)
多进程数据加载中的随机性¶
默认情况下,每个工作进程的 PyTorch 种子将设置为base_seed + worker_id
,其中base_seed
是由主进程使用其 RNG 生成的长整型数(因此,强制性地消耗 RNG 状态)或指定的generator
。但是,其他库的种子在初始化工作进程时可能会重复,导致每个工作进程返回相同的随机数。(请参阅 FAQ 中的此部分。)
在worker_init_fn
中,您可以使用torch.utils.data.get_worker_info().seed
或torch.initial_seed()
访问为每个工作进程设置的 PyTorch 种子,并在数据加载前使用它来设置其他库的种子。
内存固定¶
当主机到 GPU 的复制源自固定(页面锁定)内存时,速度会快得多。有关何时以及如何一般使用固定内存的更多详细信息,请参阅Use pinned memory buffers。
对于数据加载,将pin_memory=True
传递给DataLoader
会自动将获取的数据张量放入固定内存,从而实现更快的数据传输到支持 CUDA 的 GPU。
默认的内存固定逻辑仅识别张量以及包含张量的映射和可迭代对象。默认情况下,如果固定逻辑看到一个批次是自定义类型(如果您的collate_fn
返回自定义批次类型,则会发生这种情况),或者如果批次的每个元素都是自定义类型,则固定逻辑将无法识别它们,并且它将返回该批次(或这些元素)而不会固定内存。要为自定义批次或数据类型启用内存固定,请在您的自定义类型上定义一个pin_memory()
方法。
请参阅下面的示例。
示例
class SimpleCustomBatch:
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
self.tgt = torch.stack(transposed_data[1], 0)
# custom memory pinning method on custom type
def pin_memory(self):
self.inp = self.inp.pin_memory()
self.tgt = self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True)
for batch_ndx, sample in enumerate(loader):
print(sample.inp.is_pinned())
print(sample.tgt.is_pinned())
- class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')[source]¶
数据加载器将数据集和采样器组合在一起,并提供给定数据集的可迭代对象。
DataLoader
支持使用单进程或多进程加载的映射式和可迭代式数据集,自定义加载顺序以及可选的自动批处理(整理)和内存固定。有关更多详细信息,请参阅
torch.utils.data
文档页面。- 参数
dataset (Dataset) – 要从中加载数据的dataset。
batch_size (int, 可选) – 要加载的每个批次的样本数(默认值:
1
)。shuffle (bool, 可选) – 设置为
True
以在每个 epoch 中对数据重新洗牌(默认值:False
)。sampler (Sampler 或 Iterable, 可选) – 定义从数据集中抽取样本的策略。可以是实现了
__len__
的任何Iterable
。如果指定,则不能指定shuffle
。batch_sampler (Sampler 或 Iterable, 可选) – 与
sampler
类似,但一次返回一批索引。与batch_size
、shuffle
、sampler
和drop_last
互斥。num_workers (int, 可选) – 用于数据加载的子进程数。
0
表示数据将在主进程中加载。(默认值:0
)collate_fn (Callable, 可选) – 将样本列表合并以形成 Tensor(s) 的小批量。在使用来自映射式数据集的批处理加载时使用。
pin_memory (bool, 可选) – 如果
True
,则数据加载器将在返回张量之前将其复制到设备/CUDA 固定内存中。如果您的数据元素是自定义类型,或者您的collate_fn
返回一个自定义类型的批次,请参见下面的示例。drop_last (bool, 可选) – 设置为
True
以删除最后一个不完整的批次,如果数据集大小不能被批次大小整除。如果False
并且数据集的大小不能被批次大小整除,则最后一个批次将更小。(默认值:False
)timeout (数值, 可选) – 如果为正数,则表示从工作进程中收集批次的超时值。应始终为非负数。(默认值:
0
)worker_init_fn (Callable, 可选) – 如果不为
None
,则在每个工作进程子进程上调用它,并以工作进程 ID([0, num_workers - 1]
中的整数)作为输入,在播种和数据加载之前。(默认值:None
)multiprocessing_context (str 或 multiprocessing.context.BaseContext, 可选) – 如果为
None
,则将使用操作系统的默认 多处理上下文。(默认值:None
)generator (torch.Generator, 可选) – 如果不为
None
,则 RandomSampler 将使用此 RNG 生成随机索引,并且多处理将使用此 RNG 为工作进程生成base_seed
。(默认值:None
)prefetch_factor (int, 可选, 仅限关键字参数) – 每个工作进程预先加载的批次数。
2
表示所有工作进程总共将预取 2 * num_workers 个批次。(默认值取决于为 num_workers 设置的值。如果 num_workers 的值为 0,则默认为None
。否则,如果 num_workers 的值> 0
,则默认为2
)。persistent_workers (bool, 可选) – 如果
True
,则数据加载器不会在数据集被使用一次后关闭工作进程。这允许保持工作进程的 Dataset 实例处于活动状态。(默认值:False
)pin_memory_device (str, 可选) – 如果
pin_memory
为True
,则要将内存固定到的设备。
警告
如果使用
spawn
启动方法,则worker_init_fn
不能是不可 pickle 的对象,例如 lambda 函数。有关 PyTorch 中多处理的更多详细信息,请参见 多处理最佳实践。警告
len(dataloader)
启发式方法基于所用采样器的长度。当dataset
是IterableDataset
时,它会根据len(dataset) / batch_size
返回一个估计值,并根据drop_last
进行适当的舍入,而不管多进程加载配置如何。这表示 PyTorch 可以做出的最佳猜测,因为 PyTorch 相信用户dataset
代码可以正确处理多进程加载以避免重复数据。但是,如果分片导致多个工作进程具有不完整的最后一个批次,则此估计值仍然可能不准确,因为 (1) 一个原本完整的批次可以被分成多个批次,以及 (2) 当设置
drop_last
时,可以丢弃多个批次数量的样本。不幸的是,PyTorch 通常无法检测到此类情况。有关这两种类型的数据集以及
IterableDataset
如何与 多进程数据加载 交互的更多详细信息,请参见 数据集类型。警告
有关随机种子相关问题的说明,请参见 可重复性、我的数据加载器工作进程返回相同的随机数 和 多进程数据加载中的随机性。
- 类 torch.utils.data.Dataset[源代码]¶
表示一个
Dataset
的抽象类。所有表示键到数据样本映射的数据集都应该继承它。所有子类都应该重写
__getitem__()
,支持为给定键获取数据样本。子类还可以选择性地重写__len__()
,许多Sampler
实现和DataLoader
的默认选项都期望它返回数据集的大小。子类还可以选择性地实现__getitems__()
,以加速批量样本加载。此方法接受批量样本索引列表并返回样本列表。注意
DataLoader
默认构造一个产生整数索引的索引采样器。为了使其与具有非整数索引/键的映射样式数据集一起使用,必须提供自定义采样器。
- 类 torch.utils.data.IterableDataset[源代码]¶
一个可迭代的数据集。
所有表示数据样本可迭代的数据集都应该继承它。当数据来自流时,这种形式的数据集特别有用。
所有子类都应该重写
__iter__()
,它将返回此数据集中样本的迭代器。当子类与
DataLoader
一起使用时,数据集中的每个项目都将从DataLoader
迭代器中产生。当num_workers > 0
时,每个工作进程将拥有数据集对象的副本,因此通常希望独立配置每个副本以避免工作进程返回重复的数据。get_worker_info()
在工作进程中调用时,返回有关工作进程的信息。它可以在数据集的__iter__()
方法或DataLoader
的worker_init_fn
选项中使用来修改每个副本的行为。示例 1:在
__iter__()
中跨所有工作进程分配工作负载>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
示例 2:使用
worker_init_fn
跨所有工作进程分配工作负载>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]
- 类 torch.utils.data.TensorDataset(*tensors)[源代码]¶
包装张量的数据集。
每个样本将通过沿第一个维度索引张量来检索。
- 参数
*tensors (张量) – 第一个维度大小相同的张量。
- 类 torch.utils.data.StackDataset(*args, **kwargs)[源代码]¶
将多个数据集堆叠起来的数据集。
此类用于组装作为数据集给出的复杂输入数据的不同部分。
示例
>>> images = ImageDataset() >>> texts = TextDataset() >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) >>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
- 类 torch.utils.data.ConcatDataset(datasets)[源代码]¶
将多个数据集连接起来的数据集。
此类用于组装不同的现有数据集。
- 参数
datasets (序列) – 要连接的数据集列表
- 类 torch.utils.data.ChainDataset(datasets)[源代码]¶
用于将多个
IterableDataset
连接起来的数据集。此类用于组装不同的现有数据集流。链接操作是即时进行的,因此使用此类连接大规模数据集将非常高效。
- 参数
datasets (IterableDataset的可迭代对象) – 要连接在一起的数据集
- 类 torch.utils.data.Subset(dataset, indices)[源代码]¶
指定索引处数据集的子集。
- 参数
dataset (Dataset) – 整个数据集
indices (序列) – 为子集选择的整个集合中的索引
- torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[源代码]¶
处理每个批次内元素的集合类型的通用整理函数。
该函数还打开函数注册表以处理特定的元素类型。default_collate_fn_map为张量、NumPy 数组、数字和字符串提供默认的整理函数。
- 参数
示例
>>> def collate_tensor_fn(batch, *, collate_fn_map): ... # Extend this function to handle batch of tensors ... return torch.stack(batch, 0) >>> def custom_collate(batch): ... collate_map = {torch.Tensor: collate_tensor_fn} ... return collate(batch, collate_fn_map=collate_map) >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
注意
每个整理函数都需要一个用于批次的定位参数和一个用于整理函数字典的关键字参数,作为collate_fn_map。
- torch.utils.data.default_collate(batch)[源代码]¶
接收一批数据,并将批次内的元素放入具有额外外部维度(批次大小)的张量中。
确切的输出类型可以是
torch.Tensor
、Sequence oftorch.Tensor
、torch.Tensor
的集合,或保持不变,具体取决于输入类型。当在DataLoader
中定义batch_size或batch_sampler时,将其用作整理的默认函数。以下是根据批次内元素的类型进行的通用输入类型到输出类型映射
torch.Tensor
->torch.Tensor
(添加了一个外部批大小维度)NumPy 数组 ->
torch.Tensor
float ->
torch.Tensor
int ->
torch.Tensor
str -> str(不变)
bytes -> bytes(不变)
Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]
NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]
Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]
- 参数
batch – 要整理的单个批次
示例
>>> # Example with a batch of `int`s: >>> default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: >>> default_collate(['a', 'b', 'c']) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: >>> Point = namedtuple('Point', ['x', 'y']) >>> default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: >>> default_collate([(0, 1), (2, 3)]) [tensor([0, 2]), tensor([1, 3])] >>> # Example with `List` inside the batch: >>> default_collate([[0, 1], [2, 3]]) [tensor([0, 2]), tensor([1, 3])] >>> # Two options to extend `default_collate` to handle specific type >>> # Option 1: Write custom collate function and invoke `default_collate` >>> def custom_collate(batch): ... elem = batch[0] ... if isinstance(elem, CustomType): # Some custom condition ... return ... ... else: # Fall back to `default_collate` ... return default_collate(batch) >>> # Option 2: In-place modify `default_collate_fn_map` >>> def collate_customtype_fn(batch, *, collate_fn_map=None): ... return ... >>> default_collate_fn_map.update(CustomType, collate_customtype_fn) >>> default_collate(batch) # Handle `CustomType` automatically
- torch.utils.data.default_convert(data)[source]¶
将每个 NumPy 数组元素转换为
torch.Tensor
。如果输入是 Sequence、Collection 或 Mapping,则尝试将其中的每个元素转换为
torch.Tensor
。如果输入不是 NumPy 数组,则保持不变。当DataLoader
中未定义 batch_sampler 和 batch_size 时,此函数用作整理的默认函数。输入类型到输出类型的映射与
default_collate()
的类似。有关更多详细信息,请参阅那里的描述。- 参数
data – 要转换的单个数据点
示例
>>> # Example with `int` >>> default_convert(0) 0 >>> # Example with NumPy array >>> default_convert(np.array([0, 1])) tensor([0, 1]) >>> # Example with NamedTuple >>> Point = namedtuple('Point', ['x', 'y']) >>> default_convert(Point(0, 0)) Point(x=0, y=0) >>> default_convert(Point(np.array(0), np.array(0))) Point(x=tensor(0), y=tensor(0)) >>> # Example with List >>> default_convert([np.array([0, 1]), np.array([2, 3])]) [tensor([0, 1]), tensor([2, 3])]
- torch.utils.data.get_worker_info()[source]¶
返回有关当前
DataLoader
迭代器工作进程的信息。在工作进程中调用时,这将返回一个保证具有以下属性的对象
id
:当前工作进程的 ID。num_workers
:工作进程的总数。seed
:为当前工作进程设置的随机种子。此值由主进程 RNG 和工作进程 ID 确定。有关更多详细信息,请参阅DataLoader
的文档。dataset
:**此**进程中数据集对象的副本。请注意,这将在与主进程中不同的进程中成为不同的对象。
在主进程中调用时,这将返回
None
。注意
在传递给
DataLoader
的worker_init_fn
中使用时,此方法可用于以不同的方式设置每个工作进程,例如,使用worker_id
将dataset
对象配置为仅读取分片数据集的特定部分,或使用seed
为数据集代码中使用的其他库设置种子。- 返回类型
Optional[WorkerInfo]
- torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[source]¶
将数据集随机拆分为给定长度的非重叠新数据集。
如果给定一个总和为 1 的分数列表,则将自动计算长度,对于每个提供的分数,计算结果为 floor(frac * len(dataset))。
计算长度后,如果存在任何余数,则将以循环方式将 1 个计数分配给长度,直到没有剩余余数。
可以选择修复生成器以获得可重复的结果,例如:
示例
>>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
- class torch.utils.data.Sampler(data_source=None)[source]¶
所有采样器的基类。
每个 Sampler 子类都必须提供一个
__iter__()
方法,提供一种迭代数据集元素的索引或索引列表(批次)的方法,并且可以提供一个__len__()
方法,该方法返回返回的迭代器的长度。- 参数
data_source (Dataset) – 此参数未使用,将在 2.2.0 中删除。您可能仍然拥有利用它的自定义实现。
示例
>>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist()
注意
__len__()
方法不是DataLoader
严格要求的,但在涉及DataLoader
长度的任何计算中都预期。
- class torch.utils.data.SequentialSampler(data_source)[source]¶
按顺序采样元素,始终以相同的顺序。
- 参数
data_source (Dataset) – 要从中采样的数据集
- class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[source]¶
随机采样元素。如果不用替换,则从洗牌后的数据集中采样。
如果使用替换,则用户可以指定
num_samples
进行抽取。
- class torch.utils.data.SubsetRandomSampler(indices, generator=None)[source]¶
从给定的索引列表中随机采样元素,不进行替换。
- 参数
indices (sequence) – 索引序列
generator (Generator) – 用于采样的生成器。
- class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[source]¶
根据给定的概率(权重)从
[0,..,len(weights)-1]
中采样元素。- 参数
示例
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2]
- class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[source]¶
包装另一个采样器以生成一个小批量的索引。
- 参数
示例
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[source]¶
限制数据加载到数据集子集的采样器。
它与
torch.nn.parallel.DistributedDataParallel
结合使用时尤其有用。在这种情况下,每个进程都可以将DistributedSampler
实例作为DataLoader
采样器,并加载与其独有的原始数据集的子集。注意
假设数据集的大小恒定,并且它的任何实例始终以相同的顺序返回相同的元素。
- 参数
dataset (Dataset) – 用于采样的数据集。
num_replicas (int, 可选) – 参与分布式训练的进程数。默认情况下,
world_size
是从当前分布式组中检索的。rank (int, 可选) – 当前进程在
num_replicas
中的排名。默认情况下,rank
是从当前分布式组中检索的。shuffle (bool, 可选) – 如果为
True
(默认值),则采样器将打乱索引。seed (int, 可选) – 如果
shuffle=True
,则用于打乱采样器的随机种子。此数字在分布式组中的所有进程中应相同。默认值:0
。drop_last (bool, 可选) – 如果为
True
,则采样器将删除数据的尾部,使其能够在副本数量之间均匀划分。如果为False
,则采样器将添加额外的索引,以使数据能够在副本之间均匀划分。默认值:False
。
警告
在分布式模式下,在创建
DataLoader
迭代器**之前**,在每个 epoch 开始时调用set_epoch()
方法对于使跨多个 epoch 的混洗正常工作是必要的。否则,将始终使用相同的排序。示例
>>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader)