快捷方式

torch

torch 包含用于多维张量的数据结构,并定义了对这些张量的数学运算。此外,它还提供了许多实用程序,用于高效地序列化张量和任意类型,以及其他有用的实用程序。

它有一个 CUDA 对应物,使您能够在计算能力 >= 3.0 的 NVIDIA GPU 上运行张量计算。

张量

is_tensor

如果 obj 是 PyTorch 张量,则返回 True。

is_storage

如果 obj 是 PyTorch 存储对象,则返回 True。

is_complex

如果 input 的数据类型是复数数据类型,即 torch.complex64torch.complex128 中的一种,则返回 True。

is_conj

如果 input 是共轭张量,即其共轭位设置为 True,则返回 True。

is_floating_point

如果 input 的数据类型是浮点数据类型,即 torch.float64torch.float32torch.float16torch.bfloat16 中的一种,则返回 True。

is_nonzero

如果 input 是一个单元素张量,并且在类型转换后不等于零,则返回 True。

set_default_dtype

将默认浮点数据类型设置为 d

get_default_dtype

获取当前默认浮点 torch.dtype

set_default_device

设置要在 device 上分配的默认 torch.Tensor

get_default_device

获取要在 device 上分配的默认 torch.Tensor

set_default_tensor_type

numel

返回 input 张量中的元素总数。

set_printoptions

设置打印选项。

set_flush_denormal

在 CPU 上禁用非规格化浮点数。

创建操作

注意

随机采样创建操作列在随机采样下,包括:torch.rand() torch.rand_like() torch.randn() torch.randn_like() torch.randint() torch.randint_like() torch.randperm() 您还可以将 torch.empty()原地随机采样方法一起使用,以创建从更广泛的分布中采样值的 torch.Tensor

tensor

通过复制 data 构造一个没有自动求导历史的张量(也称为“叶张量”,请参阅自动求导机制)。

sparse_coo_tensor

以指定的 indices 处的值构造一个COO(坐标)格式的稀疏张量

sparse_csr_tensor

使用给定的 crow_indicescol_indices 构造一个 CSR(压缩稀疏行)格式的稀疏张量

sparse_csc_tensor

使用给定的 ccol_indicesrow_indices 构造一个 CSC(压缩稀疏列)格式的稀疏张量

sparse_bsr_tensor

使用给定的 crow_indicescol_indices 构造一个 BSR(块压缩稀疏行)格式的稀疏张量,其中包含指定的二维块。

sparse_bsc_tensor

使用给定的 ccol_indicesrow_indices 构造一个 BSC(块压缩稀疏列)格式的稀疏张量,其中包含指定的二维块。

asarray

obj 转换为张量。

as_tensor

data 转换为张量,如果可能,共享数据并保留自动求导历史记录。

as_strided

使用指定的 sizestridestorage_offset 创建现有 torch.Tensor input 的视图。

from_file

创建一个 CPU 张量,其存储由内存映射文件支持。

from_numpy

numpy.ndarray 创建一个 Tensor

from_dlpack

将外部库中的张量转换为 torch.Tensor

frombuffer

从实现 Python 缓冲区协议的对象创建一个一维 Tensor

zeros

返回一个填充了标量值 0 的张量,其形状由可变参数 size 定义。

zeros_like

返回一个填充了标量值 0 的张量,其大小与 input 相同。

ones

返回一个填充了标量值 1 的张量,其形状由可变参数 size 定义。

ones_like

返回一个填充了标量值 1 的张量,其大小与 input 相同。

arange

返回一个大小为 endstartstep\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil 的一维张量,其值从区间 [start, end) 中以公差 stepstart 开始取值。

range

返回一个大小为 endstartstep+1\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1 的一维张量,其值从 startend 以步长 step 取值。

linspace

创建一个大小为 steps 的一维张量,其值在 startend 之间(包括端点)均匀分布。

logspace

创建一个大小为 steps 的一维张量,其值在以 base 为底的对数刻度上从 basestart{{\text{{base}}}}^{{\text{{start}}}}baseend{{\text{{base}}}}^{{\text{{end}}}} 之间(包括端点)均匀分布。

eye

返回一个二维张量,对角线上为 1,其他位置为 0。

empty

返回一个填充了未初始化数据的张量。

empty_like

返回一个与 input 大小相同的未初始化张量。

empty_strided

创建一个具有指定 sizestride 的张量,并填充未定义的数据。

full

创建一个大小为 size 并填充了 fill_value 的张量。

full_like

返回一个与 input 大小相同并填充了 fill_value 的张量。

quantize_per_tensor

将浮点张量转换为具有给定比例和零点的量化张量。

quantize_per_channel

将浮点张量转换为具有给定比例和零点的按通道量化张量。

dequantize

通过反量化量化张量返回一个 fp32 张量。

complex

构造一个复张量,其实部等于 real,虚部等于 imag

polar

构造一个复张量,其元素是与极坐标对应的笛卡尔坐标,绝对值为 abs,角度为 angle

heaviside

计算 input 中每个元素的 Heaviside 阶跃函数。

索引、切片、连接、修改操作

adjoint

返回张量的共轭视图,并交换最后两个维度。

argwhere

返回一个张量,其中包含 input 中所有非零元素的索引。

cat

在给定维度上连接给定的 seq 张量序列。

concat

torch.cat() 的别名。

连接(concatenate)

torch.cat() 的别名。

共轭(conj)

返回一个 input 的视图,其共轭位已翻转。

块(chunk)

尝试将张量拆分为指定数量的块。

深度分割(dsplit)

根据 indices_or_sections,将具有三个或更多维度的张量 input 在深度方向上拆分为多个张量。

列堆叠(column_stack)

通过水平堆叠 tensors 中的张量来创建一个新的张量。

深度堆叠(dstack)

按深度方向(沿第三个轴)依次堆叠张量。

收集(gather)

沿着 dim 指定的轴收集值。

水平分割(hsplit)

根据 indices_or_sections,将具有一维或多维的张量 input 水平拆分为多个张量。

水平堆叠(hstack)

按水平方向(列方向)依次堆叠张量。

索引加法(index_add)

函数描述请参见 index_add_()

索引复制(index_copy)

函数描述请参见 index_add_()

索引约简(index_reduce)

函数描述请参见 index_reduce_()

索引选择(index_select)

返回一个新的张量,该张量使用 LongTensor 类型的 index 中的条目,沿着维度 diminput 张量进行索引。

掩码选择(masked_select)

返回一个新的 1 维张量,该张量根据布尔掩码 maskBoolTensor 类型)对 input 张量进行索引。

移动维度(movedim)

inputsource 位置处的维度移动到 destination 位置处。

移动轴(moveaxis)

torch.movedim() 的别名。

缩小范围(narrow)

返回一个新的张量,它是 input 张量的缩小范围版本。

缩小范围复制(narrow_copy)

Tensor.narrow() 相同,但返回的是副本而不是共享存储。

非零元素(nonzero)

排列(permute)

返回原始张量 input 的一个视图,其维度已进行排列。

调整形状(reshape)

返回一个与 input 具有相同数据和元素数量的张量,但具有指定的形状。

行堆叠(row_stack)

torch.vstack() 的别名。

选择(select)

沿着选定维度,在给定索引处对 input 张量进行切片。

分散(scatter)

torch.Tensor.scatter_() 的非原地版本

对角线分散(diagonal_scatter)

src 张量的值嵌入到 input 的对角线元素中,相对于 dim1dim2

选择分散(select_scatter)

src 张量的值嵌入到 input 中的给定索引处。

切片分散(slice_scatter)

src 张量的值嵌入到 input 中的给定维度处。

分散加法(scatter_add)

torch.Tensor.scatter_add_() 的非原地版本

分散约简(scatter_reduce)

torch.Tensor.scatter_reduce_() 的非原地版本

分割(split)

将张量分割成多个块。

压缩(squeeze)

返回一个张量,其中 input 中所有指定维度的大小为 1 的维度都被移除。

堆叠(stack)

沿着新维度连接一系列张量。

交换轴(swapaxes)

torch.transpose() 的别名。

交换维度(swapdims)

torch.transpose() 的别名。

转置(t)

期望 input 是 <= 2 维张量,并转置维度 0 和 1。

提取(take)

返回一个新的张量,其中包含 input 在给定索引处的元素。

沿维度提取(take_along_dim)

沿着给定的 dim,从 input 中选择 indices 中的一维索引处的值。

张量分割(tensor_split)

沿着维度 dim,根据 indices_or_sections 指定的索引或段数,将张量分割为多个子张量,所有子张量都是 input 的视图。

平铺(tile)

通过重复 input 的元素来构造张量。

转置(transpose)

返回 input 的转置版本张量。

解除绑定(unbind)

移除一个张量维度。

展开索引(unravel_index)

将平面索引的张量转换为坐标张量的元组,该元组索引到指定形状的任意张量中。

取消压缩(unsqueeze)

返回一个新的张量,在指定位置插入一个大小为 1 的维度。

垂直分割(vsplit)

根据 indices_or_sections,将具有二维或多维的张量 input 垂直拆分为多个张量。

垂直堆叠(vstack)

按垂直方向(行方向)依次堆叠张量。

条件选择(where)

根据 condition,返回从 inputother 中选择的元素张量。

生成器

生成器(Generator)

创建并返回一个生成器对象,该对象管理生成伪随机数的算法的状态。

随机采样

设置种子(seed)

在所有设备上将生成随机数的种子设置为非确定性随机数。

手动设置种子(manual_seed)

在所有设备上设置生成随机数的种子。

初始种子(initial_seed)

以 Python long 类型返回生成随机数的初始种子。

获取随机数生成器状态(get_rng_state)

torch.ByteTensor 类型返回随机数生成器状态。

设置随机数生成器状态(set_rng_state)

设置随机数生成器状态。

torch.default_generator 返回默认的 CPU torch.Generator

伯努利分布(bernoulli)

从伯努利分布中抽取二进制随机数(0 或 1)。

多项式分布(multinomial)

返回一个张量,其中每行包含从多项式(更严格的定义是多变量,有关详细信息,请参阅 torch.distributions.multinomial.Multinomial)概率分布中采样的 num_samples 个索引,该分布位于张量 input 的对应行中。

正态分布(normal)

返回一个张量,其中包含从独立正态分布中抽取的随机数,这些正态分布的均值和标准差是给定的。

泊松分布(poisson)

返回一个与 input 大小相同的张量,其中每个元素都从泊松分布中采样,速率参数由 input 中的对应元素给出,即:

随机数(rand)

返回一个张量,其中填充了从区间 [0,1)[0, 1) 上的均匀分布生成的随机数。

类似随机数(rand_like)

返回一个与 input 大小相同的张量,其中填充了从区间 [0,1)[0, 1) 上的均匀分布生成的随机数。

随机整数(randint)

返回一个张量,其中填充了在 low(包含)和 high(不包含)之间均匀生成的随机整数。

类似随机整数(randint_like)

返回一个与张量 input 形状相同的张量,其中填充了在 low(包含)和 high(不包含)之间均匀生成的随机整数。

标准正态分布随机数(randn)

返回一个张量,该张量填充了来自均值为 0、方差为 1 的正态分布(也称为标准正态分布)的随机数。

randn_like

返回一个与 input 大小相同的张量,该张量填充了来自均值为 0、方差为 1 的正态分布的随机数。

randperm

返回一个整数的随机排列,范围从 0n - 1

原地随机采样

还有一些在张量上定义的原地随机采样函数。点击可参考其文档

准随机采样

quasirandom.SobolEngine

torch.quasirandom.SobolEngine 是用于生成(加扰)Sobol 序列的引擎。

序列化

save

将对象保存到磁盘文件。

load

从文件加载使用 torch.save() 保存的对象。

并行

get_num_threads

返回用于 CPU 操作并行化的线程数

set_num_threads

设置用于 CPU 上操作内并行化的线程数。

get_num_interop_threads

返回用于 CPU 上操作间并行化的线程数(例如,

set_num_interop_threads

设置用于操作间并行化的线程数(例如,

局部禁用梯度计算

上下文管理器 torch.no_grad()torch.enable_grad()torch.set_grad_enabled() 有助于局部禁用和启用梯度计算。有关其用法的更多详细信息,请参阅局部禁用梯度计算。这些上下文管理器是线程本地的,因此如果您使用 threading 模块等将工作发送到另一个线程,它们将不起作用。

示例

>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False

>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...     y = x * 2
>>> y.requires_grad
False

>>> torch.set_grad_enabled(True)  # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True

>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

no_grad

用于禁用梯度计算的上下文管理器。

enable_grad

用于启用梯度计算的上下文管理器。

autograd.grad_mode.set_grad_enabled

用于打开或关闭梯度计算的上下文管理器。

is_grad_enabled

如果当前启用了梯度模式,则返回 True。

autograd.grad_mode.inference_mode

用于启用或禁用推理模式的上下文管理器。

is_inference_mode_enabled

如果当前启用了推理模式,则返回 True。

数学运算

逐点运算

abs

计算 input 中每个元素的绝对值。

absolute

torch.abs() 的别名

acos

计算 input 中每个元素的反余弦。

arccos

torch.acos() 的别名。

acosh

返回一个新张量,其中包含 input 元素的反双曲余弦。

arccosh

torch.acosh() 的别名。

add

other 乘以 alpha 后加到 input 上。

addcdiv

执行 tensor1tensor2 的逐元素除法,将结果乘以标量 value,然后将其加到 input

addcmul

执行 tensor1tensor2 的逐元素乘法,将结果乘以标量 value,然后将其加到 input

angle

计算给定 input 张量的逐元素角度(以弧度为单位)。

asin

返回一个新张量,其中包含 input 元素的反正弦。

arcsin

torch.asin() 的别名。

asinh

返回一个新张量,其中包含 input 元素的反双曲正弦。

arcsinh

torch.asinh() 的别名。

atan

返回一个新张量,其中包含 input 元素的反正切。

arctan

torch.atan() 的别名。

atanh

返回一个新张量,其中包含 input 元素的反双曲正切。

arctanh

torch.atanh() 的别名。

atan2

inputi/otheri\text{input}_{i} / \text{other}_{i} 的逐元素反正切,同时考虑象限。

arctan2

torch.atan2() 的别名。

bitwise_not

计算给定输入张量的按位非。

bitwise_and

计算 inputother 的按位与。

bitwise_or

计算 inputother 的按位或。

bitwise_xor

计算 inputother 的按位异或。

bitwise_left_shift

计算 input 左移 other 位的算术左移。

bitwise_right_shift

计算 input 右移 other 位的算术右移。

ceil

返回一个新张量,其中包含 input 元素的上限,即大于或等于每个元素的最小整数。

clamp

input 中的所有元素钳制到范围 [ min, max ] 内。

clip

torch.clamp() 的别名。

conj_physical

计算给定 input 张量的逐元素共轭。

copysign

创建一个新的浮点张量,其幅度为 input,符号为 other,逐元素进行。

cos

返回一个新张量,其中包含 input 元素的余弦值。

cosh

返回一个新张量,其中包含 input 元素的双曲余弦值。

deg2rad

返回一个新张量,其中 input 的每个元素都从角度转换为弧度。

div

将输入 input 的每个元素除以 other 的对应元素。

divide

torch.div() 的别名。

digamma

torch.special.digamma() 的别名。

erf

torch.special.erf() 的别名。

erfc

torch.special.erfc() 的别名。

erfinv

torch.special.erfinv() 的别名。

exp

返回一个新张量,其中包含输入张量 input 元素的指数。

exp2

torch.special.exp2() 的别名。

expm1

torch.special.expm1() 的别名。

fake_quantize_per_channel_affine

返回一个新张量,其中 input 中的数据使用 scalezero_pointquant_minquant_max 沿 axis 指定的通道进行每通道伪量化。

fake_quantize_per_tensor_affine

返回一个新张量,其中 input 中的数据使用 scalezero_pointquant_minquant_max 进行伪量化。

fix

torch.trunc() 的别名

float_power

input 逐元素地提高到 exponent 的幂,使用双精度。

floor

返回一个新张量,其中包含 input 元素的向下取整值,即小于或等于每个元素的最大整数。

floor_divide

fmod

逐元素地应用 C++ 的 std::fmod

frac

计算 input 中每个元素的小数部分。

frexp

input 分解为尾数和指数张量,使得 input=mantissa×2exponent\text{input} = \text{mantissa} \times 2^{\text{exponent}}.

gradient

使用 二阶精确中心差分法 估计函数 g:RnRg : \mathbb{R}^n \rightarrow \mathbb{R} 在一维或多维上的梯度,并在边界处使用一阶或二阶估计。

imag

返回一个包含 self 张量虚部的新张量。

ldexp

input 乘以 2 ** other

lerp

根据标量或张量 weight 对两个张量 start(由 input 给出)和 end 进行线性插值,并返回结果 out 张量。

lgamma

计算 input 上伽马函数绝对值的自然对数。

log

返回一个新张量,其中包含 input 元素的自然对数。

log10

返回一个新张量,其中包含 input 元素的以 10 为底的对数。

log1p

返回一个新张量,其中包含 (1 + input) 的自然对数。

log2

返回一个新张量,其中包含 input 元素的以 2 为底的对数。

logaddexp

输入的指数之和的对数。

logaddexp2

输入的以 2 为底的指数之和的对数。

logical_and

计算给定输入张量的逐元素逻辑与。

logical_not

计算给定输入张量的逐元素逻辑非。

logical_or

计算给定输入张量的逐元素逻辑或。

logical_xor

计算给定输入张量的逐元素逻辑异或。

logit

torch.special.logit() 的别名。

hypot

给定直角三角形的两条直角边,返回其斜边。

i0

torch.special.i0() 的别名。

igamma

torch.special.gammainc() 的别名。

igammac

torch.special.gammaincc() 的别名。

mul

input 乘以 other

multiply

torch.mul() 的别名。

mvlgamma

torch.special.multigammaln() 的别名。

nan_to_num

input 中的 NaN、正无穷和负无穷值分别替换为 nanposinfneginf 指定的值。

neg

返回一个新张量,其元素是 input 中元素的负数。

negative

torch.neg() 的别名。

nextafter

返回 input 在元素级上朝向 other 的下一个浮点值。

polygamma

torch.special.polygamma() 的别名。

positive

返回 input

pow

使用 exponentinput 中的每个元素求幂,并返回一个包含结果的张量。

quantized_batch_norm

对 4D (NCHW) 量化张量应用批归一化。

quantized_max_pool1d

对由多个输入平面组成的输入量化张量应用一维最大池化。

quantized_max_pool2d

对由多个输入平面组成的输入量化张量应用二维最大池化。

rad2deg

返回一个新张量,其中 input 中的每个元素都从弧度转换为度数。

real

返回一个包含 self 张量实值的新张量。

reciprocal

返回一个新张量,其中包含 input 中元素的倒数。

remainder

逐元素计算 Python 的模运算

round

input 的元素舍入到最接近的整数。

rsqrt

返回一个新张量,其中包含 input 中每个元素平方根的倒数。

sigmoid

torch.special.expit() 的别名。

sign

返回一个包含 input 中元素符号的新张量。

sgn

此函数是 torch.sign() 对复数张量的扩展。

signbit

测试 input 中的每个元素是否设置了符号位。

sin

返回一个包含 input 中元素正弦值的新张量。

sinc

torch.special.sinc() 的别名。

sinh

返回一个包含 input 中元素双曲正弦值的新张量。

softmax

torch.nn.functional.softmax() 的别名。

sqrt

返回一个包含 input 中元素平方根的新张量。

square

返回一个包含 input 中元素平方的张量。

sub

input 中减去 other 乘以 alpha 的结果。

subtract

torch.sub() 的别名。

tan

返回一个包含 input 中元素正切值的新张量。

tanh

返回一个包含 input 中元素双曲正切值的新张量。

true_divide

torch.div()rounding_mode=None 时的别名。

trunc

返回一个新张量,其中包含 input 中元素的截断整数值。

xlogy

torch.special.xlogy() 的别名。

归约操作

argmax

返回 input 张量中所有元素最大值的索引。

argmin

返回展平张量或沿某个维度最小值的索引。

amax

返回 input 张量在给定维度 dim 中每个切片的最大值。

amin

返回 input 张量在给定维度 dim 中每个切片的最小值。

aminmax

计算 input 张量的最小值和最大值。

all

测试 input 中的所有元素是否都为 True

any

测试 input 中是否有任何元素为 True

max

返回 input 张量中所有元素的最大值。

min

返回 input 张量中所有元素的最小值。

dist

返回 (input - other) 的 p 范数。

logsumexp

返回 input 张量在给定维度 dim 中每一行的指数和的对数。

mean

返回 input 张量中所有元素的平均值。

nanmean

计算沿指定维度所有 非 NaN 元素的平均值。

median

返回 input 中值的中位数。

nanmedian

返回 input 中值的中位数,忽略 NaN 值。

mode

返回一个命名元组 (values, indices),其中 valuesinput 张量在给定维度 dim 中每一行的众数值,即该行中出现次数最多的值,indices 是找到的每个众数值的索引位置。

norm

返回给定张量的矩阵范数或向量范数。

nansum

返回所有元素的总和,将非数字 (NaN) 视为零。

prod

返回 input 张量中所有元素的乘积。

quantile

计算 input 张量沿维度 dim 每一行的第 q 个分位数。

nanquantile

这是 torch.quantile() 的一个变体,“忽略”了 NaN 值,计算分位数 q,就好像 input 中不存在 NaN 值一样。

std

计算 dim 指定维度上的标准差。

std_mean

计算 dim 指定维度上的标准差和均值。

sum

返回 input 张量中所有元素的总和。

unique

返回输入张量的唯一元素。

unique_consecutive

从每个连续的等效元素组中删除除第一个元素以外的所有元素。

var

计算 dim 指定维度上的方差。

var_mean

计算 dim 指定维度上的方差和均值。

count_nonzero

计算张量 input 中沿给定维度 dim 的非零值的数量。

比较操作

allclose

此函数检查 inputother 是否满足以下条件:

argsort

返回按值升序对张量沿给定维度排序的索引。

eq

执行逐元素相等比较。

equal

如果两个张量的大小和元素相同,则为 True,否则为 False

ge

按元素计算 inputother\text{input} \geq \text{other}

greater_equal

torch.ge() 的别名。

gt

按元素计算 input>other\text{input} > \text{other}

greater

torch.gt() 的别名。

isclose

返回一个新的张量,其中包含布尔元素,表示 input 的每个元素是否“接近”于 other 中的对应元素。

isfinite

返回一个新的张量,其中包含布尔元素,表示每个元素是否为 有限

isin

测试 elements 的每个元素是否在 test_elements 中。

isinf

测试 input 的每个元素是否为无穷大(正无穷大或负无穷大)。

isposinf

测试 input 的每个元素是否为正无穷大。

isneginf

测试 input 的每个元素是否为负无穷大。

isnan

返回一个新的张量,其中包含布尔元素,表示 input 的每个元素是否为 NaN。

isreal

返回一个新的张量,其中包含布尔元素,表示 input 的每个元素是否为实值。

kthvalue

返回一个命名元组 (values, indices),其中 valuesinput 张量在给定维度 dim 中每行的第 k 个最小元素。

le

按元素计算 inputother\text{input} \leq \text{other}

less_equal

torch.le() 的别名。

lt

按元素计算 input<other\text{input} < \text{other}

less

torch.lt() 的别名。

maximum

计算 inputother 中的按元素最大值。

minimum

计算 inputother 中的按元素最小值。

fmax

计算 inputother 中的按元素最大值。

fmin

计算 inputother 中的按元素最小值。

ne

按元素计算 inputother\text{input} \neq \text{other}

not_equal

torch.ne() 的别名。

sort

按值对 input 张量的元素沿给定维度进行升序排序。

topk

返回给定 input 张量沿给定维度的 k 个最大元素。

msort

按值对 input 张量的元素沿其第一个维度进行升序排序。

频谱运算

stft

短时傅立叶变换 (STFT)。

istft

逆短时傅立叶变换。

bartlett_window

Bartlett 窗函数。

blackman_window

Blackman 窗函数。

hamming_window

Hamming 窗函数。

hann_window

Hann 窗函数。

kaiser_window

使用窗口长度 window_length 和形状参数 beta 计算 Kaiser 窗。

其他操作

atleast_1d

返回每个具有零维度的输入张量的一维视图。

atleast_2d

返回每个具有零维度的输入张量的二维视图。

atleast_3d

返回每个具有零维度的输入张量的三维视图。

bincount

统计非负整数数组中每个值的频率。

block_diag

从提供的张量创建块对角矩阵。

broadcast_tensors

根据 广播语义 广播给定的张量。

broadcast_to

input 广播到形状 shape

broadcast_shapes

broadcast_tensors() 类似,但用于形状。

bucketize

返回 input 中每个值所属的桶的索引,其中桶的边界由 boundaries 设置。

cartesian_prod

对给定的张量序列进行笛卡尔积。

cdist

计算两组行向量中每对向量之间的批量 p 范数距离。

clone

返回 input 的副本。

combinations

计算给定张量的长度为 rr 的组合。

corrcoef

估计由 input 矩阵给出的变量的 Pearson 积矩相关系数矩阵,其中行是变量,列是观测值。

cov

估计由 input 矩阵给出的变量的协方差矩阵,其中行是变量,列是观测值。

cross

返回 inputother 在维度 dim 中的向量叉积。

cummax

返回一个命名元组 (values, indices),其中 valuesinput 在维度 dim 中的元素的累积最大值。

cummin

返回一个命名元组 (values, indices),其中 valuesinput 在维度 dim 中的元素的累积最小值。

cumprod

返回 input 在维度 dim 中的元素的累积乘积。

cumsum

返回 input 在维度 dim 中的元素的累积和。

diag

  • 如果 input 是向量(一维张量),则返回一个二维方阵

diag_embed

创建一个张量,其某些二维平面(由 dim1dim2 指定)的对角线由 input 填充。

diagflat

  • 如果 input 是向量(一维张量),则返回一个二维方阵

diagonal

返回 input 的偏视图,其对角线元素相对于 dim1dim2 附加为形状末尾的维度。

diff

沿给定维度计算第 n 个前向差分。

einsum

使用基于爱因斯坦求和约定的符号,沿指定维度对输入 operands 元素的乘积求和。

flatten

通过将 input 重塑为一维张量来将其展平。

flip

沿 dims 中给定的轴反转 n 维张量的顺序。

fliplr

在左/右方向翻转张量,返回一个新的张量。

flipud

在向上/向下方向翻转张量,返回一个新的张量。

kron

计算 inputother 的克罗内克积,用 \otimes 表示。

rot90

在由 dims 轴指定的平面内将 n 维张量旋转 90 度。

gcd

计算 inputother 的逐元素最大公约数 (GCD)。

histc

计算张量的直方图。

histogram

计算张量中值的直方图。

histogramdd

计算张量中值的多维直方图。

meshgrid

创建由 attr:tensors 中的一维输入指定的坐标网格。

lcm

计算 inputother 的逐元素最小公倍数 (LCM)。

logcumsumexp

返回 input 中元素在维度 dim 上的指数的累积和的对数。

ravel

返回一个连续的扁平化张量。

renorm

返回一个张量,其中 input 沿维度 dim 的每个子张量都被归一化,使得子张量的 p-范数低于值 maxnorm

repeat_interleave

重复张量的元素。

roll

沿给定维度滚动张量 input

searchsorted

sorted_sequence 的*最内层*维度中查找索引,以便在排序后,如果将 values 中的对应值插入到索引之前,则 sorted_sequence 中对应*最内层*维度的顺序将保持不变。

tensordot

返回 a 和 b 在多个维度上的缩并。

trace

返回输入二维矩阵对角线元素的总和。

tril

返回矩阵(二维张量)或矩阵批次 input 的下三角部分,结果张量 out 的其他元素设置为 0。

tril_indices

以 2×N 张量的形式返回 row×col 矩阵的下三角部分的索引,其中第一行包含所有索引的行坐标,第二行包含列坐标。

triu

返回矩阵(二维张量)或矩阵批次 input 的上三角部分,结果张量 out 的其他元素设置为 0。

triu_indices

以 2×N 张量的形式返回 row×col 矩阵的上三角部分的索引,其中第一行包含所有索引的行坐标,第二行包含列坐标。

unflatten

在多个维度上扩展输入张量的一个维度。

vander

生成一个范德蒙矩阵。

view_as_real

返回 input 作为实张量的视图。

view_as_complex

返回 input 作为复张量的视图。

resolve_conj

如果 input 的共轭位设置为 True,则返回一个新的张量,其中包含已实现的共轭,否则返回 input

resolve_neg

如果 input 的负号位设置为 True,则返回一个新的张量,其中包含已实现的负号,否则返回 input

BLAS 和 LAPACK 运算

addbmm

对存储在 batch1batch2 中的矩阵执行批处理矩阵-矩阵乘积,并减少加法步骤(所有矩阵乘法都沿第一个维度累加)。

addmm

执行矩阵 mat1mat2 的矩阵乘法。

addmv

执行矩阵 mat 和向量 vec 的矩阵-向量乘积。

addr

执行向量 vec1vec2 的外积,并将其添加到矩阵 input 中。

baddbmm

执行 batch1batch2 中矩阵的批处理矩阵-矩阵乘积。

bmm

执行存储在 inputmat2 中的矩阵的批处理矩阵-矩阵乘积。

chain_matmul

返回 NN 个二维张量的矩阵乘积。

cholesky

计算对称正定矩阵 AA 或对称正定矩阵批次的 Cholesky 分解。

cholesky_inverse

计算复数 Hermitian 或实数对称正定矩阵的逆矩阵,给定其 Cholesky 分解。

cholesky_solve

计算具有复数 Hermitian 或实数对称正定左侧的线性方程组的解,给定其 Cholesky 分解。

dot

计算两个一维张量的点积。

geqrf

这是一个用于直接调用 LAPACK 的 geqrf 的底层函数。

ger

torch.outer() 的别名。

inner

计算一维张量的点积。

inverse

torch.linalg.inv() 的别名

det

torch.linalg.det() 的别名

logdet

计算方阵或方阵批次的行列式的对数。

slogdet

torch.linalg.slogdet() 的别名

lu

计算矩阵或矩阵批次 A 的 LU 分解。

lu_solve

使用 lu_factor() 中 A 的部分旋转 LU 分解,返回线性系统 Ax=bAx = b 的 LU 解。

lu_unpack

lu_factor() 返回的 LU 分解解包到 P、L、U 矩阵中。

matmul

两个张量的矩阵乘积。

matrix_power

torch.linalg.matrix_power() 的别名

matrix_exp

torch.linalg.matrix_exp() 的别名。

mm

执行矩阵 inputmat2 的矩阵乘法。

mv

执行矩阵 input 和向量 vec 的矩阵-向量乘积。

orgqr

torch.linalg.householder_product() 的别名。

ormqr

计算一系列 Householder 矩阵与一般矩阵的矩阵乘积。

outer

inputvec2 的外积。

pinverse

torch.linalg.pinv() 的别名

qr

计算矩阵或一批矩阵 input 的 QR 分解,并返回一个命名元组 (Q, R),其中张量满足 input=QR\text{input} = Q R,其中 QQ 是正交矩阵或一批正交矩阵,RR 是上三角矩阵或一批上三角矩阵。

svd

计算矩阵或一批矩阵 input 的奇异值分解。

svd_lowrank

返回矩阵、一批矩阵或稀疏矩阵 AA 的奇异值分解 (U, S, V),使得 AUdiag(S)VHA \approx U \operatorname{diag}(S) V^{\text{H}}.

pca_lowrank

对低秩矩阵、一批此类矩阵或稀疏矩阵执行线性主成分分析 (PCA)。

lobpcg

使用无矩阵 LOBPCG 方法找到对称正定广义特征值问题的 k 个最大(或最小)特征值和对应的特征向量。

trapz

torch.trapezoid() 的别名。

trapezoid

沿着 dim 计算 梯形法则

cumulative_trapezoid

沿着 dim 累积计算 梯形法则

triangular_solve

用一个方形上三角或下三角可逆矩阵 AA 和多个右手边 bb 求解方程组。

vdot

计算两个一维向量沿一个维度的点积。

Foreach 操作

警告

此 API 处于测试阶段,未来可能会发生变化。不支持前向模式 AD。

_foreach_abs

对输入列表中的每个张量应用 torch.abs()

_foreach_abs_

对输入列表中的每个张量应用 torch.abs()

_foreach_acos

对输入列表中的每个张量应用 torch.acos()

_foreach_acos_

对输入列表中的每个张量应用 torch.acos()

_foreach_asin

对输入列表中的每个张量应用 torch.asin()

_foreach_asin_

对输入列表中的每个张量应用 torch.asin()

_foreach_atan

对输入列表中的每个张量应用 torch.atan()

_foreach_atan_

对输入列表中的每个张量应用 torch.atan()

_foreach_ceil

对输入列表中的每个张量应用 torch.ceil()

_foreach_ceil_

对输入列表中的每个张量应用 torch.ceil()

_foreach_cos

对输入列表中的每个张量应用 torch.cos()

_foreach_cos_

对输入列表中的每个张量应用 torch.cos()

_foreach_cosh

对输入列表中的每个张量应用 torch.cosh()

_foreach_cosh_

对输入列表中的每个张量应用 torch.cosh()

_foreach_erf

对输入列表中的每个张量应用 torch.erf()

_foreach_erf_

对输入列表中的每个张量应用 torch.erf()

_foreach_erfc

对输入列表中的每个张量应用 torch.erfc()

_foreach_erfc_

对输入列表中的每个张量应用 torch.erfc()

_foreach_exp

对输入列表中的每个张量应用 torch.exp()

_foreach_exp_

对输入列表中的每个张量应用 torch.exp()

_foreach_expm1

对输入列表中的每个张量应用 torch.expm1()

_foreach_expm1_

对输入列表中的每个张量应用 torch.expm1()

_foreach_floor

对输入列表中的每个张量应用 torch.floor()

_foreach_floor_

对输入列表中的每个张量应用 torch.floor()

_foreach_log

对输入列表中的每个张量应用 torch.log()

_foreach_log_

对输入列表中的每个张量应用 torch.log()

_foreach_log10

对输入列表中的每个张量应用 torch.log10()

_foreach_log10_

对输入列表中的每个张量应用 torch.log10()

_foreach_log1p

对输入列表中的每个张量应用 torch.log1p()

_foreach_log1p_

对输入列表中的每个张量应用 torch.log1p()

_foreach_log2

对输入列表中的每个张量应用 torch.log2()

_foreach_log2_

对输入列表中的每个张量应用 torch.log2()

_foreach_neg

对输入列表中的每个张量应用 torch.neg()

_foreach_neg_

对输入列表中的每个张量应用 torch.neg()

_foreach_tan

对输入列表中的每个张量应用 torch.tan()

_foreach_tan_

对输入列表中的每个张量应用 torch.tan()

_foreach_sin

对输入列表中的每个张量应用 torch.sin()

_foreach_sin_

对输入列表中的每个张量应用 torch.sin()

_foreach_sinh

对输入列表中的每个张量应用 torch.sinh()

_foreach_sinh_

对输入列表中的每个张量应用 torch.sinh()

_foreach_round

对输入列表中的每个张量应用 torch.round()

_foreach_round_

对输入列表中的每个张量应用 torch.round()

_foreach_sqrt

对输入列表中的每个张量应用 torch.sqrt()

_foreach_sqrt_

对输入列表中的每个张量应用 torch.sqrt()

_foreach_lgamma

对输入列表中的每个张量应用 torch.lgamma()

_foreach_lgamma_

对输入列表中的每个张量应用 torch.lgamma()

_foreach_frac

对输入列表中的每个张量应用 torch.frac()

_foreach_frac_

对输入列表中的每个张量应用 torch.frac()

_foreach_reciprocal

对输入列表中的每个张量应用 torch.reciprocal()

_foreach_reciprocal_

对输入列表中的每个张量应用 torch.reciprocal()

_foreach_sigmoid

对输入列表中的每个张量应用 torch.sigmoid()

_foreach_sigmoid_

对输入列表中的每个张量应用 torch.sigmoid()

_foreach_trunc

对输入列表中的每个张量应用 torch.trunc()

_foreach_trunc_

对输入列表中的每个张量应用 torch.trunc()

_foreach_zero_

对输入列表中的每个张量应用 torch.zero()

实用程序

compiled_with_cxx11_abi

返回 PyTorch 是否使用 _GLIBCXX_USE_CXX11_ABI=1 构建

result_type

返回对提供的输入张量执行算术运算后得到的 torch.dtype

can_cast

根据类型提升 文档 中描述的 PyTorch 转换规则,确定是否允许类型转换。

promote_types

返回尺寸最小且标量类型不小于 type1type2torch.dtype

use_deterministic_algorithms

设置 PyTorch 操作是否必须使用“确定性”算法。

are_deterministic_algorithms_enabled

如果全局确定性标志已打开,则返回 True。

is_deterministic_algorithms_warn_only_enabled

如果全局确定性标志设置为仅警告,则返回 True。

set_deterministic_debug_mode

设置确定性操作的调试模式。

get_deterministic_debug_mode

返回确定性操作的调试模式的当前值。

set_float32_matmul_precision

设置 float32 矩阵乘法的内部精度。

get_float32_matmul_precision

返回 float32 矩阵乘法精度的当前值。

set_warn_always

如果此标志为 False(默认值),则某些 PyTorch 警告可能每个进程只出现一次。

get_device_module

返回与给定设备关联的模块(例如,torch.device('cuda')、“mtia:0”、“xpu”等)。

is_warn_always_enabled

如果全局 warn_always 标志已打开,则返回 True。

vmap

vmap 是向量化映射;vmap(func) 返回一个新函数,该函数将 func 映射到输入的某个维度。

_assert

Python 断言的包装器,可进行符号跟踪。

符号数

torch.SymInt(node)[源代码]

类似于 int(包括魔术方法),但会重定向对包装节点的所有操作。这尤其用于在符号形状工作流中符号化地记录操作。

torch.SymFloat(node)[源代码]

类似于 float(包括魔术方法),但会重定向对包装节点的所有操作。这尤其用于在符号形状工作流中符号化地记录操作。

is_integer()[源代码]

如果浮点数是整数,则返回 True。

torch.SymBool(node)[源代码]

类似于 bool(包括魔术方法),但会重定向对包装节点的所有操作。这尤其用于在符号形状工作流中符号化地记录操作。

与常规布尔值不同,常规布尔运算符将强制执行额外的保护,而不是进行符号化求值。请改用按位运算符来处理这种情况。

sym_float

支持 SymInt 的 float 转换实用程序。

sym_int

支持 SymInt 的 int 转换实用程序。

sym_max

支持 SymInt 的 max 实用程序,避免在 a < b 时进行分支。

sym_min

支持 SymInt 的 min() 实用程序。

sym_not

支持 SymInt 的逻辑非实用程序。

sym_ite

导出路径

警告

此功能尚处于原型阶段,将来可能会出现不兼容的更改。

导出 generated/exportdb/index

控制流

警告

此功能尚处于原型阶段,将来可能会出现不兼容的更改。

cond

有条件地应用 true_fnfalse_fn

优化

compile

使用 TorchDynamo 和指定的后端优化给定模型/函数。

torch.compile 文档

运算符标签

torch.Tag

成员

core

data_dependent_output

dynamic_output_shape

generated

inplace_view

needs_fixed_stride_order

nondeterministic_bitwise

nondeterministic_seeded

pointwise

pt2_compliant_tag

view_copy

属性 name

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获取您的问题的答案

查看资源