• 教程 >
  • (原型) MaskedTensor 概述
快捷方式

(原型) MaskedTensor 概述

本教程旨在作为使用 MaskedTensors 的起点,并讨论其掩码语义。

MaskedTensor 是对 torch.Tensor 的扩展,它为用户提供了以下功能:

  • 使用任何掩码语义(例如,可变长度张量、nan* 运算符等)

  • 区分 0 和 NaN 梯度

  • 各种稀疏应用(参见下面的教程)

有关 MaskedTensors 的更详细介绍,请参阅 torch.masked 文档

使用 MaskedTensor

在本节中,我们将讨论如何使用 MaskedTensor,包括如何构建、访问数据和掩码,以及索引和切片。

准备

我们将从为教程进行必要的设置开始

import torch
from torch.masked import masked_tensor, as_masked_tensor
import warnings

# Disable prototype warnings and such
warnings.filterwarnings(action='ignore', category=UserWarning)

构建

有几种不同的方法可以构建 MaskedTensor

  • 第一种方法是直接调用 MaskedTensor 类

  • 第二种方法(也是我们推荐的方法)是使用 masked.masked_tensor()masked.as_masked_tensor() 工厂函数,它们类似于 torch.tensor()torch.as_tensor()

在本教程中,我们将假定导入行:from torch.masked import masked_tensor

访问数据和掩码

可以通过以下方式访问 MaskedTensor 中的底层字段:

  • MaskedTensor.get_data() 函数

  • the MaskedTensor.get_mask() 函数。回顾一下,True 表示“指定”或“有效”,而 False 表示“未指定”或“无效”。

一般来说,返回的底层数据在未指定的条目中可能无效,因此我们建议用户在需要没有掩码条目的张量时使用 MaskedTensor.to_tensor()(如上所示)返回一个带有填充值的张量。

索引和切片

MaskedTensor 是张量子类,这意味着它继承了与 torch.Tensor 相同的索引和切片语义。以下是一些常见索引和切片模式的示例

data = torch.arange(24).reshape(2, 3, 4)
mask = data % 2 == 0

print("data:\n", data)
print("mask:\n", mask)
data:
 tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
mask:
 tensor([[[ True, False,  True, False],
         [ True, False,  True, False],
         [ True, False,  True, False]],

        [[ True, False,  True, False],
         [ True, False,  True, False],
         [ True, False,  True, False]]])
# float is used for cleaner visualization when being printed
mt = masked_tensor(data.float(), mask)

print("mt[0]:\n", mt[0])
print("mt[:, :, 2:4]:\n", mt[:, :, 2:4])
mt[0]:
 MaskedTensor(
  [
    [  0.0000,       --,   2.0000,       --],
    [  4.0000,       --,   6.0000,       --],
    [  8.0000,       --,  10.0000,       --]
  ]
)
mt[:, :, 2:4]:
 MaskedTensor(
  [
    [
      [  2.0000,       --],
      [  6.0000,       --],
      [ 10.0000,       --]
    ],
    [
      [ 14.0000,       --],
      [ 18.0000,       --],
      [ 22.0000,       --]
    ]
  ]
)

为什么 MaskedTensor 有用?

由于 MaskedTensor 将指定值和未指定值视为一等公民,而不是事后诸葛亮(带有填充值、NaN 等),它能够解决常规张量无法解决的几个缺点;事实上,MaskedTensor 在很大程度上正是由于这些反复出现的问题而诞生的。

下面,我们将讨论一些 PyTorch 当今仍然未解决的最常见问题,并说明 MaskedTensor 如何解决这些问题。

区分 0 和 NaN 梯度

torch.Tensor 遇到的一个问题是无法区分未定义的梯度(NaN)和实际为 0 的梯度。由于 PyTorch 没有方法标记一个值是指定/有效还是未指定/无效,它被迫依赖 NaN 或 0(取决于用例),从而导致语义不可靠,因为许多操作并不意味着可以正确处理 NaN 值。更令人困惑的是,有时取决于操作的顺序,梯度可能会发生变化(例如,取决于 NaN 值在操作链中出现的早晚)。

MaskedTensor 是解决这个问题的完美方案!

torch.where

Issue 10729 中,我们注意到在使用 torch.where() 时操作顺序可能会有影响,因为我们难以区分 0 是真正的 0 还是来自未定义梯度的 0。因此,我们保持一致,并掩盖结果

当前结果

x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
y.sum().backward()
x.grad
tensor([4.5400e-05, 6.7379e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00,        nan,        nan])

MaskedTensor 结果

x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
mask = x < 0
mx = masked_tensor(x, mask, requires_grad=True)
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
y = torch.where(mask, torch.exp(mx), my)
y.sum().backward()
mx.grad
MaskedTensor(
  [  0.0000,   0.0067,       --,       --,       --,       --,       --,       --,       --,       --,       --]
)

这里的梯度只提供给选定的子集。实际上,这将 where 的梯度更改为掩盖元素,而不是将它们设置为零。

另一个 torch.where

Issue 52248 是另一个例子。

当前结果

a = torch.randn((), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
torch.where(b, a/0, c):
 tensor(1., grad_fn=<WhereBackward0>)
torch.autograd.grad(torch.where(b, a/0, c), a):
 (tensor(nan),)

MaskedTensor 结果

a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
torch.where(b, a/0, c):
 MaskedTensor(  1.0000, True)
torch.autograd.grad(torch.where(b, a/0, c), a):
 (MaskedTensor(--, False),)

这个问题类似(甚至链接到下面的下一个问题),因为它表达了对意外行为的沮丧,因为无法区分“无梯度”与“零梯度”,这反过来使得使用其他操作难以推理。

使用掩码时,x/0 会产生 NaN 梯度

Issue 4132 中,用户建议 x.grad 应该是 [0, 1] 而不是 [nan, 1],而 MaskedTensor 通过完全掩盖梯度来使这一点非常清楚。

当前结果

x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x/div # => y is [inf, 1]
mask = (div != 0)  # => mask is [0, 1]
y[mask].backward()
x.grad
tensor([nan, 1.])

MaskedTensor 结果

x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x/div # => y is [inf, 1]
mask = (div != 0) # => mask is [0, 1]
loss = as_masked_tensor(y, mask)
loss.sum().backward()
x.grad
MaskedTensor(
  [      --,   1.0000]
)

torch.nansum()torch.nanmean()

Issue 67180 中,梯度没有被正确计算(一个长期存在的问题),而 MaskedTensor 正确地处理了它。

当前结果

a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
c = a * b
c1 = torch.nansum(c)
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
bgrad1
tensor(nan)

MaskedTensor 结果

a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
mt = masked_tensor(a, ~torch.isnan(a))
c = mt * b
c1 = torch.sum(c)
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
bgrad1
MaskedTensor(  3.0000, True)

安全 Softmax

安全 softmax 是 一个问题 的另一个很好的例子,这个问题经常出现。简而言之,如果有一个完整的批次被“掩盖”或完全由填充组成(在 softmax 的情况下,转换为设置为 -inf),那么这会导致 NaN,这会导致训练发散。

幸运的是,MaskedTensor 已经解决了这个问题。考虑这个设置

data = torch.randn(3, 3)
mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]])
x = data.masked_fill(~mask, float('-inf'))
mt = masked_tensor(data, mask)
print("x:\n", x)
print("mt:\n", mt)
x:
 tensor([[ 0.2345,    -inf,    -inf],
        [-0.1863,    -inf, -0.6380],
        [   -inf,    -inf,    -inf]])
mt:
 MaskedTensor(
  [
    [  0.2345,       --,       --],
    [ -0.1863,       --,  -0.6380],
    [      --,       --,       --]
  ]
)

例如,我们想要沿着 dim=0 计算 softmax。注意,第二列是“不安全的”(即完全被掩盖),所以当计算 softmax 时,结果将产生 0/0 = nan,因为 exp(-inf) = 0。但是,我们真正想要的是梯度被掩盖,因为它们是未指定的,对于训练来说是无效的。

PyTorch 结果

x.softmax(0)
tensor([[0.6037,    nan, 0.0000],
        [0.3963,    nan, 1.0000],
        [0.0000,    nan, 0.0000]])

MaskedTensor 结果

mt.softmax(0)
MaskedTensor(
  [
    [  0.6037,       --,       --],
    [  0.3963,       --,   1.0000],
    [      --,       --,       --]
  ]
)

实现缺失的 torch.nan* 运算符

Issue 61474 中,有一个请求添加额外的运算符以涵盖各种 torch.nan* 应用,例如 torch.nanmaxtorch.nanmin 等。

一般来说,这些问题更自然地适用于掩码语义,因此,我们建议使用 MaskedTensor 而不是引入额外的运算符。由于 nanmean 已经落地,我们可以使用它作为比较点

x = torch.arange(16).float()
y = x * x.fmod(4)
z = y.masked_fill(y == 0, float('nan'))  # we want to get the mean of y when ignoring the zeros
print("y:\n", y)
# z is just y with the zeros replaced with nan's
print("z:\n", z)
y:
 tensor([ 0.,  1.,  4.,  9.,  0.,  5., 12., 21.,  0.,  9., 20., 33.,  0., 13.,
        28., 45.])
z:
 tensor([nan,  1.,  4.,  9., nan,  5., 12., 21., nan,  9., 20., 33., nan, 13.,
        28., 45.])
print("y.mean():\n", y.mean())
print("z.nanmean():\n", z.nanmean())
# MaskedTensor successfully ignores the 0's
print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0)))
y.mean():
 tensor(12.5000)
z.nanmean():
 tensor(16.6667)
torch.mean(masked_tensor(y, y != 0)):
 MaskedTensor( 16.6667, True)

在上面的示例中,我们构造了一个 y,并且想要计算该序列的平均值,同时忽略零。可以使用 torch.nanmean 来实现这一点,但我们没有针对其他 torch.nan* 操作的实现。MaskedTensor 通过能够使用基础操作来解决这个问题,并且我们已经支持了问题中列出的其他操作。例如

torch.argmin(masked_tensor(y, y != 0))
MaskedTensor(  1.0000, True)

实际上,在忽略 0 的情况下,最小参数的索引是索引 1 中的 1。

MaskedTensor 还可以支持数据被完全掩盖时的归约,这等效于上面数据张量完全为 nan 的情况。nanmean 将返回 nan(一个模糊的返回值),而 MaskedTensor 将更准确地指示一个被掩盖的结果。

x = torch.empty(16).fill_(float('nan'))
print("x:\n", x)
print("torch.nanmean(x):\n", torch.nanmean(x))
print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x))))
x:
 tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
torch.nanmean(x):
 tensor(nan)
torch.nanmean via maskedtensor:
 MaskedTensor(--, False)

这是一个与安全 softmax 类似的问题,其中 0/0 = nan,而我们真正想要的是一个未定义的值。

结论

在本教程中,我们介绍了 MaskedTensor 是什么,演示了如何使用它们,并通过一系列示例和问题来证明它们的价值,这些问题表明它们有助于解决这些问题。

进一步阅读

若要继续学习,您可以找到我们的 MaskedTensor 稀疏性教程,以了解 MaskedTensor 如何实现稀疏性以及我们目前支持的不同存储格式。

脚本的总运行时间:(0 分钟 0.047 秒)

Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源