RNNTBeamSearch¶
- class torchaudio.models.RNNTBeamSearch(model: RNNT, blank: int, temperature: float = 1.0, hypo_sort_key: Optional[Callable[[Tuple[List[int], Tensor, List[List[Tensor]], float]], float]] = None, step_max_tokens: int = 100)[source]¶
RNN-T 模型的 Beam Search 解码器。
另请参阅
torchaudio.pipelines.RNNTBundle
: 使用预训练模型的 ASR 管道。
- 参数:
model (RNNT) – 要使用的 RNN-T 模型。
blank (int) – 词汇表中 blank 标记的索引。
temperature (float, 可选) – 应用于联合网络输出的温度。值越大,样本越均匀。(默认值:1.0)
hypo_sort_key (Callable[[Hypothesis], float] 或 None, 可选) – 用于计算给定假设得分以对假设进行排名的可调用对象。如果
None
,则默认为根据标记序列长度归一化假设得分的可调用对象。(默认值:None)step_max_tokens (int, 可选) – 每个输入时间步最多发射的标记数量。(默认值:100)
- 使用
RNNTBeamSearch
的教程
方法¶
forward¶
- RNNTBeamSearch.forward(input: Tensor, length: Tensor, beam_width: int) List[Tuple[List[int], Tensor, List[List[Tensor]], float]] [source]¶
对给定的输入序列执行 beam search。
T: 帧数;D: 每帧的特征维度。
- 参数:
input (torch.Tensor) – 输入帧序列,形状为 (T, D) 或 (1, T, D)。
length (torch.Tensor) – 输入序列中的有效帧数,形状为 () 或 (1,)。
beam_width (int) – 搜索期间使用的 beam 大小。
- 返回:
beam search 找到的 top-
beam_width
个假设。- 返回类型:
List[Hypothesis]
infer¶
- RNNTBeamSearch.infer(input: Tensor, length: Tensor, beam_width: int, state: Optional[List[List[Tensor]]] = None, hypothesis: Optional[List[Tuple[List[int], Tensor, List[List[Tensor]], float]]] = None) Tuple[List[Tuple[List[int], Tensor, List[List[Tensor]], float]], List[List[Tensor]]] [source]¶
在流式模式下对给定输入序列执行 beam search。
T: 帧数;D: 每帧的特征维度。
- 参数:
input (torch.Tensor) – 输入帧序列,形状为 (T, D) 或 (1, T, D)。
length (torch.Tensor) – 输入序列中的有效帧数,形状为 () 或 (1,)。
beam_width (int) – 搜索期间使用的 beam 大小。
state (List[List[torch.Tensor]] 或 None, 可选) – 表示前一次调用中生成的转录网络内部状态的张量列表的列表。(默认值:
None
)hypothesis (List[Hypothesis] 或 None) – 用于种子搜索的,来自前一次调用的假设。(默认值:
None
)
- 返回:
- List[Hypothesis]
beam search 找到的 top-
beam_width
个假设。- List[List[torch.Tensor]]
表示当前调用中生成的转录网络内部状态的张量列表的列表。
- 返回类型:
(List[Hypothesis], List[List[torch.Tensor]])