RNNT¶
- 类 torchaudio.models.RNNT[source]¶
循环神经网络换能器 (RNN-T) 模型。
注意
要构建模型,请使用工厂函数之一。
另请参阅
torchaudio.pipelines.RNNTBundle
: 带有预训练模型的 ASR pipeline。- 参数:
transcriber (torch.nn.Module) – 转录网络。
predictor (torch.nn.Module) – 预测网络。
joiner (torch.nn.Module) – 连接网络。
方法¶
forward¶
- RNNT.forward(sources: Tensor, source_lengths: Tensor, targets: Tensor, target_lengths: Tensor, predictor_state: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, Tensor, List[List[Tensor]]] [source]¶
用于训练的正向传播。
B: 批量大小;T: 批次中源序列的最大长度;U: 批次中目标序列的最大长度;D: 每个源序列元素的特征维度。
- 参数:
sources (torch.Tensor) – 右侧用右上下文填充的源帧序列,形状为 (B, T, D)。
source_lengths (torch.Tensor) – 形状为 (B,),其中第 i 个元素表示
sources
中第 i 个批次元素的有效帧数。targets (torch.Tensor) – 目标序列,形状为 (B, U),每个元素映射到一个目标符号。
target_lengths (torch.Tensor) – 形状为 (B,),其中第 i 个元素表示
targets
中第 i 个批次元素的有效帧数。predictor_state (List[List[torch.Tensor]] 或 None, 可选) – 张量列表的列表,表示在先前调用
forward
中生成的预测网络内部状态。(默认值:None
)
- 返回值:
- torch.Tensor
连接网络的输出,形状为 (B, 最大输出源长度, 最大输出目标长度, 输出维度 (目标符号数))。
- torch.Tensor
输出源长度,形状为 (B,),其中第 i 个元素表示连接网络输出中第 i 个批次元素沿维度 1 的有效元素数。
- torch.Tensor
输出目标长度,形状为 (B,),其中第 i 个元素表示连接网络输出中第 i 个批次元素沿维度 2 的有效元素数。
- List[List[torch.Tensor]]
输出状态;张量列表的列表,表示在当前调用
forward
中生成的预测网络内部状态。
- 返回类型:
(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
transcribe_streaming¶
- RNNT.transcribe_streaming(sources: Tensor, source_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]] [source]¶
在流模式下将转录网络应用于源。
B: 批量大小;T: 批次中源序列段的最大长度;D: 每个源序列帧的特征维度。
- 参数:
sources (torch.Tensor) – 右侧用右上下文填充的源帧序列段,形状为 (B, T + 右上下文长度, D)。
source_lengths (torch.Tensor) – 形状为 (B,),其中第 i 个元素表示
sources
中第 i 个批次元素的有效帧数。state (List[List[torch.Tensor]] 或 None) – 张量列表的列表,表示在先前调用
transcribe_streaming
中生成的转录网络内部状态。
- 返回值:
- torch.Tensor
输出帧序列,形状为 (B, T // time_reduction_stride, 输出维度)。
- torch.Tensor
输出长度,形状为 (B,),其中第 i 个元素表示输出中第 i 个批次元素的有效元素数。
- List[List[torch.Tensor]]
输出状态;张量列表的列表,表示在当前调用
transcribe_streaming
中生成的转录网络内部状态。
- 返回类型:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
transcribe¶
- RNNT.transcribe(sources: Tensor, source_lengths: Tensor) Tuple[Tensor, Tensor] [source]¶
在非流模式下将转录网络应用于源。
B: 批量大小;T: 批次中源序列的最大长度;D: 每个源序列帧的特征维度。
- 参数:
sources (torch.Tensor) – 右侧用右上下文填充的源帧序列,形状为 (B, T + 右上下文长度, D)。
source_lengths (torch.Tensor) – 形状为 (B,),其中第 i 个元素表示
sources
中第 i 个批次元素的有效帧数。
- 返回值:
- torch.Tensor
输出帧序列,形状为 (B, T // time_reduction_stride, 输出维度)。
- torch.Tensor
输出长度,形状为 (B,),其中第 i 个元素表示输出帧序列中第 i 个批次元素的有效元素数。
- 返回类型:
predict¶
- RNNT.predict(targets: Tensor, target_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]] [source]¶
将预测网络应用于目标。
B: 批量大小;U: 批次中目标序列的最大长度;D: 每个目标序列帧的特征维度。
- 参数:
targets (torch.Tensor) – 目标序列,形状为 (B, U),每个元素映射到一个目标符号,即在范围 [0, num_symbols) 内。
target_lengths (torch.Tensor) – 形状为 (B,),其中第 i 个元素表示
targets
中第 i 个批次元素的有效帧数。state (List[List[torch.Tensor]] 或 None) – 张量列表的列表,表示在先前调用
predict
中生成的内部状态。
- 返回值:
- torch.Tensor
输出帧序列,形状为 (B, U, 输出维度)。
- torch.Tensor
输出长度,形状为 (B,),其中第 i 个元素表示输出中第 i 个批次元素的有效元素数。
- List[List[torch.Tensor]]
输出状态;张量列表的列表,表示在当前调用
predict
中生成的内部状态。
- 返回类型:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
join¶
- RNNT.join(source_encodings: Tensor, source_lengths: Tensor, target_encodings: Tensor, target_lengths: Tensor) Tuple[Tensor, Tensor, Tensor] [source]¶
将连接网络应用于源编码和目标编码。
B: 批量大小;T: 批次中源序列的最大长度;U: 批次中目标序列的最大长度;D: 每个源序列和目标序列编码的维度。
- 参数:
source_encodings (torch.Tensor) – 源编码序列,形状为 (B, T, D)。
source_lengths (torch.Tensor) – 形状为 (B,),其中第 i 个元素表示
source_encodings
中第 i 个批次元素的有效序列长度。target_encodings (torch.Tensor) – 目标编码序列,形状为 (B, U, D)。
target_lengths (torch.Tensor) – 形状为 (B,),其中第 i 个元素表示
target_encodings
中第 i 个批次元素的有效序列长度。
- 返回值:
- torch.Tensor
连接网络的输出,形状为 (B, T, U, 输出维度)。
- torch.Tensor
输出源长度,形状为 (B,),其中第 i 个元素表示连接网络输出中第 i 个批次元素沿维度 1 的有效元素数。
- torch.Tensor
输出目标长度,形状为 (B,),其中第 i 个元素表示连接网络输出中第 i 个批次元素沿维度 2 的有效元素数。
- 返回类型:
工厂函数¶
构建基于 Emformer 的 |
|
构建基于 Emformer 的 |
原型工厂函数¶
构建基于 Conformer 的循环神经网络换能器 (RNN-T) 模型。 |
|
构建 Conformer RNN-T 模型的基本版本。 |