• 文档 >
  • 使用混合 Demucs 进行音乐源分离 >
  • 旧版本 (稳定版)
快捷方式

使用混合 Demucs 进行音乐源分离

作者: Sean Kim

本教程展示了如何使用混合 Demucs 模型来执行音乐分离

1. 概述

执行音乐分离包括以下步骤

  1. 构建混合 Demucs 管道。

  2. 将波形格式化为预期大小的块,并循环遍历块(带重叠)并馈送到管道。

  3. 收集输出块并根据它们的重叠方式组合。

混合 Demucs [Défossez, 2021] 模型是 Demucs 模型的改进版本,Demucs 模型是一个基于波形的模型,可将音乐分离为其各自的来源,例如人声、贝斯和鼓。混合 Demucs 有效地使用频谱图通过频域学习,并且还转移到时间卷积。

2. 准备工作

首先,我们安装必要的依赖项。第一个要求是 torchaudiotorch

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

import matplotlib.pyplot as plt
2.6.0
2.6.0

除了 torchaudio 之外,还需要 mir_eval 来执行信噪失真比 (SDR) 计算。要安装 mir_eval,请使用 pip3 install mir_eval

from IPython.display import Audio
from mir_eval import separation
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.utils import download_asset

3. 构建管道

预训练模型权重和相关管道组件捆绑为 torchaudio.pipelines.HDEMUCS_HIGH_MUSDB_PLUS()。这是一个在 MUSDB18-HQ 和额外的内部额外训练数据上训练的 torchaudio.models.HDemucs 模型。此特定模型适用于较高的采样率,约为 44.1 kHZ,并且在模型实现中具有 4096 的 nfft 值和 6 的深度。

bundle = HDEMUCS_HIGH_MUSDB_PLUS

model = bundle.get_model()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

sample_rate = bundle.sample_rate

print(f"Sample rate: {sample_rate}")
  0%|          | 0.00/319M [00:00<?, ?B/s]
 14%|#3        | 44.5M/319M [00:00<00:00, 467MB/s]
 28%|##7       | 89.0M/319M [00:00<00:00, 463MB/s]
 42%|####1     | 133M/319M [00:00<00:00, 455MB/s]
 55%|#####5    | 177M/319M [00:00<00:00, 445MB/s]
 69%|######9   | 221M/319M [00:00<00:00, 451MB/s]
 83%|########2 | 264M/319M [00:00<00:00, 431MB/s]
 97%|#########7| 310M/319M [00:00<00:00, 447MB/s]
100%|##########| 319M/319M [00:00<00:00, 445MB/s]
/pytorch/audio/src/torchaudio/pipelines/_source_separation_pipeline.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path)
Sample rate: 44100

4. 配置应用程序函数

由于 HDemucs 是一个大型且消耗内存的模型,因此很难有足够的内存一次将模型应用于整首歌曲。为了解决此限制,通过将歌曲分块成较小的片段并逐段运行模型,然后重新排列在一起,来获得完整歌曲的分离源。

执行此操作时,重要的是确保每个块之间存在一些重叠,以适应边缘处的伪影。由于模型的性质,有时边缘会包含不准确或不需要的声音。

我们在下面提供了一个分块和排列的示例实现。此实现采用每侧 1 秒的重叠,然后在每侧进行线性淡入和淡出。使用淡出的重叠,我将这些片段添加在一起,以确保整个过程中的音量恒定。这通过减少模型输出边缘的使用来适应伪影。

https://download.pytorch.org/torchaudio/tutorial-assets/HDemucs_Drawing.jpg
from torchaudio.transforms import Fade


def separate_sources(
    model,
    mix,
    segment=10.0,
    overlap=0.1,
    device=None,
):
    """
    Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.

    Args:
        segment (int): segment length in seconds
        device (torch.device, str, or None): if provided, device on which to
            execute the computation, otherwise `mix.device` is assumed.
            When `device` is different from `mix.device`, only local computations will
            be on `device`, while the entire tracks will be stored on `mix.device`.
    """
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")

    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model.forward(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final


def plot_spectrogram(stft, title="Spectrogram"):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    _, axis = plt.subplots(1, 1)
    axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
    axis.set_title(title)
    plt.tight_layout()

5. 运行模型

最后,我们运行模型并将单独的源文件存储在一个目录中

作为测试歌曲,我们将使用 MedleyDB 中 NightOwl 的 A Classic Education(知识共享署名-非商业性使用-相同方式共享 4.0 许可)。这也位于 MUSDB18-HQ 数据集的 train 源中。

为了使用不同的歌曲进行测试,可以更改下面的变量名称和 URL 以及参数,以不同的方式测试歌曲分离器。

# We download the audio file from our storage. Feel free to download another file and use audio from a specific path
SAMPLE_SONG = download_asset("tutorial-assets/hdemucs_mix.wav")
waveform, sample_rate = torchaudio.load(SAMPLE_SONG)  # replace SAMPLE_SONG with desired path for different song
waveform = waveform.to(device)
mixture = waveform

# parameters
segment: int = 10
overlap = 0.1

print("Separating track")

ref = waveform.mean(0)
waveform = (waveform - ref.mean()) / ref.std()  # normalization

sources = separate_sources(
    model,
    waveform[None],
    device=device,
    segment=segment,
    overlap=overlap,
)[0]
sources = sources * ref.std() + ref.mean()

sources_list = model.sources
sources = list(sources)

audios = dict(zip(sources_list, sources))
  0%|          | 0.00/28.8M [00:00<?, ?B/s]
 57%|#####7    | 16.5M/28.8M [00:00<00:00, 80.7MB/s]
100%|##########| 28.8M/28.8M [00:00<00:00, 104MB/s]
Separating track

5.1 分离音轨

已加载的默认预训练权重集具有 4 个源,它被分离成:鼓、贝斯、其他和人声,顺序如此。它们已存储到字典“audios”中,因此可以在那里访问。对于这四个源,每个源都有一个单独的单元格,这将创建音频、频谱图,并计算 SDR 分数。SDR 是信噪失真比,本质上是音频轨“质量”的表示。

N_FFT = 4096
N_HOP = 4
stft = torchaudio.transforms.Spectrogram(
    n_fft=N_FFT,
    hop_length=N_HOP,
    power=None,
)

5.2 音频分段和处理

下面是将音轨分段 5 秒以馈送到频谱图并计算各自的 SDR 分数的处理步骤。

def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor, source: str):
    print(
        "SDR score is:",
        separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
    )
    plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
    return Audio(predicted_source, rate=sample_rate)


segment_start = 150
segment_end = 155

frame_start = segment_start * sample_rate
frame_end = segment_end * sample_rate

drums_original = download_asset("tutorial-assets/hdemucs_drums_segment.wav")
bass_original = download_asset("tutorial-assets/hdemucs_bass_segment.wav")
vocals_original = download_asset("tutorial-assets/hdemucs_vocals_segment.wav")
other_original = download_asset("tutorial-assets/hdemucs_other_segment.wav")

drums_spec = audios["drums"][:, frame_start:frame_end].cpu()
drums, sample_rate = torchaudio.load(drums_original)

bass_spec = audios["bass"][:, frame_start:frame_end].cpu()
bass, sample_rate = torchaudio.load(bass_original)

vocals_spec = audios["vocals"][:, frame_start:frame_end].cpu()
vocals, sample_rate = torchaudio.load(vocals_original)

other_spec = audios["other"][:, frame_start:frame_end].cpu()
other, sample_rate = torchaudio.load(other_original)

mix_spec = mixture[:, frame_start:frame_end].cpu()
  0%|          | 0.00/1.68M [00:00<?, ?B/s]
100%|##########| 1.68M/1.68M [00:00<00:00, 67.9MB/s]

  0%|          | 0.00/1.68M [00:00<?, ?B/s]
100%|##########| 1.68M/1.68M [00:00<00:00, 102MB/s]

  0%|          | 0.00/1.68M [00:00<?, ?B/s]
100%|##########| 1.68M/1.68M [00:00<00:00, 171MB/s]

  0%|          | 0.00/1.68M [00:00<?, ?B/s]
100%|##########| 1.68M/1.68M [00:00<00:00, 120MB/s]

5.3 频谱图和音频

在接下来的 5 个单元格中,您可以看到带有各自音频的频谱图。可以使用频谱图清楚地可视化音频。

混合片段来自原始音轨,其余音轨是模型输出

# Mixture Clip
plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate)
Spectrogram - Mixture


鼓 SDR、频谱图和音频

# Drums Clip
output_results(drums, drums_spec, "drums")
Spectrogram - drums
SDR score is: 4.964477475897244


贝斯 SDR、频谱图和音频

# Bass Clip
output_results(bass, bass_spec, "bass")
Spectrogram - bass
SDR score is: 18.90589959575034


人声 SDR、频谱图和音频

# Vocals Audio
output_results(vocals, vocals_spec, "vocals")
Spectrogram - vocals
SDR score is: 8.792372276328596


其他 SDR、频谱图和音频

# Other Clip
output_results(other, other_spec, "other")
Spectrogram - other
SDR score is: 8.866964245665635


# Optionally, the full audios can be heard in from running the next 5
# cells. They will take a bit longer to load, so to run simply uncomment
# out the ``Audio`` cells for the respective track to produce the audio
# for the full song.
#

# Full Audio
# Audio(mixture, rate=sample_rate)

# Drums Audio
# Audio(audios["drums"], rate=sample_rate)

# Bass Audio
# Audio(audios["bass"], rate=sample_rate)

# Vocals Audio
# Audio(audios["vocals"], rate=sample_rate)

# Other Audio
# Audio(audios["other"], rate=sample_rate)

脚本总运行时间:(0 分钟 25.315 秒)

图库由 Sphinx-Gallery 生成

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源