快捷方式

torch.masked

介绍

动机

警告

masked tensors 的 PyTorch API 处于原型阶段,未来可能会发生变化。

MaskedTensor 是 torch.Tensor 的扩展,它提供了以下能力:

  • 使用任何 masked 语义(例如,变长张量、nan* 算子等)

  • 区分 0 梯度和 NaN 梯度

  • 各种稀疏应用(见下方教程)

“指定”(Specified)和“未指定”(unspecified)在 PyTorch 中有着悠久的历史,但缺乏正式的语义和一致性;事实上,MaskedTensor 的诞生是为了解决普通 torch.Tensor 类无法妥善处理的一系列问题。因此,MaskedTensor 的主要目标是成为 PyTorch 中这些“指定”和“未指定”值的唯一真理来源,让它们成为一等公民而非事后补丁。这反过来应该进一步释放稀疏性的潜力,实现更安全、更一致的算子,并为用户和开发者提供更流畅、更直观的体验。

什么是 MaskedTensor?

MaskedTensor 是一种张量子类,由 1) 输入(数据)和 2) 掩码(mask)组成。掩码告诉我们应包含或忽略输入中的哪些条目。

举个例子,假设我们想屏蔽掉所有等于 0 的值(用灰色表示)并取最大值

_images/tensor_comparison.jpg

上方是普通张量的例子,下方是 MaskedTensor 的例子,其中所有的 0 都被屏蔽掉了。这显然会产生不同的结果,取决于我们是否有掩码,但这种灵活的结构允许用户在计算过程中系统地忽略他们希望忽略的任何元素。

我们已经编写了一些现有教程来帮助用户入门,例如

支持的算子

一元算子

一元算子是仅包含一个输入的算子。将其应用于 MaskedTensors 相对简单:如果在给定索引处数据被屏蔽,我们会应用该算子;否则,我们将继续屏蔽数据。

可用的一元算子有

abs

计算 input 中每个元素的绝对值。

absolute

torch.abs() 的别名

acos

计算 input 中每个元素的反余弦。

arccos

torch.acos() 的别名。

acosh

返回一个新张量,其中包含 input 中元素的反双曲余弦。

arccosh

torch.acosh() 的别名。

angle

计算给定 input 张量的逐元素角度(以弧度为单位)。

asin

返回一个新张量,其中包含 input 中元素的反正弦。

arcsin

torch.asin() 的别名。

asinh

返回一个新张量,其中包含 input 中元素的反双曲正弦。

arcsinh

torch.asinh() 的别名。

atan

返回一个新张量,其中包含 input 中元素的反正切。

arctan

torch.atan() 的别名。

atanh

返回一个新张量,其中包含 input 中元素的反双曲正切。

arctanh

torch.atanh() 的别名。

bitwise_not

计算给定输入张量的按位非。

ceil

返回一个新张量,其中包含 input 中元素的向上取整结果,即大于或等于每个元素的最小整数。

clamp

input 中的所有元素限制在 [ min, max ] 范围内。

clip

torch.clamp() 的别名。

conj_physical

计算给定 input 张量的逐元素共轭。

cos

返回一个新张量,其中包含 input 中元素的余弦。

cosh

返回一个新张量,其中包含 input 中元素的双曲余弦。

deg2rad

返回一个新张量,其中包含 input 中每个元素从角度(度)转换为弧度的结果。

digamma

torch.special.digamma() 的别名。

erf

torch.special.erf() 的别名。

erfc

torch.special.erfc() 的别名。

erfinv

torch.special.erfinv() 的别名。

exp

返回一个新张量,其中包含输入张量 input 中元素的指数。

exp2

torch.special.exp2() 的别名。

expm1

torch.special.expm1() 的别名。

fix

torch.trunc() 的别名

floor

返回一个新张量,其中包含 input 中元素的向下取整结果,即小于或等于每个元素的最大整数。

frac

计算 input 中每个元素的小数部分。

lgamma

计算 input 上伽马函数绝对值的自然对数。

log

返回一个新张量,其中包含 input 中元素的自然对数。

log10

返回一个新张量,其中包含 input 中元素以 10 为底的对数。

log1p

返回一个新张量,其中包含 (1 + input) 的自然对数。

log2

返回一个新张量,其中包含 input 中元素以 2 为底的对数。

logit

torch.special.logit() 的别名。

i0

torch.special.i0() 的别名。

isnan

返回一个新张量,其中包含布尔元素,表示 input 的每个元素是否为 NaN。

nan_to_num

NaN、正无穷和负无穷值在 input 中分别替换为由 nanposinfneginf 指定的值。

neg

返回一个新张量,其中包含 input 中元素的负值。

negative

torch.neg() 的别名

positive

返回 input

pow

计算 input 中每个元素的 exponent 次幂,并返回包含结果的张量。

rad2deg

返回一个新张量,其中包含 input 中每个元素从角度(弧度)转换为度的结果。

reciprocal

返回一个新张量,其中包含 input 中元素的倒数

round

input 中的元素四舍五入到最接近的整数。

rsqrt

返回一个新张量,其中包含 input 中每个元素平方根的倒数。

sigmoid

torch.special.expit() 的别名。

sign

返回一个新张量,其中包含 input 中元素的符号。

sgn

此函数是 torch.sign() 对于复数张量的扩展。

signbit

测试 input 的每个元素是否设置了符号位。

sin

返回一个新张量,其中包含 input 中元素的正弦。

sinc

torch.special.sinc() 的别名。

sinh

返回一个新张量,其中包含 input 中元素的双曲正弦。

sqrt

返回一个新张量,其中包含 input 中元素的平方根。

square

返回一个新张量,其中包含 input 中元素的平方。

tan

返回一个新张量,其中包含 input 中元素的正切。

tanh

返回一个新张量,其中包含 input 中元素的双曲正切。

trunc

返回一个新张量,其中包含 input 中元素的截断整数值。

可用的就地(inplace)一元算子包括上述所有算子,**除了**

angle

计算给定 input 张量的逐元素角度(以弧度为单位)。

positive

返回 input

signbit

测试 input 的每个元素是否设置了符号位。

isnan

返回一个新张量,其中包含布尔元素,表示 input 的每个元素是否为 NaN。

二元算子

如您在教程中可能看到的,MaskedTensor 也实现了二元操作,但需要注意的是,两个 MaskedTensors 中的掩码必须匹配,否则会引发错误。正如错误信息中指出的,如果您需要支持某个特定的算子,或者对它们应该如何表现有提议的语义,请在 GitHub 上开启一个 issue。目前,我们决定采用最保守的实现方式,以确保用户清楚地了解正在发生的事情,并慎重地对待 masked 语义相关的决策。

可用的二元算子有

add

将按 alpha 缩放的 other 添加到 input

atan2

逐元素计算 inputi/otheri\text{input}_{i} / \text{other}_{i} 的反正切,并考虑象限。

arctan2

torch.atan2() 的别名。

bitwise_and

计算 inputother 的按位与。

bitwise_or

计算 inputother 的按位或。

bitwise_xor

计算 inputother 的按位异或。

bitwise_left_shift

计算 inputother 位进行的左算术移位。

bitwise_right_shift

计算 inputother 位进行的右算术移位。

div

将输入 input 的每个元素除以 other 的对应元素。

divide

torch.div() 的别名。

floor_divide

fmod

逐元素应用 C++ 的 std::fmod

logaddexp

输入指数之和的对数。

logaddexp2

输入指数之和以 2 为底的对数。

mul

input 乘以 other

multiply

torch.mul() 的别名。

nextafter

逐元素返回 input 朝向 other 方向的下一个浮点值。

remainder

逐元素计算 Python 的模运算

sub

input 中减去按 alpha 缩放的 other

subtract

torch.sub() 的别名。

true_divide

torch.div() 的别名,其中 rounding_mode=None

eq

计算逐元素相等性

ne

逐元素计算 inputother\text{input} \neq \text{other}

le

逐元素计算 inputother\text{input} \leq \text{other}

ge

逐元素计算 inputother\text{input} \geq \text{other}

greater

torch.gt() 的别名。

大于等于

torch.ge() 的别名。

gt

按元素计算 input>other\text{input} > \text{other}

小于等于

torch.le() 的别名。

lt

按元素计算 input<other\text{input} < \text{other}

小于

torch.lt() 的别名。

最大值

按元素计算 inputother 的最大值。

最小值

按元素计算 inputother 的最小值。

fmax

按元素计算 inputother 的最大值。

fmin

按元素计算 inputother 的最小值。

不等于

torch.ne() 的别名。

可用的原地二元运算符包含以上所有,**除了**

logaddexp

输入指数之和的对数。

logaddexp2

输入指数之和以 2 为底的对数。

等于

如果两个张量具有相同的大小和元素,则为 True,否则为 False

fmin

按元素计算 inputother 的最小值。

最小值

按元素计算 inputother 的最小值。

fmax

按元素计算 inputother 的最大值。

规约

以下规约可用(支持 autograd)。更多信息请参阅 概述 教程,其中详细介绍了一些规约示例;而 高级语义 教程则对某些规约语义的决定方式进行了深入探讨。

求和

返回 input 张量中所有元素的和。

均值

最小值

返回 input 张量在给定维度 dim 中每个切片的最小值。

最大值

返回 input 张量在给定维度 dim 中每个切片的最大值。

最小值索引

返回展平张量或沿某个维度的最小值索引

最大值索引

返回 input 张量中所有元素的最大值索引。

乘积

返回 input 张量中所有元素的乘积。

全部为真

测试 input 中的所有元素是否都评估为 True

范数

返回给定张量的矩阵范数或向量范数。

方差

计算由 dim 指定的维度上的方差。

标准差

计算由 dim 指定的维度上的标准差。

视图和选择函数

我们还包含了一些视图和选择函数;直观上,这些操作符将同时应用于数据和掩码,然后将结果包装在 MaskedTensor 中。举个简单的例子,考虑 select()

>>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> data
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
>>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
>>> mt = masked_tensor(data, mask)
>>> data.select(0, 1)
tensor([4., 5., 6., 7.])
>>> mask.select(0, 1)
tensor([False,  True, False, False])
>>> mt.select(0, 1)
MaskedTensor(
  [      --,   5.0000,       --,       --]
)

当前支持以下操作:

至少一维

返回每个零维输入张量的一维视图。

广播张量

根据 广播语义 广播给定张量。

广播到

input 广播到形状 shape

连接

在给定维度连接 tensors 中的给定张量序列。

分块

尝试将张量分割成指定数量的块。

列堆叠

通过水平堆叠 tensors 中的张量创建一个新张量。

深度分割

根据 indices_or_sectionsinput (一个三维或更多维度的张量) 深度分割成多个张量。

展平

通过将其重塑为一维张量来展平 input

水平分割

根据 indices_or_sectionsinput (一个一维或更多维度的张量) 水平分割成多个张量。

水平堆叠

按顺序水平堆叠张量(按列)。

Kronecker 积

计算 inputother 的 Kronecker 积,记为 \otimes

网格化

根据 attr:tensors 中的一维输入创建坐标网格。

窄化

返回一个新张量,它是 input 张量的窄化版本。

nn.functional.unfold

从批量输入张量中提取滑动局部块。

展平

返回一个连续的展平张量。

选择

在给定索引处沿选定维度对 input 张量进行切片。

分割

将张量分割成块。

堆叠

沿新维度连接张量序列。

转置

要求 input 是 <= 2维张量,并转置维度 0 和 1。

转置

返回 input 张量的转置版本。

垂直分割

根据 indices_or_sectionsinput (一个二维或更多维度的张量) 垂直分割成多个张量。

垂直堆叠

按顺序垂直堆叠张量(按行)。

Tensor.expand

返回 self 张量的一个新视图,其中单例维度被扩展到更大尺寸。

Tensor.expand_as

将此张量扩展到与 other 相同的大小。

Tensor.reshape

返回一个与 self 具有相同数据和元素数量但具有指定形状的张量。

Tensor.reshape_as

返回此张量,使其形状与 other 相同。

Tensor.unfold

返回原始张量的一个视图,该视图包含 self 张量中维度 dimension 上所有大小为 size 的切片。

Tensor.view

返回一个新张量,其数据与 self 张量相同但具有不同的 shape

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源