快捷方式

torchaudio.models.emformer_rnnt_model

torchaudio.models.emformer_rnnt_model(*, input_dim: int, encoding_dim: int, num_symbols: int, segment_length: int, right_context_length: int, time_reduction_input_dim: int, time_reduction_stride: int, transformer_num_heads: int, transformer_ffn_dim: int, transformer_num_layers: int, transformer_dropout: float, transformer_activation: str, transformer_left_context_length: int, transformer_max_memory_size: int, transformer_weight_init_scale_strategy: str, transformer_tanh_on_mem: bool, symbol_embedding_dim: int, num_lstm_layers: int, lstm_layer_norm: bool, lstm_layer_norm_epsilon: float, lstm_dropout: float) RNNT[source]

構建基於 Emformer 的 RNNT 模型。

注意

對於非流式推理,期望在輸入序列上呼叫 transcribe 方法,輸入序列需要透過右上下文長度 (right_context_length) 的幀進行右側拼接。

對於流式推理,期望在輸入塊上呼叫 transcribe_streaming 方法,輸入塊應包含 segment_length 幀並與 right_context_length 幀進行右側拼接。

引數:
  • input_dim (int) – 傳遞給轉錄網路的輸入序列幀的維度。

  • encoding_dim (int) – 傳遞給聯合網路的轉錄網路和預測網路生成的編碼的維度。

  • num_symbols (int) – 目標標記集合的基數。

  • segment_length (int) – 輸入段的長度,以幀數表示。

  • right_context_length (int) – 右上下文的長度,以幀數表示。

  • time_reduction_input_dim (int) – 在應用時間縮減塊之前,將輸入序列中每個元素縮放到的維度。

  • time_reduction_stride (int) – 縮減輸入序列長度的因子(步長)。

  • transformer_num_heads (int) – 每個 Emformer 層中的注意力頭數量。

  • transformer_ffn_dim (int) – 每個 Emformer 層的全連線網路中的隱藏層維度。

  • transformer_num_layers (int) – 要例項化的 Emformer 層數量。

  • transformer_left_context_length (int) – Emformer 考慮的左上下文長度。

  • transformer_dropout (float) – Emformer 的 dropout 機率。

  • transformer_activation (str) – 在每個 Emformer 層的全連線網路中使用的啟用函式。必須是 (“relu”, “gelu”, “silu”) 之一。

  • transformer_max_memory_size (int) – 要使用的最大記憶體元素數量。

  • transformer_weight_init_scale_strategy (str) – 每層權重初始化縮放策略。必須是 (“depthwise”, “constant”, None) 之一。

  • transformer_tanh_on_mem (bool) – 如果為 True,則對記憶體元素應用 tanh。

  • symbol_embedding_dim (int) – 每個目標標記嵌入的維度。

  • num_lstm_layers (int) – 要例項化的 LSTM 層數量。

  • lstm_layer_norm (bool) – 如果為 True,則為 LSTM 層啟用層歸一化。

  • lstm_layer_norm_epsilon (float) – 在 LSTM 層歸一化層中使用的 epsilon 值。

  • lstm_dropout (float) – LSTM 的 dropout 機率。

返回:

Emformer RNN-T 模型。

返回型別:

RNNT

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源