快速傅里叶变换 (FFT) 以 O(n log n) 的时间计算离散傅里叶变换。它是各种数值算法和信号处理技术的基础,因为它使在信号“频域”中的操作与在空间或时域中的操作一样易于处理。
作为 PyTorch 支持硬件加速深度学习和科学计算目标的一部分,我们投入了改进 FFT 支持,并在 PyTorch 1.8 中发布了 torch.fft
模块。此模块实现了与 NumPy 的 np.fft
模块相同的功能,但支持加速器(如 GPU)和自动微分。
入门
无论您是否熟悉 NumPy 的 np.fft
模块,新 torch.fft
模块的入门都非常简单。您可以在此处找到模块中每个功能的完整文档,以下是其提供的功能概述:
fft
:在单个维度上计算复数 FFT,及其逆运算ifft
- 更通用的
fftn
和ifftn
:支持多个维度 - “实数”FFT 函数:
rfft
、irfft
、rfftn
、irfftn
,旨在处理时域中为实数值的信号 - “厄米特”FFT 函数:
hfft
和ihfft
,旨在处理频域中为实数值的信号 - 辅助函数:如
fftfreq
、rfftfreq
、fftshift
、ifftshift
,使信号操作更简单
我们认为这些函数提供了 FFT 功能的直观接口,并得到了 NumPy 社区的验证,尽管我们始终欢迎反馈和建议!
为了更好地说明从 NumPy 的 np.fft
模块迁移到 PyTorch 的 torch.fft
模块是多么容易,让我们看看一个简单的低通滤波器的 NumPy 实现,该滤波器从二维图像中去除高频方差,这是一种降噪或模糊的形式。
import numpy as np
import numpy.fft as fft
def lowpass_np(input, limit):
pass1 = np.abs(fft.rfftfreq(input.shape[-1])) < limit
pass2 = np.abs(fft.fftfreq(input.shape[-2])) < limit
kernel = np.outer(pass2, pass1)
fft_input = fft.rfft2(input)
return fft.irfft2(fft_input * kernel, s=input.shape[-2:])
现在让我们看看在 PyTorch 中实现相同的过滤器
import torch
import torch.fft as fft
def lowpass_torch(input, limit):
pass1 = torch.abs(fft.rfftfreq(input.shape[-1])) < limit
pass2 = torch.abs(fft.fftfreq(input.shape[-2])) < limit
kernel = torch.outer(pass2, pass1)
fft_input = fft.rfft2(input)
return fft.irfft2(fft_input * kernel, s=input.shape[-2:])
NumPy np.fft
模块的当前用法不仅可以直接转换为 torch.fft
,而且 torch.fft
操作还支持加速器(如 GPU)上的张量和自动微分。这使得(除其他外)可以使用 FFT 开发新的神经网络模块。
性能
torch.fft
模块不仅易于使用,而且速度很快!PyTorch 在 Intel CPU 上原生支持 Intel 的 MKL-FFT 库,在 CUDA 设备上原生支持 NVIDIA 的 cuFFT 库,我们精心优化了这些库的使用方式,以最大限度地提高性能。虽然您自己的结果将取决于您的 CPU 和 CUDA 硬件,但在 CUDA 设备上计算快速傅里叶变换比在 CPU 上计算快很多倍,特别是对于更大的信号。
将来,我们可能会添加对其他数学库的支持,以支持更多硬件。请参阅下文,了解如何请求其他硬件支持。
从旧版 PyTorch 更新
一些 PyTorch 用户可能知道,旧版 PyTorch 也提供了 FFT 功能,通过 torch.fft()
函数实现。不幸的是,此函数必须删除,因为其名称与新模块的名称冲突,并且我们认为新功能是 PyTorch 中使用快速傅里叶变换的最佳方式。特别是,torch.fft()
是在 PyTorch 支持复数张量之前开发的,而 torch.fft
模块是为支持复数张量而设计的。
PyTorch 还有一个“短时傅里叶变换”torch.stft
及其逆运算 torch.istft
。这些函数将被保留并更新以支持复数张量。
未来
如前所述,PyTorch 1.8 提供了 torch.fft 模块,使得在加速器上使用快速傅里叶变换 (FFT) 变得简单,并支持自动微分。我们鼓励您尝试一下!
虽然此模块迄今为止是按照 NumPy 的 np.fft
模块建模的,但我们不会止步于此。我们渴望听取您的意见,即我们的社区,了解您需要的 FFT 相关功能,并鼓励您在我们的论坛上发帖:https://discuss.pytorch.org/,或者在我们的 Github 上提交问题,提供您的反馈和请求。例如,早期使用者已经开始询问离散余弦变换和对更多硬件平台的支持,我们现在正在研究这些功能。
我们期待您的来信,并期待看到社区如何利用 PyTorch 新的 FFT 功能!