tensordict 包¶
TensorDict
类通过将多个张量打包到一个类似字典的对象中(该对象继承了常规 PyTorch 张量的特性)来简化模块之间传递多个张量的过程。
TensorDictBase 是 TensorDict 的抽象父类,一个 torch.Tensor 数据容器。 |
|
|
一个批处理的张量字典。 |
|
TensorDict 的延迟堆叠。 |
|
持久化 TensorDict 实现。 |
|
保存一个包含参数的 TensorDictBase 实例。 |
TensorDict 作为上下文管理器¶
TensorDict
可用作上下文管理器,用于需要执行某个操作并在之后撤消该操作的情况。这包括临时锁定/解锁 tensordict
>>> data.lock_() # data.set will result in an exception
>>> with data.unlock_():
... data.set("key", value)
>>> assert data.is_locked()
或使用包含模型参数和缓冲区的 TensorDict 实例执行函数调用
>>> params = TensorDict.from_module(module).clone()
>>> params.zero_()
>>> with params.to_module(module):
... y = module(x)
在第一个示例中,我们可以修改 tensordict data,因为我们已将其临时解锁。在第二个示例中,我们使用 params tensordict 实例中的参数和缓冲区填充模块,并在该调用完成后重置原始参数。
内存映射张量¶
tensordict 提供了 MemoryMappedTensor
原语,它允许您以方便的方式处理存储在物理内存中的张量。 MemoryMappedTensor
的主要优点在于其易于构造(无需处理张量的存储)、能够处理无法放入内存的大型连续数据、跨进程高效地进行(反)序列化以及对存储的张量进行高效索引。
如果所有工作器都可以访问相同的存储(在多进程和分布式设置中),则传递 MemoryMappedTensor
只需传递对磁盘上文件的引用以及一些重建所需的额外元数据。只要其存储的数据指针与原始数据指针相同,索引的内存映射张量也适用。
索引内存映射张量比从磁盘加载多个独立文件快得多,并且不需要将数组的全部内容加载到内存中。但是,PyTorch 张量的物理存储不应该有任何不同
>>> my_images = MemoryMappedTensor.empty((1_000_000, 3, 480, 480), dtype=torch.unint8)
>>> mini_batch = my_images[:10] # just reads the first 10 images of the dataset
|
一个内存映射张量。 |
实用程序¶
|
在右侧扩展张量以匹配另一个张量的形状。 |
|
在右侧扩展张量以匹配所需的形状。 |
|
测试输入 |
|
删除在指定维度上 key 中重复的索引。 |
|
|
|
检查数据对象或类型是否来自 tensordict 库的张量容器。 |
|
返回根据关键字参数或输入字典创建的 TensorDict。 |
|
将 tensordicts 合并在一起。 |
|
沿着批处理维度使用常量值填充 tensordict 中的所有张量,并返回一个新的 tensordict。 |
|
填充 tensordict 列表,以便以连续格式将它们堆叠在一起。 |
|
密集地堆叠 |
|
将某些方法的行为设置为延迟转换。 |
|
如果将为选定的方法使用延迟表示,则返回 True。 |