• 教程 >
  • (原型) MaskedTensor 高级语义
快捷方式

(原型) MaskedTensor 高级语义

在学习本教程之前,请确保您已阅读我们的 MaskedTensor 概述教程 <https://pytorch.ac.cn/tutorials/prototype/maskedtensor_overview.html>

本教程的目的是帮助用户了解一些高级语义是如何工作的以及它们是如何产生的。我们将重点关注两个特定的语义

*. MaskedTensor 与 NumPy 的 MaskedArray 之间的差异 *. 归约语义

准备

import torch
from torch.masked import masked_tensor
import numpy as np
import warnings

# Disable prototype warnings and such
warnings.filterwarnings(action='ignore', category=UserWarning)

MaskedTensor 与 NumPy 的 MaskedArray

NumPy 的 MaskedArray 与 MaskedTensor 在一些基本语义方面存在差异。

*. 它们的工厂函数和基本定义反转了掩码(类似于 torch.nn.MHA);也就是说,MaskedTensor

使用 True 表示“指定”和 False 表示“未指定”,或“有效”/“无效”,而 NumPy 则相反。我们认为,我们的掩码定义不仅更直观,而且与 PyTorch 中整体的现有语义更一致。

*. 交集语义。在 NumPy 中,如果两个元素中的一个被屏蔽,则结果元素也会被

屏蔽 - 在实践中,它们应用逻辑或运算符.

data = torch.arange(5.)
mask = torch.tensor([True, True, False, True, False])
npm0 = np.ma.masked_array(data.numpy(), (~mask).numpy())
npm1 = np.ma.masked_array(data.numpy(), (mask).numpy())

print("npm0:\n", npm0)
print("npm1:\n", npm1)
print("npm0 + npm1:\n", npm0 + npm1)
npm0:
 [0.0 1.0 -- 3.0 --]
npm1:
 [-- -- 2.0 -- 4.0]
npm0 + npm1:
 [-- -- -- -- --]

与此同时,MaskedTensor 不支持带有不匹配掩码的加法或二元运算符 - 要了解原因,请查看 关于约简的章节.

mt0 = masked_tensor(data, mask)
mt1 = masked_tensor(data, ~mask)
print("mt0:\n", mt0)
print("mt1:\n", mt1)

try:
    mt0 + mt1
except ValueError as e:
    print ("mt0 + mt1 failed. Error: ", e)
mt0:
 MaskedTensor(
  [  0.0000,   1.0000,       --,   3.0000,       --]
)
mt1:
 MaskedTensor(
  [      --,       --,   2.0000,       --,   4.0000]
)
mt0 + mt1 failed. Error:  Input masks must match. If you need support for this, please open an issue on Github.

但是,如果需要这种行为,MaskedTensor 通过提供对数据和掩码的访问权限并方便地将 MaskedTensor 转换为使用 to_tensor() 填充掩码值的张量来支持这些语义。例如

t0 = mt0.to_tensor(0)
t1 = mt1.to_tensor(0)
mt2 = masked_tensor(t0 + t1, mt0.get_mask() & mt1.get_mask())

print("t0:\n", t0)
print("t1:\n", t1)
print("mt2 (t0 + t1):\n", mt2)
t0:
 tensor([0., 1., 0., 3., 0.])
t1:
 tensor([0., 0., 2., 0., 4.])
mt2 (t0 + t1):
 MaskedTensor(
  [      --,       --,       --,       --,       --]
)

请注意,掩码是 mt0.get_mask() & mt1.get_mask(),因为 MaskedTensor 的掩码是 NumPy 的掩码的逆。

约简语义

回想一下在 MaskedTensor 概述教程 中,我们讨论了“实现缺失的 torch.nan* 操作”。这些是约简的示例 - 从张量中删除一个(或多个)维度,然后聚合结果的运算符。在本节中,我们将使用约简语义来解释我们对上述匹配掩码的严格要求。

从根本上说,:class:`MaskedTensor` 执行相同的约简操作,同时忽略被屏蔽的(未指定)值。举个例子

data = torch.arange(12, dtype=torch.float).reshape(3, 4)
mask = torch.randint(2, (3, 4), dtype=torch.bool)
mt = masked_tensor(data, mask)

print("data:\n", data)
print("mask:\n", mask)
print("mt:\n", mt)
data:
 tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
mask:
 tensor([[False,  True, False, False],
        [False,  True, False, False],
        [False,  True, False, False]])
mt:
 MaskedTensor(
  [
    [      --,   1.0000,       --,       --],
    [      --,   5.0000,       --,       --],
    [      --,   9.0000,       --,       --]
  ]
)

现在,不同的约简(都在 dim=1 上)

print("torch.sum:\n", torch.sum(mt, 1))
print("torch.mean:\n", torch.mean(mt, 1))
print("torch.prod:\n", torch.prod(mt, 1))
print("torch.amin:\n", torch.amin(mt, 1))
print("torch.amax:\n", torch.amax(mt, 1))
torch.sum:
 MaskedTensor(
  [  1.0000,   5.0000,   9.0000]
)
torch.mean:
 MaskedTensor(
  [  1.0000,   5.0000,   9.0000]
)
torch.prod:
 MaskedTensor(
  [  1.0000,   5.0000,   9.0000]
)
torch.amin:
 MaskedTensor(
  [  1.0000,   5.0000,   9.0000]
)
torch.amax:
 MaskedTensor(
  [  1.0000,   5.0000,   9.0000]
)

需要注意的是,被屏蔽元素下的值不能保证具有任何特定的值,尤其是在行或列完全被屏蔽的情况下(归一化也是如此)。有关屏蔽语义的更多详细信息,您可以在此找到 RFC.

现在,我们可以重新审视这个问题:为什么我们要强制执行掩码对于二元运算符必须匹配的不变式?换句话说,为什么我们不使用与 np.ma.masked_array 相同的语义?考虑以下示例

data0 = torch.arange(10.).reshape(2, 5)
data1 = torch.arange(10.).reshape(2, 5) + 10
mask0 = torch.tensor([[True, True, False, False, False], [False, False, False, True, True]])
mask1 = torch.tensor([[False, False, False, True, True], [True, True, False, False, False]])
npm0 = np.ma.masked_array(data0.numpy(), (mask0).numpy())
npm1 = np.ma.masked_array(data1.numpy(), (mask1).numpy())

print("npm0:", npm0)
print("npm1:", npm1)
npm0: [[-- -- 2.0 3.0 4.0]
 [5.0 6.0 7.0 -- --]]
npm1: [[10.0 11.0 12.0 -- --]
 [-- -- 17.0 18.0 19.0]]

现在,让我们尝试加法

print("(npm0 + npm1).sum(0):\n", (npm0 + npm1).sum(0))
print("npm0.sum(0) + npm1.sum(0):\n", npm0.sum(0) + npm1.sum(0))
(npm0 + npm1).sum(0):
 [-- -- 38.0 -- --]
npm0.sum(0) + npm1.sum(0):
 [15.0 17.0 38.0 21.0 23.0]

和与加法显然应该是可结合的,但是使用 NumPy 的语义,它们不是,这对于用户来说肯定会令人困惑。

MaskedTensor 另一方面,由于 mask0 != mask1,将不允许此操作。也就是说,如果用户愿意,有一些方法可以解决这个问题(例如,使用 to_tensor() 将 MaskedTensor 的未定义元素填充为 0 值,如下所示),但用户现在必须更明确地表达他们的意图。

mt0 = masked_tensor(data0, ~mask0)
mt1 = masked_tensor(data1, ~mask1)

(mt0.to_tensor(0) + mt1.to_tensor(0)).sum(0)
tensor([15., 17., 38., 21., 23.])

结论

在本教程中,我们学习了 MaskedTensor 和 NumPy 的 MaskedArray 背后的不同设计决策,以及约简语义。总的来说,MaskedTensor 旨在避免歧义和令人困惑的语义(例如,我们尝试在二元运算中保留结合性),这反过来可能会要求用户有时更谨慎地编写代码,但我们认为这是更好的做法。如果您对此有任何想法,请 告诉我们

脚本总运行时间:(0 分钟 0.015 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获取您的问题的答案

查看资源