torchaudio.models.wav2vec2.utils.import_fairseq_model¶
- torchaudio.models.wav2vec2.utils.import_fairseq_model(original: Module) Wav2Vec2Model [源代码]¶
从 fairseq 的相应模型对象构建
Wav2Vec2Model
。- 参数:
original (torch.nn.Module) – fairseq 的 Wav2Vec2.0 或 HuBERT 模型的实例。
fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder
、fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model
或fairseq.models.hubert.hubert_asr.HubertEncoder
之一。- 返回值:
导入的模型。
- 返回类型:
- 示例 - 加载仅预训练的模型
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model >>> >>> # Load model using fairseq >>> model_file = 'wav2vec_small.pt' >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) >>> original = model[0] >>> imported = import_fairseq_model(original) >>> >>> # Perform feature extraction >>> waveform, _ = torchaudio.load('audio.wav') >>> features, _ = imported.extract_features(waveform) >>> >>> # Compare result with the original model from fairseq >>> reference = original.feature_extractor(waveform).transpose(1, 2) >>> torch.testing.assert_allclose(features, reference)
- 示例 - 微调后的模型
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model >>> >>> # Load model using fairseq >>> model_file = 'wav2vec_small_960h.pt' >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) >>> original = model[0] >>> imported = import_fairseq_model(original.w2v_encoder) >>> >>> # Perform encoding >>> waveform, _ = torchaudio.load('audio.wav') >>> emission, _ = imported(waveform) >>> >>> # Compare result with the original model from fairseq >>> mask = torch.zeros_like(waveform) >>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1) >>> torch.testing.assert_allclose(emission, reference)