复数¶
复数是可以表示为 形式的数,其中 a 和 b 是实数,j 是虚数单位,满足方程 。复数在数学和工程领域中频繁出现,尤其是在信号处理等主题中。传统上,许多用户和库(例如 TorchAudio)通过使用形状为 的浮点张量来表示数据,其中最后一个维度包含实部和虚部值来处理复数。
复数数据类型的张量在使用复数时提供了更自然的用户体验。对复数张量执行的操作(例如,torch.mv()
, torch.matmul()
)可能比模拟它们的浮点张量操作更快、更节省内存。PyTorch 中涉及复数的运算已优化,以使用向量化汇编指令和专用内核(例如 LAPACK、cuBlas)。
注意
torch.fft 模块中的谱运算支持原生复数张量。
警告
复数张量是 Beta 特性,可能会有所更改。
创建复数张量¶
我们支持两种复数数据类型:torch.cfloat 和 torch.cdouble
>>> x = torch.randn(2,2, dtype=torch.cfloat)
>>> x
tensor([[-0.4621-0.0303j, -0.2438-0.5874j],
[ 0.7706+0.1421j, 1.2110+0.1918j]])
注意
复数张量的默认数据类型由默认浮点数据类型决定。如果默认浮点数据类型是 torch.float64,则推断复数的数据类型为 torch.complex128,否则假定其数据类型为 torch.complex64。
除 torch.linspace()
、torch.logspace()
和 torch.arange()
外,所有工厂函数都支持复数张量。
从旧表示转换¶
目前使用形状为 的实数张量来解决缺少复数张量问题的用户,可以使用 torch.view_as_complex()
和 torch.view_as_real()
轻松地在其代码中切换到使用复数张量。请注意,这些函数不执行任何复制,并返回输入张量的视图。
>>> x = torch.randn(3, 2)
>>> x
tensor([[ 0.6125, -0.1681],
[-0.3773, 1.3487],
[-0.0861, -0.7981]])
>>> y = torch.view_as_complex(x)
>>> y
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
>>> torch.view_as_real(y)
tensor([[ 0.6125, -0.1681],
[-0.3773, 1.3487],
[-0.0861, -0.7981]])
访问实部和虚部¶
可以使用 real
和 imag
属性访问复数张量的实部和虚部值。
注意
访问 real 和 imag 属性不会分配任何内存,并且对 real 和 imag 张量进行原地更新将更新原始复数张量。此外,返回的 real 和 imag 张量不是连续的。
>>> y.real
tensor([ 0.6125, -0.3773, -0.0861])
>>> y.imag
tensor([-0.1681, 1.3487, -0.7981])
>>> y.real.mul_(2)
tensor([ 1.2250, -0.7546, -0.1722])
>>> y
tensor([ 1.2250-0.1681j, -0.7546+1.3487j, -0.1722-0.7981j])
>>> y.real.stride()
(2,)
幅角和模长¶
可以使用 torch.angle()
和 torch.abs()
计算复数张量的幅角和绝对值。
>>> x1=torch.tensor([3j, 4+4j])
>>> x1.abs()
tensor([3.0000, 5.6569])
>>> x1.angle()
tensor([1.5708, 0.7854])
线性代数¶
许多线性代数运算,例如 torch.matmul()
、torch.linalg.svd()
、torch.linalg.solve()
等,都支持复数。如果您希望请求我们目前不支持的运算,请搜索是否已提交过相关问题,如果尚未提交,请提交一个。
序列化¶
复数张量可以被序列化,从而允许数据以复数形式保存。
>>> torch.save(y, 'complex_tensor.pt')
>>> torch.load('complex_tensor.pt')
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
自动微分¶
PyTorch 支持复数张量的自动微分。计算出的梯度是 Conjugate Wirtinger 导数,其负数恰好是梯度下降算法中使用的最速下降方向。因此,所有现有优化器都可以直接用于复数参数。有关更多详细信息,请查看注释 复数自动微分。
优化器¶
从语义上讲,我们定义对具有复数参数的 PyTorch 优化器执行步骤等同于对这些复数参数的 torch.view_as_real()
等效项执行相同的优化器步骤。更具体地说,
>>> params = [torch.rand(2, 3, dtype=torch.complex64) for _ in range(5)]
>>> real_params = [torch.view_as_real(p) for p in params]
>>> complex_optim = torch.optim.AdamW(params)
>>> real_optim = torch.optim.AdamW(real_params)
real_optim 和 complex_optim 将计算对参数的相同更新,尽管两个优化器之间可能存在细微的数值差异,类似于 foreach 与 forloop 优化器以及 capturable 与默认优化器之间的数值差异。有关更多详细信息,请参阅 https://pytorch.ac.cn/docs/stable/notes/numerical_accuracy.html。
具体来说,虽然您可以认为我们的优化器处理复数张量与分别对它们的 p.real 和 p.imag 部分进行优化是相同的,但实现细节并非完全如此。请注意,torch.view_as_real()
等效项会将复数张量转换为形状为 的实数张量,而将复数张量拆分为两个张量是 2 个大小为 的张量。这种区别对逐点优化器(如 AdamW)没有影响,但会导致执行全局归约(如 LBFGS)的优化器产生细微差异。我们目前没有执行逐张量归约的优化器,因此尚未定义此行为。如果您有需要精确定义此行为的用例,请提交一个问题。
我们尚未完全支持以下子系统
量化
JIT
稀疏张量
分布式