快捷方式

tensordict 包

The TensorDict class simplifies the process of passing multiple tensors from module to module by packing them in a dictionary-like object that inherits features from regular pytorch tensors.

TensorDictBase()

TensorDictBase 是 TensorDict 的抽象父类,TensorDict 是一种 torch.Tensor 数据容器。

TensorDict([source, batch_size, device, ...])

张量的批量字典。

LazyStackedTensorDict(*tensordicts[, ...])

TensorDict 的惰性堆叠。

PersistentTensorDict(*[, batch_size, ...])

持久化 TensorDict 实现。

TensorDictParams([parameters, no_convert, lock])

带有参数暴露功能的 TensorDictBase 包装器。

get_defaults_to_none([set_to_none])

返回 get 默认值的状态。

构造函数和处理程序

该库提供了一些方法来与 numpy 结构化数组、namedtuple 或 h5 文件等其他数据结构进行交互。该库还提供了专门的函数来操作 tensordict,例如 saveloadstackcat

cat(input[, dim, out])

沿着给定维度将 tensordict 连接成单个 tensordict。

default_is_leaf(cls)

如果一个类型不是张量集合(tensordict 或 tensorclass),则返回 True

from_any(obj, *[, auto_batch_size, ...])

将任意对象转换为 TensorDict。

from_consolidated(filename)

从合并文件中重建 tensordict。

from_dict(d, *[, auto_batch_size, ...])

将字典转换为 TensorDict。

from_h5(h5_file, *[, auto_batch_size, ...])

将 HDF5 文件转换为 TensorDict。

from_module(module[, as_module, lock, ...])

将模块的参数和缓冲区复制到 tensordict 中。

from_modules(*modules[, as_module, lock, ...])

通过 vmap 获取多个模块的参数,用于集成学习/期望应用的特性。

from_namedtuple(named_tuple, *[, ...])

将 namedtuple 转换为 TensorDict。

from_pytree(pytree, *[, batch_size, ...])

将 pytree 转换为 TensorDict 实例。

from_struct_array(struct_array, *[, ...])

将结构化 numpy 数组转换为 TensorDict。

from_tuple(obj, *[, auto_batch_size, ...])

将元组转换为 TensorDict。

fromkeys(keys[, value])

从键列表和单个值创建 tensordict。

is_batchedtensor(arg0)

is_leaf_nontensor(cls)

如果一个类型不是张量集合(tensordict 或 tensorclass)或不是张量,则返回 True

lazy_stack(input[, dim, out])

创建 tensordict 的惰性堆叠。

load(prefix[, device, non_blocking, out])

从磁盘加载 tensordict。

load_memmap(prefix[, device, non_blocking, out])

从磁盘加载内存映射的 tensordict。

maybe_dense_stack(input[, dim, out])

尝试对 tensordict 进行密集堆叠,并在需要时回退到惰性堆叠。

memmap(data[, prefix, copy_existing, ...])

将所有张量写入新 tensordict 中的相应内存映射张量。

save(data[, prefix, copy_existing, ...])

将 tensordict 保存到磁盘。

stack(input[, dim, out])

沿着给定维度将 tensordict 堆叠成单个 tensordict。

将 TensorDict 用作上下文管理器

TensorDict 可以在需要执行某个操作然后撤销该操作的情况下用作上下文管理器。这包括临时锁定/解锁 tensordict

>>> data.lock_()  # data.set will result in an exception
>>> with data.unlock_():
...     data.set("key", value)
>>> assert data.is_locked()

或使用包含模型参数和缓冲区的 TensorDict 实例执行函数调用

>>> params = TensorDict.from_module(module).clone()
>>> params.zero_()
>>> with params.to_module(module):
...     y = module(x)

在第一个示例中,我们可以修改 tensordict data,因为我们临时解锁了它。在第二个示例中,我们使用 params tensordict 实例中包含的参数和缓冲区填充模块,并在调用完成后重置原始参数。

内存映射张量

tensordict 提供了 MemoryMappedTensor 原语,它允许您方便地处理存储在物理内存中的张量。MemoryMappedTensor 的主要优点包括易于构建(无需处理张量的存储)、处理不适合内存的大块连续数据的能力、跨进程的高效序列化/反序列化以及存储张量的高效索引。

如果所有工作进程都可以访问相同的存储(多进程和分布式设置中均如此),传递 MemoryMappedTensor 仅需传递磁盘上文件的引用以及用于重建它的一堆额外元数据。只要索引内存映射张量的存储数据指针与原始数据指针相同,情况也是如此。

索引内存映射张量比从磁盘加载多个独立文件快得多,并且不需要将整个数组内容加载到内存中。但是,PyTorch 张量的物理存储应该没有区别

>>> my_images = MemoryMappedTensor.empty((1_000_000, 3, 480, 480), dtype=torch.unint8)
>>> mini_batch = my_images[:10]  # just reads the first 10 images of the dataset

MemoryMappedTensor(source, *[, dtype, ...])

内存映射张量。

逐点操作

Tensordict 支持各种逐点操作,允许您对其内部存储的张量执行元素级计算。这些操作与常规 PyTorch 张量上的操作类似。

支持的操作

目前支持以下逐点操作

  • 左加和右加 (+)

  • 左减和右减 (-)

  • 左乘和右乘 (*)

  • 左除和右除 (/)

  • 左乘方 (**)

还支持许多其他操作,例如 clamp()sqrt() 等。

执行逐点操作

您可以在两个 Tensordict 之间或在 Tensordict 与张量/标量值之间执行逐点操作。

示例 1:Tensordict-Tensordict 操作

>>> import torch
>>> from tensordict import TensorDict
>>> td1 = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> td2 = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> result = td1 * td2

在此示例中,* 运算符被逐元素地应用于 td1 和 td2 中对应的张量。

示例 2:Tensordict-张量操作

>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> tensor = torch.randn(4)
>>> result = td * tensor

在这里,* 运算符被逐元素地应用于 td 中的每个张量和提供的张量。该张量会被广播以匹配 Tensordict 中每个张量的形状。

示例 3:Tensordict-标量操作

>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> scalar = 2.0
>>> result = td * scalar

在这种情况下,* 运算符被逐元素地应用于 td 中的每个张量和提供的标量。

广播规则

当在 Tensordict 和张量/标量之间执行逐点操作时,张量/标量会被广播以匹配 Tensordict 中每个张量的形状:张量在左侧被广播以匹配 tensordict 的形状,然后在右侧单独广播以匹配张量的形状。如果将 TensorDict 视为单个张量实例,这遵循 PyTorch 中使用的标准广播规则。

例如,如果您有一个包含形状为 (3, 4) 的张量的 Tensordict,并将其乘以形状为 (4,) 的张量,该张量在应用操作之前将被广播为形状 (3, 4)。如果 tensordict 包含一个形状为 (3, 4, 5) 的张量,用于乘法的张量在该乘法中将在右侧广播为 (3, 4, 5)

如果在多个 tensordict 之间执行逐点操作且它们的批处理大小不同,它们将被广播到公共形状。

逐点操作的效率

如果可能,将使用 torch._foreach_<op> 融合核函数来加速逐点操作的计算。

处理缺失条目

当在两个 Tensordict 之间执行逐点操作时,它们必须具有相同的键。某些操作,如 add(),具有一个 default 关键字参数,可用于处理具有独占条目的 tensordict。如果 default=None(默认值),则两个 Tensordict 必须具有完全匹配的键集。如果 default="intersection",则仅考虑相交的键集,而忽略其他键。在所有其他情况下,default 将用于操作两侧所有缺失的条目。

工具函数

utils.expand_as_right(tensor, dest)

在右侧扩展张量以匹配另一个张量的形状。

utils.expand_right(tensor, shape)

在右侧扩展张量以匹配所需的形状。

utils.isin(input, reference, key[, dim])

测试输入中 dim 维度上 key 的每个元素是否存在于参考中。

utils.remove_duplicates(input, key[, dim, ...])

移除指定维度上 key 中重复的索引。

is_batchedtensor(arg0)

is_tensor_collection(datatype)

检查数据对象或类型是否是来自 tensordict 库的张量容器。

make_tensordict([input_dict, batch_size, ...])

返回从关键字参数或输入字典创建的 TensorDict。

merge_tensordicts(*tensordicts[, callback_exist])

合并 tensordict。

pad(tensordict, pad_size[, value])

使用常量值沿批量维度填充 tensordict 中的所有张量,并返回一个新的 tensordict。

pad_sequence(list_of_tensordicts[, pad_dim, ...])

填充 tensordict 列表,以便将它们以连续格式堆叠在一起。

dense_stack_tds(td_list[, dim])

密集堆叠具有相同结构的 TensorDictBase 对象列表(或 LazyStackedTensorDict)。

set_lazy_legacy(mode)

将某些方法的行为设置为惰性转换。

lazy_legacy([allow_none])

如果对选定方法使用惰性表示,则返回 True

parse_tensor_dict_string(s)

将 TensorDict repr 解析为 TensorDict。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源