注意
点击此处下载完整的示例代码
CTC 强制对齐 API 教程¶
作者: Xiaohui Zhang, Moto Hira
强制对齐是一个将文本记录与语音对齐的过程。本教程展示了如何使用 torchaudio.functional.forced_align()
将文本记录与语音对齐,该功能是根据 Scaling Speech Technology to 1,000+ Languages 的工作开发的。
forced_align()
具有自定义的 CPU 和 CUDA 实现,这些实现比上面的 vanilla Python 实现性能更高,并且更准确。它还可以处理带有特殊 <star>
令牌的缺失文本记录。
还有一个高级 API,torchaudio.pipelines.Wav2Vec2FABundle
,它封装了本教程中解释的预处理/后处理,并使运行强制对齐变得容易。多语言数据强制对齐 使用此 API 说明如何对齐非英语文本记录。
准备工作¶
import torch
import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
2.6.0
2.6.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
import IPython
import matplotlib.pyplot as plt
import torchaudio.functional as F
首先,我们准备要使用的语音数据和文本记录。
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
waveform, _ = torchaudio.load(SPEECH_FILE)
TRANSCRIPT = "i had that curiosity beside me at this moment".split()
生成发射概率¶
forced_align()
接受发射概率和令牌序列,并输出令牌的时间戳及其分数。
发射概率表示令牌的帧级别概率分布,可以通过将波形传递给声学模型来获得。
令牌是文本记录的数值表示。有很多方法可以对文本记录进行令牌化,但在这里,我们简单地将字母表映射为整数,这就是我们即将使用的声学模型在训练时构建标签的方式。
我们将使用预训练的 Wav2Vec2 模型,torchaudio.pipelines.MMS_FA
,以获得发射概率并对文本记录进行令牌化。
bundle = torchaudio.pipelines.MMS_FA
model = bundle.get_model(with_star=False).to(device)
with torch.inference_mode():
emission, _ = model(waveform.to(device))
Downloading: "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt" to /root/.cache/torch/hub/checkpoints/model.pt
0%| | 0.00/1.18G [00:00<?, ?B/s]
3%|2 | 30.5M/1.18G [00:00<00:03, 319MB/s]
5%|5 | 61.0M/1.18G [00:00<00:03, 317MB/s]
8%|7 | 91.8M/1.18G [00:00<00:03, 319MB/s]
10%|# | 122M/1.18G [00:00<00:03, 308MB/s]
13%|#3 | 161M/1.18G [00:00<00:03, 342MB/s]
16%|#6 | 195M/1.18G [00:00<00:03, 347MB/s]
19%|#9 | 233M/1.18G [00:00<00:02, 365MB/s]
22%|##2 | 268M/1.18G [00:00<00:02, 333MB/s]
26%|##5 | 308M/1.18G [00:00<00:02, 357MB/s]
29%|##8 | 348M/1.18G [00:01<00:02, 375MB/s]
32%|###2 | 390M/1.18G [00:01<00:02, 392MB/s]
36%|###5 | 427M/1.18G [00:01<00:02, 376MB/s]
39%|###8 | 464M/1.18G [00:01<00:02, 364MB/s]
41%|####1 | 499M/1.18G [00:01<00:02, 360MB/s]
44%|####4 | 534M/1.18G [00:01<00:01, 364MB/s]
47%|####7 | 569M/1.18G [00:01<00:01, 360MB/s]
50%|##### | 607M/1.18G [00:01<00:01, 369MB/s]
53%|#####3 | 642M/1.18G [00:01<00:01, 368MB/s]
57%|#####6 | 680M/1.18G [00:01<00:01, 376MB/s]
59%|#####9 | 716M/1.18G [00:02<00:01, 375MB/s]
63%|######2 | 753M/1.18G [00:02<00:01, 377MB/s]
66%|######5 | 789M/1.18G [00:02<00:01, 374MB/s]
69%|######8 | 828M/1.18G [00:02<00:01, 383MB/s]
72%|#######1 | 864M/1.18G [00:02<00:00, 382MB/s]
75%|#######5 | 906M/1.18G [00:02<00:00, 397MB/s]
79%|#######8 | 947M/1.18G [00:02<00:00, 408MB/s]
82%|########1 | 986M/1.18G [00:02<00:00, 408MB/s]
85%|########5 | 1.00G/1.18G [00:02<00:00, 397MB/s]
89%|########8 | 1.04G/1.18G [00:03<00:00, 417MB/s]
92%|#########2| 1.08G/1.18G [00:03<00:00, 372MB/s]
95%|#########5| 1.12G/1.18G [00:03<00:00, 370MB/s]
98%|#########8| 1.15G/1.18G [00:03<00:00, 362MB/s]
100%|##########| 1.18G/1.18G [00:03<00:00, 369MB/s]
data:image/s3,"s3://crabby-images/6618b/6618b9274b5077d3dfdf7840c79070e1107aa5cb" alt="Frame-wise class probabilities"
令牌化文本记录¶
我们创建一个字典,将每个标签映射到令牌。
LABELS = bundle.get_labels(star=None)
DICTIONARY = bundle.get_dict(star=None)
for k, v in DICTIONARY.items():
print(f"{k}: {v}")
-: 0
a: 1
i: 2
e: 3
n: 4
o: 5
u: 6
t: 7
s: 8
r: 9
m: 10
k: 11
l: 12
d: 13
g: 14
h: 15
y: 16
b: 17
p: 18
w: 19
c: 20
v: 21
j: 22
z: 23
f: 24
': 25
q: 26
x: 27
将文本记录转换为令牌非常简单,如下所示:
tokenized_transcript = [DICTIONARY[c] for word in TRANSCRIPT for c in word]
for t in tokenized_transcript:
print(t, end=" ")
print()
2 15 1 13 7 15 1 7 20 6 9 2 5 8 2 7 16 17 3 8 2 13 3 10 3 1 7 7 15 2 8 10 5 10 3 4 7
计算对齐¶
帧级别对齐¶
现在我们调用 TorchAudio 的强制对齐 API 来计算帧级别对齐。有关函数签名的详细信息,请参阅 forced_align()
。
def align(emission, tokens):
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
alignments, scores = F.forced_align(emission, targets, blank=0)
alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
scores = scores.exp() # convert back to probability
return alignments, scores
aligned_tokens, alignment_scores = align(emission, tokenized_transcript)
现在让我们看一下输出。
for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}")
0: 0 [-], 1.00
1: 0 [-], 1.00
2: 0 [-], 1.00
3: 0 [-], 1.00
4: 0 [-], 1.00
5: 0 [-], 1.00
6: 0 [-], 1.00
7: 0 [-], 1.00
8: 0 [-], 1.00
9: 0 [-], 1.00
10: 0 [-], 1.00
11: 0 [-], 1.00
12: 0 [-], 1.00
13: 0 [-], 1.00
14: 0 [-], 1.00
15: 0 [-], 1.00
16: 0 [-], 1.00
17: 0 [-], 1.00
18: 0 [-], 1.00
19: 0 [-], 1.00
20: 0 [-], 1.00
21: 0 [-], 1.00
22: 0 [-], 1.00
23: 0 [-], 1.00
24: 0 [-], 1.00
25: 0 [-], 1.00
26: 0 [-], 1.00
27: 0 [-], 1.00
28: 0 [-], 1.00
29: 0 [-], 1.00
30: 0 [-], 1.00
31: 0 [-], 1.00
32: 2 [i], 1.00
33: 0 [-], 1.00
34: 0 [-], 1.00
35: 15 [h], 1.00
36: 15 [h], 0.93
37: 1 [a], 1.00
38: 0 [-], 0.96
39: 0 [-], 1.00
40: 0 [-], 1.00
41: 13 [d], 1.00
42: 0 [-], 1.00
43: 0 [-], 0.97
44: 7 [t], 1.00
45: 15 [h], 1.00
46: 0 [-], 0.98
47: 1 [a], 1.00
48: 0 [-], 1.00
49: 0 [-], 1.00
50: 7 [t], 1.00
51: 0 [-], 1.00
52: 0 [-], 1.00
53: 0 [-], 1.00
54: 20 [c], 1.00
55: 0 [-], 1.00
56: 0 [-], 1.00
57: 0 [-], 1.00
58: 6 [u], 1.00
59: 6 [u], 0.96
60: 0 [-], 1.00
61: 0 [-], 1.00
62: 0 [-], 0.53
63: 9 [r], 1.00
64: 0 [-], 1.00
65: 2 [i], 1.00
66: 0 [-], 1.00
67: 0 [-], 1.00
68: 0 [-], 1.00
69: 0 [-], 1.00
70: 0 [-], 1.00
71: 0 [-], 0.96
72: 5 [o], 1.00
73: 0 [-], 1.00
74: 0 [-], 1.00
75: 0 [-], 1.00
76: 0 [-], 1.00
77: 0 [-], 1.00
78: 0 [-], 1.00
79: 8 [s], 1.00
80: 0 [-], 1.00
81: 0 [-], 1.00
82: 0 [-], 0.99
83: 2 [i], 1.00
84: 0 [-], 1.00
85: 7 [t], 1.00
86: 0 [-], 1.00
87: 0 [-], 1.00
88: 16 [y], 1.00
89: 0 [-], 1.00
90: 0 [-], 1.00
91: 0 [-], 1.00
92: 0 [-], 1.00
93: 17 [b], 1.00
94: 0 [-], 1.00
95: 3 [e], 1.00
96: 0 [-], 1.00
97: 0 [-], 1.00
98: 0 [-], 1.00
99: 0 [-], 1.00
100: 0 [-], 1.00
101: 8 [s], 1.00
102: 0 [-], 1.00
103: 0 [-], 1.00
104: 0 [-], 1.00
105: 0 [-], 1.00
106: 0 [-], 1.00
107: 0 [-], 1.00
108: 0 [-], 1.00
109: 0 [-], 0.64
110: 2 [i], 1.00
111: 0 [-], 1.00
112: 0 [-], 1.00
113: 13 [d], 1.00
114: 3 [e], 0.85
115: 0 [-], 1.00
116: 10 [m], 1.00
117: 0 [-], 1.00
118: 0 [-], 1.00
119: 3 [e], 1.00
120: 0 [-], 1.00
121: 0 [-], 1.00
122: 0 [-], 1.00
123: 0 [-], 1.00
124: 1 [a], 1.00
125: 0 [-], 1.00
126: 0 [-], 1.00
127: 7 [t], 1.00
128: 0 [-], 1.00
129: 7 [t], 1.00
130: 15 [h], 1.00
131: 0 [-], 0.79
132: 2 [i], 1.00
133: 0 [-], 1.00
134: 0 [-], 1.00
135: 0 [-], 1.00
136: 8 [s], 1.00
137: 0 [-], 1.00
138: 0 [-], 1.00
139: 0 [-], 1.00
140: 0 [-], 1.00
141: 10 [m], 1.00
142: 0 [-], 1.00
143: 0 [-], 1.00
144: 5 [o], 1.00
145: 0 [-], 1.00
146: 0 [-], 1.00
147: 0 [-], 1.00
148: 10 [m], 1.00
149: 0 [-], 1.00
150: 0 [-], 1.00
151: 3 [e], 1.00
152: 0 [-], 1.00
153: 4 [n], 1.00
154: 0 [-], 1.00
155: 7 [t], 1.00
156: 0 [-], 1.00
157: 0 [-], 1.00
158: 0 [-], 1.00
159: 0 [-], 1.00
160: 0 [-], 1.00
161: 0 [-], 1.00
162: 0 [-], 1.00
163: 0 [-], 1.00
164: 0 [-], 1.00
165: 0 [-], 1.00
166: 0 [-], 1.00
167: 0 [-], 1.00
168: 0 [-], 1.00
注意
对齐以发射概率的帧坐标表示,这与原始波形不同。
它包含空白令牌和重复令牌。以下是非空白令牌的解释。
31: 0 [-], 1.00
32: 2 [i], 1.00 "i" starts and ends
33: 0 [-], 1.00
34: 0 [-], 1.00
35: 15 [h], 1.00 "h" starts
36: 15 [h], 0.93 "h" ends
37: 1 [a], 1.00 "a" starts and ends
38: 0 [-], 0.96
39: 0 [-], 1.00
40: 0 [-], 1.00
41: 13 [d], 1.00 "d" starts and ends
42: 0 [-], 1.00
注意
当空白令牌之后出现相同的令牌时,它不被视为重复,而是新的出现。
a a a b -> a b
a - - b -> a b
a a - b -> a b
a - a b -> a a b
^^^ ^^^
令牌级别对齐¶
下一步是解决重复问题,以便每个对齐都不依赖于先前的对齐。torchaudio.functional.merge_tokens()
计算 TokenSpan
对象,该对象表示文本记录中的哪个令牌在哪个时间跨度内出现。
token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
print("Token\tTime\tScore")
for s in token_spans:
print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
Token Time Score
i [ 32, 33) 1.00
h [ 35, 37) 0.96
a [ 37, 38) 1.00
d [ 41, 42) 1.00
t [ 44, 45) 1.00
h [ 45, 46) 1.00
a [ 47, 48) 1.00
t [ 50, 51) 1.00
c [ 54, 55) 1.00
u [ 58, 60) 0.98
r [ 63, 64) 1.00
i [ 65, 66) 1.00
o [ 72, 73) 1.00
s [ 79, 80) 1.00
i [ 83, 84) 1.00
t [ 85, 86) 1.00
y [ 88, 89) 1.00
b [ 93, 94) 1.00
e [ 95, 96) 1.00
s [101, 102) 1.00
i [110, 111) 1.00
d [113, 114) 1.00
e [114, 115) 0.85
m [116, 117) 1.00
e [119, 120) 1.00
a [124, 125) 1.00
t [127, 128) 1.00
t [129, 130) 1.00
h [130, 131) 1.00
i [132, 133) 1.00
s [136, 137) 1.00
m [141, 142) 1.00
o [144, 145) 1.00
m [148, 149) 1.00
e [151, 152) 1.00
n [153, 154) 1.00
t [155, 156) 1.00
词级别对齐¶
现在让我们将令牌级别对齐分组为词级别对齐。
def unflatten(list_, lengths):
assert len(list_) == sum(lengths)
i = 0
ret = []
for l in lengths:
ret.append(list_[i : i + l])
i += l
return ret
word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT])
音频预览¶
# Compute average score weighted by the span length
def _score(spans):
return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)
def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / num_frames
x0 = int(ratio * spans[0].start)
x1 = int(ratio * spans[-1].end)
print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
num_frames = emission.size(1)
# Generate the audio for each segment
print(TRANSCRIPT)
IPython.display.Audio(SPEECH_FILE)
['i', 'had', 'that', 'curiosity', 'beside', 'me', 'at', 'this', 'moment']
preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT[0])
i (1.00): 0.644 - 0.664 sec
preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT[1])
had (0.98): 0.704 - 0.845 sec
preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT[2])
that (1.00): 0.885 - 1.026 sec
preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT[3])
curiosity (1.00): 1.086 - 1.790 sec
preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT[4])
beside (0.97): 1.871 - 2.314 sec
preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT[5])
me (1.00): 2.334 - 2.414 sec
preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT[6])
at (1.00): 2.495 - 2.575 sec
preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7])
this (1.00): 2.595 - 2.756 sec
preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8])
moment (1.00): 2.837 - 3.138 sec
可视化¶
现在让我们看一下对齐结果,并将原始语音分割成词。
def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / emission.size(1) / sample_rate
fig, axes = plt.subplots(2, 1)
axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
axes[0].set_title("Emission")
axes[0].set_xticks([])
axes[1].specgram(waveform[0], Fs=sample_rate)
for t_spans, chars in zip(token_spans, transcript):
t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1
axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white")
axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span, char in zip(t_spans, chars):
t0 = span.start * ratio
axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False)
axes[1].set_xlabel("time [second]")
axes[1].set_xlim([0, None])
fig.tight_layout()
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
data:image/s3,"s3://crabby-images/2d104/2d104d8d3a37309009cd3fad2bd8309352c4a1f5" alt="Emission"
对 blank
令牌的不一致处理¶
当将令牌级别对齐拆分为词时,您会注意到某些空白令牌的处理方式不同,这使得结果的解释有些模糊。
当我们绘制分数时,这很容易看出。下图显示了词区域和非词区域,以及非空白令牌的帧级别分数。
def plot_scores(word_spans, scores):
fig, ax = plt.subplots()
span_xs, span_hs = [], []
ax.axvspan(word_spans[0][0].start - 0.05, word_spans[-1][-1].end + 0.05, facecolor="paleturquoise", edgecolor="none", zorder=-1)
for t_span in word_spans:
for span in t_span:
for t in range(span.start, span.end):
span_xs.append(t + 0.5)
span_hs.append(scores[t].item())
ax.annotate(LABELS[span.token], (span.start, -0.07))
ax.axvspan(t_span[0].start - 0.05, t_span[-1].end + 0.05, facecolor="mistyrose", edgecolor="none", zorder=-1)
ax.bar(span_xs, span_hs, color="lightsalmon", edgecolor="coral")
ax.set_title("Frame-level scores and word segments")
ax.set_ylim(-0.1, None)
ax.grid(True, axis="y")
ax.axhline(0, color="black")
fig.tight_layout()
plot_scores(word_spans, alignment_scores)
data:image/s3,"s3://crabby-images/5c0d6/5c0d680cc8129d74bb4d09728b23b7864a84e549" alt="Frame-level scores and word segments"
在此图中,空白令牌是那些没有竖线的突出显示区域。您可以看到,有些空白令牌被解释为词的一部分(红色突出显示),而另一些(蓝色突出显示)则不是。
其中一个原因是模型在训练时没有词边界的标签。空白令牌不仅被视为重复,还被视为词与词之间的静音。
但是,随之出现一个问题。在一个词之后或接近结尾的帧应该是静音还是重复?
在上面的示例中,如果您回到之前的频谱图和词区域图,您会看到在 “curiosity” 中的 “y” 之后,在多个频率桶中仍然有一些活动。
如果将该帧包含在词中是否更准确?
遗憾的是,CTC 没有为此提供全面的解决方案。用 CTC 训练的模型已知会表现出 “峰值” 响应,也就是说,它们倾向于在标签出现时出现峰值,但峰值不会持续标签的持续时间。(注意:预训练的 Wav2Vec2 模型倾向于在标签出现开始时出现峰值,但这并非总是如此。)
[Zeyer et al., 2021] 对 CTC 的峰值行为进行了深入分析。我们鼓励那些有兴趣了解更多的人参考该论文。以下是论文中的引述,这正是我们在这里面临的问题。
峰值行为在某些情况下可能会成为问题, 例如,当应用程序需要不使用空白标签时, 例如,为了获得有意义的时间精确的音素对齐到文本记录。
高级:处理带有 <star>
令牌的文本记录¶
现在让我们看一下当文本记录部分缺失时,如何使用能够建模任何令牌的 <star>
令牌来提高对齐质量。
这里我们使用与上面相同的英语示例。但是我们从文本记录中删除了开头的文本 “i had that curiosity beside me at”
。将音频与这样的文本记录对齐会导致现有词 “this” 的错误对齐。但是,可以通过使用 <star>
令牌来建模缺失的文本,从而缓解此问题。
首先,我们扩展字典以包含 <star>
令牌。
DICTIONARY["*"] = len(DICTIONARY)
接下来,我们使用与 <star>
令牌相对应的额外维度来扩展发射概率张量。
star_dim = torch.zeros((1, emission.size(1), 1), device=emission.device, dtype=emission.dtype)
emission = torch.cat((emission, star_dim), 2)
assert len(DICTIONARY) == emission.shape[2]
plot_emission(emission[0])
data:image/s3,"s3://crabby-images/de486/de4867489c3039cb4d95bd72bf45b75723350128" alt="Frame-wise class probabilities"
以下函数结合了所有过程,并从一次性发射概率中计算词片段。
def compute_alignments(emission, transcript, dictionary):
tokens = [dictionary[char] for word in transcript for char in word]
alignment, scores = align(emission, tokens)
token_spans = F.merge_tokens(alignment, scores)
word_spans = unflatten(token_spans, [len(word) for word in transcript])
return word_spans
完整文本记录¶
word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
data:image/s3,"s3://crabby-images/fb128/fb1288baae57d1df035059697b60439c3fff19a9" alt="Emission"
带有 <star>
令牌的部分文本记录¶
现在我们用 <star>
令牌替换文本记录的第一部分。
transcript = "* this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)
data:image/s3,"s3://crabby-images/6cf80/6cf80c93720043d13ebe7bd3ace06ac7df58901c" alt="Emission"
preview_word(waveform, word_spans[0], num_frames, transcript[0])
* (1.00): 0.000 - 2.595 sec
preview_word(waveform, word_spans[1], num_frames, transcript[1])
this (1.00): 2.595 - 2.756 sec
preview_word(waveform, word_spans[2], num_frames, transcript[2])
moment (1.00): 2.837 - 3.138 sec
不带 <star>
令牌的部分文本记录¶
作为比较,以下代码在不使用 <star>
令牌的情况下对齐部分文本记录。它演示了 <star>
令牌在处理删除错误方面的效果。
transcript = "this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)
data:image/s3,"s3://crabby-images/4113d/4113d608139b81bfb1834e82ca6f388bb6b1f1b7" alt="Emission"
结论¶
在本教程中,我们研究了如何使用 torchaudio 的强制对齐 API 来对齐和分割语音文件,并演示了一个高级用法:当存在转录错误时,引入 <star>
令牌如何提高对齐精度。
致谢¶
感谢 Vineel Pratap 和 Zhaoheng Ni 开发和开源强制对齐器 API。
脚本的总运行时间: (0 分钟 6.927 秒)