快捷方式

torch.masked

简介

动机

警告

掩码张量的 PyTorch API 处于原型阶段,将来可能会发生变化。

MaskedTensor 是 torch.Tensor 的扩展,它为用户提供了以下功能:

  • 使用任何掩码语义(例如,可变长度张量、nan* 运算符等)

  • 区分 0 和 NaN 梯度

  • 各种稀疏应用(参见下面的教程)

“指定”和“未指定”在 PyTorch 中有着悠久的历史,但没有正式的语义,当然也没有一致性;事实上,MaskedTensor 的诞生源于 vanilla torch.Tensor 类无法正确解决的一系列问题。因此,MaskedTensor 的主要目标是成为 PyTorch 中“指定”和“未指定”值的真相来源,使它们成为一等公民,而不是事后诸葛亮。反过来,这将进一步释放 稀疏性 的潜力,使运算符更安全、更一致,并为用户和开发者提供更流畅、更直观的体验。

什么是 MaskedTensor?

MaskedTensor 是一个张量子类,它包含 1) 输入(数据)和 2) 掩码。掩码告诉我们应该包含或忽略输入中的哪些条目。

例如,假设我们想要屏蔽所有等于 0 的值(用灰色表示)并取最大值

_images/tensor_comparison.jpg

上面是 vanilla 张量示例,而下面是 MaskedTensor,其中所有 0 都被屏蔽了。这显然会产生不同的结果,具体取决于我们是否拥有掩码,但这种灵活的结构允许用户在计算过程中系统地忽略他们想要的任何元素。

我们已经编写了许多现有的教程来帮助用户入门,例如

支持的运算符

一元运算符

一元运算符是仅包含单个输入的运算符。将它们应用于 MaskedTensor 相对简单:如果数据在给定索引处被屏蔽,我们将应用运算符,否则我们将继续屏蔽数据。

可用的单目运算符是

abs

计算 input 中每个元素的绝对值。

absolute

torch.abs() 的别名

acos

计算 input 中每个元素的反余弦。

arccos

torch.acos() 的别名。

acosh

返回一个新张量,其中包含 input 元素的反双曲余弦。

arccosh

torch.acosh() 的别名。

angle

计算给定 input 张量的元素级角度(以弧度为单位)。

asin

返回一个新张量,其中包含 input 元素的反正弦。

arcsin

torch.asin() 的别名。

asinh

返回一个新张量,其中包含 input 元素的反双曲正弦。

arcsinh

torch.asinh() 的别名。

atan

返回一个新张量,其中包含 input 元素的反正切。

arctan

torch.atan() 的别名。

atanh

返回一个新张量,其中包含 input 元素的反双曲正切。

arctanh

torch.atanh() 的别名。

bitwise_not

计算给定输入张量的按位非。

ceil

返回一个新张量,其中包含 input 元素的上限,即大于或等于每个元素的最小整数。

clamp

input 中的所有元素钳制到范围 [ min, max ] 内。

clip

torch.clamp() 的别名。

conj_physical

计算给定 input 张量的元素级共轭。

cos

返回一个新张量,其中包含 input 元素的余弦。

cosh

返回一个新张量,其中包含 input 元素的双曲余弦。

deg2rad

返回一个新张量,其中 input 的每个元素都从度数转换为弧度。

digamma

torch.special.digamma() 的别名。

erf

torch.special.erf() 的别名。

erfc

torch.special.erfc() 的别名。

erfinv

torch.special.erfinv() 的别名。

exp

返回一个新张量,其中包含输入张量 input 元素的指数。

exp2

torch.special.exp2() 的别名。

expm1

torch.special.expm1() 的别名。

fix

torch.trunc() 的别名。

floor

返回一个新张量,其中包含 input 元素的向下取整值,即小于或等于每个元素的最大整数。

frac

计算 input 中每个元素的小数部分。

lgamma

计算 input 上伽马函数绝对值的自然对数。

log

返回一个新张量,其中包含 input 元素的自然对数。

log10

返回一个新张量,其中包含 input 元素以 10 为底的对数。

log1p

返回一个新张量,其中包含 (1 + input) 的自然对数。

log2

返回一个新张量,其中包含 input 元素以 2 为底的对数。

logit

torch.special.logit() 的别名。

i0

torch.special.i0() 的别名。

isnan

返回一个新张量,其中包含布尔元素,表示 input 的每个元素是否为 NaN。

nan_to_num

NaN、正无穷大和负无穷大值分别替换为 nanposinfneginf 指定的值。

neg

返回一个新张量,其中包含 input 元素的负值。

negative

torch.neg() 的别名。

positive

返回 input

pow

input 中每个元素的幂次方与 exponent 相乘,并返回一个包含结果的张量。

rad2deg

返回一个新张量,其中包含 input 的每个元素从弧度转换为度数。

reciprocal

返回一个新张量,其中包含 input 元素的倒数。

round

input 中的元素四舍五入到最接近的整数。

rsqrt

返回一个新的张量,其中包含 input 中每个元素的平方根的倒数。

sigmoid

torch.special.expit() 的别名。

sign

返回一个新的张量,其中包含 input 中元素的符号。

sgn

此函数是 torch.sign() 对复数张量的扩展。

signbit

测试 input 中每个元素的符号位是否被设置。

sin

返回一个新的张量,其中包含 input 中元素的正弦值。

sinc

torch.special.sinc() 的别名。

sinh

返回一个新的张量,其中包含 input 中元素的双曲正弦值。

sqrt

返回一个新的张量,其中包含 input 中元素的平方根。

square

返回一个新的张量,其中包含 input 中元素的平方。

tan

返回一个新的张量,其中包含 input 中元素的正切值。

tanh

返回一个新的张量,其中包含 input 中元素的双曲正切值。

trunc

返回一个新的张量,其中包含 input 中元素的截断整数。

可用的就地一元运算符是以上所有,**除了**

angle

计算给定 input 张量的元素级角度(以弧度为单位)。

positive

返回 input

signbit

测试 input 中每个元素的符号位是否被设置。

isnan

返回一个新张量,其中包含布尔元素,表示 input 的每个元素是否为 NaN。

二元运算符

正如您在教程中所见,MaskedTensor 也实现了二元运算,但前提是两个 MaskedTensor 中的掩码必须匹配,否则会引发错误。如错误中所述,如果您需要对特定运算符的支持,或者对它们的行为方式提出了语义建议,请在 GitHub 上提交问题。目前,我们决定采用最保守的实现方式,以确保用户准确了解正在发生的事情,并对使用掩码语义的决策保持谨慎。

可用的二元运算符有:

add

other(乘以 alpha)加到 input 上。

atan2

inputi/otheri\text{input}_{i} / \text{other}_{i} 进行逐元素反正切运算,并考虑象限。

arctan2

torch.atan2() 的别名。

bitwise_and

计算 inputother 的按位与运算。

bitwise_or

计算 inputother 的按位或运算。

按位异或

计算 inputother 的按位异或。

按位左移

计算 input 左移 other 位。

按位右移

计算 input 右移 other 位。

除法

将输入 input 的每个元素除以 other 的对应元素。

除法

torch.div() 的别名。

向下取整除法

取模

逐元素应用 C++ 的 std::fmod

对数加指数

输入的指数之和的对数。

对数加指数2

输入的指数之和的以 2 为底的对数。

乘法

input 乘以 other

乘法

torch.mul() 的别名。

下一个浮点数

逐元素返回 input 之后,朝向 other 的下一个浮点数。

求余

逐元素计算 Python 的取模运算

减法

input 中减去 other,并乘以 alpha

减法

torch.sub() 的别名。

真除法

torch.div() 的别名,其中 rounding_mode=None

相等

计算元素级相等

ne

计算 inputother\text{input} \neq \text{other} 元素级。

le

计算 inputother\text{input} \leq \text{other} 元素级。

ge

计算 inputother\text{input} \geq \text{other} 元素级。

greater

torch.gt() 的别名。

greater_equal

torch.ge() 的别名。

gt

计算 input>other\text{input} > \text{other} 元素级。

less_equal

torch.le() 的别名。

lt

按元素计算 input<other\text{input} < \text{other}

less

torch.lt() 的别名。

maximum

计算 inputother 的按元素最大值。

minimum

计算 inputother 的按元素最小值。

fmax

计算 inputother 的按元素最大值。

fmin

计算 inputother 的按元素最小值。

not_equal

torch.ne() 的别名。

可用的就地二元运算符是以上所有,**除了**

对数加指数

输入的指数之和的对数。

对数加指数2

输入的指数之和的以 2 为底的对数。

equal

如果两个张量大小和元素相同,则为 True,否则为 False

fmin

计算 inputother 的按元素最小值。

minimum

计算 inputother 的按元素最小值。

fmax

计算 inputother 的按元素最大值。

归约

以下归约可用(支持自动微分)。有关更多信息,概述 教程详细介绍了一些归约示例,而 高级语义 教程对我们如何决定某些归约语义进行了更深入的讨论。

sum

返回 input 张量中所有元素的总和。

mean

返回 input 张量中所有元素的平均值。

amin

返回 input 张量在给定维度(s) dim 中每个切片的最小值。

amax

返回给定维度 diminput 张量的每个切片的最大值。

argmin

返回扁平化张量或沿某个维度上的最小值索引。

argmax

返回 input 张量中所有元素的最大值的索引。

prod

返回 input 张量中所有元素的乘积。

all

测试 input 中的所有元素是否都计算为 True

norm

返回给定张量的矩阵范数或向量范数。

var

计算由 dim 指定的维度上的方差。

std

计算由 dim 指定的维度上的标准差。

视图和选择函数

我们还包含了一些视图和选择函数;直观地说,这些运算符将同时应用于数据和掩码,然后将结果包装在 MaskedTensor 中。举个简单的例子,考虑 select()

>>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> data
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
>>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
>>> mt = masked_tensor(data, mask)
>>> data.select(0, 1)
tensor([4., 5., 6., 7.])
>>> mask.select(0, 1)
tensor([False,  True, False, False])
>>> mt.select(0, 1)
MaskedTensor(
  [      --,   5.0000,       --,       --]
)

目前支持以下操作

atleast_1d

返回每个零维输入张量的 1 维视图。

broadcast_tensors

根据 广播语义 广播给定的张量。

broadcast_to

input 广播到形状 shape

cat

将给定的 seq 张量序列在给定维度上连接起来。

chunk

尝试将张量拆分为指定数量的块。

column_stack

通过水平堆叠 tensors 中的张量来创建一个新的张量。

dsplit

根据 indices_or_sections,将具有三个或更多维度的张量 input 深度拆分为多个张量。

flatten

通过将其重塑为一维张量来展平 input

hsplit

根据 indices_or_sections,将具有一个或多个维度的张量 input 水平拆分为多个张量。

hstack

按顺序水平(按列)堆叠张量。

kron

计算 inputother 的克罗内克积,用 \otimes 表示。

meshgrid

创建由 attr:tensors 中的 1D 输入指定的坐标网格。

narrow

返回一个新的张量,它是 input 张量的缩小版本。

ravel

返回一个连续的扁平张量。

select

在给定索引处沿选定维度对 input 张量进行切片。

split

将张量拆分为块。

t

期望 input 为 <= 2-D 张量,并转置维度 0 和 1。

transpose

返回一个张量,它是 input 的转置版本。

vsplit

根据 indices_or_sections,将具有两个或更多维度的张量 input 垂直拆分为多个张量。

vstack

将张量按垂直方向(按行)依次堆叠。

Tensor.expand

返回一个新的 self 张量的视图,其中单一维度扩展到更大的尺寸。

Tensor.expand_as

将此张量扩展到与 other 相同的尺寸。

Tensor.reshape

返回一个与 self 具有相同数据和元素数量的张量,但具有指定的形状。

Tensor.reshape_as

返回此张量,其形状与 other 相同。

Tensor.view

返回一个新的张量,其数据与 self 张量相同,但具有不同的 shape

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源