快捷方式

pad

class tensordict.pad(tensordict: T, pad_size: Sequence[int], value: _​float = 0.0)

使用常数值沿批处理维度填充 tensordict 中的所有张量,并返回一个新的 tensordict。

参数:
  • tensordict (TensorDict) – 要填充的 tensordict

  • pad_size (Sequence[int]) – 用于填充 tensordict 的部分批处理维度的填充大小,从第一个维度开始向前。批处理大小的 [len(pad_size) / 2] 个维度将被填充。例如,仅填充第一个维度时,pad 的形式为 (padding_left, padding_right)。填充两个维度时,形式为 (padding_left, padding_right, padding_top, padding_bottom),依此类推。pad_size 必须为偶数,且小于或等于批处理维度的两倍。

  • value (float, optional) – 用于填充的填充值,默认为 0.0

返回:

沿批处理维度填充后的新 TensorDict

示例

>>> from tensordict import TensorDict, pad
>>> import torch
>>> td = TensorDict({'a': torch.ones(3, 4, 1),
...     'b': torch.ones(3, 4, 1, 1)}, batch_size=[3, 4])
>>> dim0_left, dim0_right, dim1_left, dim1_right = [0, 1, 0, 2]
>>> padded_td = pad(td, [dim0_left, dim0_right, dim1_left, dim1_right], value=0.0)
>>> print(padded_td.batch_size)
torch.Size([4, 6])
>>> print(padded_td.get("a").shape)
torch.Size([4, 6, 1])
>>> print(padded_td.get("b").shape)
torch.Size([4, 6, 1, 1])

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源