快捷方式

torch.nested

引言

警告

nested tensors 的 PyTorch API 处于原型阶段,近期可能会发生变化。

嵌套张量允许将不规则形状的数据包含在一个张量中并对其进行操作。这些数据在底层以高效的紧凑表示存储,同时暴露标准 PyTorch 张量接口用于应用操作。

嵌套张量的一个常见应用是表示各种领域中变长序列数据的批次,例如句子长度、图像大小以及音频/视频剪辑长度的变化。传统上,这类数据是通过将序列填充到批次中的最大长度来处理的,然后对填充后的形式执行计算,并随后进行掩码以去除填充。这种方法效率低下且容易出错,而嵌套张量的存在正是为了解决这些问题。

调用嵌套张量上操作的 API 与普通 torch.Tensor 的 API 没有区别,这使得与现有模型能够无缝集成,主要区别在于 输入的构造

由于这是一项原型特性,目前 支持的操作 集有限,但正在不断增加。我们欢迎提出问题、特性请求和贡献。有关贡献的更多信息可以在 此 Readme 中找到。

构造

注意

PyTorch 中存在两种形式的嵌套张量,通过构造时指定的布局进行区分。布局可以是 torch.stridedtorch.jagged 之一。我们建议尽可能利用 torch.jagged 布局。虽然它目前仅支持一个不规则维度,但它具有更好的操作覆盖范围,正在积极开发中,并且与 torch.compile 集成良好。本文档遵循此建议,为简洁起见,将使用 torch.jagged 布局的嵌套张量称为“NJT”。

构造很简单,只需将张量列表传递给 torch.nested.nested_tensor 构造函数。使用 torch.jagged 布局的嵌套张量(也称为“NJT”)支持一个不规则维度。此构造函数将根据下面 data_layout 部分中描述的布局,将输入张量复制到连续的紧凑内存块中。

>>> a, b = torch.arange(3), torch.arange(5) + 3
>>> a
tensor([0, 1, 2])
>>> b
tensor([3, 4, 5, 6, 7])
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> print([component for component in nt])
[tensor([0, 1, 2]), tensor([3, 4, 5, 6, 7])]

列表中的每个张量必须具有相同的维度数,但形状可以在单个维度上有所不同。如果输入组件的维度数不匹配,构造函数将引发错误。

>>> a = torch.randn(50, 128) # 2D tensor
>>> b = torch.randn(2, 50, 128) # 3D tensor
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
...
RuntimeError: When constructing a nested tensor, all tensors in list must have the same dim

在构造过程中,可以通过常用的关键字参数选择 dtype、device 以及是否需要梯度。

>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32, device="cuda", requires_grad=True)
>>> print([component for component in nt])
[tensor([0., 1., 2.], device='cuda:0',
       grad_fn=<UnbindBackwardAutogradNestedTensor0>), tensor([3., 4., 5., 6., 7.], device='cuda:0',
       grad_fn=<UnbindBackwardAutogradNestedTensor0>)]

torch.nested.as_nested_tensor 可用于保留传递给构造函数的张量的 autograd 历史记录。使用此构造函数时,梯度将流经嵌套张量返回到原始组件。请注意,此构造函数仍会将输入组件复制到连续的紧凑内存块中。

>>> a = torch.randn(12, 512, requires_grad=True)
>>> b = torch.randn(23, 512, requires_grad=True)
>>> nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.sum().backward()
>>> a.grad
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])
>>> b.grad
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])

以上函数都创建连续的 NJT,其中分配了一块内存来存储底层组件的紧凑形式(详见下面的 data_layout 部分)。

还可以对预先存在的带有填充的密集张量创建非连续的 NJT 视图,从而避免内存分配和复制。torch.nested.narrow() 是实现此目的的工具。

>>> padded = torch.randn(3, 5, 4)
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
>>> nt = torch.nested.narrow(padded, dim=1, start=0, length=seq_lens, layout=torch.jagged)
>>> nt.shape
torch.Size([3, j1, 4])
>>> nt.is_contiguous()
False

请注意,嵌套张量充当原始填充密集张量的视图,引用相同的内存而不进行复制/分配。对非连续 NJT 的操作支持相对有限,因此如果遇到支持空白,始终可以使用 contiguous() 转换为连续 NJT。

数据布局与形状

为了提高效率,嵌套张量通常将其张量组件打包到连续的内存块中,并维护附加元数据以指定批次项边界。对于 torch.jagged 布局,连续内存块存储在 values 组件中,而 offsets 组件则划定不规则维度的批次项边界。

_images/njt_visual.png

必要时可以直接访问底层的 NJT 组件。

>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(32, 128) # text 2
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.values().shape  # note the "packing" of the ragged dimension; no padding needed
torch.Size([82, 128])
>>> nt.offsets()
tensor([ 0, 50, 82])

直接从锯齿状 valuesoffsets 组件构造 NJT 也很有用;torch.nested.nested_tensor_from_jagged() 构造函数即可实现此目的。

>>> values = torch.randn(82, 128)
>>> offsets = torch.tensor([0, 50, 82], dtype=torch.int64)
>>> nt = torch.nested.nested_tensor_from_jagged(values=values, offsets=offsets)

NJT 有一个明确定义的形状,其维度比其组件多 1。不规则维度的底层结构由一个符号值表示(在下面的示例中为 j1)。

>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.dim()
3
>>> nt.shape
torch.Size([2, j1, 128])

NJT 必须具有相同的不规则结构才能彼此兼容。例如,要对两个 NJT 执行二元操作,它们的不规则结构必须匹配(即它们的形状必须具有相同的不规则形状符号)。具体来说,每个符号对应一个精确的 offsets 张量,因此两个 NJT 必须具有相同的 offsets 张量才能彼此兼容。

>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt2 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt1.offsets() is nt2.offsets()
False
>>> nt3 = nt1 + nt2
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)

在上面的示例中,即使两个 NJT 的概念形状相同,它们也没有共享对同一 offsets 张量的引用,因此它们的形状不同,且不兼容。我们认识到这种行为不直观,并且正在努力在嵌套张量的 beta 版本中放宽此限制。如需解决方法,请参阅本文档的 故障排除 部分。

除了 offsets 元数据外,NJT 还可以计算并缓存其组件的最小和最大序列长度,这对于调用特定内核(例如 SDPA)非常有用。目前没有公共 API 用于访问这些信息,但这将在 beta 版本中发生变化。

支持的操作

本节包含一份您可能会觉得有用的嵌套张量上的常见操作列表。它不是一份完整的列表,因为 PyTorch 中有大约数千个操作。虽然目前嵌套张量支持其中相当一部分操作,但完全支持是一项巨大的任务。嵌套张量的理想状态是完全支持非嵌套张量可用的所有 PyTorch 操作。为了帮助我们实现此目标,请考虑

  • 在此处 请求您用例所需的特定操作,以帮助我们确定优先级。

  • 贡献!为给定的 PyTorch 操作添加嵌套张量支持并不太难;详见下面的 贡献 部分。

查看嵌套张量组件

unbind() 允许您获取嵌套张量组件的视图。

>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt.unbind()
(tensor([[-0.9916, -0.3363, -0.2799],
        [-2.3520, -0.5896, -0.4374]]), tensor([[-2.0969, -1.0104,  1.4841],
        [ 2.0952,  0.2973,  0.2516],
        [ 0.9035,  1.3623,  0.2026]]))
>>> nt.unbind()[0] is not a
True
>>> nt.unbind()[0].mul_(3)
tensor([[ 3.6858, -3.7030, -4.4525],
        [-2.3481,  2.0236,  0.1975]])
>>> nt.unbind()
(tensor([[-2.9747, -1.0089, -0.8396],
        [-7.0561, -1.7688, -1.3122]]), tensor([[-2.0969, -1.0104,  1.4841],
        [ 2.0952,  0.2973,  0.2516],
        [ 0.9035,  1.3623,  0.2026]]))

请注意,nt.unbind()[0] 不是一个副本,而是底层内存的一个切片,它表示嵌套张量的第一个条目或组件。

与填充张量的转换

torch.nested.to_padded_tensor() 将 NJT 转换为指定填充值的填充密集张量。不规则维度将被填充到最大序列长度的大小。

>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(6, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> padded = torch.nested.to_padded_tensor(nt, padding=4.2)
>>> padded
tensor([[[ 1.6107,  0.5723,  0.3913],
         [ 0.0700, -0.4954,  1.8663],
         [ 4.2000,  4.2000,  4.2000],
         [ 4.2000,  4.2000,  4.2000],
         [ 4.2000,  4.2000,  4.2000],
         [ 4.2000,  4.2000,  4.2000]],
        [[-0.0479, -0.7610, -0.3484],
         [ 1.1345,  1.0556,  0.3634],
         [-1.7122, -0.5921,  0.0540],
         [-0.5506,  0.7608,  2.0606],
         [ 1.5658, -1.1934,  0.3041],
         [ 0.1483, -1.1284,  0.6957]]])

这可以作为一个应急方案来解决 NJT 支持空白,但理想情况下应尽可能避免此类转换,以实现最佳内存使用和性能,因为更高效的嵌套张量布局不会具体化填充。

反向转换可以使用 torch.nested.narrow() 完成,它将不规则结构应用于给定的密集张量以生成 NJT。请注意,默认情况下,此操作不复制底层数据,因此输出 NJT 通常是非连续的。如果需要连续的 NJT,此处显式调用 contiguous() 可能会很有用。

>>> padded = torch.randn(3, 5, 4)
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
>>> nt = torch.nested.narrow(padded, dim=1, length=seq_lens, layout=torch.jagged)
>>> nt.shape
torch.Size([3, j1, 4])
>>> nt = nt.contiguous()
>>> nt.shape
torch.Size([3, j2, 4])

形状操作

嵌套张量支持广泛的形状操作,包括视图。

>>> a = torch.randn(2, 6)
>>> b = torch.randn(4, 6)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt.shape
torch.Size([2, j1, 6])
>>> nt.unsqueeze(-1).shape
torch.Size([2, j1, 6, 1])
>>> nt.unflatten(-1, [2, 3]).shape
torch.Size([2, j1, 2, 3])
>>> torch.cat([nt, nt], dim=2).shape
torch.Size([2, j1, 12])
>>> torch.stack([nt, nt], dim=2).shape
torch.Size([2, j1, 2, 6])
>>> nt.transpose(-1, -2).shape
torch.Size([2, 6, j1])

注意力机制

由于变长序列是注意力机制的常见输入,嵌套张量支持重要的注意力运算符 Scaled Dot Product Attention (SDPA)FlexAttention。有关 NJT 与 SDPA 的用法示例,请参阅此处;有关 NJT 与 FlexAttention 的用法示例,请参阅此处

与 torch.compile 的用法

NJT 被设计用于与 torch.compile() 配合使用以实现最佳性能,并且我们始终建议在可能的情况下将 torch.compile() 与 NJT 一起使用。无论作为输入传递给已编译函数或模块,还是在函数内部内联实例化,NJT 都能即开即用且无图中断。

注意

如果您的用例无法使用 torch.compile(),使用 NJT 仍然可能受益于性能和内存使用,但这并不那么明确。重要的是,被操作的张量足够大,这样性能提升才不会被 python 张量子类的开销所抵消。

>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(4, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> def f(x): return x.sin() + 1
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output = compiled_f(nt)
>>> output.shape
torch.Size([2, j1, 3])
>>> def g(values, offsets): return torch.nested.nested_tensor_from_jagged(values, offsets) * 2.
...
>>> compiled_g = torch.compile(g, fullgraph=True)
>>> output2 = compiled_g(nt.values(), nt.offsets())
>>> output2.shape
torch.Size([2, j1, 3])

请注意,NJT 支持 动态形状,以避免因不规则结构变化而导致不必要的重新编译。

>>> a = torch.randn(2, 3)
>>> b = torch.randn(4, 3)
>>> c = torch.randn(5, 3)
>>> d = torch.randn(6, 3)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt2 = torch.nested.nested_tensor([c, d], layout=torch.jagged)
>>> def f(x): return x.sin() + 1
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output1 = compiled_f(nt1)
>>> output2 = compiled_f(nt2)  # NB: No recompile needed even though ragged structure differs

如果在将 NJT 与 torch.compile 结合使用时遇到问题或晦涩难懂的错误,请向 PyTorch 提交问题。在 torch.compile 中实现完全的子类支持是一项长期工作,目前可能存在一些不完善之处。

故障排除

本节包含在使用嵌套张量时可能遇到的一些常见错误,以及这些错误的原因和解决建议。

未实现的操作

随着嵌套张量操作支持的增长,此错误正变得越来越少见,但鉴于 PyTorch 中有数千个操作,目前仍有可能遇到此错误。

NotImplementedError: aten.view_as_real.default

这个错误很直观;我们还没有为这个特定操作添加操作支持。如果您愿意,可以自己贡献一个实现,或者只需请求我们在未来的 PyTorch 版本中添加对此操作的支持。

不规则结构不兼容

RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)

当调用对具有不兼容不规则结构的多个 NJT 进行操作的函数时,会发生此错误。目前,要求输入的 NJT 具有完全相同的 offsets 组件,以具有相同的符号化不规则结构符号(例如 j1)。

作为此情况的解决方法,可以直接从 valuesoffsets 组件构造 NJT。由于两个 NJT 都引用相同的 offsets 组件,它们被认为具有相同的不规则结构,因此是兼容的。

>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt2 = torch.nested.nested_tensor_from_jagged(values=torch.randn(82, 128), offsets=nt1.offsets())
>>> nt3 = nt1 + nt2
>>> nt3.shape
torch.Size([2, j1, 128])

torch.compile 中的数据依赖操作

torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True

当在 torch.compile 中调用执行数据依赖操作的函数时,会发生此错误;这通常发生在需要检查 NJT 的 offsets 值以确定输出形状的函数中。例如

>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> def f(nt): return nt.chunk(2, dim=0)[0]
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output = compiled_f(nt)

在此示例中,在 NJT 的批次维度上调用 chunk() 需要检查 NJT 的 offsets 数据,以划定紧凑不规则维度中的批次项边界。作为解决方法,可以设置几个 torch.compile 标志

>>> torch._dynamo.config.capture_dynamic_output_shape_ops = True
>>> torch._dynamo.config.capture_scalar_outputs = True

如果在设置这些标志后仍然看到数据依赖运算符错误,请向 PyTorch 提交问题。此 torch.compile() 区域仍在积极开发中,对 NJT 的某些支持可能尚不完善。

贡献

如果您想为嵌套张量的开发做出贡献,最有影响力的方式之一是为当前不支持的 PyTorch 操作添加嵌套张量支持。这个过程通常包括几个简单的步骤

  1. 确定要添加的操作名称;这应该类似于 aten.view_as_real.default。此操作的签名可以在 aten/src/ATen/native/native_functions.yaml 中找到。

  2. 按照其他操作在该文件中建立的模式,在 torch/nested/_internal/ops.py 中注册一个操作实现。使用 native_functions.yaml 中的签名进行 schema 验证。

实现操作最常见的方法是将 NJT 解包成其组件,在底层的 values 缓冲区上重新分派操作,并将相关的 NJT 元数据(包括 offsets)传播到新的输出 NJT。如果操作的输出预计与输入具有不同的形状,则必须计算新的 offsets 等元数据。

当操作应用于批次或不规则维度时,这些技巧可以帮助快速获得一个可用的实现

  • 对于 非批次 操作,基于 unbind() 的回退方法应该有效。

  • 对于不规则维度上的操作,考虑转换为带有适当选择的填充值的填充密集张量,该填充值不会对输出产生负面偏差,然后运行操作,再转换回 NJT。在 torch.compile 中,这些转换可以融合,以避免具体化填充中间结果。

构造和转换函数的详细文档

torch.nested.nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False)[源代码][源代码]

tensor_list(一个张量列表)构造一个没有 autograd 历史记录的嵌套张量(也称为“叶张量”,详见 Autograd 机制)。

参数
  • tensor_list (List[array_like]) – 张量列表,或任何可传递给 torch.tensor 的对象,

  • 其中列表中的每个元素具有相同的维度。

关键字参数
  • dtype (torch.dtype, 可选) – 返回的嵌套张量所需的类型。默认值:如果为 None,则与列表中最左侧张量具有相同的 torch.dtype

  • layout (torch.layout, 可选) – 返回的嵌套张量所需的布局。仅支持 strided 和 jagged 布局。默认值:如果为 None,则为 strided 布局。

  • device (torch.device, 可选) – 返回的嵌套张量所需的设备。默认值:如果为 None,则与列表中最左侧张量具有相同的 torch.device

  • requires_grad (bool, 可选) – 如果 autograd 应该记录对返回的嵌套张量的操作。默认值:False

  • pin_memory (bool, 可选) – 如果设置,返回的嵌套张量将分配在固定内存中。仅适用于 CPU 张量。默认值:False

返回类型

Tensor

示例

>>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
>>> nt = torch.nested.nested_tensor([a, b], requires_grad=True)
>>> nt.is_leaf
True
torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None, jagged_dim=None, min_seqlen=None, max_seqlen=None)[源代码][源代码]

从给定的锯齿状组件构造一个锯齿布局的嵌套张量。锯齿布局包含一个必需的 values 缓冲区,其锯齿维度被打包成一个维度。offsets / lengths 元数据确定此维度如何拆分为批次元素,并期望与 values 缓冲区分配在同一设备上。

预期的元数据格式
  • offsets:打包维度内的索引,将其分割成大小不均匀的批次元素。示例:[0, 2, 3, 6] 表示一个大小为 6 的打包锯齿状维度在概念上应该被分割成长度为 [2, 1, 3] 的批次元素。请注意,为了方便核函数(kernel),需要起始和结束偏移量(即形状为 batch_size + 1)。

  • lengths:各个批次元素的长度;形状 == batch_size。示例:[2, 1, 3] 表示一个大小为 6 的打包锯齿状维度在概念上应该被分割成长度为 [2, 1, 3] 的批次元素。

请注意,同时提供偏移量和长度可能很有用。这描述了一个带有“空洞”的嵌套张量,其中偏移量指示每个批次项的起始位置,长度指定元素的总数(见下例)。

返回的锯齿状布局嵌套张量将是输入值张量的一个视图(view)。

参数
  • values (torch.Tensor) – 形状为 (sum_B(*), D_1, …, D_N) 的底层缓冲区。锯齿状维度被打包成一个单一维度,使用偏移量/长度元数据来区分批次元素。

  • offsets (可选的 torch.Tensor) – 形状为 B + 1 的锯齿状维度内的偏移量。

  • lengths (可选的 torch.Tensor) – 形状为 B 的批次元素的长度。

  • jagged_dim (可选的 python:int) – 指示 values 中哪个维度是打包的锯齿状维度。如果为 None,则将其设置为 dim=1(即紧跟在批次维度之后的维度)。默认值:None

  • min_seqlen (可选的 python:int) – 如果设置,则使用指定的值作为返回的嵌套张量的缓存最小序列长度。这可以作为按需计算该值的有用替代方案,可能避免 GPU -> CPU 同步。默认值:None

  • max_seqlen (可选的 python:int) – 如果设置,则使用指定的值作为返回的嵌套张量的缓存最大序列长度。这可以作为按需计算该值的有用替代方案,可能避免 GPU -> CPU 同步。默认值:None

返回类型

Tensor

示例

>>> values = torch.randn(12, 5)
>>> offsets = torch.tensor([0, 3, 5, 6, 10, 12])
>>> nt = nested_tensor_from_jagged(values, offsets)
>>> # 3D shape with the middle dimension jagged
>>> nt.shape
torch.Size([5, j2, 5])
>>> # Length of each item in the batch:
>>> offsets.diff()
tensor([3, 2, 1, 4, 2])

>>> values = torch.randn(6, 5)
>>> offsets = torch.tensor([0, 2, 3, 6])
>>> lengths = torch.tensor([1, 1, 2])
>>> # NT with holes
>>> nt = nested_tensor_from_jagged(values, offsets, lengths)
>>> a, b, c = nt.unbind()
>>> # Batch item 1 consists of indices [0, 1)
>>> torch.equal(a, values[0:1, :])
True
>>> # Batch item 2 consists of indices [2, 3)
>>> torch.equal(b, values[2:3, :])
True
>>> # Batch item 3 consists of indices [3, 5)
>>> torch.equal(c, values[3:5, :])
True
torch.nested.as_nested_tensor(ts, dtype=None, device=None, layout=None)[source][source]

从张量或张量列表/元组构建一个保留 autograd 历史的嵌套张量。

如果传入的是嵌套张量,则除非 device / dtype / layout 不同,否则将直接返回。请注意,转换 device / dtype 会导致复制,而当前此函数不支持转换 layout。

如果传入的是非嵌套张量,则会将其视为一批大小一致的组成部分。如果传入的 device / dtype 与输入的 device / dtype 不同,或者输入是非连续的,则会发生复制。否则,将直接使用输入的存储。

如果提供的是张量列表,则在构建嵌套张量时,列表中的张量总是会被复制。

参数

ts (TensorList[Tensor] 或 Tuple[Tensor])– 一个要视为嵌套张量的张量,或者具有相同 ndim 的张量列表/元组。

关键字参数
  • dtype (torch.dtype, 可选) – 返回的嵌套张量所需的类型。默认值:如果为 None,则与列表中最左侧张量具有相同的 torch.dtype

  • device (torch.device, 可选) – 返回的嵌套张量所需的设备。默认值:如果为 None,则与列表中最左侧张量具有相同的 torch.device

  • layout (torch.layout, 可选) – 返回的嵌套张量所需的布局。仅支持 strided 和 jagged 布局。默认值:如果为 None,则为 strided 布局。

返回类型

Tensor

示例

>>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
>>> nt = torch.nested.as_nested_tensor([a, b])
>>> nt.is_leaf
False
>>> fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)])
>>> nt.backward(fake_grad)
>>> a.grad
tensor([1., 1., 1.])
>>> b.grad
tensor([0., 0., 0., 0., 0.])
>>> c = torch.randn(3, 5, requires_grad=True)
>>> nt2 = torch.nested.as_nested_tensor(c)
torch.nested.to_padded_tensor(input, padding, output_size=None, out=None) Tensor

通过填充 input 嵌套张量来返回一个新的(非嵌套)张量。前面的条目将填充嵌套数据,而后面的条目将填充(pad)。

警告

to_padded_tensor() 总是复制底层数据,因为嵌套张量和非嵌套张量在内存布局上不同。

参数

padding (float)– 用于填充后面条目的值。

关键字参数
  • output_size (Tuple[int])– 输出张量的大小。如果给定,它必须足够大以包含所有嵌套数据;否则,将通过取每个嵌套子张量在每个维度上的最大大小来推断。

  • out (Tensor, 可选的)– 输出张量。

示例

>>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))])
nested_tensor([
  tensor([[ 1.6862, -1.1282,  1.1031,  0.0464, -1.3276],
          [-1.9967, -1.0054,  1.8972,  0.9174, -1.4995]]),
  tensor([[-1.8546, -0.7194, -0.2918, -0.1846],
          [ 0.2773,  0.8793, -0.5183, -0.6447],
          [ 1.8009,  1.8468, -0.9832, -1.5272]])
])
>>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0)
tensor([[[ 1.6862, -1.1282,  1.1031,  0.0464, -1.3276],
         [-1.9967, -1.0054,  1.8972,  0.9174, -1.4995],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],
        [[-1.8546, -0.7194, -0.2918, -0.1846,  0.0000],
         [ 0.2773,  0.8793, -0.5183, -0.6447,  0.0000],
         [ 1.8009,  1.8468, -0.9832, -1.5272,  0.0000]]])
>>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6))
tensor([[[ 1.6862, -1.1282,  1.1031,  0.0464, -1.3276,  1.0000],
         [-1.9967, -1.0054,  1.8972,  0.9174, -1.4995,  1.0000],
         [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000],
         [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000]],
        [[-1.8546, -0.7194, -0.2918, -0.1846,  1.0000,  1.0000],
         [ 0.2773,  0.8793, -0.5183, -0.6447,  1.0000,  1.0000],
         [ 1.8009,  1.8468, -0.9832, -1.5272,  1.0000,  1.0000],
         [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000]]])
>>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2))
RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported.
torch.nested.masked_select(tensor, mask)[source][source]

给定一个跨步(strided)张量输入和一个跨步掩码,构建一个嵌套张量,结果的锯齿状布局嵌套张量将保留掩码为 True 的位置的值。掩码的维度得以保留并由偏移量表示,这与 masked_select() 不同,后者的输出被展平为 1D 张量。

参数: tensor (torch.Tensor):用于构建锯齿状布局嵌套张量的跨步张量。 mask (torch.Tensor):应用于输入张量的跨步掩码张量。

示例

>>> tensor = torch.randn(3, 3)
>>> mask = torch.tensor([[False, False, True], [True, False, True], [False, False, True]])
>>> nt = torch.nested.masked_select(tensor, mask)
>>> nt.shape
torch.Size([3, j4])
>>> # Length of each item in the batch:
>>> nt.offsets().diff()
tensor([1, 2, 1])

>>> tensor = torch.randn(6, 5)
>>> mask = torch.tensor([False])
>>> nt = torch.nested.masked_select(tensor, mask)
>>> nt.shape
torch.Size([6, j5])
>>> # Length of each item in the batch:
>>> nt.offsets().diff()
tensor([0, 0, 0, 0, 0, 0])
返回类型

Tensor

torch.nested.narrow(tensor, dim, start, length, layout=torch.strided)[source][source]

从跨步张量 tensor 构建一个嵌套张量(可能是视图)。这遵循与 torch.Tensor.narrow 类似的语义,其中在新嵌套张量的 dim 维度中仅显示区间 [start, start+length) 内的元素。由于嵌套表示允许在该维度的每一“行”中使用不同的 startlength,因此 startlength 也可以是形状为 tensor.shape[0] 的张量。

取决于用于嵌套张量的布局,存在一些差异。如果使用 strided 布局,torch.narrow 会将窄化后的数据复制到一个具有 strided 布局的连续 NT 中,而 jagged 布局的 narrow() 会创建原始跨步张量的一个非连续视图。这种特定的表示形式对于在 Transformer 模型中表示 kv-caches 非常有用,因为专门的 SDPA 核函数可以轻松处理这种格式,从而提高性能。

参数
  • tensor (torch.Tensor)– 一个跨步张量,如果使用锯齿状布局,它将用作嵌套张量的底层数据;如果使用跨步布局,则会复制。

  • dim (int)– 应用窄化的维度。锯齿状布局仅支持 dim=1,而 strided 支持所有 dim。

  • start (Union[int, torch.Tensor])– 窄化操作的起始元素。

  • length (Union[int, torch.Tensor])– 窄化操作期间获取的元素数量。

关键字参数

layout (torch.layout, 可选) – 返回的嵌套张量所需的布局。仅支持 strided 和 jagged 布局。默认值:如果为 None,则为 strided 布局。

返回类型

Tensor

示例

>>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64)
>>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64)
>>> narrow_base = torch.randn(5, 10, 20)
>>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged)
>>> nt_narrowed.is_contiguous()
False

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并解答您的问题

查看资源