快捷方式

torch_tensorrt.fx

函式

torch_tensorrt.fx.compile(module: Module, input, min_acc_module_size: int = 10, max_batch_size: int = 2048, max_workspace_size=33554432, explicit_batch_dimension=False, lower_precision=LowerPrecision.FP16, verbose_log=False, timing_cache_prefix='', save_timing_cache=False, cuda_graph_batch_size=- 1, dynamic_batch=True, is_aten=False, use_experimental_fx_rt=False, correctness_atol=0.1, correctness_rtol=0.1) Module[原始碼]

獲取原始模組、輸入和下降設定,執行下降工作流將模組轉換為下降後的模組,即所謂的 TRTModule。

引數
  • module – 原始模組,用於下降。

  • input – 模組的輸入。

  • max_batch_size – 最大批次大小(必須 ≥ 1 才能設定,0 表示未設定)

  • min_acc_module_size – 加速子模組所需的最少節點數

  • max_workspace_size – 提供給 TensorRT 的最大工作空間大小。

  • explicit_batch_dimension – 如果設定為 True,在 TensorRT 中使用顯式批次維度,否則使用隱式批次維度。

  • lower_precision – 提供給 TRTModule 的 lower_precision 配置。

  • verbose_log – 如果設定為 True,啟用 TensorRT 的詳細日誌。

  • timing_cache_prefix – fx2trt 使用的計時快取檔案的名稱。

  • save_timing_cache – 如果設定為 True,使用當前的計時快取資料更新計時快取。

  • cuda_graph_batch_size – Cuda 圖批次大小,預設為 -1。

  • dynamic_batch – 批次維度 (dim=0) 是否為動態。

  • use_experimental_fx_rt – 使用下一代 TRTModule,它支援基於 Python 和 TorchScript 的執行(包括在 C++ 中)。

返回值

由 TensorRT 下降處理後的 torch.nn.Module。

class torch_tensorrt.fx.TRTModule(engine=None, input_names=None, output_names=None, cuda_graph_batch_size=- 1)[原始碼]
class torch_tensorrt.fx.InputTensorSpec(shape: Sequence[int], dtype: dtype, device: device = device(type='cpu'), shape_ranges: List[Tuple[Sequence[int], Sequence[int], Sequence[int]]] = [], has_batch_dim: bool = True)[原始碼]

此類包含輸入張量的資訊。

shape: 張量的形狀。

dtype: 張量的資料型別。

device: 張量的裝置。這僅用於生成給定模型的輸入,以便執行形狀推理。對於 TensorRT 引擎,輸入必須在 cuda 裝置上。

(續 device 描述)

shape_ranges: 如果需要動態形狀(形狀維度為 -1),則必須提供此欄位(預設為空列表)。每個 shape_range 是一個包含三個元組的元組 ((min_input_shape), (optimized_input_shape), (max_input_shape))。每個 shape_range 用於填充一個 TensorRT 最佳化配置。例如,如果輸入形狀從 (1, 224) 變化到 (100, 224),並且我們希望對 (25, 224) 進行最佳化,因為它是最常見的輸入形狀,那麼我們將 shape_ranges 設定為 ((1, 224), (25, 225), (100, 224))。

(續 shape_ranges 描述)

has_batch_dim: 形狀是否包含批次維度。如果引擎需要使用動態形狀執行,則必須提供批次維度。

(續 has_batch_dim 描述)

class torch_tensorrt.fx.TRTInterpreter(module: GraphModule, input_specs: List[InputTensorSpec], explicit_batch_dimension: bool = False, explicit_precision: bool =False, logger_level=None)[原始碼]
class torch_tensorrt.fx.TRTInterpreterResult(engine, input_names, output_names, serialized_cache)[原始碼]

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源