RNNTBundle¶
- class torchaudio.pipelines.RNNTBundle[source]¶
用于使用 RNN-T 模型执行自动语音识别(ASR,语音转文本)推理的组件捆绑类。
更具体地说,该类提供了生成特征提取流水线、包装指定 RNN-T 模型的解码器以及输出标记后处理程序的方法,这些方法共同构成一个完整的端到端 ASR 推理流水线,该流水线在给定原始波形的情况下生成文本序列。
它可以支持非流式(全上下文)推理以及流式推理。
用户不应直接实例化此类的对象;相反,用户应该使用模块中存在的实例(表示预训练模型),例如
torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH
。- 示例
>>> import torchaudio >>> from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH >>> import torch >>> >>> # Non-streaming inference. >>> # Build feature extractor, decoder with RNN-T model, and token processor. >>> feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_feature_extractor() 100%|███████████████████████████████| 3.81k/3.81k [00:00<00:00, 4.22MB/s] >>> decoder = EMFORMER_RNNT_BASE_LIBRISPEECH.get_decoder() Downloading: "https://download.pytorch.org/torchaudio/models/emformer_rnnt_base_librispeech.pt" 100%|███████████████████████████████| 293M/293M [00:07<00:00, 42.1MB/s] >>> token_processor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_token_processor() 100%|███████████████████████████████| 295k/295k [00:00<00:00, 25.4MB/s] >>> >>> # Instantiate LibriSpeech dataset; retrieve waveform for first sample. >>> dataset = torchaudio.datasets.LIBRISPEECH("/home/librispeech", url="test-clean") >>> waveform = next(iter(dataset))[0].squeeze() >>> >>> with torch.no_grad(): >>> # Produce mel-scale spectrogram features. >>> features, length = feature_extractor(waveform) >>> >>> # Generate top-10 hypotheses. >>> hypotheses = decoder(features, length, 10) >>> >>> # For top hypothesis, convert predicted tokens to text. >>> text = token_processor(hypotheses[0][0]) >>> print(text) he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to [...] >>> >>> >>> # Streaming inference. >>> hop_length = EMFORMER_RNNT_BASE_LIBRISPEECH.hop_length >>> num_samples_segment = EMFORMER_RNNT_BASE_LIBRISPEECH.segment_length * hop_length >>> num_samples_segment_right_context = ( >>> num_samples_segment + EMFORMER_RNNT_BASE_LIBRISPEECH.right_context_length * hop_length >>> ) >>> >>> # Build streaming inference feature extractor. >>> streaming_feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_streaming_feature_extractor() >>> >>> # Process same waveform as before, this time sequentially across overlapping segments >>> # to simulate streaming inference. Note the usage of ``streaming_feature_extractor`` and ``decoder.infer``. >>> state, hypothesis = None, None >>> for idx in range(0, len(waveform), num_samples_segment): >>> segment = waveform[idx: idx + num_samples_segment_right_context] >>> segment = torch.nn.functional.pad(segment, (0, num_samples_segment_right_context - len(segment))) >>> with torch.no_grad(): >>> features, length = streaming_feature_extractor(segment) >>> hypotheses, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis) >>> hypothesis = hypotheses[0] >>> transcript = token_processor(hypothesis[0]) >>> if transcript: >>> print(transcript, end=" ", flush=True) he hoped there would be stew for dinner turn ips and car rots and bru 'd oes and fat mut ton pieces to [...]
- 使用
RNNTBundle
的教程 - 使用 Emformer RNN-T 进行在线 ASR使用 Emformer RNN-T 进行设备端 ASR
属性¶
hop_length¶
n_fft¶
n_mels¶
right_context_length¶
sample_rate¶
segment_length¶
方法¶
get_decoder¶
- RNNTBundle.get_decoder() RNNTBeamSearch [source]¶
构建 RNN-T 解码器。
- 返回值:
RNNTBeamSearch
get_feature_extractor¶
- RNNTBundle.get_feature_extractor() FeatureExtractor [source]¶
构建用于非流式(全上下文)ASR 的特征提取器。
- 返回值:
FeatureExtractor
get_streaming_feature_extractor¶
- RNNTBundle.get_streaming_feature_extractor() FeatureExtractor [source]¶
构建用于流式(同步)ASR 的特征提取器。
- 返回值:
FeatureExtractor
get_token_processor¶
- RNNTBundle.get_token_processor() TokenProcessor [source]¶
构建标记处理器。
- 返回值:
TokenProcessor