RNNT¶
- 類 torchaudio.models.RNNT[source]¶
迴圈神經網路換能器 (RNN-T) 模型。
注意
要構建模型,請使用工廠函式之一。
另請參閱
torchaudio.pipelines.RNNTBundle: 帶有預訓練模型的 ASR pipeline。- 引數:
transcriber (torch.nn.Module) – 轉錄網路。
predictor (torch.nn.Module) – 預測網路。
joiner (torch.nn.Module) – 連線網路。
方法¶
forward¶
- RNNT.forward(sources: Tensor, source_lengths: Tensor, targets: Tensor, target_lengths: Tensor, predictor_state: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, Tensor, List[List[Tensor]]][source]¶
用於訓練的正向傳播。
B: 批次大小;T: 批次中源序列的最大長度;U: 批次中目標序列的最大長度;D: 每個源序列元素的特徵維度。
- 引數:
sources (torch.Tensor) – 右側用右上下文填充的源幀序列,形狀為 (B, T, D)。
source_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示
sources中第 i 個批次元素的有效幀數。targets (torch.Tensor) – 目標序列,形狀為 (B, U),每個元素對映到一個目標符號。
target_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示
targets中第 i 個批次元素的有效幀數。predictor_state (List[List[torch.Tensor]] 或 None, 可選) – 張量列表的列表,表示在先前呼叫
forward中生成的預測網路內部狀態。(預設值:None)
- 返回值:
- torch.Tensor
連線網路的輸出,形狀為 (B, 最大輸出源長度, 最大輸出目標長度, 輸出維度 (目標符號數))。
- torch.Tensor
輸出源長度,形狀為 (B,),其中第 i 個元素表示連線網路輸出中第 i 個批次元素沿維度 1 的有效元素數。
- torch.Tensor
輸出目標長度,形狀為 (B,),其中第 i 個元素表示連線網路輸出中第 i 個批次元素沿維度 2 的有效元素數。
- List[List[torch.Tensor]]
輸出狀態;張量列表的列表,表示在當前呼叫
forward中生成的預測網路內部狀態。
- 返回型別:
(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
transcribe_streaming¶
- RNNT.transcribe_streaming(sources: Tensor, source_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]][source]¶
在流模式下將轉錄網路應用於源。
B: 批次大小;T: 批次中源序列段的最大長度;D: 每個源序列幀的特徵維度。
- 引數:
sources (torch.Tensor) – 右側用右上下文填充的源幀序列段,形狀為 (B, T + 右上下文長度, D)。
source_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示
sources中第 i 個批次元素的有效幀數。state (List[List[torch.Tensor]] 或 None) – 張量列表的列表,表示在先前呼叫
transcribe_streaming中生成的轉錄網路內部狀態。
- 返回值:
- torch.Tensor
輸出幀序列,形狀為 (B, T // time_reduction_stride, 輸出維度)。
- torch.Tensor
輸出長度,形狀為 (B,),其中第 i 個元素表示輸出中第 i 個批次元素的有效元素數。
- List[List[torch.Tensor]]
輸出狀態;張量列表的列表,表示在當前呼叫
transcribe_streaming中生成的轉錄網路內部狀態。
- 返回型別:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
transcribe¶
- RNNT.transcribe(sources: Tensor, source_lengths: Tensor) Tuple[Tensor, Tensor][source]¶
在非流模式下將轉錄網路應用於源。
B: 批次大小;T: 批次中源序列的最大長度;D: 每個源序列幀的特徵維度。
- 引數:
sources (torch.Tensor) – 右側用右上下文填充的源幀序列,形狀為 (B, T + 右上下文長度, D)。
source_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示
sources中第 i 個批次元素的有效幀數。
- 返回值:
- torch.Tensor
輸出幀序列,形狀為 (B, T // time_reduction_stride, 輸出維度)。
- torch.Tensor
輸出長度,形狀為 (B,),其中第 i 個元素表示輸出幀序列中第 i 個批次元素的有效元素數。
- 返回型別:
predict¶
- RNNT.predict(targets: Tensor, target_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]][source]¶
將預測網路應用於目標。
B: 批次大小;U: 批次中目標序列的最大長度;D: 每個目標序列幀的特徵維度。
- 引數:
targets (torch.Tensor) – 目標序列,形狀為 (B, U),每個元素對映到一個目標符號,即在範圍 [0, num_symbols) 內。
target_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示
targets中第 i 個批次元素的有效幀數。state (List[List[torch.Tensor]] 或 None) – 張量列表的列表,表示在先前呼叫
predict中生成的內部狀態。
- 返回值:
- torch.Tensor
輸出幀序列,形狀為 (B, U, 輸出維度)。
- torch.Tensor
輸出長度,形狀為 (B,),其中第 i 個元素表示輸出中第 i 個批次元素的有效元素數。
- List[List[torch.Tensor]]
輸出狀態;張量列表的列表,表示在當前呼叫
predict中生成的內部狀態。
- 返回型別:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
join¶
- RNNT.join(source_encodings: Tensor, source_lengths: Tensor, target_encodings: Tensor, target_lengths: Tensor) Tuple[Tensor, Tensor, Tensor][source]¶
將連線網路應用於源編碼和目標編碼。
B: 批次大小;T: 批次中源序列的最大長度;U: 批次中目標序列的最大長度;D: 每個源序列和目標序列編碼的維度。
- 引數:
source_encodings (torch.Tensor) – 源編碼序列,形狀為 (B, T, D)。
source_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示
source_encodings中第 i 個批次元素的有效序列長度。target_encodings (torch.Tensor) – 目標編碼序列,形狀為 (B, U, D)。
target_lengths (torch.Tensor) – 形狀為 (B,),其中第 i 個元素表示
target_encodings中第 i 個批次元素的有效序列長度。
- 返回值:
- torch.Tensor
連線網路的輸出,形狀為 (B, T, U, 輸出維度)。
- torch.Tensor
輸出源長度,形狀為 (B,),其中第 i 個元素表示連線網路輸出中第 i 個批次元素沿維度 1 的有效元素數。
- torch.Tensor
輸出目標長度,形狀為 (B,),其中第 i 個元素表示連線網路輸出中第 i 個批次元素沿維度 2 的有效元素數。
- 返回型別:
工廠函式¶
構建基於 Emformer 的 |
|
構建基於 Emformer 的 |
原型工廠函式¶
構建基於 Conformer 的迴圈神經網路換能器 (RNN-T) 模型。 |
|
構建 Conformer RNN-T 模型的基本版本。 |