注意
点击这里下载完整的示例代码
音频重采样¶
作者: Caroline Chen, Moto Hira
本教程展示如何使用 torchaudio 的重采样 API。
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
print(torch.__version__)
print(torchaudio.__version__)
2.6.0
2.6.0
准备工作¶
首先,我们导入模块并定义辅助函数。
import math
import timeit
import librosa
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import pandas as pd
import resampy
from IPython.display import Audio
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
DEFAULT_OFFSET = 201
def _get_log_freq(sample_rate, max_sweep_rate, offset):
"""Get freqs evenly spaced out in log-scale, between [0, max_sweep_rate // 2]
offset is used to avoid negative infinity `log(offset + x)`.
"""
start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
def _get_inverse_log_freq(freq, sample_rate, offset):
"""Find the time where the given frequency is given by _get_log_freq"""
half = sample_rate // 2
return sample_rate * (math.log(1 + freq / offset) / math.log(1 + half / offset))
def _get_freq_ticks(sample_rate, offset, f_max):
# Given the original sample rate used for generating the sweep,
# find the x-axis value where the log-scale major frequency values fall in
times, freq = [], []
for exp in range(2, 5):
for v in range(1, 10):
f = v * 10**exp
if f < sample_rate // 2:
t = _get_inverse_log_freq(f, sample_rate, offset) / sample_rate
times.append(t)
freq.append(f)
t_max = _get_inverse_log_freq(f_max, sample_rate, offset) / sample_rate
times.append(t_max)
freq.append(f_max)
return times, freq
def get_sine_sweep(sample_rate, offset=DEFAULT_OFFSET):
max_sweep_rate = sample_rate
freq = _get_log_freq(sample_rate, max_sweep_rate, offset)
delta = 2 * math.pi * freq / sample_rate
cummulative = torch.cumsum(delta, dim=0)
signal = torch.sin(cummulative).unsqueeze(dim=0)
return signal
def plot_sweep(
waveform,
sample_rate,
title,
max_sweep_rate=48000,
offset=DEFAULT_OFFSET,
):
x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2]
y_ticks = [1000, 5000, 10000, 20000, sample_rate // 2]
time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2)
freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq]
freq_y = [f for f in freq if f in y_ticks and 1000 <= f <= sample_rate // 2]
figure, axis = plt.subplots(1, 1)
_, _, _, cax = axis.specgram(waveform[0].numpy(), Fs=sample_rate)
plt.xticks(time, freq_x)
plt.yticks(freq_y, freq_y)
axis.set_xlabel("Original Signal Frequency (Hz, log scale)")
axis.set_ylabel("Waveform Frequency (Hz)")
axis.xaxis.grid(True, alpha=0.67)
axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.colorbar(cax)
重采样概述¶
要将音频波形从一个频率重采样到另一个频率,您可以使用 torchaudio.transforms.Resample
或 torchaudio.functional.resample()
。transforms.Resample
预先计算并缓存用于重采样的内核,而 functional.resample
则动态计算它,因此当使用相同的参数重采样多个波形时,使用 torchaudio.transforms.Resample
将会加快速度(请参阅基准测试部分)。
两种重采样方法都使用 带限 sinc 插值 来计算任意时间步长的信号值。该实现涉及卷积,因此我们可以利用 GPU / 多线程来提高性能。
注意
当在多个子进程中使用重采样时,例如使用多个工作进程进行数据加载,您的应用程序可能会创建比系统有效处理能力更多的线程。在这种情况下,设置 torch.set_num_threads(1)
可能会有所帮助。
由于有限数量的样本只能表示有限数量的频率,因此重采样不会产生完美的结果,并且可以使用各种参数来控制其质量和计算速度。我们通过重采样对数正弦波扫描来演示这些特性,对数正弦波扫描是一种随时间频率呈指数增长的正弦波。
下面的频谱图显示了信号的频率表示,其中 x 轴对应于原始波形的频率(以对数刻度),y 轴对应于绘制的波形的频率,颜色强度对应于幅度。
sample_rate = 48000
waveform = get_sine_sweep(sample_rate)
plot_sweep(waveform, sample_rate, title="Original Waveform")
Audio(waveform.numpy()[0], rate=sample_rate)
data:image/s3,"s3://crabby-images/f151a/f151a4785d676652b9d8a188b60a9213b4e55fa0" alt="Original Waveform (sample rate: 48000 Hz)"
现在我们重采样(下采样)它。
我们看到在重采样波形的频谱图中,存在原始波形中不存在的伪影。这种效应称为混叠。此页面解释了它是如何发生的,以及为什么它看起来像反射。
resample_rate = 32000
resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype)
resampled_waveform = resampler(waveform)
plot_sweep(resampled_waveform, resample_rate, title="Resampled Waveform")
Audio(resampled_waveform.numpy()[0], rate=resample_rate)
data:image/s3,"s3://crabby-images/3648a/3648ad83eaf3fbe7b3e30012da13246e2bc963a6" alt="Resampled Waveform (sample rate: 32000 Hz)"
使用参数控制重采样质量¶
低通滤波器宽度¶
由于用于插值的滤波器无限延伸,因此 lowpass_filter_width
参数用于控制用于窗口化插值的滤波器宽度。它也称为过零点数,因为插值在每个时间单位都通过零点。使用较大的 lowpass_filter_width
可提供更清晰、更精确的滤波器,但计算成本更高。
sample_rate = 48000
resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")
data:image/s3,"s3://crabby-images/3bf6f/3bf6fa1e64a7953fa567ba47f97b4b7a78ea0144" alt="lowpass_filter_width=6 (sample rate: 32000 Hz)"
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")
data:image/s3,"s3://crabby-images/d1b6c/d1b6c8c095108d417a04ffffdb2ae5d4dce735e4" alt="lowpass_filter_width=128 (sample rate: 32000 Hz)"
滚降¶
rolloff
参数表示为奈奎斯特频率的一部分,奈奎斯特频率是给定有限采样率可表示的最大频率。rolloff
确定低通滤波器截止频率并控制混叠程度,当频率高于奈奎斯特频率时,混叠会将频率映射到较低频率。因此,较低的滚降将减少混叠量,但也会减少一些较高的频率。
sample_rate = 48000
resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.99")
data:image/s3,"s3://crabby-images/9c44a/9c44a9662d75d416fae9dd6c3c1af2b1661b551e" alt="rolloff=0.99 (sample rate: 32000 Hz)"
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")
data:image/s3,"s3://crabby-images/98d75/98d757da98c49064ec99ac29eadba04637713d52" alt="rolloff=0.8 (sample rate: 32000 Hz)"
窗口函数¶
默认情况下,torchaudio
的重采样使用 Hann 窗口滤波器,它是一个加权余弦函数。它还额外支持 Kaiser 窗口,这是一个接近最优的窗口函数,其中包含一个额外的 beta
参数,该参数允许设计滤波器的平滑度和脉冲宽度。这可以使用 resampling_method
参数进行控制。
sample_rate = 48000
resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_hann")
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")
data:image/s3,"s3://crabby-images/bf53a/bf53abae62ea32af0d52dabf077973453d2f9eef" alt="Hann Window Default (sample rate: 32000 Hz)"
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_kaiser")
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
data:image/s3,"s3://crabby-images/19339/1933992037ddd0580f2e8758d7535b6c7cb97d79" alt="Kaiser Window Default (sample rate: 32000 Hz)"
与 librosa 的比较¶
torchaudio
的重采样函数可用于产生类似于 librosa (resampy) 的 kaiser 窗口重采样的结果,但带有一些噪声
sample_rate = 48000
resample_rate = 32000
kaiser_best¶
resampled_waveform = F.resample(
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="sinc_interp_kaiser",
beta=14.769656459379492,
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")
data:image/s3,"s3://crabby-images/485b2/485b2cb4fcd3ef49ceacab14f1ed580b92228301" alt="Kaiser Window Best (torchaudio) (sample rate: 32000 Hz)"
librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_best")
).unsqueeze(0)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)")
data:image/s3,"s3://crabby-images/3328b/3328bfb221a79dce0467f6986a26b3f5f47a058d" alt="Kaiser Window Best (librosa) (sample rate: 32000 Hz)"
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser best MSE:", mse)
torchaudio and librosa kaiser best MSE: 2.0806901153660115e-06
kaiser_fast¶
resampled_waveform = F.resample(
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="sinc_interp_kaiser",
beta=8.555504641634386,
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")
data:image/s3,"s3://crabby-images/25b18/25b18438095fd83e91cbcd4f4c5a1ae52ad5d511" alt="Kaiser Window Fast (torchaudio) (sample rate: 32000 Hz)"
librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_fast")
).unsqueeze(0)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)")
data:image/s3,"s3://crabby-images/551af/551af5b56b337dd50c8016744b6ca29d92e0c452" alt="Kaiser Window Fast (librosa) (sample rate: 32000 Hz)"
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser fast MSE:", mse)
torchaudio and librosa kaiser fast MSE: 2.5200744248601437e-05
性能基准测试¶
以下是两个采样率对之间波形下采样和上采样的基准测试。我们演示了 lowpass_filter_width
、窗口类型和采样率可能对性能产生的影响。此外,我们提供了与 librosa
的 kaiser_best
和 kaiser_fast
的比较,使用它们在 torchaudio
中对应的参数。
print(f"torchaudio: {torchaudio.__version__}")
print(f"librosa: {librosa.__version__}")
print(f"resampy: {resampy.__version__}")
torchaudio: 2.6.0
librosa: 0.10.0
resampy: 0.2.2
def benchmark_resample_functional(
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=6,
rolloff=0.99,
resampling_method="sinc_interp_hann",
beta=None,
iters=5,
):
return (
timeit.timeit(
stmt="""
torchaudio.functional.resample(
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff,
resampling_method=resampling_method,
beta=beta,
)
""",
setup="import torchaudio",
number=iters,
globals=locals(),
)
* 1000
/ iters
)
def benchmark_resample_transforms(
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=6,
rolloff=0.99,
resampling_method="sinc_interp_hann",
beta=None,
iters=5,
):
return (
timeit.timeit(
stmt="resampler(waveform)",
setup="""
import torchaudio
resampler = torchaudio.transforms.Resample(
sample_rate,
resample_rate,
lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff,
resampling_method=resampling_method,
dtype=waveform.dtype,
beta=beta,
)
resampler.to(waveform.device)
""",
number=iters,
globals=locals(),
)
* 1000
/ iters
)
def benchmark_resample_librosa(
waveform,
sample_rate,
resample_rate,
res_type=None,
iters=5,
):
waveform_np = waveform.squeeze().numpy()
return (
timeit.timeit(
stmt="""
librosa.resample(
waveform_np,
orig_sr=sample_rate,
target_sr=resample_rate,
res_type=res_type,
)
""",
setup="import librosa",
number=iters,
globals=locals(),
)
* 1000
/ iters
)
def benchmark(sample_rate, resample_rate):
times, rows = [], []
waveform = get_sine_sweep(sample_rate).to(torch.float32)
args = (waveform, sample_rate, resample_rate)
# sinc 64 zero-crossings
f_time = benchmark_resample_functional(*args, lowpass_filter_width=64)
t_time = benchmark_resample_transforms(*args, lowpass_filter_width=64)
times.append([None, f_time, t_time])
rows.append("sinc (width 64)")
# sinc 6 zero-crossings
f_time = benchmark_resample_functional(*args, lowpass_filter_width=16)
t_time = benchmark_resample_transforms(*args, lowpass_filter_width=16)
times.append([None, f_time, t_time])
rows.append("sinc (width 16)")
# kaiser best
kwargs = {
"lowpass_filter_width": 64,
"rolloff": 0.9475937167399596,
"resampling_method": "sinc_interp_kaiser",
"beta": 14.769656459379492,
}
lib_time = benchmark_resample_librosa(*args, res_type="kaiser_best")
f_time = benchmark_resample_functional(*args, **kwargs)
t_time = benchmark_resample_transforms(*args, **kwargs)
times.append([lib_time, f_time, t_time])
rows.append("kaiser_best")
# kaiser fast
kwargs = {
"lowpass_filter_width": 16,
"rolloff": 0.85,
"resampling_method": "sinc_interp_kaiser",
"beta": 8.555504641634386,
}
lib_time = benchmark_resample_librosa(*args, res_type="kaiser_fast")
f_time = benchmark_resample_functional(*args, **kwargs)
t_time = benchmark_resample_transforms(*args, **kwargs)
times.append([lib_time, f_time, t_time])
rows.append("kaiser_fast")
df = pd.DataFrame(times, columns=["librosa", "functional", "transforms"], index=rows)
return df
def plot(df):
print(df.round(2))
ax = df.plot(kind="bar")
plt.ylabel("Time Elapsed [ms]")
plt.xticks(rotation=0, fontsize=10)
for cont, col, color in zip(ax.containers, df.columns, mcolors.TABLEAU_COLORS):
label = ["N/A" if v != v else str(v) for v in df[col].round(2)]
ax.bar_label(cont, labels=label, color=color, fontweight="bold", fontsize="x-small")
下采样 (48 -> 44.1 kHz)¶
df = benchmark(48_000, 44_100)
plot(df)
data:image/s3,"s3://crabby-images/6a4fb/6a4fb2c863cac21192f7c5b7385ae78b557740b5" alt="audio resampling tutorial"
librosa functional transforms
sinc (width 64) NaN 0.90 0.40
sinc (width 16) NaN 0.72 0.35
kaiser_best 83.91 1.21 0.38
kaiser_fast 7.89 0.95 0.34
下采样 (16 -> 8 kHz)¶
df = benchmark(16_000, 8_000)
plot(df)
data:image/s3,"s3://crabby-images/b6c3d/b6c3d92dc3ce27c7c11259afad006e965d61d908" alt="audio resampling tutorial"
librosa functional transforms
sinc (width 64) NaN 1.29 1.10
sinc (width 16) NaN 0.54 0.37
kaiser_best 11.29 1.36 1.17
kaiser_fast 3.14 0.67 0.41
上采样 (44.1 -> 48 kHz)¶
df = benchmark(44_100, 48_000)
plot(df)
data:image/s3,"s3://crabby-images/588bb/588bb4f9a419a1272ca4b96ed60cfe820e7a5909" alt="audio resampling tutorial"
librosa functional transforms
sinc (width 64) NaN 0.87 0.36
sinc (width 16) NaN 0.70 0.34
kaiser_best 32.74 1.14 0.38
kaiser_fast 7.88 0.94 0.34
上采样 (8 -> 16 kHz)¶
df = benchmark(8_000, 16_000)
plot(df)
data:image/s3,"s3://crabby-images/a653b/a653b4e248c66ae4a2edf400464031bd4dfd25b4" alt="audio resampling tutorial"
librosa functional transforms
sinc (width 64) NaN 0.70 0.46
sinc (width 16) NaN 0.38 0.22
kaiser_best 11.24 0.71 0.48
kaiser_fast 2.99 0.41 0.24
总结¶
详细说明结果
较大的
lowpass_filter_width
会导致更大的重采样内核,因此会增加内核计算和卷积的计算时间使用
sinc_interp_kaiser
会导致比默认sinc_interp_hann
更长的计算时间,因为它计算中间窗口值更复杂采样率和重采样率之间较大的最大公约数 (GCD) 将导致简化,从而允许更小的内核和更快的内核计算。
脚本的总运行时间: (0 分 3.361 秒)