快捷方式

torch

torch 包包含用于多维张量的数据结构,并定义了对这些张量的数学运算。此外,它还提供了许多用于高效序列化张量和任意类型的实用程序,以及其他有用的实用程序。

它有一个 CUDA 对应版本,使您能够在具有 compute capability >= 3.0 的 NVIDIA GPU 上运行张量计算。

张量

is_tensor

如果obj 是 PyTorch 张量,则返回 True。

is_storage

如果obj 是 PyTorch 存储对象,则返回 True。

is_complex

如果input 的数据类型是复数数据类型(即 torch.complex64torch.complex128 中的一个),则返回 True。

is_conj

如果input 是共轭张量,即其共轭位设置为True,则返回 True。

is_floating_point

如果input 的数据类型是浮点数据类型(即 torch.float64torch.float32torch.float16torch.bfloat16 中的一个),则返回 True。

is_nonzero

如果input 是一个单元素张量,并且在类型转换后不等于零,则返回 True。

set_default_dtype

将默认浮点 dtype 设置为d

get_default_dtype

获取当前默认浮点 torch.dtype

set_default_device

将默认 torch.Tensor 分配到 device 上。

get_default_device

获取默认 torch.Tensor 分配到的 device

set_default_tensor_type

numel

返回input 张量中的元素总数。

set_printoptions

设置打印选项。

set_flush_denormal

禁用 CPU 上的非规格化浮点数。

创建操作

注意

随机采样创建操作列在 随机采样 下,包括:torch.rand() torch.rand_like() torch.randn() torch.randn_like() torch.randint() torch.randint_like() torch.randperm() 您也可以使用 torch.empty()就地随机采样 方法来创建具有从更广泛的分布中采样值的 torch.Tensor

tensor

通过复制data构建一个没有自动梯度历史的张量(也称为“叶张量”,请参阅 Autograd 机制)。

sparse_coo_tensor

使用给定indices处的指定值构建 COO(rdinate) 格式的稀疏张量

sparse_csr_tensor

使用给定crow_indicescol_indices处的指定值构建 CSR (Compressed Sparse Row) 格式的稀疏张量

sparse_csc_tensor

使用给定ccol_indicesrow_indices处的指定值构建 CSC (Compressed Sparse Column) 格式的稀疏张量

sparse_bsr_tensor

使用给定crow_indicescol_indices处的指定二维块构建 BSR (Block Compressed Sparse Row) 格式的稀疏张量

sparse_bsc_tensor

使用指定的二维块构建一个 BSC(块压缩稀疏列)格式的稀疏张量,这些块位于给定的 ccol_indicesrow_indices

asarray

obj 转换为张量。

as_tensor

data 转换为张量,如果可能,共享数据并保留自动梯度历史。

as_strided

创建现有 torch.Tensor input 的视图,具有指定的 sizestridestorage_offset

from_file

创建一个 CPU 张量,其存储由内存映射文件支持。

from_numpy

numpy.ndarray 创建一个 Tensor

from_dlpack

将来自外部库的张量转换为 torch.Tensor

frombuffer

从实现 Python 缓冲协议的对象创建一个一维 Tensor

zeros

返回一个填充标量值 0 的张量,其形状由可变参数 size 定义。

zeros_like

返回一个填充标量值 0 的张量,其大小与 input 相同。

ones

返回一个填充标量值 1 的张量,其形状由可变参数 size 定义。

ones_like

返回一个填充标量值 1 的张量,其大小与 input 相同。

arange

返回大小为 endstartstep\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil 的一维张量,其值取自区间 [start, end),公差为 step,从 start 开始。

range

返回大小为 endstartstep+1\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1 的一维张量,其值从 startend,步长为 step

linspace

创建一个大小为 steps 的一维张量,其值从 startend(包含端点)均匀分布。

logspace

创建一个大小为 steps 的一维张量,其值从 basestart{{\text{{base}}}}^{{\text{{start}}}}baseend{{\text{{base}}}}^{{\text{{end}}}}(包含端点)均匀分布,以对数刻度表示,底数为 base

eye

返回一个二维张量,对角线上的元素为 1,其他元素为 0。

empty

返回一个填充未初始化数据的张量。

empty_like

返回一个未初始化的张量,其大小与 input 相同。

empty_strided

创建一个具有指定 sizestride 的张量,并填充未定义数据。

full

创建一个大小为 size 的张量,并填充 fill_value

full_like

返回一个与 input 大小相同的张量,并填充 fill_value

quantize_per_tensor

将浮点张量转换为具有给定比例和零点的量化张量。

quantize_per_channel

将浮点张量转换为具有给定比例和零点的按通道量化张量。

dequantize

通过反量化量化张量返回一个 fp32 张量

complex

构建一个复数张量,其实部等于 real,其虚部等于 imag

polar

构建一个复数张量,其元素是对应于极坐标的笛卡尔坐标,其中绝对值为 abs,角度为 angle

heaviside

计算 input 中每个元素的 Heaviside 阶跃函数。

索引、切片、连接、变异操作

adjoint

返回张量的共轭视图,并将最后两个维度进行转置。

argwhere

返回一个张量,包含 input 中所有非零元素的索引。

cat

在给定的维度上连接给定的 seq 张量序列。

concat

torch.cat() 的别名。

concatenate

torch.cat() 的别名。

conj

返回 input 的视图,其中共轭位被翻转。

chunk

尝试将张量拆分为指定数量的块。

dsplit

根据 indices_or_sections,将具有三个或更多维度的张量 input 深度拆分为多个张量。

column_stack

通过水平堆叠 tensors 中的张量创建一个新的张量。

dstack

按深度顺序(沿第三轴)堆叠张量。

gather

沿由 dim 指定的轴收集值。

hsplit

根据 indices_or_sections,将具有一个或多个维度的张量 input 水平拆分为多个张量。

hstack

水平(列方向)顺序堆叠张量。

index_add

有关函数描述,请参阅 index_add_()

index_copy

有关函数描述,请参阅 index_add_()

index_reduce

有关函数描述,请参见 index_reduce_()

index_select

返回一个新的张量,该张量使用 LongTensor 类型且位于 index 中的条目,沿维度 diminput 张量进行索引。

masked_select

返回一个新的 1 维张量,该张量根据布尔掩码 maskBoolTensor 类型)对 input 张量进行索引。

movedim

input 中位于 source 位置的维度移动到 destination 位置。

moveaxis

torch.movedim() 的别名。

narrow

返回一个新的张量,它是 input 张量的缩小版本。

narrow_copy

Tensor.narrow() 相同,除了它返回一个副本而不是共享存储。

nonzero

permute

返回原始张量 input 的一个视图,其维度已置换。

reshape

返回一个与 input 具有相同数据和元素数量的张量,但具有指定的形状。

row_stack

torch.vstack() 的别名。

select

沿给定索引处的选定维度切片 input 张量。

scatter

torch.Tensor.scatter_() 的非就地版本。

diagonal_scatter

src 张量的值嵌入到 input 中,相对于 dim1dim2,沿着 input 的对角线元素。

select_scatter

src 张量的值嵌入到 input 中,位于给定的索引处。

slice_scatter

src 张量的值嵌入到 input 中,位于给定的维度处。

scatter_add

torch.Tensor.scatter_add_() 的非就地版本。

scatter_reduce

torch.Tensor.scatter_reduce_() 的非就地版本。

split

将张量分割成块。

squeeze

返回一个张量,其中删除了 input 的所有大小为 1 的指定维度。

stack

沿着一个新维度连接一系列张量。

swapaxes

torch.transpose() 的别名。

swapdims

torch.transpose() 的别名。

t

期望 input 为 <= 2 维张量,并转置维度 0 和 1。

take

返回一个新的张量,其中包含给定索引处 input 的元素。

take_along_dim

input 中选择值,这些值来自沿给定 dimindices 中的一维索引。

tensor_split

将张量拆分为多个子张量,所有这些子张量都是 input 的视图,沿着维度 dim,根据 indices_or_sections 指定的索引或部分数量进行拆分。

tile

通过重复 input 的元素来构造一个张量。

transpose

返回一个张量,它是 input 的转置版本。

unbind

删除张量维度。

unravel_index

将平面索引的张量转换为坐标张量的元组,这些坐标张量索引到指定形状的任意张量中。

unsqueeze

返回一个新的张量,在指定位置插入一个大小为一的维度。

vsplit

根据 indices_or_sections,将具有两个或多个维度的张量 input 垂直拆分为多个张量。

vstack

按顺序垂直(行方式)堆叠张量。

where

根据 condition,返回一个从 inputother 中选择的元素的张量。

加速器

在 PyTorch 代码库中,我们将“加速器”定义为 torch.device,它与 CPU 一起使用以加速计算。这些设备使用异步执行方案,使用 torch.Streamtorch.Event 作为它们执行同步的主要方式。我们还假设在给定主机上一次只能有一个这样的加速器可用。这使我们能够将当前加速器用作相关概念(如固定内存、Stream 设备类型、FSDP 等)的默认设备。

截至今天,加速器设备是(无特定顺序)“CUDA”“MTIA”“XPU” 和 PrivateUse1(许多设备不在 PyTorch 代码库本身中)。

Stream

一个按顺序执行各自任务的异步队列,按照先进先出 (FIFO) 顺序。

Event

查询和记录 Stream 状态,以识别或控制跨 Stream 的依赖关系并测量时间。

生成器

Generator

创建并返回一个生成器对象,该对象管理生成伪随机数的算法的状态。

随机采样

seed

将所有设备上生成随机数的种子设置为非确定性随机数。

manual_seed

设置所有设备上生成随机数的种子。

initial_seed

将生成随机数的初始种子作为 Python long 返回。

get_rng_state

将随机数生成器状态作为 torch.ByteTensor 返回。

set_rng_state

设置随机数生成器状态。

torch.default_generator 返回 默认 CPU torch.Generator

bernoulli

从伯努利分布中抽取二进制随机数(0 或 1)。

multinomial

返回一个张量,其中每一行包含从多项式(更严格的定义是多元的,有关更多详细信息,请参阅 torch.distributions.multinomial.Multinomial)概率分布中采样的 num_samples 个索引,该概率分布位于张量 input 的对应行中。

normal

返回一个随机数张量,这些随机数是从均值和标准差给定的单独正态分布中抽取的。

poisson

返回一个与 input 大小相同的张量,其中每个元素都从泊松分布中采样,其速率参数由 input 中的对应元素给出,即

rand

返回一个张量,其中填充了区间 [0,1)[0, 1) 上均匀分布的随机数。

rand_like

返回一个与 input 大小相同的张量,其中填充了区间 [0,1)[0, 1) 上均匀分布的随机数。

randint

返回一个张量,其中填充了在low(包含)和high(不包含)之间均匀生成的随机整数。

randint_like

返回一个与张量input形状相同的张量,其中填充了在low(包含)和high(不包含)之间均匀生成的随机整数。

randn

返回一个张量,其中填充了来自均值为0,方差为1(也称为标准正态分布)的正态分布的随机数。

randn_like

返回一个与input大小相同的张量,其中填充了来自均值为0,方差为1的正态分布的随机数。

randperm

返回从0n - 1的整数的随机排列。

就地随机采样

在张量上还定义了一些其他的就地随机采样函数。点击查看它们的文档

准随机采样

quasirandom.SobolEngine

torch.quasirandom.SobolEngine是用于生成(扰动的)Sobol序列的引擎。

序列化

save

将对象保存到磁盘文件。

load

从文件中加载使用torch.save()保存的对象。

并行性

get_num_threads

返回用于并行化 CPU 操作的线程数。

set_num_threads

设置用于 CPU 上的运算内并行性的线程数。

get_num_interop_threads

返回用于 CPU 上的运算间并行性的线程数(例如,

set_num_interop_threads

设置用于运算间并行性的线程数(例如,

局部禁用梯度计算

上下文管理器torch.no_grad()torch.enable_grad()torch.set_grad_enabled()有助于局部禁用和启用梯度计算。有关其用法的更多详细信息,请参阅局部禁用梯度计算。这些上下文管理器是线程局部的,因此如果您使用threading模块等将工作发送到另一个线程,它们将不起作用。

示例

>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False

>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...     y = x * 2
>>> y.requires_grad
False

>>> torch.set_grad_enabled(True)  # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True

>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

no_grad

禁用梯度计算的上下文管理器。

enable_grad

启用梯度计算的上下文管理器。

autograd.grad_mode.set_grad_enabled

设置梯度计算打开或关闭的上下文管理器。

is_grad_enabled

如果当前启用了 grad 模式,则返回 True。

autograd.grad_mode.inference_mode

启用或禁用推理模式的上下文管理器。

is_inference_mode_enabled

如果当前启用了推理模式,则返回 True。

数学运算

逐元素运算

abs

计算input中每个元素的绝对值。

absolute

torch.abs()的别名

acos

计算input中每个元素的反余弦。

arccos

torch.acos()的别名。

acosh

返回一个新的张量,其中包含input元素的反双曲余弦。

arccosh

torch.acosh()的别名。

add

other(乘以alpha)添加到input

addcdiv

执行tensor1除以tensor2的逐元素运算,将结果乘以标量value并将其添加到input

addcmul

执行tensor1乘以tensor2的逐元素运算,将结果乘以标量value并将其添加到input

angle

计算给定input张量的逐元素角度(以弧度为单位)。

asin

返回一个新的张量,其中包含input元素的反正弦。

arcsin

torch.asin()的别名。

asinh

返回一个新的张量,其中包含input元素的反双曲正弦。

arcsinh

torch.asinh()的别名。

atan

返回一个新的张量,其中包含input元素的反正切。

arctan

torch.atan()的别名。

atanh

返回一个新的张量,其中包含input元素的反双曲正切。

arctanh

torch.atanh()的别名。

atan2

inputi/otheri\text{input}_{i} / \text{other}_{i}的逐元素反正切,考虑象限。

arctan2

torch.atan2()的别名。

bitwise_not

计算给定输入张量的按位非。

bitwise_and

计算inputother的按位与。

bitwise_or

计算inputother的按位或。

bitwise_xor

计算inputother的按位异或。

bitwise_left_shift

计算input左移other位的算术移位。

bitwise_right_shift

计算input右移other位的算术移位。

ceil

返回一个新的张量,其中包含input元素的上取整,即大于或等于每个元素的最小整数。

clamp

input中的所有元素限制在范围[ min, max ]内。

clip

torch.clamp()的别名。

conj_physical

计算给定input张量的逐元素共轭。

copysign

创建一个新的浮点张量,其大小与input相同,符号与other相同,逐元素计算。

cos

返回一个新的张量,其中包含input元素的余弦值。

cosh

返回一个新的张量,其中包含input元素的双曲余弦值。

deg2rad

返回一个新的张量,其中input的每个元素都从角度(度)转换为弧度。

div

将输入input的每个元素除以other的对应元素。

divide

torch.div()的别名。

digamma

torch.special.digamma()的别名。

erf

torch.special.erf()的别名。

erfc

torch.special.erfc()的别名。

erfinv

torch.special.erfinv()的别名。

exp

返回一个新的张量,其中包含输入张量input元素的指数值。

exp2

torch.special.exp2()的别名。

expm1

torch.special.expm1()的别名。

fake_quantize_per_channel_affine

返回一个新的张量,其中input中的数据使用scalezero_pointquant_minquant_max进行逐通道的伪量化,跨由axis指定的通道。

fake_quantize_per_tensor_affine

返回一个新的张量,其中input中的数据使用scalezero_pointquant_minquant_max进行伪量化。

fix

torch.trunc()的别名。

float_power

input的每个元素提升到exponent次幂,以双精度进行逐元素计算。

floor

返回一个新的张量,其中包含input元素的下取整值,即小于或等于每个元素的最大整数。

floor_divide

fmod

逐元素应用C++的std::fmod

frac

计算input中每个元素的小数部分。

frexp

input分解为尾数和指数张量,使得input=mantissa×2exponent\text{input} = \text{mantissa} \times 2^{\text{exponent}}.

gradient

使用二阶精确中心差分法和边界处的一阶或二阶估计,估算函数g:RnRg : \mathbb{R}^n \rightarrow \mathbb{R}在一维或多维中的梯度。

imag

返回一个新的张量,其中包含self张量的虚数值。

ldexp

input乘以2 ** other

lerp

根据标量或张量weight,对两个张量start(由input给出)和end进行线性插值,并返回结果张量out

lgamma

计算input上伽马函数绝对值的自然对数。

log

返回一个新的张量,其中包含input元素的自然对数。

log10

返回一个新的张量,其中包含input元素以10为底的对数。

log1p

返回一个新的张量,其中包含(1 + input)的自然对数。

log2

返回一个新的张量,其中包含input元素以2为底的对数。

logaddexp

输入指数和的对数。

logaddexp2

以2为底的输入指数和的对数。

logical_and

计算给定输入张量的逐元素逻辑与。

logical_not

计算给定输入张量的逐元素逻辑非。

logical_or

计算给定输入张量的逐元素逻辑或。

logical_xor

计算给定输入张量的逐元素逻辑异或。

logit

torch.special.logit()的别名。

hypot

给定直角三角形的两条直角边,返回其斜边。

i0

torch.special.i0()的别名。

igamma

torch.special.gammainc()的别名。

igammac

torch.special.gammaincc()的别名。

mul

input乘以other

multiply

torch.mul()的别名。

mvlgamma

torch.special.multigammaln()的别名。

nan_to_num

input中的NaN、正无穷大和负无穷大值分别替换为nanposinfneginf指定的值。

neg

返回一个新的张量,其中包含input元素的负值。

negative

torch.neg()的别名。

nextafter

逐元素返回紧随 input 且向 other 方向的下一个浮点数。

polygamma

torch.special.polygamma() 的别名。

positive

返回 input

pow

使用 exponentinput 中每个元素进行幂运算,并返回一个包含结果的张量。

quantized_batch_norm

对 4D (NCHW) 量化张量应用批归一化。

quantized_max_pool1d

对由多个输入平面组成的输入量化张量应用 1D 最大池化。

quantized_max_pool2d

对由多个输入平面组成的输入量化张量应用 2D 最大池化。

rad2deg

返回一个新的张量,其中 input 中每个元素都从弧度转换为度数。

real

返回一个新的张量,包含 self 张量的实数值。

reciprocal

返回一个新的张量,包含 input 中每个元素的倒数。

remainder

逐元素计算 Python 的取模运算

round

input 中的元素四舍五入到最接近的整数。

rsqrt

返回一个新的张量,包含 input 中每个元素的平方根的倒数。

sigmoid

torch.special.expit() 的别名。

sign

返回一个新的张量,包含 input 中每个元素的符号。

sgn

此函数是 torch.sign() 对复数张量的扩展。

signbit

测试 input 中每个元素的符号位是否被设置。

sin

返回一个新的张量,包含 input 中每个元素的正弦值。

sinc

torch.special.sinc() 的别名。

sinh

返回一个新的张量,包含 input 中每个元素的双曲正弦值。

softmax

torch.nn.functional.softmax() 的别名。

sqrt

返回一个新的张量,包含 input 中每个元素的平方根。

square

返回一个新的张量,包含 input 中每个元素的平方。

sub

input 中减去 other,并乘以 alpha

subtract

torch.sub() 的别名。

tan

返回一个新的张量,包含 input 中每个元素的正切值。

tanh

返回一个新的张量,包含 input 中每个元素的双曲正切值。

true_divide

torch.div() 的别名,其中 rounding_mode=None

trunc

返回一个新的张量,包含 input 中每个元素的截断整数。

xlogy

torch.special.xlogy() 的别名。

约简操作

argmax

返回 input 张量中所有元素的最大值的索引。

argmin

返回扁平化张量或沿指定维度最小值(s)的索引。

amax

返回给定维度(s) diminput 张量每个切片的最大值。

amin

返回给定维度(s) diminput 张量每个切片的最小值。

aminmax

计算 input 张量的最小值和最大值。

all

测试 input 中所有元素是否都计算为 True

any

测试 input 中是否存在任何元素计算为 True

max

返回 input 张量中所有元素的最大值。

min

返回 input 张量中所有元素的最小值。

dist

返回 (input - other) 的 p 范数。

logsumexp

返回给定维度 diminput 张量每一行的指数和的对数。

mean

返回 input 张量中所有元素的平均值。

nanmean

计算沿指定维度所有非 NaN 元素的平均值。

median

返回 input 中值的中间值。

nanmedian

返回 input 中值的中间值,忽略 NaN 值。

mode

返回一个名为元组 (values, indices),其中 values 是给定维度 diminput 张量每一行的众数,即在该行中出现次数最多的值,而 indices 是找到的每个众数的值的索引位置。

norm

返回给定张量的矩阵范数或向量范数。

nansum

返回所有元素的总和,将非数字 (NaN) 视为零。

prod

返回 input 张量中所有元素的乘积。

quantile

计算沿维度 diminput 张量每一行的第 q 个分位数。

nanquantile

这是 torch.quantile() 的一个变体,它“忽略” NaN 值,计算分位数 q,就像 input 中不存在 NaN 值一样。

std

计算由 dim 指定的维度上的标准差。

std_mean

计算由 dim 指定的维度上的标准差和均值。

sum

返回 input 张量中所有元素的总和。

unique

返回输入张量的唯一元素。

unique_consecutive

消除每个连续的等价元素组中的第一个元素之外的所有元素。

var

计算由 dim 指定的维度上的方差。

var_mean

计算由 dim 指定的维度上的方差和均值。

count_nonzero

计算张量 input 沿给定 dim 的非零值的个数。

比较操作

allclose

此函数检查 inputother 是否满足条件

argsort

返回沿给定维度按值升序对张量排序的索引。

eq

计算逐元素相等性。

equal

如果两个张量具有相同的大小和元素,则为 True,否则为 False

ge

逐元素计算 inputother\text{input} \geq \text{other}

greater_equal

torch.ge() 的别名。

gt

逐元素计算 input>other\text{input} > \text{other}

大于

torch.gt() 的别名。

接近

返回一个新的张量,其布尔元素表示 input 的每个元素是否“接近” other 的对应元素。

有限

返回一个新的张量,其布尔元素表示每个元素是否为有限

包含

测试 elements 的每个元素是否在 test_elements 中。

无穷大

测试 input 的每个元素是否为无穷大(正无穷或负无穷)。

正无穷

测试 input 的每个元素是否为正无穷大。

负无穷

测试 input 的每个元素是否为负无穷大。

非数字

返回一个新的张量,其布尔元素表示 input 的每个元素是否为 NaN。

实数

返回一个新的张量,其布尔元素表示 input 的每个元素是否为实数。

第 k 个值

返回一个名为元组 (values, indices),其中 values 是给定维度 diminput 张量每一行的第 k 个最小元素。

小于等于

逐元素计算 inputother\text{input} \leq \text{other}

小于等于

torch.le() 的别名。

小于

逐元素计算 input<other\text{input} < \text{other}

小于

torch.lt() 的别名。

最大值

计算 inputother 的逐元素最大值。

最小值

计算 inputother 的逐元素最小值。

fmax

计算 inputother 的逐元素最大值。

fmin

计算 inputother 的逐元素最小值。

不相等

逐元素计算 inputother\text{input} \neq \text{other}

不相等

torch.ne() 的别名。

排序

沿给定维度对 input 张量的元素进行升序排序。

前 k 个值

沿给定维度返回给定 input 张量的 k 个最大元素。

多维排序

沿其第一维度对 input 张量的元素进行升序排序。

频谱运算

短时傅里叶变换

短时傅里叶变换 (STFT)。

逆短时傅里叶变换

逆短时傅里叶变换。

巴特利特窗

巴特利特窗函数。

布莱克曼窗

布莱克曼窗函数。

汉明窗

汉明窗函数。

汉宁窗

汉宁窗函数。

凯泽窗

计算窗口长度为 window_length 且形状参数为 beta 的凯泽窗。

其他运算

至少一维

返回每个零维输入张量的一维视图。

至少二维

返回每个零维输入张量的二维视图。

至少三维

返回每个零维输入张量的三维视图。

计数

计算非负整数数组中每个值的频率。

块对角线

从提供的张量创建一个块对角矩阵。

广播张量

根据 广播语义 广播给定的张量。

广播到

input 广播到形状 shape

广播形状

类似于 broadcast_tensors(),但用于形状。

分箱

返回 input 中每个值所属的箱的索引,其中箱的边界由 boundaries 设置。

笛卡尔积

对给定的张量序列执行笛卡尔积。

成对距离

计算两个行向量集合的每个对之间的 p 范数距离。

克隆

返回 input 的副本。

组合

计算给定张量的长度为 rr 的组合。

相关系数

估计由 input 矩阵给出的变量的皮尔逊积矩相关系数矩阵,其中行表示变量,列表示观测值。

协方差

估计由 input 矩阵给出的变量的协方差矩阵,其中行表示变量,列表示观测值。

叉积

返回 inputother 在维度 dim 上的向量的叉积。

累积最大值

返回一个名为元组 (values, indices),其中 valuesinput 在维度 dim 上的元素的累积最大值。

累积最小值

返回一个名为元组 (values, indices),其中 valuesinput 在维度 dim 上的元素的累积最小值。

累积积

返回 input 在维度 dim 上的元素的累积积。

累积和

返回 input 在维度 dim 上的元素的累积和。

对角线

  • 如果 input 是向量(一维张量),则返回一个二维方阵张量

对角线嵌入

创建一个张量,其某些二维平面(由 dim1dim2 指定)的对角线由 input 填充。

对角线展开

  • 如果 input 是向量(一维张量),则返回一个二维方阵张量

对角线

返回 input 的一个部分视图,其相对于 dim1dim2 的对角线元素作为维度附加到形状的末尾。

差分

沿给定维度计算第 n 个前向差分。

爱因斯坦求和

根据爱因斯坦求和约定,使用基于符号的表示法,沿指定维度对输入 operands 元素的乘积求和。

展平

通过将其重塑为一维张量来展平 input

翻转

沿 dims 中给定的轴反转 n 维张量的顺序。

fliplr

沿左右方向翻转张量,返回一个新的张量。

flipud

沿上下方向翻转张量,返回一个新的张量。

kron

计算克罗内克积,记为 \otimesinputother

rot90

将 n 维张量在由 dims 轴指定的平面内旋转 90 度。

gcd

计算 inputother 的逐元素最大公约数 (GCD)。

histc

计算张量的直方图。

histogram

计算张量中值的直方图。

histogramdd

计算张量中值的多分量直方图。

meshgrid

创建由 attr:tensors 中的 1D 输入指定的坐标网格。

lcm

计算 inputother 的逐元素最小公倍数 (LCM)。

logcumsumexp

返回 input 在维度 dim 上的元素指数的累积求和的对数。

ravel

返回一个连续的扁平化张量。

renorm

返回一个张量,其中 input 沿着维度 dim 的每个子张量的 p 范数都小于值 maxnorm

repeat_interleave

重复张量的元素。

roll

沿着给定维度(们)滚动张量 input

searchsorted

查找 sorted_sequence 最内层维度中的索引,使得如果将 values 中对应的值插入到这些索引之前,在排序后,sorted_sequence 中对应的最内层维度的顺序将保持不变。

tensordot

返回 a 和 b 在多个维度上的收缩。

trace

返回输入 2D 矩阵对角线元素的和。

tril

返回矩阵(2D 张量)或矩阵批次 input 的下三角部分,结果张量 out 的其他元素设置为 0。

tril_indices

在一个 2xN 的张量中返回 rowcol 列矩阵的下三角部分的索引,其中第一行包含所有索引的行坐标,第二行包含列坐标。

triu

返回矩阵(2D 张量)或矩阵批次 input 的上三角部分,结果张量 out 的其他元素设置为 0。

triu_indices

在一个 2xN 的张量中返回 rowcol 列矩阵的上三角部分的索引,其中第一行包含所有索引的行坐标,第二行包含列坐标。

unflatten

将输入张量的一个维度扩展到多个维度。

vander

生成范德蒙矩阵。

view_as_real

返回 input 作为实数张量的视图。

view_as_complex

返回 input 作为复数张量的视图。

resolve_conj

如果 input 的共轭位设置为 True,则返回一个具有具体化共轭的新张量,否则返回 input

resolve_neg

如果 input 的负位设置为 True,则返回一个具有具体化否定的新张量,否则返回 input

BLAS 和 LAPACK 操作

addbmm

执行存储在 batch1batch2 中的矩阵的批处理矩阵-矩阵乘积,并进行简化的加法步骤(所有矩阵乘法都沿着第一个维度累积)。

addmm

执行矩阵 mat1mat2 的矩阵乘法。

addmv

执行矩阵 mat 和向量 vec 的矩阵-向量乘积。

addr

执行向量 vec1vec2 的外积,并将其添加到矩阵 input 中。

baddbmm

执行 batch1batch2 中矩阵的批处理矩阵-矩阵乘积。

bmm

执行存储在 inputmat2 中的矩阵的批处理矩阵-矩阵乘积。

chain_matmul

返回 NN 个 2D 张量的矩阵乘积。

cholesky

计算对称正定矩阵 AA 或对称正定矩阵批次的 Cholesky 分解。

cholesky_inverse

给定其 Cholesky 分解,计算复厄米特或实对称正定矩阵的逆。

cholesky_solve

给定其 Cholesky 分解,计算具有复厄米特或实对称正定左端项的线性方程组的解。

dot

计算两个 1D 张量的点积。

geqrf

这是一个用于直接调用 LAPACK 的 geqrf 的低级函数。

ger

torch.outer() 的别名。

inner

计算 1D 张量的点积。

inverse

torch.linalg.inv() 的别名

det

torch.linalg.det() 的别名

logdet

计算方阵或方阵批次的 log 行列式。

slogdet

torch.linalg.slogdet() 的别名

lu

计算矩阵或矩阵批次 A 的 LU 分解。

lu_solve

使用 lu_factor() 中 A 的部分旋转 LU 分解,返回线性系统 Ax=bAx = b 的 LU 解。

lu_unpack

lu_factor() 返回的 LU 分解解包成 P, L, U 矩阵。

matmul

两个张量的矩阵乘积。

matrix_power

torch.linalg.matrix_power() 的别名

matrix_exp

torch.linalg.matrix_exp() 的别名。

mm

执行矩阵 inputmat2 的矩阵乘法。

mv

执行矩阵 input 和向量 vec 的矩阵-向量乘积。

orgqr

torch.linalg.householder_product() 的别名。

ormqr

计算 Householder 矩阵乘积与一般矩阵的矩阵-矩阵乘积。

outer

inputvec2 的外积。

pinverse

torch.linalg.pinv() 的别名

qr

计算矩阵或一批矩阵input的QR分解,并返回一个名为元组 (Q, R) 的张量,使得 input=QR\text{input} = Q R,其中QQ是正交矩阵或正交矩阵的批次,而RR是上三角矩阵或上三角矩阵的批次。

svd

计算矩阵或一批矩阵input的奇异值分解。

svd_lowrank

返回矩阵、矩阵批次或稀疏矩阵AA的奇异值分解(U, S, V),使得 AUdiag(S)VHA \approx U \operatorname{diag}(S) V^{\text{H}}.

pca_lowrank

对低秩矩阵、此类矩阵的批次或稀疏矩阵执行线性主成分分析 (PCA)。

lobpcg

使用无矩阵 LOBPCG 方法查找对称正定广义特征值问题的 k 个最大(或最小)特征值及其对应的特征向量。

trapz

别名,等价于 torch.trapezoid()

trapezoid

沿dim计算梯形法则

cumulative_trapezoid

沿dim累积计算梯形法则

triangular_solve

求解具有方形上三角或下三角可逆矩阵AA和多个右侧bb的方程组。

vdot

沿某个维度计算两个一维向量的点积。

Foreach 操作

警告

此 API 处于测试阶段,将来可能会发生变化。不支持前向模式 AD。

_foreach_abs

torch.abs()应用于输入列表中的每个张量。

_foreach_abs_

torch.abs()应用于输入列表中的每个张量。

_foreach_acos

torch.acos()应用于输入列表中的每个张量。

_foreach_acos_

torch.acos()应用于输入列表中的每个张量。

_foreach_asin

torch.asin()应用于输入列表中的每个张量。

_foreach_asin_

torch.asin()应用于输入列表中的每个张量。

_foreach_atan

torch.atan()应用于输入列表中的每个张量。

_foreach_atan_

torch.atan()应用于输入列表中的每个张量。

_foreach_ceil

torch.ceil()应用于输入列表中的每个张量。

_foreach_ceil_

torch.ceil()应用于输入列表中的每个张量。

_foreach_cos

torch.cos()应用于输入列表中的每个张量。

_foreach_cos_

torch.cos()应用于输入列表中的每个张量。

_foreach_cosh

torch.cosh()应用于输入列表中的每个张量。

_foreach_cosh_

torch.cosh()应用于输入列表中的每个张量。

_foreach_erf

torch.erf()应用于输入列表中的每个张量。

_foreach_erf_

torch.erf()应用于输入列表中的每个张量。

_foreach_erfc

torch.erfc()应用于输入列表中的每个张量。

_foreach_erfc_

torch.erfc()应用于输入列表中的每个张量。

_foreach_exp

torch.exp()应用于输入列表中的每个张量。

_foreach_exp_

torch.exp()应用于输入列表中的每个张量。

_foreach_expm1

torch.expm1()应用于输入列表中的每个张量。

_foreach_expm1_

torch.expm1()应用于输入列表中的每个张量。

_foreach_floor

torch.floor()应用于输入列表中的每个张量。

_foreach_floor_

torch.floor()应用于输入列表中的每个张量。

_foreach_log

torch.log()应用于输入列表中的每个张量。

_foreach_log_

torch.log()应用于输入列表中的每个张量。

_foreach_log10

torch.log10()应用于输入列表中的每个张量。

_foreach_log10_

torch.log10()应用于输入列表中的每个张量。

_foreach_log1p

torch.log1p()应用于输入列表中的每个张量。

_foreach_log1p_

torch.log1p()应用于输入列表中的每个张量。

_foreach_log2

torch.log2()应用于输入列表中的每个张量。

_foreach_log2_

torch.log2()应用于输入列表中的每个张量。

_foreach_neg

torch.neg()应用于输入列表中的每个张量。

_foreach_neg_

torch.neg()应用于输入列表中的每个张量。

_foreach_tan

torch.tan()应用于输入列表中的每个张量。

_foreach_tan_

torch.tan()应用于输入列表中的每个张量。

_foreach_sin

torch.sin()应用于输入列表中的每个张量。

_foreach_sin_

torch.sin()应用于输入列表中的每个张量。

_foreach_sinh

torch.sinh()应用于输入列表中的每个张量。

_foreach_sinh_

torch.sinh()应用于输入列表中的每个张量。

_foreach_round

torch.round()应用于输入列表中的每个张量。

_foreach_round_

torch.round()应用于输入列表中的每个张量。

_foreach_sqrt

torch.sqrt()应用于输入列表中的每个张量。

_foreach_sqrt_

torch.sqrt()应用于输入列表中的每个张量。

_foreach_lgamma

torch.lgamma()应用于输入列表中的每个张量。

_foreach_lgamma_

torch.lgamma()应用于输入列表中的每个张量。

_foreach_frac

torch.frac()应用于输入列表中的每个张量。

_foreach_frac_

torch.frac()应用于输入列表中的每个张量。

_foreach_reciprocal

torch.reciprocal()应用于输入列表中的每个张量。

_foreach_reciprocal_

torch.reciprocal()应用于输入列表中的每个张量。

_foreach_sigmoid

torch.sigmoid()应用于输入列表中的每个张量。

_foreach_sigmoid_

torch.sigmoid()应用于输入列表中的每个张量。

_foreach_trunc

torch.trunc()应用于输入列表中的每个张量。

_foreach_trunc_

torch.trunc()应用于输入列表中的每个张量。

_foreach_zero_

torch.zero()应用于输入列表中的每个张量。

实用程序

compiled_with_cxx11_abi

返回 PyTorch 是否使用 _GLIBCXX_USE_CXX11_ABI=1 构建。

result_type

返回对提供的输入张量执行算术运算将产生的torch.dtype

can_cast

确定在类型提升文档中描述的 PyTorch 转换规则下是否允许类型转换。

promote_types

返回大小最小且标量类型不低于type1type2torch.dtype

use_deterministic_algorithms

设置 PyTorch 操作是否必须使用“确定性”算法。

are_deterministic_algorithms_enabled

如果全局确定性标志已开启,则返回 True。

is_deterministic_algorithms_warn_only_enabled

如果全局确定性标志设置为仅警告,则返回 True。

set_deterministic_debug_mode

设置确定性操作的调试模式。

get_deterministic_debug_mode

返回确定性操作当前的调试模式值。

set_float32_matmul_precision

设置 float32 矩阵乘法的内部精度。

get_float32_matmul_precision

返回 float32 矩阵乘法当前的精度值。

set_warn_always

当此标志为 False(默认值)时,某些 PyTorch 警告可能每个进程仅出现一次。

get_device_module

返回与给定设备关联的模块(例如,torch.device('cuda')、"mtia:0"、"xpu" 等)。

is_warn_always_enabled

如果全局 warn_always 标志已开启,则返回 True。

vmap

vmap 是向量化映射;vmap(func) 返回一个新函数,该函数将 func 映射到输入的某些维度上。

_assert

Python 的 assert 的包装器,可进行符号跟踪。

符号数字

class torch.SymInt(node)[source]

类似于 int(包括魔术方法),但将包装节点上的所有操作重定向。这尤其用于在符号形状工作流中符号记录操作。

as_integer_ratio()[source]

将此 int 表示为精确的整数比率

返回类型

Tuple[SymInt, int]

class torch.SymFloat(node)[source]

类似于 float(包括魔术方法),但将包装节点上的所有操作重定向。这尤其用于在符号形状工作流中符号记录操作。

as_integer_ratio()[source]

将此 float 表示为精确的整数比率

返回类型

Tuple[int, int]

is_integer()[source]

如果 float 是整数,则返回 True。

class torch.SymBool(node)[source]

类似于 bool(包括魔术方法),但将包装节点上的所有操作重定向。这尤其用于在符号形状工作流中符号记录操作。

与常规布尔值不同,常规布尔运算符将强制执行额外的保护,而不是进行符号评估。使用按位运算符来处理此问题。

sym_float

SymInt 感知的 float 类型转换实用程序。

sym_int

SymInt 感知的 int 类型转换实用程序。

sym_max

SymInt 感知的 max 实用程序,避免在 a < b 上分支。

sym_min

SymInt 感知的 min() 实用程序。

sym_not

SymInt 感知的逻辑非实用程序。

sym_ite

导出路径

警告

此功能为原型,将来可能会发生与兼容性相关的重大更改。

导出 generated/exportdb/index

控制流

警告

此功能为原型,将来可能会发生与兼容性相关的重大更改。

cond

有条件地应用 true_fnfalse_fn

优化

compile

使用 TorchDynamo 和指定的后端优化给定的模型/函数。

torch.compile 文档

运算符标签

class torch.Tag

成员

core

data_dependent_output

dynamic_output_shape

flexible_layout

generated

inplace_view

needs_fixed_stride_order

nondeterministic_bitwise

nondeterministic_seeded

pointwise

pt2_compliant_tag

view_copy

property name

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源