• 文档 >
  • CTC 强制对齐 API 教程 >
  • 当前(稳定)
快捷方式

CTC 强制对齐 API 教程

作者Xiaohui ZhangMoto Hira

强制对齐是一个将转录与语音对齐的过程。本教程展示了如何使用 torchaudio.functional.forced_align() 将转录与语音对齐,该函数是根据 将语音技术扩展到 1000 多种语言 的工作开发的。

forced_align() 具有自定义的 CPU 和 CUDA 实现,这些实现比上述的原始 Python 实现性能更高,并且更准确。它还可以使用特殊 <star> 令牌处理缺失的转录。

还有一个高级 API torchaudio.pipelines.Wav2Vec2FABundle,它包装了本教程中解释的预处理/后处理,并使其易于运行强制对齐。多语言数据强制对齐 使用此 API 来说明如何对齐非英语转录。

准备

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)
2.3.0
2.3.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
import IPython
import matplotlib.pyplot as plt

import torchaudio.functional as F

首先,我们准备要使用的语音数据和转录。

SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
waveform, _ = torchaudio.load(SPEECH_FILE)
TRANSCRIPT = "i had that curiosity beside me at this moment".split()

生成发射

forced_align() 采用发射和令牌序列,并输出令牌的时间戳及其得分。

发射表示令牌上的逐帧概率分布,可以通过将波形传递给声学模型来获得。

令牌是转录的数字表达。有许多令牌化转录的方法,但在这里,我们简单地将字母映射到整数,这是在训练我们即将使用的声学模型时构造标签的方式。

我们将使用预训练的 Wav2Vec2 模型 torchaudio.pipelines.MMS_FA 来获取排放并标记转录。

bundle = torchaudio.pipelines.MMS_FA

model = bundle.get_model(with_star=False).to(device)
with torch.inference_mode():
    emission, _ = model(waveform.to(device))
Downloading: "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt" to /root/.cache/torch/hub/checkpoints/model.pt

  0%|          | 0.00/1.18G [00:00<?, ?B/s]
  2%|1         | 23.9M/1.18G [00:00<00:04, 249MB/s]
  4%|3         | 47.8M/1.18G [00:00<00:05, 236MB/s]
  6%|5         | 70.4M/1.18G [00:00<00:05, 232MB/s]
  8%|7         | 96.1M/1.18G [00:00<00:04, 246MB/s]
 10%|9         | 120M/1.18G [00:00<00:04, 243MB/s]
 12%|#1        | 143M/1.18G [00:00<00:04, 239MB/s]
 14%|#3        | 168M/1.18G [00:00<00:04, 247MB/s]
 16%|#6        | 194M/1.18G [00:00<00:04, 253MB/s]
 19%|#9        | 229M/1.18G [00:00<00:03, 290MB/s]
 22%|##1       | 259M/1.18G [00:01<00:03, 296MB/s]
 24%|##4       | 292M/1.18G [00:01<00:03, 312MB/s]
 27%|##6       | 322M/1.18G [00:01<00:03, 306MB/s]
 29%|##9       | 352M/1.18G [00:01<00:02, 308MB/s]
 32%|###2      | 385M/1.18G [00:01<00:02, 321MB/s]
 35%|###4      | 419M/1.18G [00:01<00:02, 332MB/s]
 37%|###7      | 451M/1.18G [00:01<00:02, 328MB/s]
 40%|####      | 482M/1.18G [00:01<00:02, 323MB/s]
 43%|####2     | 516M/1.18G [00:01<00:02, 331MB/s]
 46%|####5     | 548M/1.18G [00:01<00:02, 332MB/s]
 48%|####8     | 580M/1.18G [00:02<00:01, 331MB/s]
 51%|#####     | 612M/1.18G [00:02<00:01, 323MB/s]
 53%|#####3    | 642M/1.18G [00:02<00:01, 310MB/s]
 56%|#####5    | 672M/1.18G [00:02<00:01, 280MB/s]
 58%|#####8    | 699M/1.18G [00:02<00:01, 273MB/s]
 60%|######    | 726M/1.18G [00:02<00:02, 246MB/s]
 63%|######2   | 752M/1.18G [00:02<00:01, 254MB/s]
 65%|######5   | 784M/1.18G [00:02<00:01, 276MB/s]
 68%|######7   | 817M/1.18G [00:02<00:01, 293MB/s]
 71%|#######   | 849M/1.18G [00:03<00:01, 306MB/s]
 73%|#######2  | 878M/1.18G [00:03<00:01, 295MB/s]
 75%|#######5  | 907M/1.18G [00:03<00:01, 294MB/s]
 78%|#######7  | 937M/1.18G [00:03<00:00, 300MB/s]
 80%|########  | 968M/1.18G [00:03<00:00, 306MB/s]
 83%|########2 | 997M/1.18G [00:03<00:00, 273MB/s]
 85%|########5 | 1.00G/1.18G [00:03<00:00, 253MB/s]
 87%|########7 | 1.03G/1.18G [00:03<00:00, 267MB/s]
 90%|########9 | 1.05G/1.18G [00:03<00:00, 268MB/s]
 92%|#########1| 1.08G/1.18G [00:04<00:00, 241MB/s]
 94%|#########3| 1.10G/1.18G [00:04<00:00, 242MB/s]
 96%|#########5| 1.13G/1.18G [00:04<00:00, 248MB/s]
 98%|#########8| 1.15G/1.18G [00:04<00:00, 253MB/s]
100%|##########| 1.18G/1.18G [00:04<00:00, 282MB/s]
def plot_emission(emission):
    fig, ax = plt.subplots()
    ax.imshow(emission.cpu().T)
    ax.set_title("Frame-wise class probabilities")
    ax.set_xlabel("Time")
    ax.set_ylabel("Labels")
    fig.tight_layout()


plot_emission(emission[0])
Frame-wise class probabilities

标记转录

我们创建一个字典,将每个标签映射到令牌。

LABELS = bundle.get_labels(star=None)
DICTIONARY = bundle.get_dict(star=None)
for k, v in DICTIONARY.items():
    print(f"{k}: {v}")
-: 0
a: 1
i: 2
e: 3
n: 4
o: 5
u: 6
t: 7
s: 8
r: 9
m: 10
k: 11
l: 12
d: 13
g: 14
h: 15
y: 16
b: 17
p: 18
w: 19
c: 20
v: 21
j: 22
z: 23
f: 24
': 25
q: 26
x: 27

将转录转换为令牌就像

tokenized_transcript = [DICTIONARY[c] for word in TRANSCRIPT for c in word]

for t in tokenized_transcript:
    print(t, end=" ")
print()
2 15 1 13 7 15 1 7 20 6 9 2 5 8 2 7 16 17 3 8 2 13 3 10 3 1 7 7 15 2 8 10 5 10 3 4 7

计算对齐

帧级对齐

现在,我们调用 TorchAudio 的强制对齐 API 来计算帧级对齐。有关函数签名的详细信息,请参阅 forced_align()

def align(emission, tokens):
    targets = torch.tensor([tokens], dtype=torch.int32, device=device)
    alignments, scores = F.forced_align(emission, targets, blank=0)

    alignments, scores = alignments[0], scores[0]  # remove batch dimension for simplicity
    scores = scores.exp()  # convert back to probability
    return alignments, scores


aligned_tokens, alignment_scores = align(emission, tokenized_transcript)

现在让我们看看输出。

for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
    print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}")
  0:     0 [-], 1.00
  1:     0 [-], 1.00
  2:     0 [-], 1.00
  3:     0 [-], 1.00
  4:     0 [-], 1.00
  5:     0 [-], 1.00
  6:     0 [-], 1.00
  7:     0 [-], 1.00
  8:     0 [-], 1.00
  9:     0 [-], 1.00
 10:     0 [-], 1.00
 11:     0 [-], 1.00
 12:     0 [-], 1.00
 13:     0 [-], 1.00
 14:     0 [-], 1.00
 15:     0 [-], 1.00
 16:     0 [-], 1.00
 17:     0 [-], 1.00
 18:     0 [-], 1.00
 19:     0 [-], 1.00
 20:     0 [-], 1.00
 21:     0 [-], 1.00
 22:     0 [-], 1.00
 23:     0 [-], 1.00
 24:     0 [-], 1.00
 25:     0 [-], 1.00
 26:     0 [-], 1.00
 27:     0 [-], 1.00
 28:     0 [-], 1.00
 29:     0 [-], 1.00
 30:     0 [-], 1.00
 31:     0 [-], 1.00
 32:     2 [i], 1.00
 33:     0 [-], 1.00
 34:     0 [-], 1.00
 35:    15 [h], 1.00
 36:    15 [h], 0.93
 37:     1 [a], 1.00
 38:     0 [-], 0.96
 39:     0 [-], 1.00
 40:     0 [-], 1.00
 41:    13 [d], 1.00
 42:     0 [-], 1.00
 43:     0 [-], 0.97
 44:     7 [t], 1.00
 45:    15 [h], 1.00
 46:     0 [-], 0.98
 47:     1 [a], 1.00
 48:     0 [-], 1.00
 49:     0 [-], 1.00
 50:     7 [t], 1.00
 51:     0 [-], 1.00
 52:     0 [-], 1.00
 53:     0 [-], 1.00
 54:    20 [c], 1.00
 55:     0 [-], 1.00
 56:     0 [-], 1.00
 57:     0 [-], 1.00
 58:     6 [u], 1.00
 59:     6 [u], 0.96
 60:     0 [-], 1.00
 61:     0 [-], 1.00
 62:     0 [-], 0.53
 63:     9 [r], 1.00
 64:     0 [-], 1.00
 65:     2 [i], 1.00
 66:     0 [-], 1.00
 67:     0 [-], 1.00
 68:     0 [-], 1.00
 69:     0 [-], 1.00
 70:     0 [-], 1.00
 71:     0 [-], 0.96
 72:     5 [o], 1.00
 73:     0 [-], 1.00
 74:     0 [-], 1.00
 75:     0 [-], 1.00
 76:     0 [-], 1.00
 77:     0 [-], 1.00
 78:     0 [-], 1.00
 79:     8 [s], 1.00
 80:     0 [-], 1.00
 81:     0 [-], 1.00
 82:     0 [-], 0.99
 83:     2 [i], 1.00
 84:     0 [-], 1.00
 85:     7 [t], 1.00
 86:     0 [-], 1.00
 87:     0 [-], 1.00
 88:    16 [y], 1.00
 89:     0 [-], 1.00
 90:     0 [-], 1.00
 91:     0 [-], 1.00
 92:     0 [-], 1.00
 93:    17 [b], 1.00
 94:     0 [-], 1.00
 95:     3 [e], 1.00
 96:     0 [-], 1.00
 97:     0 [-], 1.00
 98:     0 [-], 1.00
 99:     0 [-], 1.00
100:     0 [-], 1.00
101:     8 [s], 1.00
102:     0 [-], 1.00
103:     0 [-], 1.00
104:     0 [-], 1.00
105:     0 [-], 1.00
106:     0 [-], 1.00
107:     0 [-], 1.00
108:     0 [-], 1.00
109:     0 [-], 0.64
110:     2 [i], 1.00
111:     0 [-], 1.00
112:     0 [-], 1.00
113:    13 [d], 1.00
114:     3 [e], 0.85
115:     0 [-], 1.00
116:    10 [m], 1.00
117:     0 [-], 1.00
118:     0 [-], 1.00
119:     3 [e], 1.00
120:     0 [-], 1.00
121:     0 [-], 1.00
122:     0 [-], 1.00
123:     0 [-], 1.00
124:     1 [a], 1.00
125:     0 [-], 1.00
126:     0 [-], 1.00
127:     7 [t], 1.00
128:     0 [-], 1.00
129:     7 [t], 1.00
130:    15 [h], 1.00
131:     0 [-], 0.79
132:     2 [i], 1.00
133:     0 [-], 1.00
134:     0 [-], 1.00
135:     0 [-], 1.00
136:     8 [s], 1.00
137:     0 [-], 1.00
138:     0 [-], 1.00
139:     0 [-], 1.00
140:     0 [-], 1.00
141:    10 [m], 1.00
142:     0 [-], 1.00
143:     0 [-], 1.00
144:     5 [o], 1.00
145:     0 [-], 1.00
146:     0 [-], 1.00
147:     0 [-], 1.00
148:    10 [m], 1.00
149:     0 [-], 1.00
150:     0 [-], 1.00
151:     3 [e], 1.00
152:     0 [-], 1.00
153:     4 [n], 1.00
154:     0 [-], 1.00
155:     7 [t], 1.00
156:     0 [-], 1.00
157:     0 [-], 1.00
158:     0 [-], 1.00
159:     0 [-], 1.00
160:     0 [-], 1.00
161:     0 [-], 1.00
162:     0 [-], 1.00
163:     0 [-], 1.00
164:     0 [-], 1.00
165:     0 [-], 1.00
166:     0 [-], 1.00
167:     0 [-], 1.00
168:     0 [-], 1.00

注意

对齐以排放的帧坐标表示,这与原始波形不同。

它包含空白令牌和重复令牌。以下是空白令牌的解释。

31:     0 [-], 1.00
32:     2 [i], 1.00  "i" starts and ends
33:     0 [-], 1.00
34:     0 [-], 1.00
35:    15 [h], 1.00  "h" starts
36:    15 [h], 0.93  "h" ends
37:     1 [a], 1.00  "a" starts and ends
38:     0 [-], 0.96
39:     0 [-], 1.00
40:     0 [-], 1.00
41:    13 [d], 1.00  "d" starts and ends
42:     0 [-], 1.00

注意

当同一令牌出现在空白令牌之后时,它不会被视为重复,而是作为新出现。

a a a b -> a b
a - - b -> a b
a a - b -> a b
a - a b -> a a b
  ^^^       ^^^

令牌级对齐

下一步是解决重复,以便每个对齐不依赖于先前的对齐。 torchaudio.functional.merge_tokens() 计算 TokenSpan 对象,它表示转录中的哪个令牌出现在哪个时间跨度。

token_spans = F.merge_tokens(aligned_tokens, alignment_scores)

print("Token\tTime\tScore")
for s in token_spans:
    print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
Token   Time    Score
i       [ 32,  33)      1.00
h       [ 35,  37)      0.96
a       [ 37,  38)      1.00
d       [ 41,  42)      1.00
t       [ 44,  45)      1.00
h       [ 45,  46)      1.00
a       [ 47,  48)      1.00
t       [ 50,  51)      1.00
c       [ 54,  55)      1.00
u       [ 58,  60)      0.98
r       [ 63,  64)      1.00
i       [ 65,  66)      1.00
o       [ 72,  73)      1.00
s       [ 79,  80)      1.00
i       [ 83,  84)      1.00
t       [ 85,  86)      1.00
y       [ 88,  89)      1.00
b       [ 93,  94)      1.00
e       [ 95,  96)      1.00
s       [101, 102)      1.00
i       [110, 111)      1.00
d       [113, 114)      1.00
e       [114, 115)      0.85
m       [116, 117)      1.00
e       [119, 120)      1.00
a       [124, 125)      1.00
t       [127, 128)      1.00
t       [129, 130)      1.00
h       [130, 131)      1.00
i       [132, 133)      1.00
s       [136, 137)      1.00
m       [141, 142)      1.00
o       [144, 145)      1.00
m       [148, 149)      1.00
e       [151, 152)      1.00
n       [153, 154)      1.00
t       [155, 156)      1.00

单词级对齐

现在,让我们将令牌级对齐分组到单词级对齐。

def unflatten(list_, lengths):
    assert len(list_) == sum(lengths)
    i = 0
    ret = []
    for l in lengths:
        ret.append(list_[i : i + l])
        i += l
    return ret


word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT])

音频预览

# Compute average score weighted by the span length
def _score(spans):
    return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)


def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate):
    ratio = waveform.size(1) / num_frames
    x0 = int(ratio * spans[0].start)
    x1 = int(ratio * spans[-1].end)
    print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
    segment = waveform[:, x0:x1]
    return IPython.display.Audio(segment.numpy(), rate=sample_rate)


num_frames = emission.size(1)
# Generate the audio for each segment
print(TRANSCRIPT)
IPython.display.Audio(SPEECH_FILE)
['i', 'had', 'that', 'curiosity', 'beside', 'me', 'at', 'this', 'moment']


preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT[0])
i (1.00): 0.644 - 0.664 sec


preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT[1])
had (0.98): 0.704 - 0.845 sec


preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT[2])
that (1.00): 0.885 - 1.026 sec


preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT[3])
curiosity (1.00): 1.086 - 1.790 sec


preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT[4])
beside (0.97): 1.871 - 2.314 sec


preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT[5])
me (1.00): 2.334 - 2.414 sec


preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT[6])
at (1.00): 2.495 - 2.575 sec


preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7])
this (1.00): 2.595 - 2.756 sec


preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8])
moment (1.00): 2.837 - 3.138 sec


可视化

现在让我们看看对齐结果并将原始语音分割成单词。

def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate):
    ratio = waveform.size(1) / emission.size(1) / sample_rate

    fig, axes = plt.subplots(2, 1)
    axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
    axes[0].set_title("Emission")
    axes[0].set_xticks([])

    axes[1].specgram(waveform[0], Fs=sample_rate)
    for t_spans, chars in zip(token_spans, transcript):
        t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1
        axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white")
        axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
        axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)

        for span, char in zip(t_spans, chars):
            t0 = span.start * ratio
            axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False)

    axes[1].set_xlabel("time [second]")
    axes[1].set_xlim([0, None])
    fig.tight_layout()
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
Emission

blank 令牌的不一致处理

在将标记级对齐拆分为单词时,您会注意到某些空白标记的处理方式不同,这使得结果的解释有些模棱两可。

当我们绘制得分时,很容易看到这一点。下图显示了单词区域和非单词区域,以及非空白标记的帧级得分。

def plot_scores(word_spans, scores):
    fig, ax = plt.subplots()
    span_xs, span_hs = [], []
    ax.axvspan(word_spans[0][0].start - 0.05, word_spans[-1][-1].end + 0.05, facecolor="paleturquoise", edgecolor="none", zorder=-1)
    for t_span in word_spans:
        for span in t_span:
            for t in range(span.start, span.end):
                span_xs.append(t + 0.5)
                span_hs.append(scores[t].item())
            ax.annotate(LABELS[span.token], (span.start, -0.07))
        ax.axvspan(t_span[0].start - 0.05, t_span[-1].end + 0.05, facecolor="mistyrose", edgecolor="none", zorder=-1)
    ax.bar(span_xs, span_hs, color="lightsalmon", edgecolor="coral")
    ax.set_title("Frame-level scores and word segments")
    ax.set_ylim(-0.1, None)
    ax.grid(True, axis="y")
    ax.axhline(0, color="black")
    fig.tight_layout()


plot_scores(word_spans, alignment_scores)
Frame-level scores and word segments

在此图中,空白标记是那些没有垂直条纹的高亮区域。您可以看到,有些空白标记被解释为单词的一部分(突出显示为红色),而另一些(突出显示为蓝色)则不是。

造成这种情况的一个原因是,该模型是在没有单词边界标签的情况下进行训练的。空白标记不仅被视为重复,还被视为单词之间的静音。

但是,这样就产生了一个问题。单词结束时或接近单词结束时的帧应该是静音还是重复?

在上面的示例中,如果您返回到频谱图和单词区域的先前绘图,您会看到在“好奇心”中的“y”之后,在多个频率段中仍然有一些活动。

如果将该帧包含在单词中,会更准确吗?

不幸的是,CTC 没有提供一个全面的解决方案。使用 CTC 训练的模型已知会表现出“峰值”响应,也就是说,它们倾向于针对标签的出现而激增,但激增不会持续标签的整个持续时间。(注意:预训练的 Wav2Vec2 模型倾向于在标签出现时激增,但这并不总是如此。)

[Zeyer et al.,2021] 对 CTC 的峰值行为进行了深入分析。我们鼓励那些有兴趣了解更多的人参考该论文。以下是论文中的一段引用,这是我们在此面临的确切问题。

在某些情况下,峰值行为可能会带来问题, 例如,当应用程序要求不使用空白标签时, 例如,为了获得音素的准确时间对齐 到转录。

高级:使用 <star> 标记处理转录

现在让我们看看当转录部分缺失时,我们如何使用 <star> 标记(它能够对任何标记建模)来提高对齐质量。

这里我们使用上面使用的相同的英语示例。但我们从转录中删除了开头文本 “i had that curiosity beside me at”。使用此类转录对齐音频会导致现有单词“this”的对齐错误。但是,可以通过使用 <star> 标记对缺失文本进行建模来缓解此问题。

首先,我们将词典扩展为包含 <star> 令牌。

DICTIONARY["*"] = len(DICTIONARY)

接下来,我们使用对应于 <star> 令牌的额外维度扩展发射张量。

star_dim = torch.zeros((1, emission.size(1), 1), device=emission.device, dtype=emission.dtype)
emission = torch.cat((emission, star_dim), 2)

assert len(DICTIONARY) == emission.shape[2]

plot_emission(emission[0])
Frame-wise class probabilities

以下函数结合所有进程,并一次性从发射中计算词段。

def compute_alignments(emission, transcript, dictionary):
    tokens = [dictionary[char] for word in transcript for char in word]
    alignment, scores = align(emission, tokens)
    token_spans = F.merge_tokens(alignment, scores)
    word_spans = unflatten(token_spans, [len(word) for word in transcript])
    return word_spans

完整成绩单

word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
Emission

带有 <star> 令牌的部分成绩单

现在,我们将成绩单的第一部分替换为 <star> 令牌。

transcript = "* this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)
Emission
preview_word(waveform, word_spans[0], num_frames, transcript[0])
* (1.00): 0.000 - 2.595 sec


preview_word(waveform, word_spans[1], num_frames, transcript[1])
this (1.00): 2.595 - 2.756 sec


preview_word(waveform, word_spans[2], num_frames, transcript[2])
moment (1.00): 2.837 - 3.138 sec


不带 <star> 令牌的部分成绩单

作为比较,以下内容对齐了不使用 <star> 令牌的部分成绩单。它演示了 <star> 令牌在处理删除错误时的作用。

transcript = "this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)
Emission

结论

在本教程中,我们了解了如何使用 torchaudio 的强制对齐 API 对齐和分割语音文件,并演示了一种高级用法:当存在转录错误时,引入 <star> 令牌如何提高对齐精度。

致谢

感谢 Vineel PratapZhaoheng Ni 开发并开源了强制对齐器 API。

脚本的总运行时间:(0 分钟 7.789 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源