快捷方式

RNNT

class torchaudio.models.RNNT[source]

循环神经网络转录器 (RNN-T) 模型。

注意

要构建模型,请使用其中一个工厂函数。

另请参阅

torchaudio.pipelines.RNNTBundle:具有预训练模型的 ASR 流水线。

参数:

方法

forward

RNNT.forward(sources: Tensor, source_lengths: Tensor, targets: Tensor, target_lengths: Tensor, predictor_state: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, Tensor, List[List[Tensor]]][source]

训练的前向传递。

B:批次大小;T:批次中最大源序列长度;U:批次中最大目标序列长度;D:每个源序列元素的特征维度。

参数:
  • sources (torch.Tensor) – 右填充有右上下文信息的源帧序列,形状为 (B, T, D)

  • source_lengths (torch.Tensor) – 形状为 (B,),第 i 个元素表示 sources 中第 i 个批次元素的有效帧数。

  • targets (torch.Tensor) – 目标序列,形状为 (B, U),每个元素映射到一个目标符号。

  • target_lengths (torch.Tensor) – 形状为 (B,),第 i 个元素表示 targets 中第 i 个批次元素的有效帧数。

  • predictor_state (List[List[torch.Tensor]] or None, optional) – 预测网络内部状态的张量列表,在先前 forward 调用中生成。(默认值:None)

返回值:

torch.Tensor

连接网络输出,形状为 (B, max output source length, max output target length, output_dim (number of target symbols))

torch.Tensor

输出源长度,形状为 (B,),第 i 个元素表示连接网络输出中第 i 个批次元素沿维度 1 的有效元素数。

torch.Tensor

输出目标长度,形状为 (B,),第 i 个元素表示连接网络输出中第 i 个批次元素沿维度 2 的有效元素数。

List[List[torch.Tensor]]

输出状态;预测网络内部状态的张量列表,在当前 forward 调用中生成。

返回类型:

(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

流式转录

RNNT.transcribe_streaming(sources: Tensor, source_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]][source]

以流式模式将转录网络应用于音频源。

B:批次大小;T:批次中最大音频源序列段长度;D:每个音频源序列帧的特征维度。

参数:
  • sources (torch.Tensor) – 右填充右上下文音频源帧序列段,形状为 (B, T + 右上下文长度, D)

  • source_lengths (torch.Tensor) – 形状为 (B,),第 i 个元素表示 sources 中第 i 个批次元素的有效帧数。

  • state (List[List[torch.Tensor]] or None) – 张量列表的列表,表示在先前调用 transcribe_streaming 时生成的转录网络内部状态。

返回值:

torch.Tensor

输出帧序列,形状为 (B, T // time_reduction_stride, output_dim)

torch.Tensor

输出长度,形状为 (B,),第 i 个元素表示输出中第 i 个批次元素的有效元素数量。

List[List[torch.Tensor]]

输出状态;张量列表的列表,表示在当前调用 transcribe_streaming 时生成的转录网络内部状态。

返回类型:

(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

转录

RNNT.transcribe(sources: Tensor, source_lengths: Tensor) Tuple[Tensor, Tensor][source]

以非流式模式将转录网络应用于音频源。

B:批次大小;T:批次中最大音频源序列长度;D:每个音频源序列帧的特征维度。

参数:
  • sources (torch.Tensor) – 右填充右上下文音频源帧序列,形状为 (B, T + 右上下文长度, D)

  • source_lengths (torch.Tensor) – 形状为 (B,),第 i 个元素表示 sources 中第 i 个批次元素的有效帧数。

返回值:

torch.Tensor

输出帧序列,形状为 (B, T // time_reduction_stride, output_dim)

torch.Tensor

输出长度,形状为 (B,),第 i 个元素表示输出帧序列中第 i 个批次元素的有效元素数量。

返回类型:

(torch.Tensor, torch.Tensor)

预测

RNNT.predict(targets: Tensor, target_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]][source]

将预测网络应用于目标。

B:批次大小;U:批次中最大目标序列长度;D:每个目标序列帧的特征维度。

参数:
  • targets (torch.Tensor) – 目标序列,形状为 (B, U),每个元素映射到一个目标符号,即在范围 [0, num_symbols) 内。

  • target_lengths (torch.Tensor) – 形状为 (B,),第 i 个元素表示 targets 中第 i 个批次元素的有效帧数。

  • state (List[List[torch.Tensor]] or None) – 张量列表的列表,表示在先前调用 predict 时生成的内部状态。

返回值:

torch.Tensor

输出帧序列,形状为 (B, U, output_dim)

torch.Tensor

输出长度,形状为 (B,),第 i 个元素表示输出中第 i 个批次元素的有效元素数量。

List[List[torch.Tensor]]

输出状态;张量列表的列表,表示在当前调用 predict 时生成的内部状态。

返回类型:

(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

连接

RNNT.join(source_encodings: Tensor, source_lengths: Tensor, target_encodings: Tensor, target_lengths: Tensor) Tuple[Tensor, Tensor, Tensor][source]

将联合网络应用于音频源和目标编码。

B:批次大小;T:批次中最大音频源序列长度;U:批次中最大目标序列长度;D:每个音频源和目标序列编码的维度。

参数:
  • source_encodings (torch.Tensor) – 音频源编码序列,形状为 (B, T, D)

  • source_lengths (torch.Tensor) – 形状为 (B,),第 i 个元素表示 source_encodings 中第 i 个批次元素的有效序列长度。

  • target_encodings (torch.Tensor) – 目标编码序列,形状为 (B, U, D)

  • target_lengthstorch.Tensor) - 形状为 (B,),其中第 i 个元素表示 target_encodings 中第 i 个批次元素的有效序列长度。

返回值:

torch.Tensor

联合网络输出,形状为 (B, T, U, output_dim)

torch.Tensor

输出源长度,形状为 (B,),第 i 个元素表示连接网络输出中第 i 个批次元素沿维度 1 的有效元素数。

torch.Tensor

输出目标长度,形状为 (B,),第 i 个元素表示连接网络输出中第 i 个批次元素沿维度 2 的有效元素数。

返回类型:

torch.Tensortorch.Tensortorch.Tensor

工厂函数

emformer_rnnt_model

构建基于 Emformer 的 RNNT

emformer_rnnt_base

构建基于 Emformer 的 RNNT 的基本版本。

原型工厂函数

conformer_rnnt_model

构建基于 Conformer 的递归神经网络换能器 (RNN-T) 模型。

conformer_rnnt_base

构建 Conformer RNN-T 模型的基本版本。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获取问题的解答

查看资源