快捷方式

torch.masked

简介

动机

警告

masked 张量的 PyTorch API 仍处于原型阶段,未来可能会或可能不会更改。

MaskedTensor 作为 torch.Tensor 的扩展,为用户提供以下能力:

  • 使用任何 masked 语义(例如,可变长度张量、nan* 算子等)

  • 区分 0 和 NaN 梯度

  • 各种稀疏应用(见下面的教程)

“已指定”和“未指定”在 PyTorch 中有很长的历史,但没有正式的语义,当然也没有一致性;实际上,MaskedTensor 的诞生源于 vanilla torch.Tensor 类无法妥善解决的诸多问题的积累。因此,MaskedTensor 的主要目标是成为 PyTorch 中“已指定”和“未指定”值的真理来源,在 PyTorch 中,它们是头等公民,而不是事后才考虑的。反过来,这应该进一步释放 稀疏性 的潜力,实现更安全、更一致的算子,并为用户和开发者提供更流畅、更直观的体验。

什么是 MaskedTensor?

MaskedTensor 是张量子类,由 1) 输入(数据)和 2) 掩码组成。掩码告诉我们应该包含或忽略来自输入的哪些条目。

举例来说,假设我们想要掩盖所有等于 0 的值(以灰色表示)并取最大值

_images/tensor_comparison.jpg

顶部是 vanilla 张量示例,底部是 MaskedTensor,其中所有的 0 都被掩盖。这清楚地表明,我们是否有掩码会产生不同的结果,但这种灵活的结构允许用户系统地忽略他们在计算过程中想要忽略的任何元素。

我们已经编写了许多现有的教程来帮助用户入门,例如:

支持的算子

一元算子

一元算子是只包含单个输入的算子。将它们应用于 MaskedTensor 相对简单:如果在给定索引处的数据被掩盖,我们应用该算子,否则我们将继续掩盖数据。

可用的一元算子有

abs

计算 input 中每个元素的绝对值。

absolute

torch.abs() 的别名

acos

计算 input 中每个元素的反余弦。

arccos

torch.acos() 的别名。

acosh

返回一个新张量,其中包含 input 元素的反双曲余弦。

arccosh

torch.acosh() 的别名。

angle

计算给定 input 张量的逐元素角度(以弧度为单位)。

asin

返回一个新张量,其中包含 input 元素的反正弦。

arcsin

torch.asin() 的别名。

asinh

返回一个新张量,其中包含 input 元素的反双曲正弦。

arcsinh

torch.asinh() 的别名。

atan

返回一个新张量,其中包含 input 元素的反正切。

arctan

torch.atan() 的别名。

atanh

返回一个新张量,其中包含 input 元素的反双曲正切。

arctanh

torch.atanh() 的别名。

bitwise_not

计算给定输入张量的按位 NOT。

ceil

返回一个新张量,其中包含 input 元素的 ceil 值,即大于或等于每个元素的最小整数。

clamp

input 中的所有元素夹紧到 [ min, max ] 范围内。

clip

torch.clamp() 的别名。

conj_physical

计算给定 input 张量的逐元素共轭复数。

cos

返回一个新张量,其中包含 input 元素的余弦。

cosh

返回一个新张量,其中包含 input 元素的双曲余弦。

deg2rad

返回一个新张量,其中包含 input 的每个元素,这些元素从角度(度)转换为弧度。

digamma

torch.special.digamma() 的别名。

erf

torch.special.erf() 的别名。

erfc

torch.special.erfc() 的别名。

erfinv

torch.special.erfinv() 的别名。

exp

返回一个新张量,其中包含输入张量 input 元素的指数。

exp2

torch.special.exp2() 的别名。

expm1

torch.special.expm1() 的别名。

fix

torch.trunc() 的别名

floor

返回一个新张量,其中包含 input 元素的 floor 值,即小于或等于每个元素的最大整数。

frac

计算 input 中每个元素的小数部分。

lgamma

计算 input 上伽马函数绝对值的自然对数。

log

返回一个新张量,其中包含 input 元素的自然对数。

log10

返回一个新张量,其中包含 input 元素的以 10 为底的对数。

log1p

返回一个新张量,其中包含 (1 + input) 的自然对数。

log2

返回一个新张量,其中包含 input 元素的以 2 为底的对数。

logit

torch.special.logit() 的别名。

i0

torch.special.i0() 的别名。

isnan

返回一个新张量,其中包含布尔元素,表示 input 的每个元素是否为 NaN。

nan_to_num

NaN、正无穷大和负无穷大值替换为 input 中由 nanposinfneginf 指定的值。

neg

返回一个新张量,其中包含 input 元素的负数。

negative

torch.neg() 的别名

positive

返回 input

pow

使用 exponentinput 中的每个元素求幂,并返回包含结果的张量。

rad2deg

返回一个新张量,其中包含 input 的每个元素,这些元素从角度(弧度)转换为度。

reciprocal

返回一个新张量,其中包含 input 元素的倒数

round

input 的元素四舍五入到最接近的整数。

rsqrt

返回一个新张量,其中包含 input 的每个元素的平方根的倒数。

sigmoid

torch.special.expit() 的别名。

sign

返回一个新张量,其中包含 input 元素的符号。

sgn

此函数是 torch.sign() 到复数张量的扩展。

signbit

测试 input 的每个元素的符号位是否已设置。

sin

返回一个新张量,其中包含 input 元素的正弦。

sinc

torch.special.sinc() 的别名。

sinh

返回一个新张量,其中包含 input 元素的双曲正弦。

sqrt

返回一个新张量,其中包含 input 元素的平方根。

square

返回一个新张量,其中包含 input 元素的平方。

tan

返回一个新张量,其中包含 input 元素的正切。

tanh

返回一个新张量,其中包含 input 元素的双曲正切。

trunc

返回一个新张量,其中包含 input 元素的截断整数值。

可用的一元原位算子与上述所有算子相同,除了

angle

计算给定 input 张量的逐元素角度(以弧度为单位)。

positive

返回 input

signbit

测试 input 的每个元素的符号位是否已设置。

isnan

返回一个新张量,其中包含布尔元素,表示 input 的每个元素是否为 NaN。

二元算子

正如您在教程中可能看到的那样,MaskedTensor 也实现了二元运算,但需要注意的是,两个 MaskedTensor 中的掩码必须匹配,否则会引发错误。正如错误中指出的那样,如果您需要对特定算子的支持,或者对它们应该如何表现有建议的语义,请在 GitHub 上打开一个 issue。目前,我们已决定采用最保守的实现,以确保用户确切地知道发生了什么,并且有意识地做出关于 masked 语义的决策。

可用的二元算子有

add

将按 alpha 缩放的 other 添加到 input

atan2

逐元素计算 inputi/otheri\text{input}_{i} / \text{other}_{i} 的反正切,并考虑象限。

arctan2

torch.atan2() 的别名。

bitwise_and

计算 inputother 的按位 AND。

bitwise_or

计算 inputother 的按位 OR。

bitwise_xor

计算 inputother 的按位 XOR。

bitwise_left_shift

计算 input 左移 other 位的算术左移。

bitwise_right_shift

计算 input 右移 other 位的算术右移。

div

将输入 input 的每个元素除以 other 的相应元素。

divide

torch.div() 的别名。

floor_divide

fmod

逐元素应用 C++ 的 std::fmod

logaddexp

输入指数之和的对数。

logaddexp2

以 2 为底的输入指数之和的对数。

mul

input 乘以 other

multiply

torch.mul() 的别名。

nextafter

逐元素返回 input 之后朝 other 方向的下一个浮点值。

remainder

逐元素计算 Python 的模运算

sub

input 中减去按 alpha 缩放的 other

subtract

torch.sub() 的别名。

true_divide

torch.div() 的别名,其中 rounding_mode=None

eq

逐元素计算相等性

ne

逐元素计算 inputother\text{input} \neq \text{other}

le

计算 inputother\text{input} \leq \text{other} 逐元素比较。

ge

计算 inputother\text{input} \geq \text{other} 逐元素比较。

greater

别名:torch.gt()

greater_equal

别名:torch.ge()

gt

计算 input>other\text{input} > \text{other} 逐元素比较。

less_equal

别名:torch.le()

lt

计算 input<other\text{input} < \text{other} 逐元素比较。

less

别名:torch.lt()

maximum

计算 inputother 逐元素最大值。

minimum

计算 inputother 逐元素最小值。

fmax

计算 inputother 逐元素最大值。

fmin

计算 inputother 逐元素最小值。

not_equal

别名:torch.ne()

可用的就地二元运算符是以上所有运算符,除了

logaddexp

输入指数之和的对数。

logaddexp2

以 2 为底的输入指数之和的对数。

equal

如果两个张量具有相同的大小和元素,则为 True,否则为 False

fmin

计算 inputother 逐元素最小值。

minimum

计算 inputother 逐元素最小值。

fmax

计算 inputother 逐元素最大值。

归约

以下归约可用(具有自动微分支持)。 更多信息,概述 教程详细介绍了一些归约示例,而 高级语义 教程更深入地讨论了我们如何决定某些归约语义。

sum

返回 input 张量中所有元素的和。

mean

amin

返回给定维度 diminput 张量每个切片的最小值。

amax

返回给定维度 diminput 张量每个切片的最大值。

argmin

返回扁平化张量或沿维度最小值的索引

argmax

返回 input 张量中所有元素最大值的索引。

prod

返回 input 张量中所有元素的乘积。

all

测试 input 中的所有元素是否评估为 True

norm

返回给定张量的矩阵范数或向量范数。

var

计算由 dim 指定的维度上的方差。

std

计算由 dim 指定的维度上的标准差。

视图和选择函数

我们还包括了许多视图和选择函数; 直观地,这些运算符将应用于数据和掩码,然后将结果包装在 MaskedTensor 中。 例如,考虑 select()

>>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> data
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
>>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
>>> mt = masked_tensor(data, mask)
>>> data.select(0, 1)
tensor([4., 5., 6., 7.])
>>> mask.select(0, 1)
tensor([False,  True, False, False])
>>> mt.select(0, 1)
MaskedTensor(
  [      --,   5.0000,       --,       --]
)

目前支持以下操作

atleast_1d

返回每个零维输入张量的一维视图。

broadcast_tensors

根据 广播语义 广播给定的张量。

broadcast_to

input 广播到形状 shape

cat

在给定维度中连接 tensors 中给定的张量序列。

chunk

尝试将张量拆分为指定数量的块。

column_stack

通过水平堆叠 tensors 中的张量来创建新张量。

dsplit

根据 indices_or_sectionsinput (一个三维或更多维的张量) 深度方向拆分为多个张量。

flatten

通过将 input 重塑为一维张量来将其扁平化。

hsplit

根据 indices_or_sectionsinput (一个一维或更多维的张量) 水平方向拆分为多个张量。

hstack

按水平顺序 (列方向) 堆叠张量。

kron

计算 Kronecker 积,表示为 \otimes,即 inputother 的 Kronecker 积。

meshgrid

创建由 attr:tensors 中 1D 输入指定的坐标网格。

narrow

返回一个新的张量,它是 input 张量的缩小版本。

nn.functional.unfold

从批量输入张量中提取滑动局部块。

ravel

返回连续的扁平化张量。

select

沿给定索引处的选定维度切片 input 张量。

split

将张量拆分为块。

stack

沿新维度连接张量序列。

t

期望 input 为 <= 2 维张量,并转置维度 0 和 1。

transpose

返回作为 input 转置版本的张量。

vsplit

根据 indices_or_sectionsinput (一个二维或更多维的张量) 垂直方向拆分为多个张量。

vstack

按垂直顺序 (行方向) 堆叠张量。

Tensor.expand

返回 self 张量的新视图,其中单例维度扩展为更大的尺寸。

Tensor.expand_as

将此张量扩展为与 other 相同的大小。

Tensor.reshape

返回一个张量,该张量与 self 具有相同的数据和元素数量,但具有指定的形状。

Tensor.reshape_as

返回与 other 形状相同的此张量。

Tensor.unfold

返回原始张量的视图,其中包含维度 dimensionself 张量的所有大小为 size 的切片。

Tensor.view

返回与 self 张量数据相同但 shape 不同的新张量。

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源