Wav2Vec2FABundle¶
- 類 torchaudio.pipelines.Wav2Vec2FABundle[原始碼]¶
用於捆綁關聯資訊的 資料類,以便使用預訓練的
Wav2Vec2Model進行強制對齊。該類提供了用於例項化預訓練模型的介面,以及檢索預訓練權重和與模型一起使用的額外資料所需的資訊。
Torchaudio 庫會例項化此類的物件,每個物件代表一個不同的預訓練模型。客戶端程式碼應透過這些例項訪問預訓練模型。
請參閱下方瞭解用法和可用值。
- 示例 - 特徵提取
>>> import torchaudio >>> >>> bundle = torchaudio.pipelines.MMS_FA >>> >>> # Build the model and load pretrained weight. >>> model = bundle.get_model() Downloading: 100%|███████████████████████████████| 1.18G/1.18G [00:05<00:00, 216MB/s] >>> >>> # Resample audio to the expected sampling rate >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) >>> >>> # Estimate the probability of token distribution >>> emission, _ = model(waveform) >>> >>> # Generate frame-wise alignment >>> alignment, scores = torchaudio.functional.forced_align( >>> emission, targets, input_lengths, target_lengths, blank=0) >>>
- 使用
Wav2Vec2FABundle的教程
屬性¶
sample_rate¶
方法¶
get_aligner¶
get_dict¶
- Wav2Vec2FABundle.get_dict(star: Optional[str] = '*', blank: str ='-') Dict[str, int][原始碼]¶
獲取從 token 到索引的對映(在 emission feature 維度中)
- 引數:
- 返回:
對於在 ASR 上微調的模型,返回表示輸出類別標籤的字串元組。
- 返回型別:
Tuple[str, ...]
- 示例
>>> from torchaudio.pipelines import MMS_FA as bundle >>> bundle.get_dict() {'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27, '*': 28} >>> bundle.get_dict(star=None) {'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27}
get_labels¶
- Wav2Vec2FABundle.get_labels(star: Optional[str] = '*', blank: str ='-') Tuple[str, ...][原始碼]¶
獲取與 emission 特徵維度相對應的標籤。
第一個是 blank token,並且它是可自定義的。
- 引數:
- 返回:
對於在 ASR 上微調的模型,返回表示輸出類別標籤的字串元組。
- 返回型別:
Tuple[str, ...]
- 示例
>>> from torchaudio.pipelines import MMS_FA as bundle >>> bundle.get_labels() ('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x', '*') >>> bundle.get_labels(star=None) ('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x')
get_model¶
- Wav2Vec2FABundle.get_model(with_star: bool = True, *, dl_kwargs=None) Module[原始碼]¶
構建模型並載入預訓練權重。
權重檔案會從網際網路下載並使用
torch.hub.load_state_dict_from_url()進行快取。- 引數:
with_star (bool,可選) – 如果啟用,輸出層的最後一維會擴充套件一個,這對應於 star token。
dl_kwargs (keyword arguments 字典) – 傳遞給
torch.hub.load_state_dict_from_url()。
- 返回:
Wav2Vec2Model的變體。注意
使用此方法建立的模型返回的是對數域的機率(即應用了
torch.nn.functional.log_softmax()),而其他 Wav2Vec2 模型返回的是 logit。