Wav2Vec2FABundle¶
- class torchaudio.pipelines.Wav2Vec2FABundle[source]¶
用于捆绑相关信息的数据类,以使用预训练的
Wav2Vec2Model
进行强制对齐。此类提供用于实例化预训练模型的接口,以及检索预训练权重和与模型一起使用的其他必要数据的信息。
Torchaudio 库实例化此类的对象,每个对象代表一个不同的预训练模型。客户端代码应通过这些实例访问预训练模型。
请参见下文以了解用法和可用值。
- 示例 - 特征提取
>>> import torchaudio >>> >>> bundle = torchaudio.pipelines.MMS_FA >>> >>> # Build the model and load pretrained weight. >>> model = bundle.get_model() Downloading: 100%|███████████████████████████████| 1.18G/1.18G [00:05<00:00, 216MB/s] >>> >>> # Resample audio to the expected sampling rate >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) >>> >>> # Estimate the probability of token distribution >>> emission, _ = model(waveform) >>> >>> # Generate frame-wise alignment >>> alignment, scores = torchaudio.functional.forced_align( >>> emission, targets, input_lengths, target_lengths, blank=0) >>>
- 使用
Wav2Vec2FABundle
的教程
属性¶
采样率¶
方法¶
获取对齐器¶
获取字典¶
- Wav2Vec2FABundle.get_dict(star: Optional[str] = '*', blank: str = '-') Dict[str, int] [source]¶
获取从标记到索引(在发射特征维度中)的映射
- 参数:
- 返回值:
对于在 ASR 上微调的模型,返回表示输出类标签的字符串元组。
- 返回类型:
Tuple[str, …]
- 示例
>>> from torchaudio.pipelines import MMS_FA as bundle >>> bundle.get_dict() {'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27, '*': 28} >>> bundle.get_dict(star=None) {'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27}
获取标签¶
- Wav2Vec2FABundle.get_labels(star: Optional[str] = '*', blank: str = '-') Tuple[str, ...] [source]¶
获取与发射特征维度相对应的标签。
第一个是空白标记,并且可以自定义。
- 参数:
- 返回值:
对于在 ASR 上微调的模型,返回表示输出类标签的字符串元组。
- 返回类型:
Tuple[str, …]
- 示例
>>> from torchaudio.pipelines import MMS_FA as bundle >>> bundle.get_labels() ('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x', '*') >>> bundle.get_labels(star=None) ('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x')
获取模型¶
- Wav2Vec2FABundle.get_model(with_star: bool = True, *, dl_kwargs=None) Module [source]¶
构建模型并加载预训练权重。
权重文件从互联网下载并使用
torch.hub.load_state_dict_from_url()
进行缓存。- 参数:
with_star (bool, 可选) – 如果启用,输出层的最后一维将扩展一维,对应于star标记。
dl_kwargs (关键字参数字典) – 传递给
torch.hub.load_state_dict_from_url()
。
- 返回值:
Wav2Vec2Model
的变体。注意
使用此方法创建的模型返回对数域中的概率(即应用了
torch.nn.functional.log_softmax()
),而其他Wav2Vec2模型返回logit。