复数¶
Complex numbers are numbers that can be expressed in the form , where a and b are real numbers, and j is called the imaginary unit, which satisfies the equation . Complex numbers frequently occur in mathematics and engineering, especially in topics like signal processing. Traditionally many users and libraries (e.g., TorchAudio) have handled complex numbers by representing the data in float tensors with shape where the last dimension contains the real and imaginary values.
复数类型的张量在处理复数时提供更自然的体验。对复数张量的操作(例如,torch.mv()
,torch.matmul()
)可能比模拟它们的浮点张量上的操作更快,并且内存效率更高。PyTorch 中涉及复数的操作经过优化,可以使用矢量化汇编指令和专用内核(例如 LAPACK、cuBlas)。
注意
torch.fft 模块 中的光谱操作支持原生复数张量。
警告
复数张量是测试版功能,可能会发生变化。
创建复数张量¶
我们支持两种复数类型:torch.cfloat 和 torch.cdouble
>>> x = torch.randn(2,2, dtype=torch.cfloat)
>>> x
tensor([[-0.4621-0.0303j, -0.2438-0.5874j],
[ 0.7706+0.1421j, 1.2110+0.1918j]])
注意
复数张量的默认类型由默认浮点类型决定。如果默认浮点类型是 torch.float64,则复数被推断为具有 torch.complex128 类型,否则它们被假定为具有 torch.complex64 类型。
除了 torch.linspace()
、torch.logspace()
和 torch.arange()
之外,所有工厂函数都支持复数张量。
从旧表示过渡¶
目前使用形状为 的实数张量来解决复数张量缺失问题的用户,可以使用 torch.view_as_complex()
和 torch.view_as_real()
轻松地在代码中切换到复数张量。请注意,这些函数不会执行任何复制操作,而是返回输入张量的视图。
>>> x = torch.randn(3, 2)
>>> x
tensor([[ 0.6125, -0.1681],
[-0.3773, 1.3487],
[-0.0861, -0.7981]])
>>> y = torch.view_as_complex(x)
>>> y
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
>>> torch.view_as_real(y)
tensor([[ 0.6125, -0.1681],
[-0.3773, 1.3487],
[-0.0861, -0.7981]])
访问实部和虚部¶
可以使用 real
和 imag
访问复数张量的实部和虚部。
注意
访问 real 和 imag 属性不会分配任何内存,并且对 real 和 imag 张量的就地更新将更新原始复数张量。此外,返回的 real 和 imag 张量不是连续的。
>>> y.real
tensor([ 0.6125, -0.3773, -0.0861])
>>> y.imag
tensor([-0.1681, 1.3487, -0.7981])
>>> y.real.mul_(2)
tensor([ 1.2250, -0.7546, -0.1722])
>>> y
tensor([ 1.2250-0.1681j, -0.7546+1.3487j, -0.1722-0.7981j])
>>> y.real.stride()
(2,)
角度和绝对值¶
可以使用 torch.angle()
和 torch.abs()
计算复张量的角度和绝对值。
>>> x1=torch.tensor([3j, 4+4j])
>>> x1.abs()
tensor([3.0000, 5.6569])
>>> x1.angle()
tensor([1.5708, 0.7854])
线性代数¶
许多线性代数运算,如 torch.matmul()
、torch.linalg.svd()
、torch.linalg.solve()
等,都支持复数。如果您想请求我们目前不支持的运算,请 搜索 是否已提交问题,如果没有,请 提交问题。
序列化¶
复张量可以序列化,允许将数据保存为复数值。
>>> torch.save(y, 'complex_tensor.pt')
>>> torch.load('complex_tensor.pt')
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])