快捷方式

复数

Complex numbers are numbers that can be expressed in the form a+bja + bj, where a and b are real numbers, and j is called the imaginary unit, which satisfies the equation j2=1j^2 = -1. Complex numbers frequently occur in mathematics and engineering, especially in topics like signal processing. Traditionally many users and libraries (e.g., TorchAudio) have handled complex numbers by representing the data in float tensors with shape (...,2)(..., 2) where the last dimension contains the real and imaginary values.

复数类型的张量在处理复数时提供更自然的体验。对复数张量的操作(例如,torch.mv()torch.matmul())可能比模拟它们的浮点张量上的操作更快,并且内存效率更高。PyTorch 中涉及复数的操作经过优化,可以使用矢量化汇编指令和专用内核(例如 LAPACK、cuBlas)。

注意

torch.fft 模块 中的光谱操作支持原生复数张量。

警告

复数张量是测试版功能,可能会发生变化。

创建复数张量

我们支持两种复数类型:torch.cfloattorch.cdouble

>>> x = torch.randn(2,2, dtype=torch.cfloat)
>>> x
tensor([[-0.4621-0.0303j, -0.2438-0.5874j],
     [ 0.7706+0.1421j,  1.2110+0.1918j]])

注意

复数张量的默认类型由默认浮点类型决定。如果默认浮点类型是 torch.float64,则复数被推断为具有 torch.complex128 类型,否则它们被假定为具有 torch.complex64 类型。

除了 torch.linspace()torch.logspace()torch.arange() 之外,所有工厂函数都支持复数张量。

从旧表示过渡

目前使用形状为 (...,2)(..., 2) 的实数张量来解决复数张量缺失问题的用户,可以使用 torch.view_as_complex()torch.view_as_real() 轻松地在代码中切换到复数张量。请注意,这些函数不会执行任何复制操作,而是返回输入张量的视图。

>>> x = torch.randn(3, 2)
>>> x
tensor([[ 0.6125, -0.1681],
     [-0.3773,  1.3487],
     [-0.0861, -0.7981]])
>>> y = torch.view_as_complex(x)
>>> y
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
>>> torch.view_as_real(y)
tensor([[ 0.6125, -0.1681],
     [-0.3773,  1.3487],
     [-0.0861, -0.7981]])

访问实部和虚部

可以使用 realimag 访问复数张量的实部和虚部。

注意

访问 realimag 属性不会分配任何内存,并且对 realimag 张量的就地更新将更新原始复数张量。此外,返回的 realimag 张量不是连续的。

>>> y.real
tensor([ 0.6125, -0.3773, -0.0861])
>>> y.imag
tensor([-0.1681,  1.3487, -0.7981])

>>> y.real.mul_(2)
tensor([ 1.2250, -0.7546, -0.1722])
>>> y
tensor([ 1.2250-0.1681j, -0.7546+1.3487j, -0.1722-0.7981j])
>>> y.real.stride()
(2,)

角度和绝对值

可以使用 torch.angle()torch.abs() 计算复张量的角度和绝对值。

>>> x1=torch.tensor([3j, 4+4j])
>>> x1.abs()
tensor([3.0000, 5.6569])
>>> x1.angle()
tensor([1.5708, 0.7854])

线性代数

许多线性代数运算,如 torch.matmul()torch.linalg.svd()torch.linalg.solve() 等,都支持复数。如果您想请求我们目前不支持的运算,请 搜索 是否已提交问题,如果没有,请 提交问题

序列化

复张量可以序列化,允许将数据保存为复数值。

>>> torch.save(y, 'complex_tensor.pt')
>>> torch.load('complex_tensor.pt')
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])

自动微分

PyTorch 支持复张量的自动微分。计算的梯度是共轭 Wirtinger 导数,其负值正是梯度下降算法中使用的最速下降方向。因此,所有现有的优化器都可以直接使用复参数。有关更多详细信息,请查看笔记 复数的自动微分

我们不完全支持以下子系统

  • 量化

  • JIT

  • 稀疏张量

  • 分布式

如果这些对您的用例有所帮助,请 搜索 是否已提交问题,如果没有,请 提交问题

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源