快速傅里叶变换(FFT)能在 O(n log n) 时间内计算离散傅里叶变换。它是各种数值算法和信号处理技术的基础,因为它使得在信号的“频域”中进行处理与在空间或时域中一样简便。
作为 PyTorch 支持硬件加速深度学习和科学计算目标的一部分,我们致力于改进 FFT 支持。在 PyTorch 1.8 中,我们发布了 torch.fft 模块。该模块实现了与 NumPy 的 np.fft 模块相同的函数,但增加了对 GPU 等加速器的支持,并内置了自动微分(autograd)功能。
入门指南
无论你是否熟悉 NumPy 的 np.fft 模块,上手新的 torch.fft 模块都非常简单。虽然该模块中每个函数的完整文档可以在这里找到,但其提供的功能概览如下:
fft:计算单维复数 FFT;ifft:其逆变换- 更通用的
fftn和ifftn:支持多维变换 - “实数”FFT 函数:
rfft、irfft、rfftn、irfftn,专为处理时域中为实数值的信号而设计 - “埃尔米特”(Hermitian)FFT 函数:
hfft和ihfft,专为处理频域中为实数值的信号而设计 - 辅助函数:如
fftfreq、rfftfreq、fftshift、ifftshift,使信号处理更加简便
我们认为这些函数提供了一个直接的 FFT 功能接口(已通过 NumPy 社区验证),尽管如此,我们始终欢迎大家的反馈和建议!
为了更好地说明从 NumPy 的 np.fft 模块迁移到 PyTorch 的 torch.fft 模块有多简单,让我们看看一个用于低通滤波的 NumPy 实现,它能够去除二维图像中高频方差,从而起到降噪或模糊的效果:
import numpy as np
import numpy.fft as fft
def lowpass_np(input, limit):
pass1 = np.abs(fft.rfftfreq(input.shape[-1])) < limit
pass2 = np.abs(fft.fftfreq(input.shape[-2])) < limit
kernel = np.outer(pass2, pass1)
fft_input = fft.rfft2(input)
return fft.irfft2(fft_input * kernel, s=input.shape[-2:])
现在,让我们看看在 PyTorch 中实现的相同滤波器:
import torch
import torch.fft as fft
def lowpass_torch(input, limit):
pass1 = torch.abs(fft.rfftfreq(input.shape[-1])) < limit
pass2 = torch.abs(fft.fftfreq(input.shape[-2])) < limit
kernel = torch.outer(pass2, pass1)
fft_input = fft.rfft2(input)
return fft.irfft2(fft_input * kernel, s=input.shape[-2:])
目前使用 NumPy np.fft 模块的代码不仅可以直接转换为 torch.fft,torch.fft 的操作还支持加速器(如 GPU)上的张量以及自动微分。这使得(除其他事项外)开发利用 FFT 的新型神经网络模块成为可能。
性能
torch.fft 模块不仅易于使用,而且运行速度非常快!PyTorch 原生支持 Intel CPU 上的 Intel MKL-FFT 库以及 CUDA 设备上的 NVIDIA cuFFT 库,我们仔细优化了这些库的使用方式,以最大限度地提高性能。虽然实际结果取决于你的 CPU 和 CUDA 硬件,但在 CUDA 设备上计算快速傅里叶变换的速度可能比在 CPU 上快很多倍,尤其是在处理大型信号时。
未来,我们可能会增加对更多数学库的支持,以支持更多硬件。请参阅下文了解如何请求额外的硬件支持。
从旧版本 PyTorch 更新
一些 PyTorch 用户可能知道旧版本的 PyTorch 也通过 torch.fft() 函数提供了 FFT 功能。遗憾的是,该函数必须移除,因为它的名称与新模块名称冲突,并且我们认为新功能是 PyTorch 中使用快速傅里叶变换的最佳方式。特别指出的是,torch.fft() 是在 PyTorch 支持复数张量之前开发的,而 torch.fft 模块则是专门为处理复数张量而设计的。
PyTorch 还有一个“短时傅里叶变换”函数 torch.stft 及其逆变换 torch.istft。这些函数将被保留,但已更新以支持复数张量。
未来展望
如上所述,PyTorch 1.8 提供了 torch.fft 模块,它使得在加速器上使用快速傅里叶变换(FFT)并支持自动微分变得非常简单。我们鼓励大家试用!
虽然目前该模块是参照 NumPy 的 np.fft 模块建模的,但我们不会止步于此。我们非常渴望听到社区的声音,了解你们还需要哪些与 FFT 相关的功能。我们鼓励大家在我们的论坛 https://discuss.pytorch.org/ 上发帖,或在我们的 Github 上提交 Issue,反馈你们的意见和需求。例如,早期采用者已经开始询问离散余弦变换(Discrete Cosine Transforms)以及对更多硬件平台的支持,我们目前正在调研这些功能。
我们期待听到你们的反馈,并期待看到社区如何利用 PyTorch 的新 FFT 功能!