注意
单击 此处 下载完整的示例代码
CTC 强制对齐 API 教程¶
作者: Xiaohui Zhang, Moto Hira
强制对齐是将转录与语音对齐的过程。本教程展示了如何使用 torchaudio.functional.forced_align()
将转录与语音对齐,该函数是在 将语音技术扩展到 1,000 多种语言 的工作中开发的。
forced_align()
具有自定义 CPU 和 CUDA 实现,它们比上面的普通 Python 实现性能更高,并且更准确。它还可以使用特殊的 <star>
符号处理缺失的转录。
还有一个高级 API,torchaudio.pipelines.Wav2Vec2FABundle
,它封装了本教程中解释的预处理/后处理,并简化了强制对齐的运行。 多语言数据的强制对齐 使用此 API 来说明如何对齐非英语转录。
准备¶
import torch
import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
2.5.0
2.5.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
import IPython
import matplotlib.pyplot as plt
import torchaudio.functional as F
首先,我们准备将要使用的语音数据和转录。
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
waveform, _ = torchaudio.load(SPEECH_FILE)
TRANSCRIPT = "i had that curiosity beside me at this moment".split()
生成发射¶
forced_align()
接受发射和符号序列,并输出符号的时间戳及其得分。
发射代表了对符号的帧级概率分布,可以通过将波形传递到声学模型来获得。
符号是转录的数字表示。有多种方法可以对转录进行标记,但这里,我们简单地将字母映射到整数,这与我们将在下面用到的声学模型在训练时构建标签的方式相同。
我们将使用预训练的 Wav2Vec2 模型,torchaudio.pipelines.MMS_FA
,来获得发射并对转录进行标记。
bundle = torchaudio.pipelines.MMS_FA
model = bundle.get_model(with_star=False).to(device)
with torch.inference_mode():
emission, _ = model(waveform.to(device))
Downloading: "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt" to /root/.cache/torch/hub/checkpoints/model.pt
0%| | 0.00/1.18G [00:00<?, ?B/s]
3%|2 | 30.5M/1.18G [00:00<00:03, 319MB/s]
5%|5 | 61.0M/1.18G [00:00<00:03, 317MB/s]
8%|7 | 91.8M/1.18G [00:00<00:03, 319MB/s]
10%|# | 122M/1.18G [00:00<00:03, 308MB/s]
13%|#3 | 161M/1.18G [00:00<00:03, 342MB/s]
16%|#6 | 195M/1.18G [00:00<00:03, 347MB/s]
19%|#9 | 233M/1.18G [00:00<00:02, 365MB/s]
22%|##2 | 268M/1.18G [00:00<00:02, 333MB/s]
26%|##5 | 308M/1.18G [00:00<00:02, 357MB/s]
29%|##8 | 348M/1.18G [00:01<00:02, 375MB/s]
32%|###2 | 390M/1.18G [00:01<00:02, 392MB/s]
36%|###5 | 427M/1.18G [00:01<00:02, 376MB/s]
39%|###8 | 464M/1.18G [00:01<00:02, 364MB/s]
41%|####1 | 499M/1.18G [00:01<00:02, 360MB/s]
44%|####4 | 534M/1.18G [00:01<00:01, 364MB/s]
47%|####7 | 569M/1.18G [00:01<00:01, 360MB/s]
50%|##### | 607M/1.18G [00:01<00:01, 369MB/s]
53%|#####3 | 642M/1.18G [00:01<00:01, 368MB/s]
57%|#####6 | 680M/1.18G [00:01<00:01, 376MB/s]
59%|#####9 | 716M/1.18G [00:02<00:01, 375MB/s]
63%|######2 | 753M/1.18G [00:02<00:01, 377MB/s]
66%|######5 | 789M/1.18G [00:02<00:01, 374MB/s]
69%|######8 | 828M/1.18G [00:02<00:01, 383MB/s]
72%|#######1 | 864M/1.18G [00:02<00:00, 382MB/s]
75%|#######5 | 906M/1.18G [00:02<00:00, 397MB/s]
79%|#######8 | 947M/1.18G [00:02<00:00, 408MB/s]
82%|########1 | 986M/1.18G [00:02<00:00, 408MB/s]
85%|########5 | 1.00G/1.18G [00:02<00:00, 397MB/s]
89%|########8 | 1.04G/1.18G [00:03<00:00, 417MB/s]
92%|#########2| 1.08G/1.18G [00:03<00:00, 372MB/s]
95%|#########5| 1.12G/1.18G [00:03<00:00, 370MB/s]
98%|#########8| 1.15G/1.18G [00:03<00:00, 362MB/s]
100%|##########| 1.18G/1.18G [00:03<00:00, 369MB/s]
对转录进行标记¶
我们创建一个字典,将每个标签映射到符号。
LABELS = bundle.get_labels(star=None)
DICTIONARY = bundle.get_dict(star=None)
for k, v in DICTIONARY.items():
print(f"{k}: {v}")
-: 0
a: 1
i: 2
e: 3
n: 4
o: 5
u: 6
t: 7
s: 8
r: 9
m: 10
k: 11
l: 12
d: 13
g: 14
h: 15
y: 16
b: 17
p: 18
w: 19
c: 20
v: 21
j: 22
z: 23
f: 24
': 25
q: 26
x: 27
将转录转换为符号就像
tokenized_transcript = [DICTIONARY[c] for word in TRANSCRIPT for c in word]
for t in tokenized_transcript:
print(t, end=" ")
print()
2 15 1 13 7 15 1 7 20 6 9 2 5 8 2 7 16 17 3 8 2 13 3 10 3 1 7 7 15 2 8 10 5 10 3 4 7
计算对齐¶
帧级对齐¶
现在我们调用 TorchAudio 的强制对齐 API 来计算帧级对齐。有关函数签名的详细信息,请参考 forced_align()
。
def align(emission, tokens):
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
alignments, scores = F.forced_align(emission, targets, blank=0)
alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
scores = scores.exp() # convert back to probability
return alignments, scores
aligned_tokens, alignment_scores = align(emission, tokenized_transcript)
现在让我们看一下输出。
for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}")
0: 0 [-], 1.00
1: 0 [-], 1.00
2: 0 [-], 1.00
3: 0 [-], 1.00
4: 0 [-], 1.00
5: 0 [-], 1.00
6: 0 [-], 1.00
7: 0 [-], 1.00
8: 0 [-], 1.00
9: 0 [-], 1.00
10: 0 [-], 1.00
11: 0 [-], 1.00
12: 0 [-], 1.00
13: 0 [-], 1.00
14: 0 [-], 1.00
15: 0 [-], 1.00
16: 0 [-], 1.00
17: 0 [-], 1.00
18: 0 [-], 1.00
19: 0 [-], 1.00
20: 0 [-], 1.00
21: 0 [-], 1.00
22: 0 [-], 1.00
23: 0 [-], 1.00
24: 0 [-], 1.00
25: 0 [-], 1.00
26: 0 [-], 1.00
27: 0 [-], 1.00
28: 0 [-], 1.00
29: 0 [-], 1.00
30: 0 [-], 1.00
31: 0 [-], 1.00
32: 2 [i], 1.00
33: 0 [-], 1.00
34: 0 [-], 1.00
35: 15 [h], 1.00
36: 15 [h], 0.93
37: 1 [a], 1.00
38: 0 [-], 0.96
39: 0 [-], 1.00
40: 0 [-], 1.00
41: 13 [d], 1.00
42: 0 [-], 1.00
43: 0 [-], 0.97
44: 7 [t], 1.00
45: 15 [h], 1.00
46: 0 [-], 0.98
47: 1 [a], 1.00
48: 0 [-], 1.00
49: 0 [-], 1.00
50: 7 [t], 1.00
51: 0 [-], 1.00
52: 0 [-], 1.00
53: 0 [-], 1.00
54: 20 [c], 1.00
55: 0 [-], 1.00
56: 0 [-], 1.00
57: 0 [-], 1.00
58: 6 [u], 1.00
59: 6 [u], 0.96
60: 0 [-], 1.00
61: 0 [-], 1.00
62: 0 [-], 0.53
63: 9 [r], 1.00
64: 0 [-], 1.00
65: 2 [i], 1.00
66: 0 [-], 1.00
67: 0 [-], 1.00
68: 0 [-], 1.00
69: 0 [-], 1.00
70: 0 [-], 1.00
71: 0 [-], 0.96
72: 5 [o], 1.00
73: 0 [-], 1.00
74: 0 [-], 1.00
75: 0 [-], 1.00
76: 0 [-], 1.00
77: 0 [-], 1.00
78: 0 [-], 1.00
79: 8 [s], 1.00
80: 0 [-], 1.00
81: 0 [-], 1.00
82: 0 [-], 0.99
83: 2 [i], 1.00
84: 0 [-], 1.00
85: 7 [t], 1.00
86: 0 [-], 1.00
87: 0 [-], 1.00
88: 16 [y], 1.00
89: 0 [-], 1.00
90: 0 [-], 1.00
91: 0 [-], 1.00
92: 0 [-], 1.00
93: 17 [b], 1.00
94: 0 [-], 1.00
95: 3 [e], 1.00
96: 0 [-], 1.00
97: 0 [-], 1.00
98: 0 [-], 1.00
99: 0 [-], 1.00
100: 0 [-], 1.00
101: 8 [s], 1.00
102: 0 [-], 1.00
103: 0 [-], 1.00
104: 0 [-], 1.00
105: 0 [-], 1.00
106: 0 [-], 1.00
107: 0 [-], 1.00
108: 0 [-], 1.00
109: 0 [-], 0.64
110: 2 [i], 1.00
111: 0 [-], 1.00
112: 0 [-], 1.00
113: 13 [d], 1.00
114: 3 [e], 0.85
115: 0 [-], 1.00
116: 10 [m], 1.00
117: 0 [-], 1.00
118: 0 [-], 1.00
119: 3 [e], 1.00
120: 0 [-], 1.00
121: 0 [-], 1.00
122: 0 [-], 1.00
123: 0 [-], 1.00
124: 1 [a], 1.00
125: 0 [-], 1.00
126: 0 [-], 1.00
127: 7 [t], 1.00
128: 0 [-], 1.00
129: 7 [t], 1.00
130: 15 [h], 1.00
131: 0 [-], 0.79
132: 2 [i], 1.00
133: 0 [-], 1.00
134: 0 [-], 1.00
135: 0 [-], 1.00
136: 8 [s], 1.00
137: 0 [-], 1.00
138: 0 [-], 1.00
139: 0 [-], 1.00
140: 0 [-], 1.00
141: 10 [m], 1.00
142: 0 [-], 1.00
143: 0 [-], 1.00
144: 5 [o], 1.00
145: 0 [-], 1.00
146: 0 [-], 1.00
147: 0 [-], 1.00
148: 10 [m], 1.00
149: 0 [-], 1.00
150: 0 [-], 1.00
151: 3 [e], 1.00
152: 0 [-], 1.00
153: 4 [n], 1.00
154: 0 [-], 1.00
155: 7 [t], 1.00
156: 0 [-], 1.00
157: 0 [-], 1.00
158: 0 [-], 1.00
159: 0 [-], 1.00
160: 0 [-], 1.00
161: 0 [-], 1.00
162: 0 [-], 1.00
163: 0 [-], 1.00
164: 0 [-], 1.00
165: 0 [-], 1.00
166: 0 [-], 1.00
167: 0 [-], 1.00
168: 0 [-], 1.00
注意
对齐是用发射的帧坐标表示的,这与原始波形不同。
它包含空符号和重复符号。以下是对非空符号的解释。
31: 0 [-], 1.00
32: 2 [i], 1.00 "i" starts and ends
33: 0 [-], 1.00
34: 0 [-], 1.00
35: 15 [h], 1.00 "h" starts
36: 15 [h], 0.93 "h" ends
37: 1 [a], 1.00 "a" starts and ends
38: 0 [-], 0.96
39: 0 [-], 1.00
40: 0 [-], 1.00
41: 13 [d], 1.00 "d" starts and ends
42: 0 [-], 1.00
注意
当相同符号出现在空符号之后时,它不被视为重复,而是被视为新的出现。
a a a b -> a b
a - - b -> a b
a a - b -> a b
a - a b -> a a b
^^^ ^^^
符号级对齐¶
下一步是解决重复问题,以便每个对齐不依赖于先前的对齐。 torchaudio.functional.merge_tokens()
计算 TokenSpan
对象,它表示转录中的哪个符号在哪个时间跨度内出现。
token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
print("Token\tTime\tScore")
for s in token_spans:
print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
Token Time Score
i [ 32, 33) 1.00
h [ 35, 37) 0.96
a [ 37, 38) 1.00
d [ 41, 42) 1.00
t [ 44, 45) 1.00
h [ 45, 46) 1.00
a [ 47, 48) 1.00
t [ 50, 51) 1.00
c [ 54, 55) 1.00
u [ 58, 60) 0.98
r [ 63, 64) 1.00
i [ 65, 66) 1.00
o [ 72, 73) 1.00
s [ 79, 80) 1.00
i [ 83, 84) 1.00
t [ 85, 86) 1.00
y [ 88, 89) 1.00
b [ 93, 94) 1.00
e [ 95, 96) 1.00
s [101, 102) 1.00
i [110, 111) 1.00
d [113, 114) 1.00
e [114, 115) 0.85
m [116, 117) 1.00
e [119, 120) 1.00
a [124, 125) 1.00
t [127, 128) 1.00
t [129, 130) 1.00
h [130, 131) 1.00
i [132, 133) 1.00
s [136, 137) 1.00
m [141, 142) 1.00
o [144, 145) 1.00
m [148, 149) 1.00
e [151, 152) 1.00
n [153, 154) 1.00
t [155, 156) 1.00
词级对齐¶
现在让我们将符号级对齐分组到词级对齐。
def unflatten(list_, lengths):
assert len(list_) == sum(lengths)
i = 0
ret = []
for l in lengths:
ret.append(list_[i : i + l])
i += l
return ret
word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT])
音频预览¶
# Compute average score weighted by the span length
def _score(spans):
return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)
def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / num_frames
x0 = int(ratio * spans[0].start)
x1 = int(ratio * spans[-1].end)
print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
num_frames = emission.size(1)
# Generate the audio for each segment
print(TRANSCRIPT)
IPython.display.Audio(SPEECH_FILE)
['i', 'had', 'that', 'curiosity', 'beside', 'me', 'at', 'this', 'moment']
preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT[0])
i (1.00): 0.644 - 0.664 sec
preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT[1])
had (0.98): 0.704 - 0.845 sec
preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT[2])
that (1.00): 0.885 - 1.026 sec
preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT[3])
curiosity (1.00): 1.086 - 1.790 sec
preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT[4])
beside (0.97): 1.871 - 2.314 sec
preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT[5])
me (1.00): 2.334 - 2.414 sec
preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT[6])
at (1.00): 2.495 - 2.575 sec
preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7])
this (1.00): 2.595 - 2.756 sec
preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8])
moment (1.00): 2.837 - 3.138 sec
可视化¶
现在让我们看一下对齐结果,并将原始语音分割成词。
def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / emission.size(1) / sample_rate
fig, axes = plt.subplots(2, 1)
axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
axes[0].set_title("Emission")
axes[0].set_xticks([])
axes[1].specgram(waveform[0], Fs=sample_rate)
for t_spans, chars in zip(token_spans, transcript):
t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1
axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white")
axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span, char in zip(t_spans, chars):
t0 = span.start * ratio
axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False)
axes[1].set_xlabel("time [second]")
axes[1].set_xlim([0, None])
fig.tight_layout()
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
对 blank
符号的不一致处理¶
在将符号级对齐拆分为词时,你会注意到某些空符号的处理方式不同,这使得对结果的解释有些含糊。
当我们绘制得分时,这一点很容易看到。下图显示了词区域和非词区域,以及非空符号的帧级得分。
def plot_scores(word_spans, scores):
fig, ax = plt.subplots()
span_xs, span_hs = [], []
ax.axvspan(word_spans[0][0].start - 0.05, word_spans[-1][-1].end + 0.05, facecolor="paleturquoise", edgecolor="none", zorder=-1)
for t_span in word_spans:
for span in t_span:
for t in range(span.start, span.end):
span_xs.append(t + 0.5)
span_hs.append(scores[t].item())
ax.annotate(LABELS[span.token], (span.start, -0.07))
ax.axvspan(t_span[0].start - 0.05, t_span[-1].end + 0.05, facecolor="mistyrose", edgecolor="none", zorder=-1)
ax.bar(span_xs, span_hs, color="lightsalmon", edgecolor="coral")
ax.set_title("Frame-level scores and word segments")
ax.set_ylim(-0.1, None)
ax.grid(True, axis="y")
ax.axhline(0, color="black")
fig.tight_layout()
plot_scores(word_spans, alignment_scores)
在此图中,空符号是那些没有垂直条的突出显示区域。您可以看到,有一些空符号被解释为词的一部分(以红色突出显示),而其他空符号(以蓝色突出显示)则没有。
造成这种情况的一个原因是,模型在训练时没有词边界标签。空符号不仅被视为重复,还被视为词之间的静默。
但随后,一个问题出现了。词结束后的帧或词结束附近的帧应该被视为静默还是重复?
在上例中,如果你回到之前的光谱图和词区域的图,你会看到在“curiosity”中的“y”之后,在多个频率桶中仍然存在一些活动。
如果该帧包含在该词中,会更准确吗?
不幸的是,CTC 无法提供此问题的全面解决方案。使用 CTC 训练的模型已知会表现出“尖峰”响应,也就是说,它们倾向于在标签出现时出现峰值,但峰值不会持续整个标签的持续时间。(注意:预训练的 Wav2Vec2 模型倾向于在标签出现时开始出现峰值,但这并不总是这样。)
[Zeyer 等人,2021] 对 CTC 的尖峰行为进行了深入的分析。我们鼓励有兴趣了解更多信息的人参考这篇论文。以下是论文中的一段引文,它正是我们在这里面临的具体问题。
在某些情况下,尖峰行为可能是有问题的, 例如,当应用程序要求不使用空白标签时, 例如,为了获得语音的有效时间准确对齐 到转录文本。
高级:处理带有 <star>
标记的转录文本¶
现在让我们看看当转录文本部分缺失时,如何使用 <star>
标记来提高对齐质量,该标记能够对任何标记进行建模。
在这里,我们使用与上面相同的英文示例。但我们从转录文本中删除了开头文本 “i had that curiosity beside me at”
。将音频与这种转录文本对齐会导致现有的单词 “this” 的对齐错误。但是,这个问题可以通过使用 <star>
标记来对缺失文本进行建模来缓解。
首先,我们将字典扩展为包括 <star>
标记。
DICTIONARY["*"] = len(DICTIONARY)
接下来,我们将发射张量扩展为对应于 <star>
标记的额外维度。
star_dim = torch.zeros((1, emission.size(1), 1), device=emission.device, dtype=emission.dtype)
emission = torch.cat((emission, star_dim), 2)
assert len(DICTIONARY) == emission.shape[2]
plot_emission(emission[0])
以下函数将所有流程结合起来,并一次性从发射中计算单词片段。
def compute_alignments(emission, transcript, dictionary):
tokens = [dictionary[char] for word in transcript for char in word]
alignment, scores = align(emission, tokens)
token_spans = F.merge_tokens(alignment, scores)
word_spans = unflatten(token_spans, [len(word) for word in transcript])
return word_spans
完整转录文本¶
word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
带有 <star>
标记的部分转录文本¶
现在我们将转录文本的第一部分替换为 <star>
标记。
transcript = "* this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
* (1.00): 0.000 - 2.595 sec
preview_word(waveform, word_spans[1], num_frames, transcript[1])
this (1.00): 2.595 - 2.756 sec
preview_word(waveform, word_spans[2], num_frames, transcript[2])
moment (1.00): 2.837 - 3.138 sec
不带 <star>
标记的部分转录文本¶
作为对比,以下内容是对齐不使用 <star>
标记的部分转录文本。它展示了 <star>
标记在处理删除错误方面的效果。
transcript = "this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)
结论¶
在本教程中,我们学习了如何使用 torchaudio 的强制对齐 API 对齐和分割语音文件,并演示了其中一种高级用法:如何引入 <star>
标记可以提高转录错误存在时的对齐精度。
鸣谢¶
感谢 Vineel Pratap 和 Zhaoheng Ni 开发和开源强制对齐器 API。
脚本的总运行时间:(0 分钟 6.927 秒)