• 文档 >
  • CTC 强制对齐 API 教程 >
  • 旧版本 (稳定)
快捷方式

CTC 强制对齐 API 教程

作者: Xiaohui Zhang, Moto Hira

强制对齐是将转录与语音对齐的过程。本教程展示了如何使用 torchaudio.functional.forced_align() 将转录与语音对齐,该函数是在 将语音技术扩展到 1,000 多种语言 的工作中开发的。

forced_align() 具有自定义 CPU 和 CUDA 实现,它们比上面的普通 Python 实现性能更高,并且更准确。它还可以使用特殊的 <star> 符号处理缺失的转录。

还有一个高级 API,torchaudio.pipelines.Wav2Vec2FABundle,它封装了本教程中解释的预处理/后处理,并简化了强制对齐的运行。 多语言数据的强制对齐 使用此 API 来说明如何对齐非英语转录。

准备

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)
2.5.0
2.5.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
import IPython
import matplotlib.pyplot as plt

import torchaudio.functional as F

首先,我们准备将要使用的语音数据和转录。

SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
waveform, _ = torchaudio.load(SPEECH_FILE)
TRANSCRIPT = "i had that curiosity beside me at this moment".split()

生成发射

forced_align() 接受发射和符号序列,并输出符号的时间戳及其得分。

发射代表了对符号的帧级概率分布,可以通过将波形传递到声学模型来获得。

符号是转录的数字表示。有多种方法可以对转录进行标记,但这里,我们简单地将字母映射到整数,这与我们将在下面用到的声学模型在训练时构建标签的方式相同。

我们将使用预训练的 Wav2Vec2 模型,torchaudio.pipelines.MMS_FA,来获得发射并对转录进行标记。

bundle = torchaudio.pipelines.MMS_FA

model = bundle.get_model(with_star=False).to(device)
with torch.inference_mode():
    emission, _ = model(waveform.to(device))
Downloading: "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt" to /root/.cache/torch/hub/checkpoints/model.pt

  0%|          | 0.00/1.18G [00:00<?, ?B/s]
  3%|2         | 30.5M/1.18G [00:00<00:03, 319MB/s]
  5%|5         | 61.0M/1.18G [00:00<00:03, 317MB/s]
  8%|7         | 91.8M/1.18G [00:00<00:03, 319MB/s]
 10%|#         | 122M/1.18G [00:00<00:03, 308MB/s]
 13%|#3        | 161M/1.18G [00:00<00:03, 342MB/s]
 16%|#6        | 195M/1.18G [00:00<00:03, 347MB/s]
 19%|#9        | 233M/1.18G [00:00<00:02, 365MB/s]
 22%|##2       | 268M/1.18G [00:00<00:02, 333MB/s]
 26%|##5       | 308M/1.18G [00:00<00:02, 357MB/s]
 29%|##8       | 348M/1.18G [00:01<00:02, 375MB/s]
 32%|###2      | 390M/1.18G [00:01<00:02, 392MB/s]
 36%|###5      | 427M/1.18G [00:01<00:02, 376MB/s]
 39%|###8      | 464M/1.18G [00:01<00:02, 364MB/s]
 41%|####1     | 499M/1.18G [00:01<00:02, 360MB/s]
 44%|####4     | 534M/1.18G [00:01<00:01, 364MB/s]
 47%|####7     | 569M/1.18G [00:01<00:01, 360MB/s]
 50%|#####     | 607M/1.18G [00:01<00:01, 369MB/s]
 53%|#####3    | 642M/1.18G [00:01<00:01, 368MB/s]
 57%|#####6    | 680M/1.18G [00:01<00:01, 376MB/s]
 59%|#####9    | 716M/1.18G [00:02<00:01, 375MB/s]
 63%|######2   | 753M/1.18G [00:02<00:01, 377MB/s]
 66%|######5   | 789M/1.18G [00:02<00:01, 374MB/s]
 69%|######8   | 828M/1.18G [00:02<00:01, 383MB/s]
 72%|#######1  | 864M/1.18G [00:02<00:00, 382MB/s]
 75%|#######5  | 906M/1.18G [00:02<00:00, 397MB/s]
 79%|#######8  | 947M/1.18G [00:02<00:00, 408MB/s]
 82%|########1 | 986M/1.18G [00:02<00:00, 408MB/s]
 85%|########5 | 1.00G/1.18G [00:02<00:00, 397MB/s]
 89%|########8 | 1.04G/1.18G [00:03<00:00, 417MB/s]
 92%|#########2| 1.08G/1.18G [00:03<00:00, 372MB/s]
 95%|#########5| 1.12G/1.18G [00:03<00:00, 370MB/s]
 98%|#########8| 1.15G/1.18G [00:03<00:00, 362MB/s]
100%|##########| 1.18G/1.18G [00:03<00:00, 369MB/s]
def plot_emission(emission):
    fig, ax = plt.subplots()
    ax.imshow(emission.cpu().T)
    ax.set_title("Frame-wise class probabilities")
    ax.set_xlabel("Time")
    ax.set_ylabel("Labels")
    fig.tight_layout()


plot_emission(emission[0])
Frame-wise class probabilities

对转录进行标记

我们创建一个字典,将每个标签映射到符号。

LABELS = bundle.get_labels(star=None)
DICTIONARY = bundle.get_dict(star=None)
for k, v in DICTIONARY.items():
    print(f"{k}: {v}")
-: 0
a: 1
i: 2
e: 3
n: 4
o: 5
u: 6
t: 7
s: 8
r: 9
m: 10
k: 11
l: 12
d: 13
g: 14
h: 15
y: 16
b: 17
p: 18
w: 19
c: 20
v: 21
j: 22
z: 23
f: 24
': 25
q: 26
x: 27

将转录转换为符号就像

tokenized_transcript = [DICTIONARY[c] for word in TRANSCRIPT for c in word]

for t in tokenized_transcript:
    print(t, end=" ")
print()
2 15 1 13 7 15 1 7 20 6 9 2 5 8 2 7 16 17 3 8 2 13 3 10 3 1 7 7 15 2 8 10 5 10 3 4 7

计算对齐

帧级对齐

现在我们调用 TorchAudio 的强制对齐 API 来计算帧级对齐。有关函数签名的详细信息,请参考 forced_align()

def align(emission, tokens):
    targets = torch.tensor([tokens], dtype=torch.int32, device=device)
    alignments, scores = F.forced_align(emission, targets, blank=0)

    alignments, scores = alignments[0], scores[0]  # remove batch dimension for simplicity
    scores = scores.exp()  # convert back to probability
    return alignments, scores


aligned_tokens, alignment_scores = align(emission, tokenized_transcript)

现在让我们看一下输出。

for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
    print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}")
  0:     0 [-], 1.00
  1:     0 [-], 1.00
  2:     0 [-], 1.00
  3:     0 [-], 1.00
  4:     0 [-], 1.00
  5:     0 [-], 1.00
  6:     0 [-], 1.00
  7:     0 [-], 1.00
  8:     0 [-], 1.00
  9:     0 [-], 1.00
 10:     0 [-], 1.00
 11:     0 [-], 1.00
 12:     0 [-], 1.00
 13:     0 [-], 1.00
 14:     0 [-], 1.00
 15:     0 [-], 1.00
 16:     0 [-], 1.00
 17:     0 [-], 1.00
 18:     0 [-], 1.00
 19:     0 [-], 1.00
 20:     0 [-], 1.00
 21:     0 [-], 1.00
 22:     0 [-], 1.00
 23:     0 [-], 1.00
 24:     0 [-], 1.00
 25:     0 [-], 1.00
 26:     0 [-], 1.00
 27:     0 [-], 1.00
 28:     0 [-], 1.00
 29:     0 [-], 1.00
 30:     0 [-], 1.00
 31:     0 [-], 1.00
 32:     2 [i], 1.00
 33:     0 [-], 1.00
 34:     0 [-], 1.00
 35:    15 [h], 1.00
 36:    15 [h], 0.93
 37:     1 [a], 1.00
 38:     0 [-], 0.96
 39:     0 [-], 1.00
 40:     0 [-], 1.00
 41:    13 [d], 1.00
 42:     0 [-], 1.00
 43:     0 [-], 0.97
 44:     7 [t], 1.00
 45:    15 [h], 1.00
 46:     0 [-], 0.98
 47:     1 [a], 1.00
 48:     0 [-], 1.00
 49:     0 [-], 1.00
 50:     7 [t], 1.00
 51:     0 [-], 1.00
 52:     0 [-], 1.00
 53:     0 [-], 1.00
 54:    20 [c], 1.00
 55:     0 [-], 1.00
 56:     0 [-], 1.00
 57:     0 [-], 1.00
 58:     6 [u], 1.00
 59:     6 [u], 0.96
 60:     0 [-], 1.00
 61:     0 [-], 1.00
 62:     0 [-], 0.53
 63:     9 [r], 1.00
 64:     0 [-], 1.00
 65:     2 [i], 1.00
 66:     0 [-], 1.00
 67:     0 [-], 1.00
 68:     0 [-], 1.00
 69:     0 [-], 1.00
 70:     0 [-], 1.00
 71:     0 [-], 0.96
 72:     5 [o], 1.00
 73:     0 [-], 1.00
 74:     0 [-], 1.00
 75:     0 [-], 1.00
 76:     0 [-], 1.00
 77:     0 [-], 1.00
 78:     0 [-], 1.00
 79:     8 [s], 1.00
 80:     0 [-], 1.00
 81:     0 [-], 1.00
 82:     0 [-], 0.99
 83:     2 [i], 1.00
 84:     0 [-], 1.00
 85:     7 [t], 1.00
 86:     0 [-], 1.00
 87:     0 [-], 1.00
 88:    16 [y], 1.00
 89:     0 [-], 1.00
 90:     0 [-], 1.00
 91:     0 [-], 1.00
 92:     0 [-], 1.00
 93:    17 [b], 1.00
 94:     0 [-], 1.00
 95:     3 [e], 1.00
 96:     0 [-], 1.00
 97:     0 [-], 1.00
 98:     0 [-], 1.00
 99:     0 [-], 1.00
100:     0 [-], 1.00
101:     8 [s], 1.00
102:     0 [-], 1.00
103:     0 [-], 1.00
104:     0 [-], 1.00
105:     0 [-], 1.00
106:     0 [-], 1.00
107:     0 [-], 1.00
108:     0 [-], 1.00
109:     0 [-], 0.64
110:     2 [i], 1.00
111:     0 [-], 1.00
112:     0 [-], 1.00
113:    13 [d], 1.00
114:     3 [e], 0.85
115:     0 [-], 1.00
116:    10 [m], 1.00
117:     0 [-], 1.00
118:     0 [-], 1.00
119:     3 [e], 1.00
120:     0 [-], 1.00
121:     0 [-], 1.00
122:     0 [-], 1.00
123:     0 [-], 1.00
124:     1 [a], 1.00
125:     0 [-], 1.00
126:     0 [-], 1.00
127:     7 [t], 1.00
128:     0 [-], 1.00
129:     7 [t], 1.00
130:    15 [h], 1.00
131:     0 [-], 0.79
132:     2 [i], 1.00
133:     0 [-], 1.00
134:     0 [-], 1.00
135:     0 [-], 1.00
136:     8 [s], 1.00
137:     0 [-], 1.00
138:     0 [-], 1.00
139:     0 [-], 1.00
140:     0 [-], 1.00
141:    10 [m], 1.00
142:     0 [-], 1.00
143:     0 [-], 1.00
144:     5 [o], 1.00
145:     0 [-], 1.00
146:     0 [-], 1.00
147:     0 [-], 1.00
148:    10 [m], 1.00
149:     0 [-], 1.00
150:     0 [-], 1.00
151:     3 [e], 1.00
152:     0 [-], 1.00
153:     4 [n], 1.00
154:     0 [-], 1.00
155:     7 [t], 1.00
156:     0 [-], 1.00
157:     0 [-], 1.00
158:     0 [-], 1.00
159:     0 [-], 1.00
160:     0 [-], 1.00
161:     0 [-], 1.00
162:     0 [-], 1.00
163:     0 [-], 1.00
164:     0 [-], 1.00
165:     0 [-], 1.00
166:     0 [-], 1.00
167:     0 [-], 1.00
168:     0 [-], 1.00

注意

对齐是用发射的帧坐标表示的,这与原始波形不同。

它包含空符号和重复符号。以下是对非空符号的解释。

31:     0 [-], 1.00
32:     2 [i], 1.00  "i" starts and ends
33:     0 [-], 1.00
34:     0 [-], 1.00
35:    15 [h], 1.00  "h" starts
36:    15 [h], 0.93  "h" ends
37:     1 [a], 1.00  "a" starts and ends
38:     0 [-], 0.96
39:     0 [-], 1.00
40:     0 [-], 1.00
41:    13 [d], 1.00  "d" starts and ends
42:     0 [-], 1.00

注意

当相同符号出现在空符号之后时,它不被视为重复,而是被视为新的出现。

a a a b -> a b
a - - b -> a b
a a - b -> a b
a - a b -> a a b
  ^^^       ^^^

符号级对齐

下一步是解决重复问题,以便每个对齐不依赖于先前的对齐。 torchaudio.functional.merge_tokens() 计算 TokenSpan 对象,它表示转录中的哪个符号在哪个时间跨度内出现。

token_spans = F.merge_tokens(aligned_tokens, alignment_scores)

print("Token\tTime\tScore")
for s in token_spans:
    print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
Token   Time    Score
i       [ 32,  33)      1.00
h       [ 35,  37)      0.96
a       [ 37,  38)      1.00
d       [ 41,  42)      1.00
t       [ 44,  45)      1.00
h       [ 45,  46)      1.00
a       [ 47,  48)      1.00
t       [ 50,  51)      1.00
c       [ 54,  55)      1.00
u       [ 58,  60)      0.98
r       [ 63,  64)      1.00
i       [ 65,  66)      1.00
o       [ 72,  73)      1.00
s       [ 79,  80)      1.00
i       [ 83,  84)      1.00
t       [ 85,  86)      1.00
y       [ 88,  89)      1.00
b       [ 93,  94)      1.00
e       [ 95,  96)      1.00
s       [101, 102)      1.00
i       [110, 111)      1.00
d       [113, 114)      1.00
e       [114, 115)      0.85
m       [116, 117)      1.00
e       [119, 120)      1.00
a       [124, 125)      1.00
t       [127, 128)      1.00
t       [129, 130)      1.00
h       [130, 131)      1.00
i       [132, 133)      1.00
s       [136, 137)      1.00
m       [141, 142)      1.00
o       [144, 145)      1.00
m       [148, 149)      1.00
e       [151, 152)      1.00
n       [153, 154)      1.00
t       [155, 156)      1.00

词级对齐

现在让我们将符号级对齐分组到词级对齐。

def unflatten(list_, lengths):
    assert len(list_) == sum(lengths)
    i = 0
    ret = []
    for l in lengths:
        ret.append(list_[i : i + l])
        i += l
    return ret


word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT])

音频预览

# Compute average score weighted by the span length
def _score(spans):
    return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)


def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate):
    ratio = waveform.size(1) / num_frames
    x0 = int(ratio * spans[0].start)
    x1 = int(ratio * spans[-1].end)
    print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
    segment = waveform[:, x0:x1]
    return IPython.display.Audio(segment.numpy(), rate=sample_rate)


num_frames = emission.size(1)
# Generate the audio for each segment
print(TRANSCRIPT)
IPython.display.Audio(SPEECH_FILE)
['i', 'had', 'that', 'curiosity', 'beside', 'me', 'at', 'this', 'moment']


preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT[0])
i (1.00): 0.644 - 0.664 sec


preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT[1])
had (0.98): 0.704 - 0.845 sec


preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT[2])
that (1.00): 0.885 - 1.026 sec


preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT[3])
curiosity (1.00): 1.086 - 1.790 sec


preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT[4])
beside (0.97): 1.871 - 2.314 sec


preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT[5])
me (1.00): 2.334 - 2.414 sec


preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT[6])
at (1.00): 2.495 - 2.575 sec


preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7])
this (1.00): 2.595 - 2.756 sec


preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8])
moment (1.00): 2.837 - 3.138 sec


可视化

现在让我们看一下对齐结果,并将原始语音分割成词。

def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate):
    ratio = waveform.size(1) / emission.size(1) / sample_rate

    fig, axes = plt.subplots(2, 1)
    axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
    axes[0].set_title("Emission")
    axes[0].set_xticks([])

    axes[1].specgram(waveform[0], Fs=sample_rate)
    for t_spans, chars in zip(token_spans, transcript):
        t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1
        axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white")
        axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
        axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)

        for span, char in zip(t_spans, chars):
            t0 = span.start * ratio
            axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False)

    axes[1].set_xlabel("time [second]")
    axes[1].set_xlim([0, None])
    fig.tight_layout()
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
Emission

blank 符号的不一致处理

在将符号级对齐拆分为词时,你会注意到某些空符号的处理方式不同,这使得对结果的解释有些含糊。

当我们绘制得分时,这一点很容易看到。下图显示了词区域和非词区域,以及非空符号的帧级得分。

def plot_scores(word_spans, scores):
    fig, ax = plt.subplots()
    span_xs, span_hs = [], []
    ax.axvspan(word_spans[0][0].start - 0.05, word_spans[-1][-1].end + 0.05, facecolor="paleturquoise", edgecolor="none", zorder=-1)
    for t_span in word_spans:
        for span in t_span:
            for t in range(span.start, span.end):
                span_xs.append(t + 0.5)
                span_hs.append(scores[t].item())
            ax.annotate(LABELS[span.token], (span.start, -0.07))
        ax.axvspan(t_span[0].start - 0.05, t_span[-1].end + 0.05, facecolor="mistyrose", edgecolor="none", zorder=-1)
    ax.bar(span_xs, span_hs, color="lightsalmon", edgecolor="coral")
    ax.set_title("Frame-level scores and word segments")
    ax.set_ylim(-0.1, None)
    ax.grid(True, axis="y")
    ax.axhline(0, color="black")
    fig.tight_layout()


plot_scores(word_spans, alignment_scores)
Frame-level scores and word segments

在此图中,空符号是那些没有垂直条的突出显示区域。您可以看到,有一些空符号被解释为词的一部分(以红色突出显示),而其他空符号(以蓝色突出显示)则没有。

造成这种情况的一个原因是,模型在训练时没有词边界标签。空符号不仅被视为重复,还被视为词之间的静默。

但随后,一个问题出现了。词结束后的帧或词结束附近的帧应该被视为静默还是重复?

在上例中,如果你回到之前的光谱图和词区域的图,你会看到在“curiosity”中的“y”之后,在多个频率桶中仍然存在一些活动。

如果该帧包含在该词中,会更准确吗?

不幸的是,CTC 无法提供此问题的全面解决方案。使用 CTC 训练的模型已知会表现出“尖峰”响应,也就是说,它们倾向于在标签出现时出现峰值,但峰值不会持续整个标签的持续时间。(注意:预训练的 Wav2Vec2 模型倾向于在标签出现时开始出现峰值,但这并不总是这样。)

[Zeyer 等人,2021] 对 CTC 的尖峰行为进行了深入的分析。我们鼓励有兴趣了解更多信息的人参考这篇论文。以下是论文中的一段引文,它正是我们在这里面临的具体问题。

在某些情况下,尖峰行为可能是有问题的, 例如,当应用程序要求不使用空白标签时, 例如,为了获得语音的有效时间准确对齐 到转录文本。

高级:处理带有 <star> 标记的转录文本

现在让我们看看当转录文本部分缺失时,如何使用 <star> 标记来提高对齐质量,该标记能够对任何标记进行建模。

在这里,我们使用与上面相同的英文示例。但我们从转录文本中删除了开头文本 “i had that curiosity beside me at”。将音频与这种转录文本对齐会导致现有的单词 “this” 的对齐错误。但是,这个问题可以通过使用 <star> 标记来对缺失文本进行建模来缓解。

首先,我们将字典扩展为包括 <star> 标记。

DICTIONARY["*"] = len(DICTIONARY)

接下来,我们将发射张量扩展为对应于 <star> 标记的额外维度。

star_dim = torch.zeros((1, emission.size(1), 1), device=emission.device, dtype=emission.dtype)
emission = torch.cat((emission, star_dim), 2)

assert len(DICTIONARY) == emission.shape[2]

plot_emission(emission[0])
Frame-wise class probabilities

以下函数将所有流程结合起来,并一次性从发射中计算单词片段。

def compute_alignments(emission, transcript, dictionary):
    tokens = [dictionary[char] for word in transcript for char in word]
    alignment, scores = align(emission, tokens)
    token_spans = F.merge_tokens(alignment, scores)
    word_spans = unflatten(token_spans, [len(word) for word in transcript])
    return word_spans

完整转录文本

word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
Emission

带有 <star> 标记的部分转录文本

现在我们将转录文本的第一部分替换为 <star> 标记。

transcript = "* this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)
Emission
preview_word(waveform, word_spans[0], num_frames, transcript[0])
* (1.00): 0.000 - 2.595 sec


preview_word(waveform, word_spans[1], num_frames, transcript[1])
this (1.00): 2.595 - 2.756 sec


preview_word(waveform, word_spans[2], num_frames, transcript[2])
moment (1.00): 2.837 - 3.138 sec


不带 <star> 标记的部分转录文本

作为对比,以下内容是对齐不使用 <star> 标记的部分转录文本。它展示了 <star> 标记在处理删除错误方面的效果。

transcript = "this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)
Emission

结论

在本教程中,我们学习了如何使用 torchaudio 的强制对齐 API 对齐和分割语音文件,并演示了其中一种高级用法:如何引入 <star> 标记可以提高转录错误存在时的对齐精度。

鸣谢

感谢 Vineel PratapZhaoheng Ni 开发和开源强制对齐器 API。

脚本的总运行时间:(0 分钟 6.927 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源