注意
单击此处下载完整的示例代码
(原型)MaskedTensor 概述¶
创建于:2022 年 10 月 28 日 | 最后更新:2022 年 10 月 28 日 | 最后验证:未验证
本教程旨在作为使用 MaskedTensor 的起点,并讨论其掩码语义。
MaskedTensor 是 torch.Tensor
的扩展,它为用户提供了以下能力
使用任何掩码语义(例如,可变长度张量、nan* 运算符等)
区分 0 和 NaN 梯度
各种稀疏应用(请参阅下面的教程)
有关 MaskedTensor 是什么的更详细介绍,请参阅 torch.masked 文档。
使用 MaskedTensor¶
在本节中,我们将讨论如何使用 MaskedTensor,包括如何构造、访问数据和掩码,以及索引和切片。
准备工作¶
我们将首先进行本教程的必要设置
import torch
from torch.masked import masked_tensor, as_masked_tensor
import warnings
# Disable prototype warnings and such
warnings.filterwarnings(action='ignore', category=UserWarning)
构造¶
有几种不同的方法来构造 MaskedTensor
第一种方法是直接调用 MaskedTensor 类
第二种方法(也是我们推荐的方法)是使用
masked.masked_tensor()
和masked.as_masked_tensor()
工厂函数,它们类似于torch.tensor()
和torch.as_tensor()
在本教程中,我们将假设导入行:from torch.masked import masked_tensor。
访问数据和掩码¶
可以通过以下方式访问 MaskedTensor 中的底层字段
MaskedTensor.get_data()
函数MaskedTensor.get_mask()
函数。回想一下,True
表示“已指定”或“有效”,而False
表示“未指定”或“无效”。
一般来说,返回的底层数据在未指定的条目中可能无效,因此我们建议当用户需要没有掩码条目的张量时,他们应使用 MaskedTensor.to_tensor()
(如上所示)返回具有填充值的张量。
索引和切片¶
MaskedTensor
是 Tensor 子类,这意味着它继承了与 torch.Tensor
相同的索引和切片语义。以下是一些常见索引和切片模式的示例
data:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
mask:
tensor([[[ True, False, True, False],
[ True, False, True, False],
[ True, False, True, False]],
[[ True, False, True, False],
[ True, False, True, False],
[ True, False, True, False]]])
mt[0]:
MaskedTensor(
[
[ 0.0000, --, 2.0000, --],
[ 4.0000, --, 6.0000, --],
[ 8.0000, --, 10.0000, --]
]
)
mt[:, :, 2:4]:
MaskedTensor(
[
[
[ 2.0000, --],
[ 6.0000, --],
[ 10.0000, --]
],
[
[ 14.0000, --],
[ 18.0000, --],
[ 22.0000, --]
]
]
)
为什么 MaskedTensor 有用?¶
由于 MaskedTensor
将指定值和未指定值视为一等公民,而不是事后才考虑(使用填充值、nan 等),因此它能够解决常规张量无法解决的几个缺点;实际上,MaskedTensor
的诞生很大程度上是由于这些反复出现的问题。
下面,我们将讨论 PyTorch 今天仍然存在的一些最常见问题,并说明 MaskedTensor
如何解决这些问题。
区分 0 和 NaN 梯度¶
torch.Tensor
遇到的一个问题是无法区分未定义的梯度 (NaN) 与实际为 0 的梯度。由于 PyTorch 没有一种方法可以将值标记为已指定/有效与未指定/无效,因此它被迫依赖 NaN 或 0(取决于用例),从而导致不可靠的语义,因为许多操作并非旨在正确处理 NaN 值。更令人困惑的是,有时梯度可能会因操作顺序而异(例如,取决于 NaN 值在操作链中出现的时间)。
MaskedTensor
是解决此问题的完美方案!
torch.where¶
在 Issue 10729 中,我们注意到在使用 torch.where()
时,操作顺序可能会很重要,因为我们在区分 0 是真实的 0 还是来自未定义的梯度时遇到问题。因此,我们保持一致并屏蔽掉结果
当前结果
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
y.sum().backward()
x.grad
tensor([4.5400e-05, 6.7379e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, nan, nan])
MaskedTensor
结果
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
mask = x < 0
mx = masked_tensor(x, mask, requires_grad=True)
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
y = torch.where(mask, torch.exp(mx), my)
y.sum().backward()
mx.grad
MaskedTensor(
[ 0.0000, 0.0067, --, --, --, --, --, --, --, --, --]
)
此处的梯度仅提供给选定的子集。实际上,这会将 where 的梯度更改为屏蔽掉元素而不是将它们设置为零。
另一个 torch.where¶
Issue 52248 是另一个示例。
当前结果
a = torch.randn((), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
torch.where(b, a/0, c):
tensor(1., grad_fn=<WhereBackward0>)
torch.autograd.grad(torch.where(b, a/0, c), a):
(tensor(nan),)
MaskedTensor
结果
a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
torch.where(b, a/0, c):
MaskedTensor( 1.0000, True)
torch.autograd.grad(torch.where(b, a/0, c), a):
(MaskedTensor(--, False),)
此问题类似(甚至链接到下面的下一个问题),因为它表达了对意外行为的沮丧,原因是无法区分“无梯度”与“零梯度”,这反过来又使得与其他操作的协作难以理解。
当使用掩码时,x/0 会产生 NaN 梯度¶
在 Issue 4132 中,用户建议 x.grad 应该为 [0, 1] 而不是 [nan, 1],而 MaskedTensor
通过完全屏蔽掉梯度来非常清楚地说明这一点。
当前结果
tensor([nan, 1.])
MaskedTensor
结果
MaskedTensor(
[ --, 1.0000]
)
torch.nansum()
和 torch.nanmean()
¶
在 Issue 67180 中,梯度未正确计算(长期存在的问题),而 MaskedTensor
可以正确处理它。
当前结果
a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
c = a * b
c1 = torch.nansum(c)
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
bgrad1
tensor(nan)
MaskedTensor
结果
a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
mt = masked_tensor(a, ~torch.isnan(a))
c = mt * b
c1 = torch.sum(c)
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
bgrad1
MaskedTensor( 3.0000, True)
安全 Softmax¶
安全 softmax 是 一个问题 的另一个很好的例子,该问题经常出现。简而言之,如果整个批次被“屏蔽掉”或完全由填充组成(在 softmax 情况下,这意味着设置为 -inf),那么这将导致 NaN,这可能导致训练发散。
幸运的是,MaskedTensor
解决了这个问题。考虑以下设置
data = torch.randn(3, 3)
mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]])
x = data.masked_fill(~mask, float('-inf'))
mt = masked_tensor(data, mask)
print("x:\n", x)
print("mt:\n", mt)
x:
tensor([[ 0.2345, -inf, -inf],
[-0.1863, -inf, -0.6380],
[ -inf, -inf, -inf]])
mt:
MaskedTensor(
[
[ 0.2345, --, --],
[ -0.1863, --, -0.6380],
[ --, --, --]
]
)
例如,我们要沿 dim=0 计算 softmax。请注意,第二列是“不安全的”(即完全屏蔽掉),因此在计算 softmax 时,结果将产生 0/0 = nan,因为 exp(-inf) = 0。但是,我们真正希望的是屏蔽掉梯度,因为它们是未指定的,并且对于训练是无效的。
PyTorch 结果
x.softmax(0)
tensor([[0.6037, nan, 0.0000],
[0.3963, nan, 1.0000],
[0.0000, nan, 0.0000]])
MaskedTensor
结果
mt.softmax(0)
MaskedTensor(
[
[ 0.6037, --, --],
[ 0.3963, --, 1.0000],
[ --, --, --]
]
)
实现缺少的 torch.nan* 运算符¶
在 Issue 61474 中,有人请求添加额外的运算符来涵盖各种 torch.nan* 应用,例如 torch.nanmax
、torch.nanmin
等。
一般来说,这些问题更自然地适用于掩码语义,因此我们建议使用 MaskedTensor
而不是引入额外的运算符。由于 nanmean 已经落地,我们可以将其用作比较点
y:
tensor([ 0., 1., 4., 9., 0., 5., 12., 21., 0., 9., 20., 33., 0., 13.,
28., 45.])
z:
tensor([nan, 1., 4., 9., nan, 5., 12., 21., nan, 9., 20., 33., nan, 13.,
28., 45.])
print("y.mean():\n", y.mean())
print("z.nanmean():\n", z.nanmean())
# MaskedTensor successfully ignores the 0's
print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0)))
y.mean():
tensor(12.5000)
z.nanmean():
tensor(16.6667)
torch.mean(masked_tensor(y, y != 0)):
MaskedTensor( 16.6667, True)
在上面的示例中,我们构造了一个 y,并希望在忽略零的情况下计算序列的平均值。torch.nanmean 可用于执行此操作,但我们没有其余 torch.nan* 操作的实现。MaskedTensor
通过能够使用基本操作来解决此问题,并且我们已经支持问题中列出的其他操作。例如
torch.argmin(masked_tensor(y, y != 0))
MaskedTensor( 1.0000, True)
实际上,忽略 0 时最小参数的索引是索引 1 中的 1。
当数据完全被屏蔽掉时,MaskedTensor
也可以支持归约,这等效于数据张量完全为 nan
的情况。nanmean
将返回 nan
(一个模棱两可的返回值),而 MaskedTensor 将更准确地指示一个屏蔽掉的结果。
x = torch.empty(16).fill_(float('nan'))
print("x:\n", x)
print("torch.nanmean(x):\n", torch.nanmean(x))
print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x))))
x:
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
torch.nanmean(x):
tensor(nan)
torch.nanmean via maskedtensor:
MaskedTensor(--, False)
这与安全 softmax 问题类似,其中 0/0 = nan,而我们真正想要的是未定义的值。
结论¶
在本教程中,我们介绍了什么是 MaskedTensor,演示了如何使用它们,并通过一系列示例和它们帮助解决的问题激发了它们的价值。
进一步阅读¶
要继续学习更多内容,您可以找到我们的 MaskedTensor 稀疏性教程,了解 MaskedTensor 如何实现稀疏性以及我们目前支持的不同存储格式。
脚本的总运行时间: (0 分钟 0.049 秒)