TensorDict¶
- class tensordict.TensorDict(source: T | dict[str, CompatibleType] = None, batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, names: Sequence[str] | None = None, non_blocking: bool = None, lock: bool = False, **kwargs)¶
张量的批处理字典。
TensorDict 是一个张量容器,其中所有张量都以键值对的方式存储,并且每个元素共享相同的前
N
个前导维度形状,其中 是任意数字,且N >= 0
。此外,如果 tensordict 有指定的设备,则每个元素必须共享该设备。
TensorDict 实例支持许多常规的张量操作,但值得注意的是代数运算除外
形状操作:当调用形状操作(索引、重塑、视图、扩展、转置、置换、unsqueeze、squeeze、掩码等)时,操作就像在与批大小形状相同的张量上执行,然后扩展到右侧,例如
>>> td = TensorDict({'a': torch.zeros(3, 4, 5)}, batch_size=[3, 4]) >>> # returns a TensorDict of batch size [3, 4, 1]: >>> td_unsqueeze = td.unsqueeze(-1) >>> # returns a TensorDict of batch size [12] >>> td_view = td.view(-1) >>> # returns a tensor of batch size [12, 4] >>> a_view = td.view(-1).get("a")
类型转换操作:可以使用
>>> td_cpu = td.to("cpu") >>> dictionary = td.to_dict()
使用数据类型调用 .to() 方法将返回错误。
克隆 (
clone()
),连续 (contiguous()
);读取:td.get(key),td.get_at(key, index)
内容修改:
td.set(key, value)
,td.set_(key, value)
,td.update(td_or_dict)
,td.update_(td_or_dict)
,td.fill_(key, value)
,td.rename_key_(old_name, new_name)
等。多个 tensordict 的操作:torch.cat(tensordict_list, dim),torch.stack(tensordict_list, dim),td1 == td2,td.apply(lambda x+y, other_td) 等。
- 参数:
source (TensorDict 或 Dict[NestedKey, Union[Tensor, TensorDictBase]]) – 数据源。如果为空,则随后可以填充 tensordict。也可以通过一系列关键字参数构建
TensorDict
,就像dict(...)
的情况一样。batch_size (整数的可迭代对象, 可选) – tensordict 的批大小。只要与内容兼容,批大小就可以随后修改。如果没有提供批大小,则假定为空批大小(不会从数据中自动推断)。要自动设置批大小,请参阅
auto_batch_size_()
。device (torch.device 或 兼容类型, 可选) – TensorDict 的设备。如果提供,所有张量都将存储在该设备上。如果没有,则允许使用不同设备上的张量。
names (字符串列表, 可选) – tensordict 维度的名称。如果提供,其长度必须与
batch_size
的长度匹配。默认为None
(无维度名称,或每个维度为None
)。non_blocking (bool, 可选) – 如果为
True
并且传递了设备,则 tensordict 将在没有同步的情况下传递。这是最快的方法,但仅在从 cpu 到 cuda 进行转换时安全(否则用户必须实现同步调用)。如果传递False
,则每个张量移动都将同步完成。如果为None
(默认值),则设备转换将异步执行,但如果需要,将在创建后执行同步。此选项通常比False
快,并且可能比True
慢。lock (bool, 可选) – 如果为
True
,则生成的 tensordict 将被锁定。
示例
>>> import torch >>> from tensordict import TensorDict >>> source = {'random': torch.randn(3, 4), ... 'zeros': torch.zeros(3, 4, 5)} >>> batch_size = [3] >>> td = TensorDict(source, batch_size=batch_size) >>> print(td.shape) # equivalent to td.batch_size torch.Size([3]) >>> td_unqueeze = td.unsqueeze(-1) >>> print(td_unqueeze.get("zeros").shape) torch.Size([3, 1, 4, 5]) >>> print(td_unqueeze[0].shape) torch.Size([1]) >>> print(td_unqueeze.view(-1).shape) torch.Size([3]) >>> print((td.clone()==td).all()) True
- abs() T ¶
计算 TensorDict 中每个元素的绝对值。
- abs_() T ¶
就地计算 TensorDict 中每个元素的绝对值。
- acos() T ¶
计算 TensorDict 中每个元素的
acos()
值。
- acos_() T ¶
就地计算 TensorDict 中每个元素的
acos()
值。
- add(other: TensorDictBase | torch.Tensor, *, alpha: float | None = None, default: str | CompatibleType | None = None) TensorDictBase ¶
将
other
(乘以alpha
)加到self
上。\[\text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i\]- 参数:
other (TensorDictBase 或 torch.Tensor) – 要添加到
self
的张量或 TensorDict。- 关键字参数:
alpha (数字,可选) –
other
的乘数。default (torch.Tensor 或 str,可选) – 用于独占条目的默认值。如果没有提供,则两个 TensorDict 的键列表必须完全匹配。如果传递了
default="intersection"
,则只考虑相交的键集,其他键将被忽略。在所有其他情况下,default
将用于操作双方所有缺失的条目。
- add_(other: TensorDictBase | float, *, alpha: float | None = None)¶
add()
的就地版本。注意
就地
add
不支持default
关键字参数。
- addcdiv(other1: TensorDictBase | torch.Tensor, other2: TensorDictBase | torch.Tensor, value: float | None = 1)¶
执行
other1
除以other2
的逐元素除法,将结果乘以标量value
并将其添加到self
中。\[\text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i}\]self
、other1
和other2
的元素的形状必须是可广播的。对于 FloatTensor 或 DoubleTensor 类型的输入,
value
必须是实数,否则为整数。- 参数:
other1 (TensorDict 或 Tensor) – 分子 TensorDict(或张量)
tensor2 (TensorDict 或 Tensor) – 分母 TensorDict(或张量)
- 关键字参数:
value (数字,可选) – \(\text{tensor1} / \text{tensor2}\) 的乘数
- addcmul(other1, other2, *, value: float | None = 1)¶
执行
other1
乘以other2
的逐元素乘法,将结果乘以标量value
并将其添加到self
中。\[\text{out}_i = \text{input}_i + \text{value} \times \text{other1}_i \times \text{other2}_i\]self
、other1
和other2
的形状必须是可广播的。对于 FloatTensor 或 DoubleTensor 类型的输入,
value
必须是实数,否则为整数。- 参数:
other1 (TensorDict 或 Tensor) – 要相乘的 TensorDict 或张量
other2 (TensorDict 或 Tensor) – 要相乘的 TensorDict 或张量
- 关键字参数:
value (数字,可选) – \(other1 .* other2\) 的乘数
- all(dim: int = None) bool | TensorDictBase ¶
检查张量字典中所有值是否为 True/非空。
- 参数:
dim (int, optional) – 如果为
None
,则返回一个布尔值,指示所有张量是否返回 tensor.all() == True。如果为整数,则仅当此维度与张量字典形状兼容时,才会在指定的维度上调用 all。
- any(dim: int = None) bool | TensorDictBase ¶
检查张量字典中是否有任何值为 True/非空。
- 参数:
dim (int, optional) – 如果为
None
,则返回一个布尔值,指示所有张量是否返回 tensor.any() == True。如果为整数,则仅当此维度与张量字典形状兼容时,才会在指定的维度上调用 all。
- apply(fn: Callable, *others: T, batch_size: Sequence[int] | None = None, device: torch.device | None = _NoDefault.ZERO, names: Sequence[str] | None = _NoDefault.ZERO, inplace: bool = False, default: Any = _NoDefault.ZERO, filter_empty: bool | None = None, propagate_lock: bool = False, call_on_nested: bool = False, out: TensorDictBase | None = None, **constructor_kwargs) T | None ¶
将一个可调用对象应用于张量字典中存储的所有值,并在新的张量字典中设置它们。
可调用对象的签名必须为
Callable[Tuple[Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]
。- 参数:
fn (Callable) – 要应用于张量字典中张量的函数。
*others (TensorDictBase 实例, optional) – 如果提供,这些张量字典实例应该具有与 self 相匹配的结构。
fn
参数应接收与张量字典数量一样多的未命名输入,包括 self。如果其他张量字典缺少条目,则可以通过default
关键字参数传递默认值。
- 关键字参数:
batch_size (整数序列, optional) – 如果提供,则生成的 TensorDict 将具有所需的 batch_size。
batch_size
参数应与转换后的 batch_size 相匹配。这是一个仅限关键字的参数。device (torch.device, optional) – 生成的设备(如果有)。
names (字符串列表, optional) – 新的维度名称,如果修改了 batch_size。
inplace (bool, optional) – 如果为 True,则就地进行更改。默认为 False。这是一个仅限关键字的参数。
default (Any, optional) – 其他张量字典中缺少条目的默认值。如果未提供,则缺少的条目将引发 KeyError。
filter_empty (bool, optional) – 如果为
True
,则将过滤掉空的张量字典。这也有较低的计算成本,因为不会创建和销毁空数据结构。非张量数据被视为叶子,因此即使函数未对其进行处理,也会保留在张量字典中。默认为False
以保持向后兼容性。propagate_lock (bool, optional) – 如果为
True
,则锁定的张量字典将生成另一个锁定的张量字典。默认为False
。call_on_nested (bool, optional) –
如果为
True
,则该函数将被调用在一级张量和容器(TensorDict 或 tensorclass)上。在这种情况下,func
负责将其调用传播到嵌套级别。这允许在将调用传播到嵌套的张量字典时进行细粒度的行为。如果为False
,则该函数仅在叶子上调用,并且apply
将负责将该函数分派到所有叶子上。>>> td = TensorDict({"a": {"b": [0.0, 1.0]}, "c": [1.0, 2.0]}) >>> def mean_tensor_only(val): ... if is_tensor_collection(val): ... raise RuntimeError("Unexpected!") ... return val.mean() >>> td_mean = td.apply(mean_tensor_only) >>> def mean_any(val): ... if is_tensor_collection(val): ... # Recurse ... return val.apply(mean_any, call_on_nested=True) ... return val.mean() >>> td_mean = td.apply(mean_any, call_on_nested=True)
out (TensorDictBase, optional) –
一个用于写入结果的张量字典。这可以用来避免创建新的张量字典
>>> td = TensorDict({"a": 0}) >>> td.apply(lambda x: x+1, out=td) >>> assert (td==1).all()
警告
如果对张量字典执行的操作需要访问多个键才能进行单次计算,则提供等于
self
的out
参数可能会导致操作静默地提供错误的结果。例如>>> td = TensorDict({"a": 1, "b": 1}) >>> td.apply(lambda x: x+td["a"])["b"] # Right! tensor(2) >>> td.apply(lambda x: x+td["a"], out=td)["b"] # Wrong! tensor(3)
**constructor_kwargs – 要传递给 TensorDict 构造函数的其他关键字参数。
- 返回值:
一个包含已转换张量的新张量字典。
示例
>>> td = TensorDict({ ... "a": -torch.ones(3), ... "b": {"c": torch.ones(3)}}, ... batch_size=[3]) >>> td_1 = td.apply(lambda x: x+1) >>> assert (td_1["a"] == 0).all() >>> assert (td_1["b", "c"] == 2).all() >>> td_2 = td.apply(lambda x, y: x+y, td) >>> assert (td_2["a"] == -2).all() >>> assert (td_2["b", "c"] == 2).all()
注意
如果函数返回
None
,则忽略该条目。这可用于过滤张量字典中的数据>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, []) >>> def filter(tensor): ... if tensor == 1: ... return tensor >>> td.apply(filter) TensorDict( fields={ 1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ 1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
注意
apply 方法将返回一个
TensorDict
实例,无论输入类型是什么。要保持相同的类型,可以执行>>> out = td.clone(False).update(td.apply(...))
- apply_(fn: Callable, *others, **kwargs) T ¶
将可调用对象应用于 TensorDict 中存储的所有值,并就地重写它们。
- 参数:
fn (Callable) – 要应用于张量字典中张量的函数。
*others (sequence of TensorDictBase, optional) – 用于的其他 TensorDict。
关键字参数:参见
apply()
。- 返回值:
自身或应用了函数的自身副本
- asin() T ¶
计算 TensorDict 中每个元素的
asin()
值。
- asin_() T ¶
就地计算 TensorDict 中每个元素的
asin()
值。
- atan() T ¶
计算 TensorDict 中每个元素的
atan()
值。
- atan_() T ¶
就地计算 TensorDict 中每个元素的
atan()
值。
- auto_batch_size_(batch_dims: int | None = None) T ¶
设置 TensorDict 的最大批量大小,最多为可选的 batch_dims。
- 参数:
batch_dims (int, optional) – 如果提供,批量大小最多为
batch_dims
长度。- 返回值:
自身
示例
>>> from tensordict import TensorDict >>> import torch >>> td = TensorDict({"a": torch.randn(3, 4, 5), "b": {"c": torch.randn(3, 4, 6)}}, batch_size=[]) >>> td.auto_batch_size_() >>> print(td.batch_size) torch.Size([3, 4]) >>> td.auto_batch_size_(batch_dims=1) >>> print(td.batch_size) torch.Size([3])
- property batch_size: Size¶
TensorDict 的形状(或批量大小)。
TensorDict 的形状对应于其包含的张量的前 N 个共同维度,其中 N 是任意数量。批量大小与表示张量语义相关形状的“特征大小”形成对比。例如,一批视频可能具有形状
[B, T, C, W, H]
,其中[B, T]
是批量大小(批次和时间维度),而[C, W, H]
是特征维度(通道和空间维度)。TensorDict 的形状由用户在初始化时控制(即,不是从张量形状推断出来的)。
如果新大小与 TensorDict 内容兼容,则可以动态编辑批量大小。例如,始终允许将批量大小设置为空值。
- 返回值:
描述 TensorDict 批量大小的
Size
对象。
示例
>>> data = TensorDict({ ... "key 0": torch.randn(3, 4), ... "key 1": torch.randn(3, 5), ... "nested": TensorDict({"key 0": torch.randn(3, 4)}, batch_size=[3, 4])}, ... batch_size=[3]) >>> data.batch_size = () # resets the batch-size to an empty value
- bfloat16()¶
将所有张量转换为
torch.bfloat16
。
- bool()¶
将所有张量转换为
torch.bool
。
- classmethod cat(input, dim=0, *, out=None)¶
沿着给定维度将 TensorDict 连接成单个 TensorDict。
此调用等效于调用
torch.cat()
,但与 torch.compile 兼容。
- ceil() T ¶
计算 TensorDict 中每个元素的
ceil()
值。
- ceil_() T ¶
就地计算 TensorDict 中每个元素的
ceil()
值。
- chunk(chunks: int, dim: int = 0) tuple[TensorDictBase, ...] ¶
如果可能,将 TensorDict 分割成指定数量的块。
每个块都是输入 TensorDict 的视图。
示例
>>> td = TensorDict({ ... 'x': torch.arange(24).reshape(3, 4, 2), ... }, batch_size=[3, 4]) >>> td0, td1 = td.chunk(dim=-1, chunks=2) >>> td0['x'] tensor([[[ 0, 1], [ 2, 3]], [[ 8, 9], [10, 11]], [[16, 17], [18, 19]]])
- clamp_max(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T ¶
如果
self
的元素大于该值,则将其钳位到other
。- 参数:
other (TensorDict 或 Tensor) – 另一个输入 TensorDict 或张量。
- 关键字参数:
default (torch.Tensor 或 str,可选) – 用于独占条目的默认值。如果没有提供,则两个 TensorDict 的键列表必须完全匹配。如果传递了
default="intersection"
,则只考虑相交的键集,其他键将被忽略。在所有其他情况下,default
将用于操作双方所有缺失的条目。
- clamp_max_(other: TensorDictBase | torch.Tensor) T ¶
clamp_max()
的就地版本。注意
就地
clamp_max
不支持default
关键字参数。
- clamp_min(other: TensorDictBase | torch.Tensor, default: str | CompatibleType | None = None) T ¶
如果
self
的元素小于该值,则将其钳位到other
。- 参数:
other (TensorDict 或 Tensor) – 另一个输入 TensorDict 或张量。
- 关键字参数:
default (torch.Tensor 或 str,可选) – 用于独占条目的默认值。如果没有提供,则两个 TensorDict 的键列表必须完全匹配。如果传递了
default="intersection"
,则只考虑相交的键集,其他键将被忽略。在所有其他情况下,default
将用于操作双方所有缺失的条目。
- clamp_min_(other: TensorDictBase | torch.Tensor) T ¶
clamp_min()
的就地版本。注意
就地
clamp_min
不支持default
关键字参数。
- clear() T ¶
清除张量字典的内容。
- clear_device_() T ¶
清除张量字典的设备。
返回值:self
- clone(recurse: bool = True, **kwargs) T ¶
将 TensorDictBase 子类实例克隆到同一类型的新的 TensorDictBase 子类上。
要从任何其他 TensorDictBase 子类型创建 TensorDict 实例,请改用
to_tensordict()
方法。- 参数:
recurse (bool, optional) – 如果为
True
,则 TensorDict 中包含的每个张量也将被复制。否则,只会复制 TensorDict 树结构。默认为True
。
注意
与许多其他操作(逐点算术、形状操作等)不同,
clone
不会继承原始的锁属性。做出此设计选择是为了能够创建可修改的克隆,这是最常见的用法。
- complex128()¶
将所有张量转换为
torch.complex128
。
- complex32()¶
将所有张量转换为
torch.complex32
。
- complex64()¶
将所有张量转换为
torch.complex64
。
- consolidate(filename: Path | str | None = None, *, num_threads=0, device: torch.device | None = None, non_blocking: bool = False, inplace: bool = False, return_early: bool = False, use_buffer: bool = False, share_memory: bool = False, pin_memory: bool = False, metadata: bool = False) None ¶
将张量字典的内容整合到一个存储中,以实现快速序列化。
- 参数:
filename (Path, optional) – 用于作为张量字典存储的内存映射张量的可选文件路径。
- 关键字参数:
num_threads (integer, optional) – 用于填充存储的线程数。
device (torch.device, optional) – 存储必须在其中实例化的可选设备。
non_blocking (bool, optional) – 传递给
copy_()
的non_blocking
参数。inplace (bool, optional) – 如果为
True
,则结果张量字典与self
相同,但值已更新。默认为False
。return_early (bool, optional) – 如果为
True
且num_threads>0
,则该方法将返回张量字典的 future。可以使用 future.result() 查询结果张量字典。use_buffer (bool, optional) – 如果为
True
且传递了文件名,则将在共享内存中创建一个中间本地缓冲区,并将数据作为最后一步复制到存储位置。这可能比直接写入远程物理内存(例如,NFS)更快。默认为False
。share_memory (bool, optional) – 如果为
True
,则存储将放置在共享内存中。默认为False
。pin_memory (bool, 可选) – 是否将合并后的数据放置在固定内存中。默认为
False
。metadata (bool, 可选) – 如果为
True
,元数据将与公共存储一起存储。如果提供了文件名,则此选项无效。存储元数据在想要控制序列化方式时很有用,因为如果元数据存在或不存在,TensorDict 处理合并的 TD 的 pickle/unpickle 方式不同。
注意
如果 tensordict 已经合并,则所有参数都会被忽略,并返回
self
。调用contiguous()
重新合并。示例
>>> import pickle >>> import tempfile >>> import torch >>> import tqdm >>> from torch.utils.benchmark import Timer >>> from tensordict import TensorDict >>> data = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> data_consolidated = data.consolidate() >>> # check that the data has a single data_ptr() >>> assert torch.tensor([ ... v.untyped_storage().data_ptr() for v in data_c.values(True, True) ... ]).unique().numel() == 1 >>> # Serializing the tensordict will be faster with data_consolidated >>> with open("data.pickle", "wb") as f: ... print("regular", Timer("pickle.dump(data, f)", globals=globals()).adaptive_autorange()) >>> with open("data_c.pickle", "wb") as f: ... print("consolidated", Timer("pickle.dump(data_consolidated, f)", globals=globals()).adaptive_autorange())
- contiguous() T ¶
返回一个具有连续值的新 tensordict(如果值已经是连续的,则返回自身)。
- copy()¶
返回 tensordict 的浅拷贝(即复制结构但不复制数据)。
等价于 TensorDictBase.clone(recurse=False)
- copy_(tensordict: T, non_blocking: bool = False) T ¶
-
non_blocking 参数将被忽略,仅为了与
torch.Tensor.copy_()
保持兼容。
- copy_at_(tensordict: T, idx: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], non_blocking: bool = False) T ¶
- cos() T ¶
计算 TensorDict 中每个元素的
cos()
值。
- cos_() T ¶
就地计算 TensorDict 中每个元素的
cos()
值。
- cosh() T ¶
计算 TensorDict 中每个元素的
cosh()
值。
- cosh_() T ¶
就地计算 TensorDict 中每个元素的
cosh()
值。
- create_nested(key)¶
创建与当前 tensordict 形状、设备和维度名称相同的嵌套 tensordict。
如果值已存在,则此操作将覆盖它。此操作在锁定的 tensordict 中是被阻塞的。
示例
>>> data = TensorDict({}, [3, 4, 5]) >>> data.create_nested("root") >>> data.create_nested(("some", "nested", "value")) >>> print(data) TensorDict( fields={ root: TensorDict( fields={ }, batch_size=torch.Size([3, 4, 5]), device=None, is_shared=False), some: TensorDict( fields={ nested: TensorDict( fields={ value: TensorDict( fields={ }, batch_size=torch.Size([3, 4, 5]), device=None, is_shared=False)}, batch_size=torch.Size([3, 4, 5]), device=None, is_shared=False)}, batch_size=torch.Size([3, 4, 5]), device=None, is_shared=False)}, batch_size=torch.Size([3, 4, 5]), device=None, is_shared=False)
- cuda(device: Optional[int] = None, **kwargs) T ¶
将 tensordict 转换为 cuda 设备(如果尚未在该设备上)。
- 参数:
device (int, 可选) – 如果提供,则为应将张量转换为的 cuda 设备。
此函数还支持
to()
的所有关键字参数。
- property data¶
返回一个包含叶子张量的 .data 属性的 tensordict。
- del_(key: NestedKey) T ¶
删除 tensordict 的一个键。
- 参数:
key (NestedKey) – 要删除的键
- 返回值:
自身
- detach() T ¶
分离 tensordict 中的张量。
- 返回值:
一个新的 tensordict,其中没有张量需要梯度。
- detach_() T ¶
就地分离 tensordict 中的张量。
- 返回值:
自身。
- property device: torch.device | None¶
tensordict 的设备。
如果在构造函数中未提供设备或通过 tensordict.to(device) 设置,则返回 None。
- dim() int ¶
参见
batch_dims()
。
- div(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T ¶
将输入
self
的每个元素除以other
的对应元素。\[\text{out}_i = \frac{\text{input}_i}{\text{other}_i}\]支持广播、类型提升以及整数、浮点数、TensorDict 或张量输入。始终将整数类型提升到默认标量类型。
- 参数:
other (TensorDict, Tensor 或 Number) – 除数。
- 关键字参数:
default (torch.Tensor 或 str,可选) – 用于独占条目的默认值。如果没有提供,则两个 TensorDict 的键列表必须完全匹配。如果传递了
default="intersection"
,则只考虑相交的键集,其他键将被忽略。在所有其他情况下,default
将用于操作双方所有缺失的条目。
- div_(other: TensorDictBase | torch.Tensor) T ¶
div()
的就地版本。注意
就地
div
不支持default
关键字参数。
- double()¶
将所有张量转换为
torch.bool
。
- property dtype¶
如果 TensorDict 中的值的数据类型唯一,则返回该数据类型。
- dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
将 TensorDict 保存到磁盘。
此函数是
memmap()
的代理。
- empty(recurse=False, *, batch_size=None, device=_NoDefault.ZERO, names=_NoDefault.ZERO) T ¶
返回一个新的、空的 TensorDict,具有相同的设备和批次大小。
- 参数:
recurse (bool, 可选) – 如果为
True
,则将复制TensorDict
的整个结构,但不包含内容。否则,仅复制根节点。默认为False
。- 关键字参数:
batch_size (torch.Size, 可选) – TensorDict 的新批次大小。
device (torch.device, 可选) – 新设备。
names (str 列表, 可选) – 维度名称。
- entry_class(key: NestedKey) type ¶
返回条目的类,可能避免调用 isinstance(td.get(key), type)。
当
get()
执行成本较高时,应优先使用此方法而不是tensordict.get(key).shape
。
- erf() T ¶
计算 TensorDict 中每个元素的
erf()
值。
- erf_() T ¶
就地计算 TensorDict 中每个元素的
erf()
值。
- erfc() T ¶
计算 TensorDict 中每个元素的
erfc()
值。
- erfc_() T ¶
就地计算 TensorDict 中每个元素的
erfc()
值。
- exclude(*keys: NestedKey, inplace: bool = False) T ¶
排除 TensorDict 的键,并返回一个不包含这些条目的新 TensorDict。
值不会被复制:对原始 TensorDict 或新 TensorDict 中的张量进行就地修改将导致两个 TensorDict 都发生变化。
- 参数:
- 返回值:
一个不包含排除条目的新 TensorDict(如果
inplace=True
,则为同一个)。
示例
>>> from tensordict import TensorDict >>> td = TensorDict({"a": 0, "b": {"c": 1, "d": 2}}, []) >>> td.exclude("a", ("b", "c")) TensorDict( fields={ b: TensorDict( fields={ d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> td.exclude("a", "b") TensorDict( fields={ }, batch_size=torch.Size([]), device=None, is_shared=False)
- exp() T ¶
计算 TensorDict 中每个元素的
exp()
值。
- exp_() T ¶
就地计算 TensorDict 中每个元素的
exp()
值。
- expand(*args, **kwargs) T ¶
根据
expand()
函数扩展 TensorDict 中每个张量的形状,忽略特征维度。支持使用可迭代对象指定形状。
示例
>>> td = TensorDict({ ... 'a': torch.zeros(3, 4, 5), ... 'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4]) >>> td_expand = td.expand(10, 3, 4) >>> assert td_expand.shape == torch.Size([10, 3, 4]) >>> assert td_expand.get("a").shape == torch.Size([10, 3, 4, 5])
- expand_as(other: TensorDictBase | torch.Tensor) TensorDictBase ¶
将 TensorDict 的形状广播到 other 的形状,并相应地扩展它。
如果输入是张量集合(TensorDict 或 TensorClass),则叶子将一对一地扩展。
示例
>>> from tensordict import TensorDict >>> import torch >>> td0 = TensorDict({ ... "a": torch.ones(3, 1, 4), ... "b": {"c": torch.ones(3, 2, 1, 4)}}, ... batch_size=[3], ... ) >>> td1 = TensorDict({ ... "a": torch.zeros(2, 3, 5, 4), ... "b": {"c": torch.zeros(2, 3, 2, 6, 4)}}, ... batch_size=[2, 3], ... ) >>> expanded = td0.expand_as(td1) >>> assert (expanded==1).all() >>> print(expanded) TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3, 5, 4]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([2, 3, 2, 6, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)
- expm1() T ¶
计算 TensorDict 中每个元素的
expm1()
值。
- expm1_() T ¶
就地计算 TensorDict 中每个元素的
expm1()
值。
- filter_empty_()¶
就地过滤掉所有空 TensorDict。
- filter_non_tensor_data() T ¶
过滤掉所有非张量数据。
- flatten(start_dim=0, end_dim=- 1)¶
展平 TensorDict 中的所有张量。
示例
>>> td = TensorDict({ ... "a": torch.arange(60).view(3, 4, 5), ... "b": torch.arange(12).view(3, 4)}, batch_size=[3, 4]) >>> td_flat = td.flatten(0, 1) >>> td_flat.batch_size torch.Size([12]) >>> td_flat["a"] tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24], [25, 26, 27, 28, 29], [30, 31, 32, 33, 34], [35, 36, 37, 38, 39], [40, 41, 42, 43, 44], [45, 46, 47, 48, 49], [50, 51, 52, 53, 54], [55, 56, 57, 58, 59]]) >>> td_flat["b"] tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
- flatten_keys(separator: str = '.', inplace: bool = False, is_leaf: Callable[[Type], bool] | None = None) T ¶
递归地将嵌套的 TensorDict 转换为扁平化的 TensorDict。
TensorDict 类型将丢失,结果将是一个简单的 TensorDict 实例。
- 参数:
示例
>>> data = TensorDict({"a": 1, ("b", "c"): 2, ("e", "f", "g"): 3}, batch_size=[]) >>> data.flatten_keys(separator=" - ") TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b - c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), e - f - g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
此方法和
unflatten_keys()
在处理状态字典时特别有用,因为它们可以无缝地将扁平化的字典转换为模仿模型结构的数据结构。示例
>>> model = torch.nn.Sequential(torch.nn.Linear(3 ,4)) >>> ddp_model = torch.ao.quantization.QuantWrapper(model) >>> state_dict = TensorDict(ddp_model.state_dict(), batch_size=[]).unflatten_keys(".") >>> print(state_dict) TensorDict( fields={ module: TensorDict( fields={ 0: TensorDict( fields={ bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> model_state_dict = state_dict.get("module") >>> print(model_state_dict) TensorDict( fields={ 0: TensorDict( fields={ bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> model.load_state_dict(dict(model_state_dict.flatten_keys(".")))
- float()¶
将所有张量转换为
torch.float
。
- float16()¶
将所有张量转换为
torch.float16
。
- float32()¶
将所有张量转换为
torch.float32
。
- float64()¶
将所有张量转换为
torch.float64
。
- floor() T ¶
计算 TensorDict 中每个元素的
floor()
值。
- floor_() T ¶
就地计算 TensorDict 中每个元素的
floor()
值。
- frac() T ¶
计算 TensorDict 中每个元素的
frac()
值。
- frac_() T ¶
就地计算 TensorDict 中每个元素的
frac()
值。
- classmethod from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=None)¶
从字典或另一个
TensorDict
创建一个 TensorDict 并返回。如果未指定
batch_size
,则返回可能的最大批次大小。此函数也适用于嵌套字典,或者可用于确定嵌套 TensorDict 的批大小。
- 参数:
input_dict (字典, 可选) – 用作数据源的字典(兼容嵌套键)。
batch_size (整数可迭代对象, 可选) – TensorDict 的批大小。
device (torch.device 或 兼容类型, 可选) – TensorDict 的设备。
batch_dims (整数, 可选) –
batch_dims
(即要考虑用于batch_size
的前导维度的数量)。与batch_size
互斥。请注意,这是 TensorDict 的 **最大** 批维度数,可以容忍较小的数字。names (字符串列表, 可选) – TensorDict 的维度名称。
示例
>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)} >>> print(TensorDict.from_dict(input_dict)) TensorDict( fields={ a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> # nested dict: the nested TensorDict can have a different batch-size >>> # as long as its leading dims match. >>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}} >>> print(TensorDict.from_dict(input_dict)) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> # we can also use this to work out the batch sie of a tensordict >>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, []) >>> print(TensorDict.from_dict(input_td)) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- from_dict_instance(input_dict, batch_size=None, device=None, batch_dims=None, names=None)¶
from_dict()
的实例方法版本。与
from_dict()
不同,此方法将尝试在现有树中保留 TensorDict 类型(对于任何现有的叶子)。示例
>>> from tensordict import TensorDict, tensorclass >>> import torch >>> >>> @tensorclass >>> class MyClass: ... x: torch.Tensor ... y: int >>> >>> td = TensorDict({"a": torch.randn(()), "b": MyClass(x=torch.zeros(()), y=1)}) >>> print(td.from_dict_instance(td.to_dict())) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: MyClass( x=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), y=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(td.from_dict(td.to_dict())) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- 类方法 from_h5(filename, mode='r')¶
从 h5 文件创建 PersistentTensorDict。
此函数将自动确定每个嵌套 TensorDict 的批大小。
- 类方法 from_module(module: Module, as_module: 布尔值 = False, lock: 布尔值 = False, use_state_dict: 布尔值 = False, filter_empty: 布尔值 = True)¶
将模块的参数和缓冲区复制到 TensorDict 中。
- 参数:
module (nn.Module) – 获取参数的模块。
as_module (布尔值, 可选) – 如果
True
,则将返回TensorDictParams
实例,该实例可用于在torch.nn.Module
中存储参数。默认为False
。lock (布尔值, 可选) – 如果
True
,则结果 TensorDict 将被锁定。默认为True
。use_state_dict (布尔值, 可选) –
如果
True
,则将使用模块中的状态字典并将其展平到具有模型树结构的 TensorDict 中。默认为False
。.. 注意This is particularly useful when state-dict hooks have to be used.
示例
>>> from torch import nn >>> module = nn.TransformerDecoder( ... decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4), ... num_layers=1) >>> params = TensorDict.from_module(module) >>> print(params["layers", "0", "linear1"]) TensorDict( fields={ bias: Parameter(shape=torch.Size([2048]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([2048, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- 类方法 from_modules(*modules, as_module: 布尔值 = False, lock: 布尔值 = True, use_state_dict: 布尔值 = False, lazy_stack: 布尔值 = False, expand_identical: 布尔值 = False)¶
通过 vmap 获取多个模块的参数,用于集成学习/特征期望应用。
- 参数:
modules (nn.Module 序列) – 获取参数的模块。如果模块的结构不同,则需要惰性堆叠(请参见下面的
lazy_stack
参数)。- 关键字参数:
as_module (布尔值, 可选) – 如果
True
,则将返回TensorDictParams
实例,该实例可用于在torch.nn.Module
中存储参数。默认为False
。lock (布尔值, 可选) – 如果
True
,则结果 TensorDict 将被锁定。默认为True
。use_state_dict (布尔值, 可选) –
如果
True
,则将使用模块中的状态字典并将其展平到具有模型树结构的 TensorDict 中。默认为False
。.. 注意This is particularly useful when state-dict hooks have to be used.
lazy_stack (布尔值, 可选) –
参数是否应密集或惰性堆叠。默认为
False
(密集堆叠)。注意
lazy_stack
和as_module
是互斥的功能。警告
惰性和非惰性输出之间存在一个关键区别,即非惰性输出将使用所需的批大小重新实例化参数,而
lazy_stack
将仅表示参数为惰性堆叠。这意味着,当lazy_stack=True
时,原始参数可以安全地传递给优化器,而当将其设置为True
时,需要传递新参数。警告
虽然使用惰性堆叠来保留原始参数引用可能很诱人,但请记住,每次调用
get()
时,惰性堆叠都会执行堆叠操作。这将需要内存(参数大小的 N 倍,如果构建图形则更多)和时间进行计算。这也意味着优化器将包含更多参数,并且step()
或zero_grad()
等操作将需要更长时间才能执行。一般来说,应将lazy_stack
保留用于极少数用例。expand_identical (布尔值, 可选) – 如果
True
并且正在将相同参数(相同标识)堆叠到自身,则将返回此参数的扩展版本。当lazy_stack=True
时,将忽略此参数。
示例
>>> from torch import nn >>> from tensordict import TensorDict >>> torch.manual_seed(0) >>> empty_module = nn.Linear(3, 4, device="meta") >>> n_models = 2 >>> modules = [nn.Linear(3, 4) for _ in range(n_models)] >>> params = TensorDict.from_modules(*modules) >>> print(params) TensorDict( fields={ bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) >>> # example of batch execution >>> def exec_module(params, x): ... with params.to_module(empty_module): ... return empty_module(x) >>> x = torch.randn(3) >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> # since lazy_stack = False, backprop leaves the original params untouched >>> y.sum().backward() >>> assert params["weight"].grad.norm() > 0 >>> assert modules[0].weight.grad is None
使用
lazy_stack=True
时,情况略有不同。>>> params = TensorDict.from_modules(*modules, lazy_stack=True) >>> print(params) LazyStackedTensorDict( fields={ bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=None, is_shared=False, stack_dim=0) >>> # example of batch execution >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> y.sum().backward() >>> assert modules[0].weight.grad is not None
- classmethod from_namedtuple(named_tuple, *, auto_batch_size: bool = False)¶
递归地将命名元组转换为 TensorDict。
- 关键字参数:
auto_batch_size (bool, 可选) – 如果
True
,则会自动计算批次大小。默认为False
。
示例
>>> from tensordict import TensorDict >>> import torch >>> data = TensorDict({ ... "a_tensor": torch.zeros((3)), ... "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3]) >>> nt = data.to_namedtuple() >>> print(nt) GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!')) >>> TensorDict.from_namedtuple(nt, auto_batch_size=True) TensorDict( fields={ a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a_string: NonTensorData(data=zero!, batch_size=torch.Size([3]), device=None), a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- classmethod from_pytree(pytree, *, batch_size: torch.Size | None = None, auto_batch_size: bool = False, batch_dims: int | None = None)¶
将 pytree 转换为 TensorDict 实例。
此方法旨在尽可能地保留 pytree 的嵌套结构。
添加了额外的非张量键来跟踪每个级别的标识,从而提供了一个内置的 pytree 到 tensordict 的双射变换 API。
当前接受的类包括列表、元组、命名元组和字典。
注意
对于字典,非 NestedKey 键将作为
NonTensorData
实例单独注册。注意
可转换为张量的类型(如 int、float 或 np.ndarray)将转换为 torch.Tensor 实例。请注意,此转换是满射的:将 tensordict 转换回 pytree 将无法恢复原始类型。
示例
>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key >>> class WeirdLookingClass: ... pass ... >>> weird_key = WeirdLookingClass() >>> # Make a pytree with tuple, lists, dict and namedtuple >>> pytree = ( ... [torch.randint(10, (3,)), torch.zeros(2)], ... { ... "tensor": torch.randn( ... 2, ... ), ... "td": TensorDict({"one": 1}), ... weird_key: torch.randint(10, (2,)), ... "list": [1, 2, 3], ... }, ... {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, ... ) >>> # Build a TensorDict from that pytree >>> td = TensorDict.from_pytree(pytree) >>> # Recover the pytree >>> pytree_recon = td.to_pytree() >>> # Check that the leaves match >>> def check(v1, v2): >>> assert (v1 == v2).all() >>> >>> torch.utils._pytree.tree_map(check, pytree, pytree_recon) >>> assert weird_key in pytree_recon[1]
- classmethod fromkeys(keys: List[NestedKey], value: Any = 0)¶
根据键列表和单个值创建 tensordict。
- 参数:
keys (NestedKey 列表) – 指定新字典键的可迭代对象。
value (兼容类型, 可选) – 所有键的值。默认为
0
。
- gather(dim: int, index: Tensor, out: T | None = None) T ¶
沿着由 dim 指定的轴收集值。
- 参数:
dim (int) – 收集元素的维度。
index (torch.Tensor) – 一个长张量,其维度数量与 tensordict 的维度数量匹配,只有两个张量之间的一个维度不同(收集维度)。其元素指的是沿着所需维度收集的索引。
out (TensorDictBase, 可选) – 目标 tensordict。它必须与索引具有相同的形状。
示例
>>> td = TensorDict( ... {"a": torch.randn(3, 4, 5), ... "b": TensorDict({"c": torch.zeros(3, 4, 5)}, [3, 4, 5])}, ... [3, 4]) >>> index = torch.randint(4, (3, 2)) >>> td_gather = td.gather(dim=1, index=index) >>> print(td_gather) TensorDict( fields={ a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 2, 5]), device=None, is_shared=False)}, batch_size=torch.Size([3, 2]), device=None, is_shared=False)
Gather 保留维度名称。
示例
>>> td.names = ["a", "b"] >>> td_gather = td.gather(dim=1, index=index) >>> td_gather.names ["a", "b"]
- gather_and_stack(dst: int, group: 'dist.ProcessGroup' | None = None) T | None ¶
从各个工作进程收集 tensordict,并将它们堆叠到目标工作进程中的 self 上。
- 参数:
dst (int) – 目标工作进程的秩,将在其中调用
gather_and_stack()
。group (torch.distributed.ProcessGroup, 可选) – 如果设置,则将使用指定的进程组进行通信。否则,将使用默认进程组。默认为
None
。
示例
>>> from torch import multiprocessing as mp >>> from tensordict import TensorDict >>> import torch >>> >>> def client(): ... torch.distributed.init_process_group( ... "gloo", ... rank=1, ... world_size=2, ... init_method=f"tcp://127.0.0.1:10003", ... ) ... # Create a single tensordict to be sent to server ... td = TensorDict( ... {("a", "b"): torch.randn(2), ... "c": torch.randn(2)}, [2] ... ) ... td.gather_and_stack(0) ... >>> def server(): ... torch.distributed.init_process_group( ... "gloo", ... rank=0, ... world_size=2, ... init_method=f"tcp://127.0.0.1:10003", ... ) ... # Creates the destination tensordict on server. ... # The first dim must be equal to world_size-1 ... td = TensorDict( ... {("a", "b"): torch.zeros(2), ... "c": torch.zeros(2)}, [2] ... ).expand(1, 2).contiguous() ... td.gather_and_stack(0) ... assert td["a", "b"] != 0 ... print("yuppie") ... >>> if __name__ == "__main__": ... mp.set_start_method("spawn") ... ... main_worker = mp.Process(target=server) ... secondary_worker = mp.Process(target=client) ... ... main_worker.start() ... secondary_worker.start() ... ... main_worker.join() ... secondary_worker.join()
- get(key: NestedKey, default: Any = _NoDefault.ZERO) Tensor ¶
获取使用输入键存储的值。
- 参数:
key (str, str 的元组) – 要查询的键。如果为 str 的元组,则等效于 getattr 的链式调用。
default – 如果 tensordict 中未找到键,则为默认值。
示例
>>> td = TensorDict({"x": 1}, batch_size=[]) >>> td.get("x") tensor(1) >>> td.get("y", default=None) None
- get_at(key: NestedKey, index: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], default: Tensor = _NoDefault.ZERO) Tensor ¶
根据键 key 和索引 idx 获取 tensordict 的值。
- 参数:
key (str, str 的元组) – 要检索的键。
index (int, slice, torch.Tensor, iterable) – 张量的索引。
default (torch.Tensor) – 如果键不存在于 tensordict 中,则返回的默认值。
- 返回值:
索引张量。
示例
>>> td = TensorDict({"x": torch.arange(3)}, batch_size=[]) >>> td.get_at("x", index=1) tensor(1)
- get_non_tensor(key: NestedKey, default=_NoDefault.ZERO)¶
获取非张量值(如果存在),或者如果未找到非张量值则返回default。
此方法对张量/TensorDict 值具有鲁棒性,这意味着如果收集到的值是常规张量,它也将被返回(尽管此方法会带来一些开销,不应超出其自然范围)。
有关如何在 tensordict 中设置非张量值的更多信息,请参阅
set_non_tensor()
。- 参数:
key (NestedKey) – NonTensorData 对象的位置。
default (Any, optional) – 如果找不到键,则返回的值。
- 返回:
tensordict.tensorclass.NonTensorData
的内容, 或者如果它不是
tensordict.tensorclass.NonTensorData
则返回对应于key
的条目(或者如果找不到条目则返回default
)。
示例
>>> data = TensorDict({}, batch_size=[]) >>> data.set_non_tensor(("nested", "the string"), "a string!") >>> assert data.get_non_tensor(("nested", "the string")) == "a string!" >>> # regular `get` works but returns a NonTensorData object >>> data.get(("nested", "the string")) NonTensorData( data='a string!', batch_size=torch.Size([]), device=None, is_shared=False)
- property grad¶
返回一个包含叶子张量的 .grad 属性的 tensordict。
- half()¶
将所有张量转换为
torch.half
。
- int()¶
将所有张量转换为
torch.int
。
- int16()¶
将所有张量转换为
torch.int16
。
- int32()¶
将所有张量转换为
torch.int32
。
- int64()¶
将所有张量转换为
torch.int64
。
- int8()¶
将所有张量转换为
torch.int8
。
- irecv(src: int, *, group: 'dist.ProcessGroup' | None = None, return_premature: bool = False, init_tag: int = 0, pseudo_rand: bool = False) tuple[int, list[torch.Future]] | list[torch.Future] | None ¶
异步接收 tensordict 的内容并使用它更新内容。
查看
isend()
方法中的示例以了解上下文。- 参数:
src (int) – 源工作器的排名。
- 关键字参数:
group (torch.distributed.ProcessGroup, 可选) – 如果设置,则将使用指定的进程组进行通信。否则,将使用默认进程组。默认为
None
。return_premature (bool) – 如果
True
,则返回一个 future 列表,等待直到 tensordict 更新。默认为False
,即在调用中等待直到更新完成。init_tag (int) – 源工作者使用的
init_tag
。pseudo_rand (bool) – 如果为 True,则标签序列将是伪随机的,允许从不同的节点发送多个数据而不会重叠。请注意,这些伪随机数的生成成本很高(1e-5 秒/数),这意味着它可能会减慢算法的运行时间。此值必须与传递给
isend()
的值匹配。默认为False
。
- 返回值:
- 如果
return_premature=True
,则返回一个 future 列表,等待 直到 tensordict 更新。
- 如果
- is_consolidated()¶
检查 TensorDict 是否具有合并的存储。
- is_empty()¶
检查 tensordict 是否包含任何叶子。
- is_memmap() bool ¶
检查 tensordict 是否是内存映射的。
如果 TensorDict 实例是内存映射的,则它被锁定(条目不能重命名、删除或添加)。如果使用全是内存映射的张量创建
TensorDict
,这并不意味着is_memmap
将返回True
(因为新的张量可能被内存映射也可能不被内存映射)。只有当调用tensordict.memmap_() 时,tensordict 才会被视为内存映射的。对于 CUDA 设备上的 tensordict,这始终为
True
。
检查 tensordict 是否在共享内存中。
如果 TensorDict 实例位于共享内存中,则会被锁定(无法重命名、删除或添加条目)。如果使用全部位于共享内存中的张量创建一个
TensorDict
,这并不意味着is_shared
会返回True
(因为新的张量可能在共享内存中,也可能不在)。只有当调用 tensordict.share_memory_() 或将 tensordict 放置在默认情况下内容共享的设备上(例如,"cuda"
)时,tensordict 才会被认为位于共享内存中。对于 CUDA 设备上的 tensordict,这始终为
True
。
- isend(dst: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) int ¶
异步发送 tensordict 的内容。
- 参数:
dst (int) – 目标工作者的等级,内容应发送到该工作者。
- 关键字参数:
示例
>>> import torch >>> from tensordict import TensorDict >>> from torch import multiprocessing as mp >>> def client(): ... torch.distributed.init_process_group( ... "gloo", ... rank=1, ... world_size=2, ... init_method=f"tcp://127.0.0.1:10003", ... ) ... ... td = TensorDict( ... { ... ("a", "b"): torch.randn(2), ... "c": torch.randn(2, 3), ... "_": torch.ones(2, 1, 5), ... }, ... [2], ... ) ... td.isend(0) ... >>> >>> def server(queue, return_premature=True): ... torch.distributed.init_process_group( ... "gloo", ... rank=0, ... world_size=2, ... init_method=f"tcp://127.0.0.1:10003", ... ) ... td = TensorDict( ... { ... ("a", "b"): torch.zeros(2), ... "c": torch.zeros(2, 3), ... "_": torch.zeros(2, 1, 5), ... }, ... [2], ... ) ... out = td.irecv(1, return_premature=return_premature) ... if return_premature: ... for fut in out: ... fut.wait() ... assert (td != 0).all() ... queue.put("yuppie") ... >>> >>> if __name__ == "__main__": ... queue = mp.Queue(1) ... main_worker = mp.Process( ... target=server, ... args=(queue, ) ... ) ... secondary_worker = mp.Process(target=client) ... ... main_worker.start() ... secondary_worker.start() ... out = queue.get(timeout=10) ... assert out == "yuppie" ... main_worker.join() ... secondary_worker.join()
- isfinite() T ¶
返回一个新的 tensordict,其中包含布尔元素,表示每个元素是否有限。
当实数值不是 NaN、负无穷大或无穷大时,它们是有限的。当复数的实部和虚部都是有限的时,它们是有限的。
- isnan() T ¶
返回一个新的 tensordict,其中包含布尔元素,表示输入的每个元素是否为 NaN。
当复数的实部和/或虚部为 NaN 时,则认为该复数为 NaN。
- isneginf() T ¶
测试输入的每个元素是否为负无穷大。
- isposinf() T ¶
测试输入的每个元素是否为负无穷大。
- isreal() T ¶
返回一个新的 tensordict,其中包含布尔元素,表示输入的每个元素是否为实数值。
- items(include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None) Iterator[tuple[str, CompatibleType]] ¶
返回 tensordict 的键值对生成器。
- keys(include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None) _TensorDictKeysView ¶
返回 tensordict 键的生成器。
- 参数:
示例
>>> from tensordict import TensorDict >>> data = TensorDict({"0": 0, "1": {"2": 2}}, batch_size=[]) >>> data.keys() ['0', '1'] >>> list(data.keys(leaves_only=True)) ['0'] >>> list(data.keys(include_nested=True, leaves_only=True)) ['0', '1', ('1', '2')]
- classmethod lazy_stack(input, dim=0, *, out=None, **kwargs)¶
创建 tensordict 的延迟堆叠。
有关详细信息,请参阅
lazy_stack()
。
- lerp(end: TensorDictBase | torch.Tensor, weight: TensorDictBase | torch.Tensor | float)¶
是否根据标量或张量
weight
对两个张量start
(由self
给定)和end
进行线性插值。\[\text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i)\]start
和end
的形状必须可广播。如果weight
是张量,则weight
、start
和end
的形状必须可广播。- 参数:
end (TensorDict) – 具有结束点的 TensorDict。
weight (TensorDict, tensor 或 float) – 插值公式的权重。
- lerp_(end: TensorDictBase | float, weight: TensorDictBase | float)¶
lerp()
的就地版本。
- lgamma() T ¶
计算 TensorDict 中每个元素的
lgamma()
值。
- lgamma_() T ¶
就地计算 TensorDict 中每个元素的
lgamma()
值。
- classmethod load(prefix: str | Path, *args, **kwargs) T ¶
从磁盘加载 TensorDict。
此类方法是
load_memmap()
的代理。
- load_(prefix: str | Path, *args, **kwargs)¶
将 TensorDict 从磁盘加载到当前 TensorDict 中。
此类方法是
load_memmap_()
的代理。
- classmethod load_memmap(prefix: str | Path, device: torch.device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None) T ¶
从磁盘加载内存映射的 TensorDict。
- 参数:
prefix (str 或 文件夹路径) – 保存的 TensorDict 应从中获取的文件夹的路径。
device (torch.device 或 等效项, 可选) – 如果提供,数据将异步转换为该设备。支持 “meta” 设备,在这种情况下,数据不会加载,但会创建一组空的“meta”张量。这对于在不实际打开任何文件的情况下了解模型的总大小和结构很有用。
non_blocking (bool, 可选) – 如果
True
,则在将张量加载到设备上后不会调用同步。默认为False
。out (TensorDictBase, 可选) – 数据应写入其中的可选 TensorDict。
示例
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
此方法还允许加载嵌套的 TensorDict。
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
TensorDict 也可以加载到“meta”设备上,或者作为伪张量。
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_memmap_(prefix: str | Path)¶
将内存映射的 TensorDict 的内容加载到调用
load_memmap_
的 TensorDict 中。有关更多信息,请参阅
load_memmap()
。
- load_state_dict(state_dict: OrderedDict[str, Any], strict=True, assign=False, from_flatten=False) T ¶
将格式化的状态字典(如
state_dict()
中所示)加载到 TensorDict 中。- 参数:
state_dict (OrderedDict) – 要复制的状态字典。
strict (bool, 可选) – 是否严格强制要求
state_dict
中的键与此 TensorDict 的torch.nn.Module.state_dict()
函数返回的键匹配。默认值:True
assign (bool, 可选) – 是否将状态字典中的项目分配到 TensorDict 中对应的键,而不是将它们就地复制到 TensorDict 的当前张量中。当
False
时,保留当前模块中张量的属性,而当True
时,保留状态字典中张量的属性。默认值:False
from_flatten (bool, 可选) – 如果
True
,则假设输入的 state_dict 已被扁平化。默认为False
。
示例
>>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, []) >>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, []) >>> sd = data.state_dict() >>> data_zeroed.load_state_dict(sd) >>> print(data_zeroed["3", "3"]) tensor(3) >>> # with flattening >>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, []) >>> data_zeroed.load_state_dict(data.state_dict(flatten=True), from_flatten=True) >>> print(data_zeroed["3", "3"]) tensor(3)
- lock_() T ¶
锁定 TensorDict 以防止非就地操作。
诸如
set()
、__setitem__()
、update()
、rename_key_()
或其他添加或删除条目的操作将被阻止。此方法可用作装饰器。
示例
>>> from tensordict import TensorDict >>> td = TensorDict({"a": 1, "b": 2, "c": 3}, batch_size=[]) >>> with td.lock_(): ... assert td.is_locked ... try: ... td.set("d", 0) # error! ... except RuntimeError: ... print("td is locked!") ... try: ... del td["d"] ... except RuntimeError: ... print("td is locked!") ... try: ... td.rename_key_("a", "d") ... except RuntimeError: ... print("td is locked!") ... td.set("a", 0, inplace=True) # No storage is added, moved or removed ... td.set_("a", 0) # No storage is added, moved or removed ... td.update({"a": 0}, inplace=True) # No storage is added, moved or removed ... td.update_({"a": 0}) # No storage is added, moved or removed >>> assert not td.is_locked
- log() T ¶
计算 TensorDict 中每个元素的
log()
值。
- log10() T ¶
计算 TensorDict 中每个元素的
log10()
值。
- log10_() T ¶
就地计算 TensorDict 中每个元素的
log10()
值。
- log1p() T ¶
计算 TensorDict 中每个元素的
log1p()
值。
- log1p_() T ¶
就地计算 TensorDict 中每个元素的
log1p()
值。
- log2() T ¶
计算 TensorDict 中每个元素的
log2()
值。
- log2_() T ¶
就地计算 TensorDict 中每个元素的
log2()
值。
- log_() T ¶
就地计算 TensorDict 中每个元素的
log()
值。
- make_memmap(key: NestedKey, shape: torch.Size | torch.Tensor, *, dtype: torch.dtype | None = None) MemoryMappedTensor ¶
给定形状和可选的数据类型,创建一个空的内存映射张量。
警告
此方法在设计上不是锁安全的。存在于多个节点上的内存映射 TensorDict 实例需要使用
memmap_refresh_()
方法进行更新。写入现有条目将导致错误。
- 参数:
key (NestedKey) – 要写入的新条目的键。如果键已存在于 TensorDict 中,则会引发异常。
shape (torch.Size 或等效项, torch.Tensor 用于嵌套张量) – 要写入的张量的形状。
- 关键字参数:
dtype (torch.dtype, 可选) – 新张量的数据类型。
- 返回值:
一个新的内存映射张量。
- make_memmap_from_storage(key: NestedKey, storage: torch.UntypedStorage, shape: torch.Size | torch.Tensor, *, dtype: torch.dtype | None = None) MemoryMappedTensor ¶
给定存储、形状和可选的数据类型,创建一个空的内存映射张量。
警告
此方法在设计上不是锁安全的。存在于多个节点上的内存映射 TensorDict 实例需要使用
memmap_refresh_()
方法进行更新。注意
如果存储关联了文件名,则它必须与文件的新的文件名匹配。如果它没有关联文件名,但 TensorDict 关联了路径,则会导致异常。
- 参数:
key (NestedKey) – 要写入的新条目的键。如果键已存在于 TensorDict 中,则会引发异常。
storage (torch.UntypedStorage) – 用于新 MemoryMappedTensor 的存储。必须是物理内存存储。
shape (torch.Size 或等效项, torch.Tensor 用于嵌套张量) – 要写入的张量的形状。
- 关键字参数:
dtype (torch.dtype, 可选) – 新张量的数据类型。
- 返回值:
一个使用给定存储的新内存映射张量。
- make_memmap_from_tensor(key: NestedKey, tensor: Tensor, *, copy_data: bool = True) MemoryMappedTensor ¶
给定张量,创建一个空的内存映射张量。
警告
此方法在设计上不是锁安全的。存在于多个节点上的内存映射 TensorDict 实例需要使用
memmap_refresh_()
方法进行更新。如果
copy_data
为True
,此方法始终复制存储内容(即,存储不共享)。- 参数:
key (NestedKey) – 要写入的新条目的键。如果键已存在于 TensorDict 中,则会引发异常。
tensor (torch.Tensor) – 要在物理内存上复制的张量。
- 关键字参数:
copy_data (bool, 可选) – 如果
False
,则新张量将共享输入的元数据(如形状和数据类型),但内容将为空。默认为True
。- 返回值:
一个使用给定存储的新内存映射张量。
- map(fn: Callable[[TensorDictBase], TensorDictBase | None], dim: int = 0, num_workers: int | None = None, *, out: TensorDictBase | None = None, chunksize: int | None = None, num_chunks: int | None = None, pool: mp.Pool | None = None, generator: torch.Generator | None = None, max_tasks_per_child: int | None = None, worker_threads: int = 1, index_with_generator: bool = False, pbar: bool = False, mp_start_method: str | None = None)¶
将函数映射到张量字典在一个维度上的分割。
此方法将通过将张量字典分成大小相等的张量字典并在所需数量的工作进程上分派操作来将函数应用于张量字典实例。
函数签名应为
Callabe[[TensorDict], Union[TensorDict, Tensor]]
。输出必须支持torch.cat()
操作。该函数必须是可序列化的。- 参数:
- 关键字参数:
out (TensorDictBase, 可选) – 输出的可选容器。它沿提供的
dim
的批次大小必须与self.ndim
匹配。如果它是共享的或内存映射的(is_shared()
或is_memmap()
返回True
),它将在远程进程中填充,避免数据向内传输。否则,来自self
切片的数据将被发送到进程,在当前进程上收集并在out
中就地写入。chunksize (int, 可选) – 每个数据块的大小。
chunksize
为 0 将沿所需维度解绑张量字典并在应用函数后重新堆叠它,而chunksize>0
将拆分张量字典并在结果张量字典列表上调用torch.cat()
。如果未提供,则块数将等于工作进程数。对于非常大的张量字典,如此大的块可能无法容纳操作所需的内存,可能需要更多块才能使操作在实践中可行。此参数与num_chunks
互斥。num_chunks (int, 可选) – 将张量字典拆分成多少个块。如果未提供,则块数将等于工作进程数。对于非常大的张量字典,如此大的块可能无法容纳操作所需的内存,可能需要更多块才能使操作在实践中可行。此参数与
chunksize
互斥。pool (mp.Pool, 可选) – 用于执行作业的多进程池实例。如果未提供,则将在
map
方法中创建一个池。generator (torch.Generator, 可选) –
用于播种的生成器。将从中生成一个基本种子,并且池的每个工作进程将使用提供的种子加
0
到num_workers
的唯一整数进行播种。如果没有提供生成器,则将使用随机整数作为种子。要使用无种子的工作进程,应单独创建池并直接传递到map()
。.. 注意Caution should be taken when providing a low-valued seed as this can cause autocorrelation between experiments, example: if 8 workers are asked and the seed is 4, the workers seed will range from 4 to 11. If the seed is 5, the workers seed will range from 5 to 12. These two experiments will have an overlap of 7 seeds, which can have unexpected effects on the results.
注意
播种工作进程的目的是让每个工作进程具有独立的种子,而不是在 map 方法的调用之间获得可重复的结果。换句话说,两个实验可能会并且很可能返回不同的结果,因为不可能知道哪个工作进程将选择哪个作业。但是,我们可以确保每个工作进程都有一个不同的种子,并且每个工作进程上的伪随机操作将不相关。
max_tasks_per_child (int, 可选) – 每个子进程选择的作业的最大数量。默认为
None
,即对作业数量没有限制。worker_threads (int, 可选) – 工作进程的线程数。默认为
1
。index_with_generator (bool, 可选) – 如果为
True
,则 TensorDict 的分割/分块将在查询期间完成,从而节省初始化时间。请注意,chunk()
和split()
比索引(在生成器中使用)效率更高,因此初始化时处理时间的提升可能会对总运行时间产生负面影响。默认为False
。pbar (bool, 可选) – 如果为
True
,则将显示进度条。需要 tqdm 可用。默认为False
。mp_start_method (str, 可选) – 多处理的启动方法。如果未提供,将使用默认的启动方法。可接受的字符串为
"fork"
和"spawn"
。请记住,使用"fork"
启动方法时,"cuda"
张量不能在进程之间共享。如果将pool
传递给map
方法,则此参数无效。
示例
>>> import torch >>> from tensordict import TensorDict >>> >>> def process_data(data): ... data.set("y", data.get("x") + 1) ... return data >>> if __name__ == "__main__": ... data = TensorDict({"x": torch.zeros(1, 1_000_000)}, [1, 1_000_000]).memmap_() ... data = data.map(process_data, dim=1) ... print(data["y"][:, :10]) ... tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
注意
此方法在处理存储在磁盘上的大型数据集(例如内存映射的 tensordict)时特别有用,其中块将是原始数据的零拷贝切片,可以几乎零成本地传递给进程。这允许以极低的成本处理非常大的数据集(例如超过 Tb 的大小)。
- map_iter(fn: Callable[[TensorDictBase], TensorDictBase | None], dim: int = 0, num_workers: int | None = None, *, shuffle: bool = False, chunksize: int | None = None, num_chunks: int | None = None, pool: mp.Pool | None = None, generator: torch.Generator | None = None, max_tasks_per_child: int | None = None, worker_threads: int = 1, index_with_generator: bool = True, pbar: bool = False, mp_start_method: str | None = None)¶
迭代地将函数映射到 TensorDict 沿一个维度上的分割。
这是
map()
的可迭代版本。此方法将通过将其分块成大小相等的 tensordict 并将操作分派到所需数量的 worker 上,来将函数应用于 tensordict 实例。它将一次产出一个结果。
函数签名应为
Callabe[[TensorDict], Union[TensorDict, Tensor]]
。函数必须是可序列化的。- 参数:
- 关键字参数:
shuffle (bool, 可选) – 索引是否应该全局洗牌。如果为
True
,则每个批次将包含非连续样本。如果index_with_generator=False
且shuffle=True`,则会引发错误。默认为False
。chunksize (int, 可选) – 每个数据块的大小。
chunksize
为 0 将沿所需维度解绑张量字典并在应用函数后重新堆叠它,而chunksize>0
将拆分张量字典并在结果张量字典列表上调用torch.cat()
。如果未提供,则块数将等于工作进程数。对于非常大的张量字典,如此大的块可能无法容纳操作所需的内存,可能需要更多块才能使操作在实践中可行。此参数与num_chunks
互斥。num_chunks (int, 可选) – 将张量字典拆分成多少个块。如果未提供,则块数将等于工作进程数。对于非常大的张量字典,如此大的块可能无法容纳操作所需的内存,可能需要更多块才能使操作在实践中可行。此参数与
chunksize
互斥。pool (mp.Pool, 可选) – 用于执行作业的多进程池实例。如果未提供,则将在
map
方法中创建一个池。generator (torch.Generator, 可选) –
用于播种的生成器。将从中生成一个基本种子,并且池的每个工作进程将使用提供的种子加
0
到num_workers
的唯一整数进行播种。如果没有提供生成器,则将使用随机整数作为种子。要使用无种子的工作进程,应单独创建池并直接传递到map()
。.. 注意Caution should be taken when providing a low-valued seed as this can cause autocorrelation between experiments, example: if 8 workers are asked and the seed is 4, the workers seed will range from 4 to 11. If the seed is 5, the workers seed will range from 5 to 12. These two experiments will have an overlap of 7 seeds, which can have unexpected effects on the results.
注意
播种工作进程的目的是让每个工作进程具有独立的种子,而不是在 map 方法的调用之间获得可重复的结果。换句话说,两个实验可能会并且很可能返回不同的结果,因为不可能知道哪个工作进程将选择哪个作业。但是,我们可以确保每个工作进程都有一个不同的种子,并且每个工作进程上的伪随机操作将不相关。
max_tasks_per_child (int, 可选) – 每个子进程选择的作业的最大数量。默认为
None
,即对作业数量没有限制。worker_threads (int, 可选) – 工作进程的线程数。默认为
1
。index_with_generator (bool, 可选) –
如果为
True
,则 TensorDict 的分割/分块将在查询期间完成,从而节省初始化时间。请注意,chunk()
和split()
比索引(在生成器中使用)效率更高,因此初始化时处理时间的提升可能会对总运行时间产生负面影响。默认为True
。注意
index_with_generator
的默认值对于map_iter
和map
不同,前者假设在内存中存储 TensorDict 的分割版本成本过高。pbar (bool, 可选) – 如果为
True
,则将显示进度条。需要 tqdm 可用。默认为False
。mp_start_method (str, 可选) – 多处理的启动方法。如果未提供,将使用默认的启动方法。可接受的字符串为
"fork"
和"spawn"
。请记住,使用"fork"
启动方法时,"cuda"
张量不能在进程之间共享。如果将pool
传递给map
方法,则此参数无效。
示例
>>> import torch >>> from tensordict import TensorDict >>> >>> def process_data(data): ... data.unlock_() ... data.set("y", data.get("x") + 1) ... return data >>> if __name__ == "__main__": ... data = TensorDict({"x": torch.zeros(1, 1_000_000)}, [1, 1_000_000]).memmap_() ... for sample in data.map_iter(process_data, dim=1, chunksize=5): ... print(sample["y"]) ... break ... tensor([[1., 1., 1., 1., 1.]])
注意
此方法在处理存储在磁盘上的大型数据集(例如内存映射的 tensordict)时特别有用,其中块将是原始数据的零拷贝切片,可以几乎零成本地传递给进程。这允许以极低的成本处理非常大的数据集(例如超过 Tb 的大小)。
注意
此函数可用于表示数据集并从中加载数据,类似于数据加载器的方式。
- masked_fill(mask: Tensor, value: float | bool) T ¶
masked_fill 的非就地版本。
- 参数:
mask (布尔型 torch.Tensor) – 要填充的值的掩码。形状必须与 tensordict 的批大小匹配。
value – 用于填充张量的值。
- 返回值:
自身
示例
>>> td = TensorDict(source={'a': torch.zeros(3, 4)}, ... batch_size=[3]) >>> mask = torch.tensor([True, False, False]) >>> td1 = td.masked_fill(mask, 1.0) >>> td1.get("a") tensor([[1., 1., 1., 1.], [0., 0., 0., 0.], [0., 0., 0., 0.]])
- masked_fill_(mask: Tensor, value: float | int | bool) T ¶
使用指定的值填充对应于掩码的值。
- 参数:
mask (布尔型 torch.Tensor) – 要填充的值的掩码。形状必须与 tensordict 的批大小匹配。
value – 用于填充张量的值。
- 返回值:
自身
示例
>>> td = TensorDict(source={'a': torch.zeros(3, 4)}, ... batch_size=[3]) >>> mask = torch.tensor([True, False, False]) >>> td.masked_fill_(mask, 1.0) >>> td.get("a") tensor([[1., 1., 1., 1.], [0., 0., 0., 0.], [0., 0., 0., 0.]])
- masked_select(mask: Tensor) T ¶
掩盖 TensorDict 中的所有张量,并返回一个新的 TensorDict 实例,其键类似于指向掩码值的键。
- 参数:
mask (torch.Tensor) – 用于张量的布尔掩码。形状必须与 TensorDict 的
batch_size
匹配。
示例
>>> td = TensorDict(source={'a': torch.zeros(3, 4)}, ... batch_size=[3]) >>> mask = torch.tensor([True, False, False]) >>> td_mask = td.masked_select(mask) >>> td_mask.get("a") tensor([[0., 0., 0., 0.]])
- maximum(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T ¶
计算
self
和other
的逐元素最大值。- 参数:
other (TensorDict 或 Tensor) – 另一个输入 TensorDict 或张量。
- 关键字参数:
default (torch.Tensor 或 str,可选) – 用于独占条目的默认值。如果没有提供,则两个 TensorDict 的键列表必须完全匹配。如果传递了
default="intersection"
,则只考虑相交的键集,其他键将被忽略。在所有其他情况下,default
将用于操作双方所有缺失的条目。
- maximum_(other: TensorDictBase | torch.Tensor) T ¶
maximum()
的就地版本。注意
就地
maximum
不支持default
关键字参数。
- classmethod maybe_dense_stack(input, dim=0, *, out=None, **kwargs)¶
尝试对 tensordict 进行密集堆叠,并在需要时回退到惰性堆叠。
有关详细信息,请参阅
maybe_dense_stack()
。
- mean(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: torch.dtype | None = None, reduce: bool | None = None) TensorDictBase | torch.Tensor ¶
返回输入 tensordict 中所有元素的平均值。
- 参数:
- 关键字参数:
dtype (torch.dtype, 可选) – 返回张量的所需数据类型。如果指定,则在执行操作之前将输入张量转换为 dtype。这对于防止数据类型溢出很有用。默认值:
None
。reduce (bool, 可选) – 如果为
True
,则缩减将在所有 TensorDict 值中发生,并将返回单个缩减后的张量。默认为False
。
- memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
将所有张量写入新 tensordict 中相应的内存映射张量。
- 参数:
- 关键字参数:
然后 TensorDict 被锁定,这意味着任何不在原地的写入操作都将抛出异常(例如,重命名、设置或删除条目)。一旦 TensorDict 解锁,内存映射属性将变为
False
,因为不再保证跨进程标识。- 返回值:
如果
return_early=False
,则创建一个新的 TensorDict,其张量存储在磁盘上,否则创建一个TensorDictFuture
实例。
注意
以这种方式序列化对于嵌套深度较深的 TensorDict 可能会很慢,因此不建议在训练循环内部调用此方法。
- memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
就地将所有张量写入相应的内存映射张量。
- 参数:
- 关键字参数:
num_threads (int, 可选) – 用于写入内存映射张量的线程数。默认为 0。
return_early (bool, optional) – 如果为
True
且num_threads>0
,则该方法将返回张量字典的 future。可以使用 future.result() 查询结果张量字典。share_non_tensor (bool, 可选) – 如果
True
,则非张量数据将在进程之间共享,并且对单个节点内任何工作程序进行的写入操作(例如就地更新或设置)将更新所有其他工作程序上的值。如果非张量叶节点的数量很高(例如,共享大型非张量数据堆栈),这可能会导致 OOM 或类似错误。默认为False
。
然后 TensorDict 被锁定,这意味着任何不在原地的写入操作都将抛出异常(例如,重命名、设置或删除条目)。一旦 TensorDict 解锁,内存映射属性将变为
False
,因为不再保证跨进程标识。- 返回值:
如果
return_early=False
,则为 self,否则为TensorDictFuture
实例。
注意
以这种方式序列化对于嵌套深度较深的 TensorDict 可能会很慢,因此不建议在训练循环内部调用此方法。
- memmap_like(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
创建一个与原始 TensorDict 形状相同的、不包含内容的内存映射 TensorDict。
- 参数:
- 关键字参数:
然后 TensorDict 被锁定,这意味着任何不在原地的写入操作都将抛出异常(例如,重命名、设置或删除条目)。一旦 TensorDict 解锁,内存映射属性将变为
False
,因为不再保证跨进程标识。- 返回值:
如果
return_early=False
,则创建一个新的TensorDict
实例,其数据存储为内存映射张量,否则创建一个TensorDictFuture
实例。
注意
这是将一组大型缓冲区写入磁盘的推荐方法,因为
memmap_()
将复制信息,这对于大型内容来说可能很慢。示例
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()¶
如果内存映射 TensorDict 具有
saved_path
,则刷新其内容。如果与此 TensorDict 没有关联路径,则此方法将引发异常。
- minimum(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T ¶
计算
self
和other
的逐元素最小值。- 参数:
other (TensorDict 或 Tensor) – 另一个输入 TensorDict 或张量。
- 关键字参数:
default (torch.Tensor 或 str,可选) – 用于独占条目的默认值。如果没有提供,则两个 TensorDict 的键列表必须完全匹配。如果传递了
default="intersection"
,则只考虑相交的键集,其他键将被忽略。在所有其他情况下,default
将用于操作双方所有缺失的条目。
- minimum_(other: TensorDictBase | torch.Tensor) T ¶
minimum()
的就地版本。注意
就地
minimum
不支持default
关键字参数。
- mul(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T ¶
将
other
乘以self
。\[\text{{out}}_i = \text{{input}}_i \times \text{{other}}_i\]支持广播、类型提升以及整数、浮点数和复数输入。
- 参数:
other (TensorDict, Tensor 或 数字) – 要从
self
中减去的张量或数字。- 关键字参数:
default (torch.Tensor 或 str,可选) – 用于独占条目的默认值。如果没有提供,则两个 TensorDict 的键列表必须完全匹配。如果传递了
default="intersection"
,则只考虑相交的键集,其他键将被忽略。在所有其他情况下,default
将用于操作双方所有缺失的条目。
- mul_(other: TensorDictBase | torch.Tensor) T ¶
mul()
的就地版本。注意
就地
mul
不支持default
关键字参数。
- named_apply(fn: Callable, *others: T, nested_keys: bool = False, batch_size: Sequence[int] | None = None, device: torch.device | None = _NoDefault.ZERO, names: Sequence[str] | None = _NoDefault.ZERO, inplace: bool = False, default: Any = _NoDefault.ZERO, filter_empty: bool | None = None, propagate_lock: bool = False, call_on_nested: bool = False, out: TensorDictBase | None = None, **constructor_kwargs) T | None ¶
将一个键条件可调用函数应用于 TensorDict 中存储的所有值,并在新的 TensorDict 中设置它们。
可调用函数的签名必须为
Callable[Tuple[str, Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]
。- 参数:
fn (Callable) – 应用于 TensorDict 中 (名称,张量) 对的函数。对于每个叶子节点,只使用其叶子节点名称(而不是完整的 NestedKey)。
*others (TensorDictBase 实例, optional) – 如果提供,这些张量字典实例应该具有与 self 相匹配的结构。
fn
参数应接收与张量字典数量一样多的未命名输入,包括 self。如果其他张量字典缺少条目,则可以通过default
关键字参数传递默认值。nested_keys (bool, 可选) – 如果为
True
,则将使用到叶子的完整路径。默认为False
,即仅将最后一个字符串传递给函数。batch_size (整数序列, optional) – 如果提供,则生成的 TensorDict 将具有所需的 batch_size。
batch_size
参数应与转换后的 batch_size 相匹配。这是一个仅限关键字的参数。device (torch.device, optional) – 生成的设备(如果有)。
names (字符串列表, optional) – 新的维度名称,如果修改了 batch_size。
inplace (bool, 可选) – 如果为 True,则更改将在原位进行。默认为 False。这是一个仅限关键字的参数。
default (Any, optional) – 其他张量字典中缺少条目的默认值。如果未提供,则缺少的条目将引发 KeyError。
filter_empty (bool, 可选) – 如果为
True
,则将过滤掉空的 TensorDict。这还可以降低计算成本,因为不会创建和销毁空数据结构。为了向后兼容,默认为False
。propagate_lock (bool, 可选) – 如果为
True
,则锁定的 TensorDict 将生成另一个锁定的 TensorDict。默认为False
。call_on_nested (bool, 可选) –
如果为
True
,则该函数将被调用在一级张量和容器(TensorDict 或 tensorclass)上。在这种情况下,func
负责将其调用传播到嵌套级别。这允许在将调用传播到嵌套的张量字典时进行细粒度的行为。如果为False
,则该函数仅在叶子上调用,并且apply
将负责将该函数分派到所有叶子上。>>> td = TensorDict({"a": {"b": [0.0, 1.0]}, "c": [1.0, 2.0]}) >>> def mean_tensor_only(val): ... if is_tensor_collection(val): ... raise RuntimeError("Unexpected!") ... return val.mean() >>> td_mean = td.apply(mean_tensor_only) >>> def mean_any(val): ... if is_tensor_collection(val): ... # Recurse ... return val.apply(mean_any, call_on_nested=True) ... return val.mean() >>> td_mean = td.apply(mean_any, call_on_nested=True)
out (TensorDictBase, optional) –
一个用于写入结果的张量字典。这可以用来避免创建新的张量字典
>>> td = TensorDict({"a": 0}) >>> td.apply(lambda x: x+1, out=td) >>> assert (td==1).all()
警告
如果对张量字典执行的操作需要访问多个键才能进行单次计算,则提供等于
self
的out
参数可能会导致操作静默地提供错误的结果。例如>>> td = TensorDict({"a": 1, "b": 1}) >>> td.apply(lambda x: x+td["a"])["b"] # Right! tensor(2) >>> td.apply(lambda x: x+td["a"], out=td)["b"] # Wrong! tensor(3)
**constructor_kwargs – 要传递给 TensorDict 构造函数的其他关键字参数。
- 返回值:
一个包含已转换张量的新张量字典。
示例
>>> td = TensorDict({ ... "a": -torch.ones(3), ... "nested": {"a": torch.ones(3), "b": torch.zeros(3)}}, ... batch_size=[3]) >>> def name_filter(name, tensor): ... if name == "a": ... return tensor >>> td.named_apply(name_filter) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> def name_filter(name, *tensors): ... if name == "a": ... r = 0 ... for tensor in tensors: ... r = r + tensor ... return tensor >>> out = td.named_apply(name_filter, td) >>> print(out) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> print(out["a"]) tensor([-1., -1., -1.])
注意
如果函数返回
None
,则忽略该条目。这可用于过滤张量字典中的数据>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, []) >>> def name_filter(name, tensor): ... if name == "1": ... return tensor >>> td.named_apply(name_filter) TensorDict( fields={ 1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ 1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- property names¶
TensorDict 的维度名称。
可以在构造时使用
names
参数设置名称。有关如何在构造后设置名称的详细信息,请参见
refine_names()
。
- nanmean(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: torch.dtype | None = None, reduce: bool | None = None) TensorDictBase | torch.Tensor ¶
返回输入 TensorDict 中所有非 NaN 元素的平均值。
- 参数:
- 关键字参数:
dtype (torch.dtype, 可选) – 返回张量的所需数据类型。如果指定,则在执行操作之前将输入张量转换为 dtype。这对于防止数据类型溢出很有用。默认值:
None
。reduce (bool, 可选) – 如果为
True
,则缩减将在所有 TensorDict 值中发生,并将返回单个缩减后的张量。默认为False
。
- nansum(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: torch.dtype | None = None, reduce: bool | None = None) TensorDictBase | torch.Tensor ¶
返回输入张量字典中所有非 NaN 元素的总和。
- 参数:
- 关键字参数:
dtype (torch.dtype, 可选) – 返回张量的所需数据类型。如果指定,则在执行操作之前将输入张量转换为 dtype。这对于防止数据类型溢出很有用。默认值:
None
。reduce (bool, 可选) – 如果为
True
,则缩减将在所有 TensorDict 值中发生,并将返回单个缩减后的张量。默认为False
。
- property ndim: int¶
参见
batch_dims()
。
- ndimension() int ¶
参见
batch_dims()
。
- neg() T ¶
计算 TensorDict 中每个元素的
neg()
值。
- neg_() T ¶
就地计算 TensorDict 中每个元素的
neg()
值。
- new_empty(*size: Size, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)¶
返回大小为
size
且包含空张量的 TensorDict。默认情况下,返回的 TensorDict 与此张量字典具有相同的
torch.dtype
和torch.device
。- 参数:
size (int...) – 定义输出张量形状的整数列表、元组或 torch.Size。
- 关键字参数:
dtype (torch.dtype, 可选) – 期望返回的张量字典的类型。默认值:如果为
None
,则torch.dtype将保持不变。device (torch.device, 可选) – 期望返回的张量字典的设备。默认值:如果为
None
,则torch.device
将保持不变。requires_grad (bool, 可选) – 是否应记录返回张量上的操作。默认值:
False
。layout (torch.layout, 可选) – 期望返回的 TensorDict 值的布局。默认值:
torch.strided
。pin_memory (bool, 可选) – 如果设置,则返回的张量将在固定内存中分配。仅适用于 CPU 张量。默认值:
False
。
- new_full(size: Size, fill_value, *, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)¶
返回大小为
size
并填充为 1 的 TensorDict。默认情况下,返回的 TensorDict 与此张量字典具有相同的
torch.dtype
和torch.device
。- 参数:
size (整数序列) – 定义输出张量形状的整数列表、元组或 torch.Size。
fill_value (标量) – 用于填充输出张量的数值。
- 关键字参数:
dtype (torch.dtype, 可选) – 期望返回的张量字典的类型。默认值:如果为
None
,则torch.dtype将保持不变。device (torch.device, 可选) – 期望返回的张量字典的设备。默认值:如果为
None
,则torch.device
将保持不变。requires_grad (bool, 可选) – 是否应记录返回张量上的操作。默认值:
False
。layout (torch.layout, 可选) – 期望返回的 TensorDict 值的布局。默认值:
torch.strided
。pin_memory (bool, 可选) – 如果设置,则返回的张量将在固定内存中分配。仅适用于 CPU 张量。默认值:
False
。
- new_ones(*size: Size, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)¶
返回大小为
size
并填充为 1 的 TensorDict。默认情况下,返回的 TensorDict 与此张量字典具有相同的
torch.dtype
和torch.device
。- 参数:
size (int...) – 定义输出张量形状的整数列表、元组或 torch.Size。
- 关键字参数:
dtype (torch.dtype, 可选) – 期望返回的张量字典的类型。默认值:如果为
None
,则torch.dtype将保持不变。device (torch.device, 可选) – 期望返回的张量字典的设备。默认值:如果为
None
,则torch.device
将保持不变。requires_grad (bool, 可选) – 是否应记录返回张量上的操作。默认值:
False
。layout (torch.layout, 可选) – 期望返回的 TensorDict 值的布局。默认值:
torch.strided
。pin_memory (bool, 可选) – 如果设置,则返回的张量将在固定内存中分配。仅适用于 CPU 张量。默认值:
False
。
- new_tensor(data: torch.Tensor | TensorDictBase, *, dtype: torch.dtype = None, device: DeviceType = _NoDefault.ZERO, requires_grad: bool = False, pin_memory: bool | None = None)¶
返回一个新的 TensorDict,其中数据为张量
data
。默认情况下,返回的 TensorDict 值与该张量具有相同的
torch.dtype
和torch.device
。data
也可以是张量集合(TensorDict
或tensorclass
),在这种情况下,new_tensor
方法迭代self
和data
的张量对。- 参数:
data (torch.Tensor 或 TensorDictBase) – 要复制的数据。
- 关键字参数:
dtype (torch.dtype, 可选) – 期望返回的张量字典的类型。默认值:如果为
None
,则torch.dtype将保持不变。device (torch.device, 可选) – 期望返回的张量字典的设备。默认值:如果为
None
,则torch.device
将保持不变。requires_grad (bool, 可选) – 是否应记录返回张量上的操作。默认值:
False
。pin_memory (bool, 可选) – 如果设置,则返回的张量将在固定内存中分配。仅适用于 CPU 张量。默认值:
False
。
- new_zeros(*size: Size, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)¶
返回大小为
size
且填充为 0 的 TensorDict。默认情况下,返回的 TensorDict 与此张量字典具有相同的
torch.dtype
和torch.device
。- 参数:
size (int...) – 定义输出张量形状的整数列表、元组或 torch.Size。
- 关键字参数:
dtype (torch.dtype, 可选) – 期望返回的张量字典的类型。默认值:如果为
None
,则torch.dtype将保持不变。device (torch.device, 可选) – 期望返回的张量字典的设备。默认值:如果为
None
,则torch.device
将保持不变。requires_grad (bool, 可选) – 是否应记录返回张量上的操作。默认值:
False
。layout (torch.layout, 可选) – 期望返回的 TensorDict 值的布局。默认值:
torch.strided
。pin_memory (bool, 可选) – 如果设置,则返回的张量将在固定内存中分配。仅适用于 CPU 张量。默认值:
False
。
- norm(*, out=None, dtype: torch.dtype | None = None)¶
计算 tensordict 中每个张量的范数。
- 关键字参数:
out (TensorDict,可选) – 输出 tensordict。
dtype (torch.dtype,可选) – 输出数据类型 (torch>=2.4)。
- numpy()¶
将 tensordict 转换为 (可能嵌套的) numpy 数组字典。
非张量数据将原样暴露。
示例
>>> from tensordict import TensorDict >>> import torch >>> data = TensorDict({"a": {"b": torch.zeros(()), "c": "a string!"}}) >>> print(data) TensorDict( fields={ a: TensorDict( fields={ b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(data.numpy()) {'a': {'b': array(0., dtype=float32), 'c': 'a string!'}}
- permute(*args, **kwargs)¶
返回一个张量字典的视图,其中批处理维度根据 dims 进行置换。
- 参数:
*dims_list (int) – 张量字典的批处理维度的新的顺序。或者,可以提供一个整数的单个可迭代对象。
dims (int 列表) – 调用 permute(…) 的另一种方式。
- 返回值:
一个新的张量字典,其中批处理维度按所需的顺序排列。
示例
>>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4]) >>> print(tensordict.permute([1, 0])) PermutedTensorDict( source=TensorDict( fields={ a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)}, batch_size=torch.Size([3, 4]), device=cpu, is_shared=False), op=permute(dims=[1, 0])) >>> print(tensordict.permute(1, 0)) PermutedTensorDict( source=TensorDict( fields={ a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)}, batch_size=torch.Size([3, 4]), device=cpu, is_shared=False), op=permute(dims=[1, 0])) >>> print(tensordict.permute(dims=[1, 0])) PermutedTensorDict( source=TensorDict( fields={ a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)}, batch_size=torch.Size([3, 4]), device=cpu, is_shared=False), op=permute(dims=[1, 0]))
- pin_memory(num_threads: int | None = None, inplace: bool = False) T ¶
在存储的张量上调用
pin_memory()
。- 参数:
num_threads (int 或 str) – 如果提供,则用于在叶子节点上调用
pin_memory
的线程数。默认为None
,这在ThreadPoolExecutor(max_workers=None)
中设置大量线程。要在主线程上执行对pin_memory()
的所有调用,请传递num_threads=0
。inplace (bool,可选) – 如果为
True
,则张量字典将被就地修改。默认为False
。
- pin_memory_(num_threads: int | str = 0) T ¶
在存储的张量上调用
pin_memory()
并返回就地修改的 TensorDict。
- pop(key: NestedKey, default: Any = _NoDefault.ZERO) Tensor ¶
从张量字典中移除并返回一个值。
如果该值不存在且未提供默认值,则会抛出 KeyError。
- 参数:
key (str 或 嵌套键) – 要查找的条目。
default (Any,可选) – 如果找不到键,则返回的值。
示例
>>> td = TensorDict({"1": 1}, []) >>> one = td.pop("1") >>> assert one == 1 >>> none = td.pop("1", default=None) >>> assert none is None
- pow(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T ¶
将
self
中每个元素的幂与other
相乘,并返回一个包含结果的张量。other
可以是单个float
数、Tensor 或TensorDict
。当
other
是张量时,input
和other
的形状必须是可广播的。- 参数:
other (float、张量 或 张量字典) – 指数值
- 关键字参数:
default (torch.Tensor 或 str,可选) – 用于独占条目的默认值。如果没有提供,则两个 TensorDict 的键列表必须完全匹配。如果传递了
default="intersection"
,则只考虑相交的键集,其他键将被忽略。在所有其他情况下,default
将用于操作双方所有缺失的条目。
- pow_(other: TensorDictBase | torch.Tensor) T ¶
pow()
的就地版本。注意
就地
pow
不支持default
关键字参数。
- prod(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: torch.dtype | None = None, reduce: bool | None = None) TensorDictBase | torch.Tensor ¶
返回输入张量字典中所有元素值的乘积。
- 参数:
- 关键字参数:
dtype (torch.dtype, 可选) – 返回张量的所需数据类型。如果指定,则在执行操作之前将输入张量转换为 dtype。这对于防止数据类型溢出很有用。默认值:
None
。reduce (bool, 可选) – 如果为
True
,则缩减将在所有 TensorDict 值中发生,并将返回单个缩减后的张量。默认为False
。
- qint32()¶
将所有张量转换为
torch.qint32
。
- qint8()¶
将所有张量转换为
torch.qint8
。
- quint4x2()¶
将所有张量转换为
torch.quint4x2
。
- quint8()¶
将所有张量转换为
torch.quint8
。
- reciprocal() T ¶
计算TensorDict中每个元素的
reciprocal()
值。
- reciprocal_() T ¶
就地计算TensorDict中每个元素的
reciprocal()
值。
- recv(src: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) int ¶
接收一个tensordict的内容并用它更新内容。
查看send方法中的示例以了解上下文。
- reduce(dst, op=None, async_op=False, return_premature=False, group=None)¶
在所有机器上减少tensordict。
只有
rank
为dst的进程将接收最终结果。
- refine_names(*names) T ¶
根据names细化self的维度名称。
细化是重命名的一个特例,它“提升”未命名的维度。None维度可以细化为任何名称;命名维度只能细化为相同的名称。
因为命名张量可以与未命名张量共存,所以细化名称提供了一种编写命名张量感知代码的好方法,该代码适用于命名和未命名张量。
names最多可以包含一个省略号(…)。省略号被贪婪地扩展;它被就地扩展以填充与self.dim()长度相同的names,使用来自self.names对应索引的names。
返回值:根据输入命名维度的相同tensordict。
示例
>>> td = TensorDict({}, batch_size=[3, 4, 5, 6]) >>> tdr = td.refine_names(None, None, None, "d") >>> assert tdr.names == [None, None, None, "d"] >>> tdr = td.refine_names("a", None, None, "d") >>> assert tdr.names == ["a", None, None, "d"]
- rename(*names, **rename_map)¶
返回tensordict的克隆,其维度已重命名。
示例
>>> td = TensorDict({}, batch_size=[1, 2, 3 ,4]) >>> td.names = list("abcd") >>> td_rename = td.rename(c="g") >>> assert td_rename.names == list("abgd")
- rename_(*names, **rename_map)¶
与
rename()
相同,但就地执行重命名。示例
>>> td = TensorDict({}, batch_size=[1, 2, 3 ,4]) >>> td.names = list("abcd") >>> assert td.rename_(c="g") >>> assert td.names == list("abgd")
- rename_key_(old_key: NestedKey, new_key: NestedKey, safe: bool = False) T ¶
使用新字符串重命名键,并返回具有更新键名称的相同tensordict。
- replace(*args, **kwargs)¶
创建tensordict的浅拷贝,其中条目已被替换。
接受一个未命名的参数,该参数必须是
TensorDictBase
子类的字典。此外,可以使用命名关键字参数更新第一级条目。- 返回值:
如果输入非空,则为
self
的副本,其中包含更新的条目。如果提供了空字典或没有字典并且kwargs为空,则返回self
。
- requires_grad_(requires_grad=True) T ¶
更改是否应记录此张量上的操作的自动梯度:就地设置此张量的requires_grad属性。
返回此tensordict。
- 参数:
requires_grad (bool,可选) – 自动梯度是否应记录此tensordict上的操作。默认为
True
。
- reshape(*args, **kwargs) T ¶
返回所需形状的连续重塑张量。
- 参数:
*shape (int) – 生成的tensordict的新形状。
- 返回值:
具有重塑键的TensorDict
示例
>>> td = TensorDict({ ... 'x': torch.arange(12).reshape(3, 4), ... }, batch_size=[3, 4]) >>> td = td.reshape(12) >>> print(td['x']) torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
- round() T ¶
计算TensorDict中每个元素的
round()
值。
- round_() T ¶
就地计算TensorDict中每个元素的
round()
值。
- save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
将 TensorDict 保存到磁盘。
此函数是
memmap()
的代理。
- property saved_path¶
返回存储内存映射 TensorDict 的路径。
当 is_memmap() 返回
False
(例如,当 TensorDict 解锁时)时,此参数消失。
- select(*keys: NestedKey, inplace: bool = False, strict: bool = True) T ¶
选择 TensorDict 的键,并返回一个仅包含所选键的新 TensorDict。
值不会被复制:对原始 TensorDict 或新 TensorDict 中的张量进行就地修改将导致两个 TensorDict 都发生变化。
- 参数:
- 返回值:
一个新的 TensorDict(如果
inplace=True
则为同一个),仅包含所选键。
注意
要选择 TensorDict 中的键并返回一个缺少这些键的此 TensorDict 版本,请参阅
split_keys()
方法。示例
>>> from tensordict import TensorDict >>> td = TensorDict({"a": 0, "b": {"c": 1, "d": 2}}, []) >>> td.select("a", ("b", "c")) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> td.select("a", "b") TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> td.select("this key does not exist", strict=False) TensorDict( fields={ }, batch_size=torch.Size([]), device=None, is_shared=False)
- send(dst: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) None ¶
将 TensorDict 的内容发送到远程工作节点。
- 参数:
dst (int) – 目标工作者的等级,内容应发送到该工作者。
- 关键字参数:
示例
>>> from torch import multiprocessing as mp >>> from tensordict import TensorDict >>> import torch >>> >>> >>> def client(): ... torch.distributed.init_process_group( ... "gloo", ... rank=1, ... world_size=2, ... init_method=f"tcp://127.0.0.1:10003", ... ) ... ... td = TensorDict( ... { ... ("a", "b"): torch.randn(2), ... "c": torch.randn(2, 3), ... "_": torch.ones(2, 1, 5), ... }, ... [2], ... ) ... td.send(0) ... >>> >>> def server(queue): ... torch.distributed.init_process_group( ... "gloo", ... rank=0, ... world_size=2, ... init_method=f"tcp://127.0.0.1:10003", ... ) ... td = TensorDict( ... { ... ("a", "b"): torch.zeros(2), ... "c": torch.zeros(2, 3), ... "_": torch.zeros(2, 1, 5), ... }, ... [2], ... ) ... td.recv(1) ... assert (td != 0).all() ... queue.put("yuppie") ... >>> >>> if __name__=="__main__": ... queue = mp.Queue(1) ... main_worker = mp.Process(target=server, args=(queue,)) ... secondary_worker = mp.Process(target=client) ... ... main_worker.start() ... secondary_worker.start() ... out = queue.get(timeout=10) ... assert out == "yuppie" ... main_worker.join() ... secondary_worker.join()
- set(key: NestedKey, item: Tensor, inplace: bool = False, *, non_blocking: bool = False, **kwargs: Any) T ¶
设置新的键值对。
- 参数:
key (str, str 元组) – 要设置的键的名称。
item (torch.Tensor 或等效类型, TensorDictBase 实例) – 要存储在 TensorDict 中的值。
inplace (bool, 可选) – 如果为
True
并且键与 TensorDict 中的现有键匹配,则将对该键值对进行就地更新。如果 inplace 为True
且找不到该条目,则会添加它。有关更严格的就地操作,请改用set_()
。默认为False
。
- 关键字参数:
non_blocking (bool, 可选) – 如果为
True
并且此副本位于不同的设备之间,则副本可能相对于主机异步发生。- 返回值:
自身
示例
>>> td = TensorDict({}, batch_size[3, 4]) >>> td.set("x", torch.randn(3, 4)) >>> y = torch.randn(3, 4, 5) >>> td.set("y", y, inplace=True) # works, even if 'y' is not present yet >>> td.set("y", torch.zeros_like(y), inplace=True) >>> assert (y==0).all() # y values are overwritten >>> td.set("y", torch.ones(5), inplace=True) # raises an exception as shapes mismatch
- set_(key: NestedKey, item: Tensor, *, non_blocking: bool = False) T ¶
在保留原始存储的同时,将值设置为现有键。
- 参数:
key (str) – 值的名称
item (torch.Tensor 或兼容类型, TensorDictBase) – 要存储在 TensorDict 中的值
- 关键字参数:
non_blocking (bool, 可选) – 如果为
True
并且此副本位于不同的设备之间,则副本可能相对于主机异步发生。- 返回值:
自身
示例
>>> td = TensorDict({}, batch_size[3, 4]) >>> x = torch.randn(3, 4) >>> td.set("x", x) >>> td.set_("x", torch.zeros_like(x)) >>> assert (x == 0).all()
- set_at_(key: NestedKey, value: Tensor, index: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], *, non_blocking: bool = False) T ¶
在由
index
指示的索引处就地设置值。- 参数:
key (str, str 的元组) – 要修改的键。
value (torch.Tensor) – 要在索引 index 处设置的值。
- 关键字参数:
non_blocking (bool, 可选) – 如果为
True
并且此副本位于不同的设备之间,则副本可能相对于主机异步发生。- 返回值:
自身
示例
>>> td = TensorDict({}, batch_size[3, 4]) >>> x = torch.randn(3, 4) >>> td.set("x", x) >>> td.set_at_("x", value=torch.ones(1, 4), index=slice(1)) >>> assert (x[0] == 1).all()
- set_non_tensor(key: NestedKey, value: Any)¶
使用
tensordict.tensorclass.NonTensorData
在 tensordict 中注册非张量值。可以使用
TensorDictBase.get_non_tensor()
或直接使用 get 检索该值,这将返回tensordict.tensorclass.NonTensorData
对象。返回:自身
示例
>>> data = TensorDict({}, batch_size=[]) >>> data.set_non_tensor(("nested", "the string"), "a string!") >>> assert data.get_non_tensor(("nested", "the string")) == "a string!" >>> # regular `get` works but returns a NonTensorData object >>> data.get(("nested", "the string")) NonTensorData( data='a string!', batch_size=torch.Size([]), device=None, is_shared=False)
- setdefault(key: NestedKey, default: Tensor, inplace: bool = False) Tensor ¶
如果
key
不在 tensordict 中,则使用default
的值为key
条目插入。如果
key
在 tensordict 中,则返回key
的值,否则返回default
。- 参数:
key (str 或 嵌套键) – 值的名称。
default (torch.Tensor 或 兼容类型, TensorDictBase) – 如果键不存在,则存储在 tensordict 中的值。
- 返回值:
tensordict 中 key 的值。如果 key 之前未设置,则为 default。
示例
>>> td = TensorDict({}, batch_size=[3, 4]) >>> val = td.setdefault("a", torch.zeros(3, 4)) >>> assert (val == 0).all() >>> val = td.setdefault("a", torch.ones(3, 4)) >>> assert (val == 0).all() # output is still 0
- property shape: Size¶
参见
batch_size
。
将所有张量放置在共享内存中。
然后 TensorDict 被锁定,这意味着任何非就地写入操作都将抛出异常(例如,重命名、设置或删除条目)。相反,一旦 tensordict 解锁,share_memory 属性将变为
False
,因为跨进程身份不再保证。- 返回值:
自身
- sigmoid() T ¶
计算 TensorDict 中每个元素的
sigmoid()
值。
- sigmoid_() T ¶
就地计算 TensorDict 中每个元素的
sigmoid()
值。
- sign() T ¶
计算 TensorDict 中每个元素的
sign()
值。
- sign_() T ¶
就地计算 TensorDict 中每个元素的
sign()
值。
- sin() T ¶
计算 TensorDict 中每个元素的
sin()
值。
- sin_() T ¶
就地计算 TensorDict 中每个元素的
sin()
值。
- sinh() T ¶
计算 TensorDict 中每个元素的
sinh()
值。
- sinh_() T ¶
就地计算 TensorDict 中每个元素的
sinh()
值。
- size(dim: int | None = None) torch.Size | int ¶
返回由
dim
指示的维度的尺寸。如果未指定
dim
,则返回 TensorDict 的batch_size
属性。
- 属性 sorted_keys: list[NestedKey]¶
返回按字母顺序排序的键。
不支持额外的参数。
如果 TensorDict 被锁定,则键会被缓存,直到 tensordict 解锁以加快执行速度。
- split(split_size: int | list[int], dim: int = 0) list[TensorDictBase] ¶
使用给定维度上的指定大小分割 TensorDict 中的每个张量,类似于 torch.split。
返回一个
TensorDict
实例列表,其中包含分割后的数据块的视图。- 参数:
- 返回值:
一个在给定维度上具有指定大小的 TensorDict 列表。
示例
>>> td = TensorDict({ ... 'x': torch.arange(12).reshape(3, 4), ... }, batch_size=[3, 4]) >>> td0, td1 = td.split([1, 2], dim=0) >>> print(td0['x']) torch.Tensor([[0, 1, 2, 3]])
- split_keys(*key_sets, inplace=False, strict: bool = True, reproduce_struct: bool = False)¶
根据一个或多个键集将 tensordict 分割成子集。
该方法将返回
N+1
个 tensordict,其中N
是提供的参数数量。- 参数:
注意
None
非张量值将被忽略,不会返回。注意
该方法不检查提供的列表中是否存在重复项。
示例
>>> td = TensorDict( ... a=0, ... b=0, ... c=0, ... d=0, ... ) >>> td_a, td_bc, td_d = td.split_keys(["a"], ["b", "c"]) >>> print(td_bc)
- sqrt()¶
计算
self
的逐元素平方根。
- squeeze(*args, **kwargs)¶
压缩 -self.batch_dims+1 和 self.batch_dims-1 之间的维度中的所有张量,并在新的 tensordict 中返回它们。
- 参数:
dim (Optional[int]) – 要压缩的维度。如果 dim 为
None
,则所有单例维度都将被压缩。默认为None
。
示例
>>> td = TensorDict({ ... 'x': torch.arange(24).reshape(3, 1, 4, 2), ... }, batch_size=[3, 1, 4]) >>> td = td.squeeze() >>> td.shape torch.Size([3, 4]) >>> td.get("x").shape torch.Size([3, 4, 2])
此操作也可以用作上下文管理器。对原始 tensordict 的更改将发生在外部,即原始张量的内容不会被更改。这也假设 tensordict 未被锁定(否则,需要解锁 tensordict)。此功能不兼容隐式压缩。
>>> td = TensorDict({ ... 'x': torch.arange(24).reshape(3, 1, 4, 2), ... }, batch_size=[3, 1, 4]) >>> with td.squeeze(1) as tds: ... tds.set("y", torch.zeros(3, 4)) >>> assert td.get("y").shape == [3, 1, 4]
- 类方法 stack(input, dim=0, *, out=None)¶
沿给定维度将 tensordict 堆叠成单个 tensordict。
此调用等效于调用
torch.stack()
,但与 torch.compile 兼容。
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) OrderedDict[str, Any] ¶
从 tensordict 生成 state_dict。
state-dict 的结构仍然是嵌套的,除非
flatten
设置为True
。tensordict state-dict 包含重建 tensordict 所需的所有张量和元数据(目前不支持名称)。
- 参数:
destination (dict,可选) – 如果提供,tensordict 的状态将更新到字典中并返回相同的对象。否则,将创建一个
OrderedDict
并返回。默认值:None
。prefix (str,可选) – 添加到张量名称的前缀,用于组合 state_dict 中的键。默认值:
''
。keep_vars (bool,可选) – 默认情况下,state dict 中返回的
torch.Tensor
项目与 autograd 分离。如果将其设置为True
,则不会执行分离。默认值:False
。flatten (bool,可选) – 结构是否应使用
"."
字符进行扁平化。默认为False
。
示例
>>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, []) >>> sd = data.state_dict() >>> print(sd) OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3', OrderedDict([('3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])), ('__batch_size', torch.Size([])), ('__device', None)]) >>> sd = data.state_dict(flatten=True) OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3.3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])
- std(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, correction: int = 1, reduce: bool | None = None) TensorDictBase | torch.Tensor ¶
返回输入 TensorDict 中所有元素的标准差值。
- 参数:
- 关键字参数:
- sub(other: TensorDictBase | float, *, alpha: float | None = None, default: str | CompatibleType | None = None)¶
从
self
中减去other
(乘以alpha
)。\[\text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i\]支持广播、类型提升以及整数、浮点数和复数输入。
- 参数:
other (TensorDict, Tensor 或 数字) – 要从
self
中减去的张量或数字。- 关键字参数:
alpha (数字) –
other
的乘数。default (torch.Tensor 或 str,可选) – 用于独占条目的默认值。如果没有提供,则两个 TensorDict 的键列表必须完全匹配。如果传递了
default="intersection"
,则只考虑相交的键集,其他键将被忽略。在所有其他情况下,default
将用于操作双方所有缺失的条目。
- sub_(other: TensorDictBase | float, alpha: float | None = None)¶
sub()
的就地版本。注意
就地
sub
不支持default
关键字参数。
- sum(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: torch.dtype | None = None, reduce: bool | None = None) TensorDictBase | torch.Tensor ¶
返回输入 TensorDict 中所有元素的和值。
- 参数:
- 关键字参数:
dtype (torch.dtype, 可选) – 返回张量的所需数据类型。如果指定,则在执行操作之前将输入张量转换为 dtype。这对于防止数据类型溢出很有用。默认值:
None
。reduce (bool, 可选) – 如果为
True
,则缩减将在所有 TensorDict 值中发生,并将返回单个缩减后的张量。默认为False
。
- tan() T ¶
计算 TensorDict 中每个元素的
tan()
值。
- tan_() T ¶
就地计算 TensorDict 中每个元素的
tan()
值。
- tanh() T ¶
计算 TensorDict 中每个元素的
tanh()
值。
- tanh_() T ¶
就地计算 TensorDict 中每个元素的
tanh()
值。
- to(*args, **kwargs) T ¶
将 TensorDictBase 子类映射到另一个设备、数据类型或另一个 TensorDictBase 子类(如果允许)。
不允许将张量转换为新的数据类型,因为 TensorDict 不限于包含单个张量数据类型。
- 参数:
device (torch.device, 可选) – TensorDict 的目标设备。
dtype (torch.dtype, 可选) – TensorDict 的目标浮点或复数数据类型。
tensor (torch.Tensor, 可选) – 数据类型和设备为 TensorDict 中所有张量的目标数据类型和设备的张量。
- 关键字参数:
non_blocking (bool, 可选) – 操作是否应为阻塞操作。
memory_format (torch.memory_format, 可选) – TensorDict 中 4D 参数和缓冲区的目标内存格式。
batch_size (torch.Size, 可选) – 输出 TensorDict 的结果批大小。
other (TensorDictBase, 可选) –
TensorDict 实例,其 dtype 和 device 是此 TensorDict 中所有张量的所需 dtype 和 device。.. 注意:: 由于
TensorDictBase
实例没有dtype,dtype 是从示例叶子中收集的。如果有多个 dtype,则不会进行 dtype 转换。
non_blocking_pin (bool, 可选) –
如果
True
,则在将张量发送到设备之前将其固定。这将异步执行,但可以通过num_threads
参数进行控制。注意
调用
tensordict.pin_memory().to("cuda")
通常会比tensordict.to("cuda", non_blocking_pin=True)
慢得多,因为在第二种情况下,pin_memory 是异步调用的。如果张量很大且数量众多,多线程pin_memory
通常会有益:当要发送的张量太少时,生成线程和收集数据的开销会超过多线程的益处,并且如果张量很小,迭代长列表的开销也会过大。num_threads (int 或 None, 可选) – 如果
non_blocking_pin=True
,则用于pin_memory
的线程数。默认情况下,将生成max(1, torch.get_num_threads())
个线程。num_threads=0
将取消 pin_memory() 调用的任何多线程。
- 返回值:
如果设备与 tensordict 设备不同和/或传递了 dtype,则为一个新的 tensordict 实例。否则为相同的 tensordict。
batch_size
仅修改是在原地完成的。
注意
如果 TensorDict 已合并,则生成的 TensorDict 也将合并。每个新张量都将是合并存储的视图,并转换为所需的设备。
示例
>>> data = TensorDict({"a": 1.0}, [], device=None) >>> data_cuda = data.to("cuda:0") # casts to cuda >>> data_int = data.to(torch.int) # casts to int >>> data_cuda_int = data.to("cuda:0", torch.int) # multiple casting >>> data_cuda = data.to(torch.randn(3, device="cuda:0")) # using an example tensor >>> data_cuda = data.to(other=TensorDict({}, [], device="cuda:0")) # using a tensordict example
- to_h5(filename, **kwargs)¶
将 tensordict 转换为具有 h5 后端的 PersistentTensorDict。
- 参数:
filename (str 或 路径) – h5 文件的路径。
device (torch.device 或 兼容, 可选) – 返回张量后预期所在的设备。默认为
None
(默认在 cpu 上)。**kwargs – 要传递给
h5py.File.create_dataset()
的 kwargs。
- 返回值:
一个链接到新创建文件的
PersitentTensorDict
实例。
示例
>>> import tempfile >>> import timeit >>> >>> from tensordict import TensorDict, MemoryMappedTensor >>> td = TensorDict({ ... "a": MemoryMappedTensor.from_tensor(torch.zeros(()).expand(1_000_000)), ... "b": {"c": MemoryMappedTensor.from_tensor(torch.zeros(()).expand(1_000_000, 3))}, ... }, [1_000_000]) >>> >>> file = tempfile.NamedTemporaryFile() >>> td_h5 = td.to_h5(file.name, compression="gzip", compression_opts=9) >>> print(td_h5) PersistentTensorDict( fields={ a: Tensor(shape=torch.Size([1000000]), device=cpu, dtype=torch.float32, is_shared=False), b: PersistentTensorDict( fields={ c: Tensor(shape=torch.Size([1000000, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([1000000]), device=None, is_shared=False)}, batch_size=torch.Size([1000000]), device=None, is_shared=False)
- to_module(module: nn.Module, *, inplace: bool | None = None, return_swap: bool = True, swap_dest=None, use_state_dict: bool = False, non_blocking: bool = False, memo=None)¶
递归地将 TensorDictBase 实例的内容写入给定的 nn.Module 属性。
- 参数:
module (nn.Module) – 要将参数写入的模块。
- 关键字参数:
inplace (bool, 可选) – 如果
True
,则模块中的参数或张量将就地更新。默认为False
。return_swap (bool, 可选) – 如果
True
,则将返回旧的参数配置。默认为False
。swap_dest (TensorDictBase, 可选) – 如果
return_swap
为True
,则为应写入交换的 tensordict。use_state_dict (bool, 可选) – 如果
True
,则将使用 state-dict API 加载参数(包括 state-dict hook)。默认为False
。non_blocking (bool, 可选) – 如果为
True
并且此副本位于不同的设备之间,则副本可能相对于主机异步发生。
示例
>>> from torch import nn >>> module = nn.TransformerDecoder( ... decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4), ... num_layers=1) >>> params = TensorDict.from_module(module) >>> params.zero_() >>> params.to_module(module) >>> assert (module.layers[0].linear1.weight == 0).all()
- to_namedtuple(dest_cls: type | None = None)¶
将 tensordict 转换为 namedtuple。
- 参数:
dest_cls (类型, 可选) – 要使用的可选 namedtuple 类。
示例
>>> from tensordict import TensorDict >>> import torch >>> data = TensorDict({ ... "a_tensor": torch.zeros((3)), ... "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3]) >>> data.to_namedtuple() GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!'))
- to_padded_tensor(padding=0.0, mask_key: NestedKey | None = None)¶
将所有嵌套张量转换为填充版本,并相应地调整批大小。
- 参数:
padding (float) – tensordict 中张量的填充值。默认为
0.0
。mask_key (NestedKey, 可选) – 如果提供,则为写入有效值掩码的键。如果异构维度不是 tensordict 批大小的一部分,将导致错误。默认为
None
- to_pytree()¶
将 tensordict 转换为 PyTree。
如果 tensordict 不是从 pytree 创建的,则此方法仅返回
self
而不进行修改。有关更多信息和示例,请参见
from_pytree()
。
- to_tensordict() T ¶
从 TensorDictBase 返回一个常规的 TensorDict 实例。
- 返回值:
一个包含相同值的新 TensorDict 对象。
- transpose(dim0, dim1)¶
返回一个 tensordict,它是输入的转置版本。给定的维度
dim0
和dim1
已交换。转置 tensordict 的就地或非就地修改也将影响原始 tensordict,因为内存是共享的,并且操作会映射回原始 tensordict。
示例
>>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4]) >>> tensordict_transpose = tensordict.transpose(0, 1) >>> print(tensordict_transpose.shape) torch.Size([4, 3]) >>> tensordict_transpose.set("b",, torch.randn(4, 3)) >>> print(tensordict.get("b").shape) torch.Size([3, 4])
- trunc() T ¶
计算 TensorDict 的每个元素的
trunc()
值。
- trunc_() T ¶
对 TensorDict 中每个元素的取整值进行就地计算。
- uint16()¶
将所有张量转换为
torch.uint16
。
- uint32()¶
将所有张量转换为
torch.uint32
。
- uint64()¶
将所有张量转换为
torch.uint64
。
- uint8()¶
将所有张量转换为
torch.uint8
。
- unbind(dim: int) tuple[T, ...] ¶
返回一个索引张量字典的元组,沿着指定的维度解绑。
示例
>>> td = TensorDict({ ... 'x': torch.arange(12).reshape(3, 4), ... }, batch_size=[3, 4]) >>> td0, td1, td2 = td.unbind(0) >>> td0['x'] tensor([0, 1, 2, 3]) >>> td1['x'] tensor([4, 5, 6, 7])
- unflatten(dim, unflattened_size)¶
展开一个张量字典的维度,将其扩展到所需的形状。
- 参数:
**dim** (int) – 指定要展开的输入张量的维度。
**unflattened_size** (shape) – 张量字典展开维度的新形状。
示例
>>> td = TensorDict({ ... "a": torch.arange(60).view(3, 4, 5), ... "b": torch.arange(12).view(3, 4)}, ... batch_size=[3, 4]) >>> td_flat = td.flatten(0, 1) >>> td_unflat = td_flat.unflatten(0, [3, 4]) >>> assert (td == td_unflat).all()
- unflatten_keys(separator: str = '.', inplace: bool = False) T ¶
将扁平化的张量字典递归地转换为嵌套的张量字典。
TensorDict 类型将丢失,结果将是一个简单的 TensorDict 实例。嵌套张量字典的元数据将从根节点推断:数据树中的所有实例将共享相同的批大小、维度名称和设备。
- 参数:
示例
>>> data = TensorDict({"a": 1, "b - c": 2, "e - f - g": 3}, batch_size=[]) >>> data.unflatten_keys(separator=" - ") TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), e: TensorDict( fields={ f: TensorDict( fields={ g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
此方法和
unflatten_keys()
在处理状态字典时特别有用,因为它们可以无缝地将扁平化的字典转换为模仿模型结构的数据结构。示例
>>> model = torch.nn.Sequential(torch.nn.Linear(3 ,4)) >>> ddp_model = torch.ao.quantization.QuantWrapper(model) >>> state_dict = TensorDict(ddp_model.state_dict(), batch_size=[]).unflatten_keys(".") >>> print(state_dict) TensorDict( fields={ module: TensorDict( fields={ 0: TensorDict( fields={ bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> model_state_dict = state_dict.get("module") >>> print(model_state_dict) TensorDict( fields={ 0: TensorDict( fields={ bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> model.load_state_dict(dict(model_state_dict.flatten_keys(".")))
- unsqueeze(*args, **kwargs)¶
为位于 -td.batch_dims 和 td.batch_dims 之间的维度,对所有张量进行扩展,并在新的张量字典中返回它们。
- 参数:
**dim** (int) – 要扩展的维度。
示例
>>> td = TensorDict({ ... 'x': torch.arange(24).reshape(3, 4, 2), ... }, batch_size=[3, 4]) >>> td = td.unsqueeze(-2) >>> td.shape torch.Size([3, 1, 4]) >>> td.get("x").shape torch.Size([3, 1, 4, 2])
此操作也可以用作上下文管理器。对原始张量字典的更改将发生在非就地,即原始张量的内容不会被更改。这也假设张量字典未被锁定(否则,需要解锁张量字典)。
>>> td = TensorDict({ ... 'x': torch.arange(24).reshape(3, 4, 2), ... }, batch_size=[3, 4]) >>> with td.unsqueeze(-2) as tds: ... tds.set("y", torch.zeros(3, 1, 4)) >>> assert td.get("y").shape == [3, 4]
- update(input_dict_or_td: dict[str, CompatibleType] | T, clone: bool = False, inplace: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None, is_leaf: Callable[[Type], bool] | None = None) T ¶
使用字典或其他 TensorDict 中的值更新 TensorDict。
- 参数:
**input_dict_or_td** (TensorDictBase 或 dict) – 要写入 self 的输入数据。
**clone** (bool, 可选) – 输入(张量)字典中的张量是否在设置前应该被克隆。默认为
False
。**inplace** (bool, 可选) – 如果为
True
并且键与张量字典中已存在的键匹配,则对于该键值对,更新将在就地进行。如果找不到该条目,则会添加它。默认为False
。
- 关键字参数:
- 返回值:
自身
示例
>>> td = TensorDict({}, batch_size=[3]) >>> a = torch.randn(3) >>> b = torch.randn(3, 4) >>> other_td = TensorDict({"a": a, "b": b}, batch_size=[]) >>> td.update(other_td, inplace=True) # writes "a" and "b" even though they can't be found >>> assert td['a'] is other_td['a'] >>> other_td = other_td.clone().zero_() >>> td.update(other_td) >>> assert td['a'] is not other_td['a']
- update_(input_dict_or_td: dict[str, CompatibleType] | T, clone: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None) T ¶
使用来自字典或另一个 TensorDict 的值就地更新 TensorDict。
与
update()
不同,如果键在self
中未知,此函数将抛出错误。- 参数:
**input_dict_or_td** (TensorDictBase 或 dict) – 要写入 self 的输入数据。
**clone** (bool, 可选) – 输入(张量)字典中的张量是否在设置前应该被克隆。默认为
False
。
- 关键字参数:
keys_to_update (嵌套键序列, 可选) – 如果提供,则仅更新
key_to_update
中的键列表。 旨在避免调用data_dest.update_(data_src.select(*keys_to_update))
。non_blocking (bool, 可选) – 如果为
True
并且此副本位于不同的设备之间,则副本可能相对于主机异步发生。
- 返回值:
自身
示例
>>> a = torch.randn(3) >>> b = torch.randn(3, 4) >>> td = TensorDict({"a": a, "b": b}, batch_size=[3]) >>> other_td = TensorDict({"a": a*0, "b": b*0}, batch_size=[]) >>> td.update_(other_td) >>> assert td['a'] is not other_td['a'] >>> assert (td['a'] == other_td['a']).all() >>> assert (td['a'] == 0).all()
- update_at_(input_dict_or_td: dict[str, CompatibleType] | T, idx: IndexType, clone: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None) T ¶
使用来自字典或另一个 TensorDict 的值在指定索引处就地更新 TensorDict。
与 TensorDict.update 不同,如果键在 TensorDict 中未知,此函数将抛出错误。
- 参数:
**input_dict_or_td** (TensorDictBase 或 dict) – 要写入 self 的输入数据。
idx (int, torch.Tensor, 可迭代对象, 切片) – 应发生更新的 tensordict 的索引。
clone (布尔值, 可选) – 输入(张量)字典中的张量在设置之前是否应克隆。默认为 False。
- 关键字参数:
keys_to_update (嵌套键序列, 可选) – 如果提供,则仅更新
key_to_update
中的键列表。non_blocking (bool, 可选) – 如果为
True
并且此副本位于不同的设备之间,则副本可能相对于主机异步发生。
- 返回值:
自身
示例
>>> td = TensorDict({ ... 'a': torch.zeros(3, 4, 5), ... 'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4]) >>> td.update_at_( ... TensorDict({ ... 'a': torch.ones(1, 4, 5), ... 'b': torch.ones(1, 4, 10)}, batch_size=[1, 4]), ... slice(1, 2)) TensorDict( fields={ a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32), b: Tensor(torch.Size([3, 4, 10]), dtype=torch.float32)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False) >>> assert (td[1] == 1).all()
- values(include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None) Iterator[tuple[str, CompatibleType]] ¶
返回一个表示 tensordict 值的生成器。
- var(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, correction: int = 1, reduce: bool | None = None) TensorDictBase | torch.Tensor ¶
返回输入 tensordict 中所有元素的方差值。
- 参数:
- 关键字参数:
- view(*shape: int, size: list | tuple | torch.Size | None = None, batch_size: torch.Size | None = None)¶
返回一个 TensorDict,其中包含根据新形状对张量进行视图,兼容 TensorDict 的 batch_size。
或者,可以提供一个 dtype 作为第一个未命名参数。在这种情况下,所有张量都将使用相应的 dtype 进行视图。请注意,这假设新形状将与提供的 dtype 兼容。有关 dtype 视图的更多信息,请参阅
view()
。- 参数:
*shape (int) – 生成的tensordict的新形状。
dtype (torch.dtype) – 或者,用于表示张量内容的 dtype。
size – 可迭代对象
- 关键字参数:
batch_size (torch.Size, 可选) – 如果提供了 dtype,则可以使用此关键字参数重置 batch-size。如果
view
使用形状调用,则此参数无效。- 返回值:
一个具有所需 batch_size 的新 TensorDict。
示例
>>> td = TensorDict(source={'a': torch.zeros(3,4,5), ... 'b': torch.zeros(3,4,10,1)}, batch_size=torch.Size([3, 4])) >>> td_view = td.view(12) >>> print(td_view.get("a").shape) # torch.Size([12, 5]) >>> print(td_view.get("b").shape) # torch.Size([12, 10, 1]) >>> td_view = td.view(-1, 4, 3) >>> print(td_view.get("a").shape) # torch.Size([1, 4, 3, 5]) >>> print(td_view.get("b").shape) # torch.Size([1, 4, 3, 10, 1])
- where(condition, other, *, out=None, pad=None)¶
返回一个
TensorDict
,其中包含从 self 或 other 中选择的元素,具体取决于 condition。- 参数:
condition (BoolTensor) – 当
True
(非零)时,生成self
,否则生成other
。other (TensorDictBase 或 Scalar) – 值(如果
other
是标量)或在 condition 为False
的索引处选择的元素。
- 关键字参数:
out (TensorDictBase, 可选) – 输出
TensorDictBase
实例。pad (标量, 可选) – 如果提供,则源或目标 TensorDict 中缺少的键将被写入为 torch.where(mask, self, pad) 或 torch.where(mask, pad, other)。默认为
None
,即不允许缺少键。
- zero_() T ¶
将 TensorDict 中的所有张量就地清零。