torchaudio.models.wav2vec2.utils.import_fairseq_model¶
- torchaudio.models.wav2vec2.utils.import_fairseq_model(original: Module) Wav2Vec2Model[原始碼]¶
從對應的 fairseq 模型物件構建
Wav2Vec2Model。- 引數:
original (torch.nn.Module) – fairseq 的 Wav2Vec2.0 或 HuBERT 模型的例項。可以是
fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder、fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model或fairseq.models.hubert.hubert_asr.HubertEncoder中的一個。- 返回:
匯入的模型。
- 返回型別:
- 示例 - 載入僅預訓練模型
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model >>> >>> # Load model using fairseq >>> model_file = 'wav2vec_small.pt' >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) >>> original = model[0] >>> imported = import_fairseq_model(original) >>> >>> # Perform feature extraction >>> waveform, _ = torchaudio.load('audio.wav') >>> features, _ = imported.extract_features(waveform) >>> >>> # Compare result with the original model from fairseq >>> reference = original.feature_extractor(waveform).transpose(1, 2) >>> torch.testing.assert_allclose(features, reference)
- 示例 - 微調模型
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model >>> >>> # Load model using fairseq >>> model_file = 'wav2vec_small_960h.pt' >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) >>> original = model[0] >>> imported = import_fairseq_model(original.w2v_encoder) >>> >>> # Perform encoding >>> waveform, _ = torchaudio.load('audio.wav') >>> emission, _ = imported(waveform) >>> >>> # Compare result with the original model from fairseq >>> mask = torch.zeros_like(waveform) >>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1) >>> torch.testing.assert_allclose(emission, reference)