快捷方式

torchaudio.models.wav2vec2.utils.import_fairseq_model

torchaudio.models.wav2vec2.utils.import_fairseq_model(original: Module) Wav2Vec2Model [源代码]

fairseq 的相应模型对象构建 Wav2Vec2Model

参数:

original (torch.nn.Module) – fairseq 的 Wav2Vec2.0 或 HuBERT 模型的实例。 fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model fairseq.models.hubert.hubert_asr.HubertEncoder 之一。

返回值:

导入的模型。

返回类型:

Wav2Vec2Model

示例 - 加载仅预训练的模型
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model
>>>
>>> # Load model using fairseq
>>> model_file = 'wav2vec_small.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0]
>>> imported = import_fairseq_model(original)
>>>
>>> # Perform feature extraction
>>> waveform, _ = torchaudio.load('audio.wav')
>>> features, _ = imported.extract_features(waveform)
>>>
>>> # Compare result with the original model from fairseq
>>> reference = original.feature_extractor(waveform).transpose(1, 2)
>>> torch.testing.assert_allclose(features, reference)
示例 - 微调后的模型
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model
>>>
>>> # Load model using fairseq
>>> model_file = 'wav2vec_small_960h.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0]
>>> imported = import_fairseq_model(original.w2v_encoder)
>>>
>>> # Perform encoding
>>> waveform, _ = torchaudio.load('audio.wav')
>>> emission, _ = imported(waveform)
>>>
>>> # Compare result with the original model from fairseq
>>> mask = torch.zeros_like(waveform)
>>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1)
>>> torch.testing.assert_allclose(emission, reference)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的答案

查看资源