快捷方式

torch.masked

简介

动机

警告

掩码张量的 PyTorch API 处于原型阶段,将来可能会更改。

MaskedTensor 是 torch.Tensor 的扩展,它允许用户

  • 使用任何掩码语义(例如可变长度张量、nan* 运算符等)

  • 区分 0 和 NaN 梯度

  • 各种稀疏应用(见下面的教程)

“指定”和“未指定”在 PyTorch 中有着悠久的历史,但没有正式的语义,当然也不一致;事实上,MaskedTensor 是由于 vanilla torch.Tensor 类无法正确解决的一系列问题而诞生的。因此,MaskedTensor 的主要目标是成为 PyTorch 中上述“指定”和“未指定”值的真相来源,使它们成为一等公民,而不是事后才想到的。反过来,这将进一步释放 稀疏性 的潜力,使运算符更安全、更一致,并为用户和开发人员提供更流畅、更直观的体验。

什么是 MaskedTensor?

MaskedTensor 是一个张量子类,它包含 1) 输入(数据)和 2) 掩码。掩码告诉我们应该包含或忽略输入中的哪些条目。

例如,假设我们想要屏蔽所有等于 0 的值(用灰色表示)并取最大值

_images/tensor_comparison.jpg

顶部是 vanilla 张量示例,底部是 MaskedTensor,其中所有 0 都被屏蔽。这清楚地表明,根据我们是否拥有掩码,结果会有所不同,但这种灵活的结构允许用户在计算过程中系统地忽略他们想要的任何元素。

我们已经编写了许多现有的教程来帮助用户入门,例如

支持的运算符

一元运算符

一元运算符是只包含单个输入的运算符。将它们应用于 MaskedTensor 相对简单:如果数据在给定索引处被屏蔽,我们将应用运算符,否则我们将继续屏蔽数据。

可用的 一元运算符 是

abs

计算 input 中每个元素的绝对值。

absolute

torch.abs() 的别名

acos

计算 input 中每个元素的反余弦。

arccos

torch.acos() 的别名。

acosh

返回一个新张量,包含 input 元素的反双曲余弦。

arccosh

torch.acosh() 的别名。

angle

计算给定 input 张量的逐元素角度(以弧度表示)。

asin

返回一个新张量,包含 input 元素的反正弦。

arcsin

torch.asin() 的别名。

asinh

返回一个新张量,包含 input 元素的反双曲正弦。

arcsinh

torch.asinh() 的别名。

atan

返回一个新的张量,包含 input 元素的反正切值。

arctan

torch.atan() 的别名。

atanh

返回一个新的张量,包含 input 元素的反双曲正切值。

arctanh

torch.atanh() 的别名。

bitwise_not

计算给定输入张量的按位非操作。

ceil

返回一个新的张量,包含 input 元素的上舍入值,即大于或等于每个元素的最小整数。

clamp

input 中的所有元素夹紧到范围 [ min, max ]

clip

torch.clamp() 的别名。

conj_physical

计算给定 input 张量的逐元素共轭。

cos

返回一个新的张量,包含 input 元素的余弦值。

cosh

返回一个新的张量,包含 input 元素的双曲余弦值。

deg2rad

返回一个新的张量,将 input 中每个元素从度数转换为弧度。

digamma

torch.special.digamma() 的别名。

erf

torch.special.erf() 的别名。

erfc

torch.special.erfc() 的别名。

erfinv

torch.special.erfinv() 的别名。

exp

返回一个新的张量,包含输入张量 input 元素的指数值。

exp2

torch.special.exp2() 的别名。

expm1

torch.special.expm1() 的别名。

fix

torch.trunc() 的别名。

floor

返回一个新的张量,包含 input 元素的下舍入值,即小于或等于每个元素的最大整数。

frac

计算 input 中每个元素的小数部分。

lgamma

计算 input 上伽马函数绝对值的自然对数。

log

返回一个新的张量,包含 input 元素的自然对数。

log10

返回一个新的张量,包含 input 元素以 10 为底的对数。

log1p

返回一个新的张量,包含 (1 + input) 的自然对数。

log2

返回一个新的张量,包含 input 元素以 2 为底的对数。

logit

torch.special.logit() 的别名。

i0

torch.special.i0() 的别名。

isnan

返回一个新的张量,包含布尔元素,表示 input 中每个元素是否为 NaN。

nan_to_num

NaN、正无穷大和负无穷大值替换为 input 中由 nanposinfneginf 指定的值。

neg

返回一个新的张量,包含 input 元素的负值。

negative

torch.neg() 的别名。

positive

返回 input

pow

input 中每个元素的幂次方与 exponent 相乘,并返回一个包含结果的张量。

rad2deg

返回一个新的张量,将 input 中每个元素从弧度转换为度数。

reciprocal

返回一个新的张量,包含 input 元素的倒数。

round

input 元素四舍五入到最接近的整数。

rsqrt

返回一个新的张量,包含 input 中每个元素的平方根的倒数。

sigmoid

torch.special.expit() 的别名。

sign

返回一个新的张量,包含 input 元素的符号。

sgn

此函数是 torch.sign() 对复数张量的扩展。

signbit

测试 input 中每个元素的符号位是否被设置。

sin

返回一个新的张量,包含 input 元素的正弦值。

sinc

torch.special.sinc() 的别名。

sinh

返回一个新的张量,包含 input 元素的双曲正弦值。

sqrt

返回一个新的张量,包含 input 元素的平方根。

square

返回一个新的张量,包含 input 元素的平方值。

tan

返回一个新的张量,包含 input 元素的正切值。

tanh

返回一个新的张量,包含 input 元素的双曲正切值。

trunc

返回一个新的张量,包含 input 元素的截断整数部分。

可用的就地一元运算符是以上所有运算符,**除外**

angle

计算给定 input 张量的逐元素角度(以弧度表示)。

positive

返回 input

signbit

测试 input 中每个元素的符号位是否被设置。

isnan

返回一个新的张量,包含布尔元素,表示 input 中每个元素是否为 NaN。

二元运算符

正如您可能在教程中看到的那样,MaskedTensor 也实现了二元运算,但前提是两个 MaskedTensor 中的掩码必须匹配,否则会抛出错误。如错误中所述,如果您需要对特定运算符的支持,或者对它们应该如何行为提出了语义建议,请在 GitHub 上打开一个问题。目前,我们决定采用最保守的实现,以确保用户确切地知道发生了什么,并在使用掩码语义时做出明智的决定。

可用的二元运算符是

add

other(乘以 alpha)添加到 input

atan2

逐元素反正切值 inputi/otheri\text{input}_{i} / \text{other}_{i},考虑象限。

arctan2

torch.atan2() 的别名。

bitwise_and

计算 inputother 的按位与操作。

bitwise_or

计算 inputother 的按位或操作。

bitwise_xor

计算 inputother 的按位异或操作。

bitwise_left_shift

计算 input 左侧算术移位 other 位。

bitwise_right_shift

计算 input 右侧算术移位 other 位。

div

将输入 input 的每个元素除以 other 的对应元素。

divide

torch.div() 的别名。

floor_divide

fmod

按元素应用 C++ 的 std::fmod

logaddexp

输入指数和的对数。

logaddexp2

输入指数和以 2 为底的对数。

mul

input 乘以 other

multiply

torch.mul() 的别名。

nextafter

返回 input 之后,朝向 other 的下一个浮点数,按元素进行。

remainder

按元素计算 Python 的模运算

sub

input 中减去 other,并乘以 alpha

subtract

torch.sub() 的别名。

true_divide

torch.div() 的别名,其中 rounding_mode=None

eq

计算按元素的相等性

ne

计算 inputother\text{input} \neq \text{other} 按元素进行。

le

计算 inputother\text{input} \leq \text{other} 按元素进行。

ge

计算 inputother\text{input} \geq \text{other} 按元素进行。

greater

torch.gt() 的别名。

greater_equal

torch.ge() 的别名。

gt

计算 input>other\text{input} > \text{other} 按元素进行。

less_equal

torch.le() 的别名。

lt

计算 input<other\text{input} < \text{other} 按元素进行。

less

torch.lt() 的别名。

maximum

计算 inputother 的按元素最大值。

minimum

计算 inputother 的按元素最小值。

fmax

计算 inputother 的按元素最大值。

fmin

计算 inputother 的按元素最小值。

not_equal

torch.ne() 的别名。

可用的就地二元运算符是所有上述运算符,**除了**

logaddexp

输入指数和的对数。

logaddexp2

输入指数和以 2 为底的对数。

equal

如果两个张量大小和元素相同,则为 True,否则为 False

fmin

计算 inputother 的按元素最小值。

minimum

计算 inputother 的按元素最小值。

fmax

计算 inputother 的按元素最大值。

缩减

以下缩减操作可用(支持自动梯度)。有关更多信息,概述 教程详细介绍了一些缩减操作示例,而 高级语义 教程更深入地讨论了我们如何决定某些缩减语义。

sum

返回 input 张量中所有元素的总和。

mean

返回 input 张量中所有元素的平均值。

amin

返回给定维度 diminput 张量每个切片的最小值。

amax

返回给定维度 diminput 张量每个切片的最大值。

argmin

返回扁平化张量或沿某个维度最小值(s) 的索引

argmax

返回 input 张量中所有元素的最大值的索引。

prod

返回 input 张量中所有元素的乘积。

all

测试 input 中的所有元素是否都计算为 True

norm

返回给定张量的矩阵范数或向量范数。

var

计算由 dim 指定的维度上的方差。

std

计算由 dim 指定的维度上的标准差。

视图和选择函数

我们还包含了一些视图和选择函数;直观地,这些运算符将同时应用于数据和掩码,然后将结果包装在 MaskedTensor 中。举个简单的例子,考虑 select()

>>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> data
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
>>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
>>> mt = masked_tensor(data, mask)
>>> data.select(0, 1)
tensor([4., 5., 6., 7.])
>>> mask.select(0, 1)
tensor([False,  True, False, False])
>>> mt.select(0, 1)
MaskedTensor(
  [      --,   5.0000,       --,       --]
)

目前支持以下操作

atleast_1d

返回每个维度为零的输入张量的 1 维视图。

broadcast_tensors

根据 广播语义 广播给定的张量。

broadcast_to

input 广播到形状 shape

cat

将给定序列的 seq 张量在给定维度上连接起来。

chunk

尝试将张量分成指定数量的块。

column_stack

通过水平堆叠 tensors 中的张量来创建一个新的张量。

dsplit

根据 indices_or_sections 深度地将 input(具有三个或更多维度的张量)分成多个张量。

flatten

通过将其重塑为一维张量来展平 input

hsplit

根据 indices_or_sections 水平地将 input(具有一个或多个维度的张量)分成多个张量。

hstack

按顺序水平堆叠(按列)张量。

kron

计算 inputother 的克罗内克积,记为 \otimes

meshgrid

创建由 attr:tensors 中的 1D 输入指定的坐标网格。

narrow

返回一个新的张量,它是 input 张量的缩小版本。

ravel

返回一个连续的扁平化张量。

select

沿着给定索引处的选定维度切片 input 张量。

split

将张量分成块。

t

期望 input 是 <= 2 维的张量,并转置维度 0 和 1。

transpose

返回一个张量,它是 input 的转置版本。

vsplit

根据 indices_or_sections 将具有两个或更多维度的张量 input 垂直拆分为多个张量。

vstack

将张量按顺序垂直(按行)堆叠。

Tensor.expand

返回一个新的 self 张量视图,其中单一维度扩展到更大的尺寸。

Tensor.expand_as

将此张量扩展到与 other 相同的尺寸。

Tensor.reshape

返回一个与 self 具有相同数据和元素数量的张量,但具有指定的形状。

Tensor.reshape_as

将此张量返回为与 other 相同的形状。

Tensor.view

返回一个新的张量,与 self 张量具有相同的数据,但具有不同的 shape

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的答案

查看资源