快捷方式

RNNT

torchaudio.models.RNNT[source]

迴圈神經網路換能器 (RNN-T) 模型。

注意

要構建模型,請使用工廠函式之一。

另請參閱

torchaudio.pipelines.RNNTBundle: 帶有預訓練模型的 ASR pipeline。

引數:

方法

forward

RNNT.forward(sources: Tensor, source_lengths: Tensor, targets: Tensor, target_lengths: Tensor, predictor_state: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, Tensor, List[List[Tensor]]][source]

用於訓練的正向傳播。

B: 批次大小;T: 批次中源序列的最大長度;U: 批次中目標序列的最大長度;D: 每個源序列元素的特徵維度。

引數:
  • sources (torch.Tensor) – 右側用右上下文填充的源幀序列,形狀為 (B, T, D)

  • source_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示 sources 中第 i 個批次元素的有效幀數。

  • targets (torch.Tensor) – 目標序列,形狀為 (B, U),每個元素對映到一個目標符號。

  • target_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示 targets 中第 i 個批次元素的有效幀數。

  • predictor_state (List[List[torch.Tensor]] 或 None, 可選) – 張量列表的列表,表示在先前呼叫 forward 中生成的預測網路內部狀態。(預設值: None)

返回值:

torch.Tensor

連線網路的輸出,形狀為 (B, 最大輸出源長度, 最大輸出目標長度, 輸出維度 (目標符號數))

torch.Tensor

輸出源長度,形狀為 (B,),其中第 i 個元素表示連線網路輸出中第 i 個批次元素沿維度 1 的有效元素數。

torch.Tensor

輸出目標長度,形狀為 (B,),其中第 i 個元素表示連線網路輸出中第 i 個批次元素沿維度 2 的有效元素數。

List[List[torch.Tensor]]

輸出狀態;張量列表的列表,表示在當前呼叫 forward 中生成的預測網路內部狀態。

返回型別:

(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

transcribe_streaming

RNNT.transcribe_streaming(sources: Tensor, source_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]][source]

在流模式下將轉錄網路應用於源。

B: 批次大小;T: 批次中源序列段的最大長度;D: 每個源序列幀的特徵維度。

引數:
  • sources (torch.Tensor) – 右側用右上下文填充的源幀序列段,形狀為 (B, T + 右上下文長度, D)

  • source_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示 sources 中第 i 個批次元素的有效幀數。

  • state (List[List[torch.Tensor]] 或 None) – 張量列表的列表,表示在先前呼叫 transcribe_streaming 中生成的轉錄網路內部狀態。

返回值:

torch.Tensor

輸出幀序列,形狀為 (B, T // time_reduction_stride, 輸出維度)

torch.Tensor

輸出長度,形狀為 (B,),其中第 i 個元素表示輸出中第 i 個批次元素的有效元素數。

List[List[torch.Tensor]]

輸出狀態;張量列表的列表,表示在當前呼叫 transcribe_streaming 中生成的轉錄網路內部狀態。

返回型別:

(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

transcribe

RNNT.transcribe(sources: Tensor, source_lengths: Tensor) Tuple[Tensor, Tensor][source]

在非流模式下將轉錄網路應用於源。

B: 批次大小;T: 批次中源序列的最大長度;D: 每個源序列幀的特徵維度。

引數:
  • sources (torch.Tensor) – 右側用右上下文填充的源幀序列,形狀為 (B, T + 右上下文長度, D)

  • source_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示 sources 中第 i 個批次元素的有效幀數。

返回值:

torch.Tensor

輸出幀序列,形狀為 (B, T // time_reduction_stride, 輸出維度)

torch.Tensor

輸出長度,形狀為 (B,),其中第 i 個元素表示輸出幀序列中第 i 個批次元素的有效元素數。

返回型別:

(torch.Tensor, torch.Tensor)

predict

RNNT.predict(targets: Tensor, target_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]][source]

將預測網路應用於目標。

B: 批次大小;U: 批次中目標序列的最大長度;D: 每個目標序列幀的特徵維度。

引數:
  • targets (torch.Tensor) – 目標序列,形狀為 (B, U),每個元素對映到一個目標符號,即在範圍 [0, num_symbols) 內。

  • target_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示 targets 中第 i 個批次元素的有效幀數。

  • state (List[List[torch.Tensor]] 或 None) – 張量列表的列表,表示在先前呼叫 predict 中生成的內部狀態。

返回值:

torch.Tensor

輸出幀序列,形狀為 (B, U, 輸出維度)

torch.Tensor

輸出長度,形狀為 (B,),其中第 i 個元素表示輸出中第 i 個批次元素的有效元素數。

List[List[torch.Tensor]]

輸出狀態;張量列表的列表,表示在當前呼叫 predict 中生成的內部狀態。

返回型別:

(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

join

RNNT.join(source_encodings: Tensor, source_lengths: Tensor, target_encodings: Tensor, target_lengths: Tensor) Tuple[Tensor, Tensor, Tensor][source]

將連線網路應用於源編碼和目標編碼。

B: 批次大小;T: 批次中源序列的最大長度;U: 批次中目標序列的最大長度;D: 每個源序列和目標序列編碼的維度。

引數:
  • source_encodings (torch.Tensor) – 源編碼序列,形狀為 (B, T, D)

  • source_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示 source_encodings 中第 i 個批次元素的有效序列長度。

  • target_encodings (torch.Tensor) – 目標編碼序列,形狀為 (B, U, D)

  • target_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示 target_encodings 中第 i 個批次元素的有效序列長度。

返回值:

torch.Tensor

連線網路的輸出,形狀為 (B, T, U, 輸出維度)

torch.Tensor

輸出源長度,形狀為 (B,),其中第 i 個元素表示連線網路輸出中第 i 個批次元素沿維度 1 的有效元素數。

torch.Tensor

輸出目標長度,形狀為 (B,),其中第 i 個元素表示連線網路輸出中第 i 個批次元素沿維度 2 的有效元素數。

返回型別:

(torch.Tensor, torch.Tensor, torch.Tensor)

工廠函式

emformer_rnnt_model

構建基於 Emformer 的 RNNT

emformer_rnnt_base

構建基於 Emformer 的 RNNT 的基本版本。

原型工廠函式

conformer_rnnt_model

構建基於 Conformer 的迴圈神經網路換能器 (RNN-T) 模型。

conformer_rnnt_base

構建 Conformer RNN-T 模型的基本版本。

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源