跳转到主要内容
博客

torch.fft 模块:PyTorch 中带有自动微分的加速快速傅里叶变换

作者: 2021 年 3 月 3 日2024 年 11 月 16 日暂无评论

快速傅里叶变换 (FFT) 以 O(n log n) 的时间计算离散傅里叶变换。它是各种数值算法和信号处理技术的基础,因为它使在信号的“频域”中工作变得像在它们的空间或时域中工作一样容易。

作为 PyTorch 支持硬件加速深度学习和科学计算目标的一部分,我们投入了改进 FFT 支持,并随 PyTorch 1.8 发布了 torch.fft 模块。该模块实现了与 NumPy 的 np.fft 模块相同的功能,但支持加速器(如 GPU)和自动求导。

入门

无论您是否熟悉 NumPy 的 np.fft 模块,都可以轻松开始使用新的 torch.fft 模块。虽然该模块中每个功能的完整文档可以在此处找到,但其提供的功能细分如下:

  • fft,用于计算单个维度上的复数 FFT,及其逆变换 ifft
  • 更通用的 fftnifftn,支持多个维度
  • “实数”FFT 函数,rfftirfftrfftnirfftn,旨在处理时域中为实数值的信号
  • “厄米”FFT 函数,hfftihfft,旨在处理频域中为实数值的信号
  • 辅助函数,如 fftfreqrfftfreqfftshiftifftshift,使信号操作更容易

我们认为这些功能为 FFT 功能提供了直接的接口,并得到了 NumPy 社区的认可,尽管我们总是对反馈和建议感兴趣!

为了更好地说明从 NumPy 的 np.fft 模块迁移到 PyTorch 的 torch.fft 模块是多么容易,让我们看一个简单低通滤波器的 NumPy 实现,该滤波器从 2 维图像中去除高频方差,这是一种降噪或模糊的形式。

import numpy as np
import numpy.fft as fft

def lowpass_np(input, limit):
    pass1 = np.abs(fft.rfftfreq(input.shape[-1])) < limit
    pass2 = np.abs(fft.fftfreq(input.shape[-2])) < limit
    kernel = np.outer(pass2, pass1)
    
    fft_input = fft.rfft2(input)
    return fft.irfft2(fft_input * kernel, s=input.shape[-2:])

现在让我们看看 PyTorch 中实现的相同过滤器

import torch
import torch.fft as fft

def lowpass_torch(input, limit):
    pass1 = torch.abs(fft.rfftfreq(input.shape[-1])) < limit
    pass2 = torch.abs(fft.fftfreq(input.shape[-2])) < limit
    kernel = torch.outer(pass2, pass1)
    
    fft_input = fft.rfft2(input)
    return fft.irfft2(fft_input * kernel, s=input.shape[-2:])

NumPy 的 np.fft 模块的当前用途不仅可以直接转换为 torch.fft,而且 torch.fft 操作还支持加速器(如 GPU)上的张量和自动求导。这使得(除其他外)可以使用 FFT 开发新的神经网络模块。

性能

torch.fft 模块不仅易于使用,而且速度很快!PyTorch 原生支持 Intel CPU 上的 Intel MKL-FFT 库和 CUDA 设备上的 NVIDIA cuFFT 库,我们精心优化了这些库的使用方式以最大限度地提高性能。虽然您自己的结果将取决于您的 CPU 和 CUDA 硬件,但在 CUDA 设备上计算快速傅里叶变换可以比在 CPU 上计算快很多倍,特别是对于更大的信号。

将来,我们可能会添加对其他数学库的支持,以支持更多硬件。请参阅下文,了解您可以请求其他硬件支持的位置。

从旧版 PyTorch 更新

一些 PyTorch 用户可能知道,旧版 PyTorch 还提供了带有 torch.fft() 函数的 FFT 功能。不幸的是,此函数不得不被删除,因为它的名称与新模块的名称冲突,我们认为新功能是在 PyTorch 中使用快速傅里叶变换的最佳方式。特别是,torch.fft() 是在 PyTorch 支持复数张量之前开发的,而 torch.fft 模块旨在与它们一起使用。

PyTorch 还有一个“短时傅里叶变换”,torch.stft,及其逆变换 torch.istft。这些函数被保留并更新以支持复数张量。

未来

如前所述,PyTorch 1.8 提供了 torch.fft 模块,这使得在加速器上使用快速傅里叶变换 (FFT) 并支持自动求导变得容易。我们鼓励您尝试一下!

虽然该模块目前已仿照 NumPy 的 np.fft 模块,但我们不止于此。我们渴望听取您(我们的社区)关于您需要的 FFT 相关功能的意见,我们鼓励您在我们的论坛 https://discuss.pytorch.org/ 上发帖,或在我们的 Github 上提交问题,提供您的反馈和请求。例如,早期采用者已经开始询问离散余弦变换和对更多硬件平台的支撑,我们现在正在研究这些功能。

我们期待收到您的来信,并看到社区将如何利用 PyTorch 的新 FFT 功能!