torch_tensorrt.fx¶
函式¶
- torch_tensorrt.fx.compile(module: Module, input, min_acc_module_size: int = 10, max_batch_size: int = 2048, max_workspace_size=33554432, explicit_batch_dimension=False, lower_precision=LowerPrecision.FP16, verbose_log=False, timing_cache_prefix='', save_timing_cache=False, cuda_graph_batch_size=- 1, dynamic_batch=True, is_aten=False, use_experimental_fx_rt=False, correctness_atol=0.1, correctness_rtol=0.1) Module[原始碼]¶
獲取原始模組、輸入和下降設定,執行下降工作流將模組轉換為下降後的模組,即所謂的 TRTModule。
- 引數
module – 原始模組,用於下降。
input – 模組的輸入。
max_batch_size – 最大批次大小(必須 ≥ 1 才能設定,0 表示未設定)
min_acc_module_size – 加速子模組所需的最少節點數
max_workspace_size – 提供給 TensorRT 的最大工作空間大小。
explicit_batch_dimension – 如果設定為 True,在 TensorRT 中使用顯式批次維度,否則使用隱式批次維度。
lower_precision – 提供給 TRTModule 的 lower_precision 配置。
verbose_log – 如果設定為 True,啟用 TensorRT 的詳細日誌。
timing_cache_prefix – fx2trt 使用的計時快取檔案的名稱。
save_timing_cache – 如果設定為 True,使用當前的計時快取資料更新計時快取。
cuda_graph_batch_size – Cuda 圖批次大小,預設為 -1。
dynamic_batch – 批次維度 (dim=0) 是否為動態。
use_experimental_fx_rt – 使用下一代 TRTModule,它支援基於 Python 和 TorchScript 的執行(包括在 C++ 中)。
- 返回值
由 TensorRT 下降處理後的 torch.nn.Module。
類¶
- class torch_tensorrt.fx.TRTModule(engine=None, input_names=None, output_names=None, cuda_graph_batch_size=- 1)[原始碼]¶
- class torch_tensorrt.fx.InputTensorSpec(shape: Sequence[int], dtype: dtype, device: device = device(type='cpu'), shape_ranges: List[Tuple[Sequence[int], Sequence[int], Sequence[int]]] = [], has_batch_dim: bool = True)[原始碼]¶
此類包含輸入張量的資訊。
shape: 張量的形狀。
dtype: 張量的資料型別。
- device: 張量的裝置。這僅用於生成給定模型的輸入,以便執行形狀推理。對於 TensorRT 引擎,輸入必須在 cuda 裝置上。
(續 device 描述)
- shape_ranges: 如果需要動態形狀(形狀維度為 -1),則必須提供此欄位(預設為空列表)。每個 shape_range 是一個包含三個元組的元組 ((min_input_shape), (optimized_input_shape), (max_input_shape))。每個 shape_range 用於填充一個 TensorRT 最佳化配置。例如,如果輸入形狀從 (1, 224) 變化到 (100, 224),並且我們希望對 (25, 224) 進行最佳化,因為它是最常見的輸入形狀,那麼我們將 shape_ranges 設定為 ((1, 224), (25, 225), (100, 224))。
(續 shape_ranges 描述)
- has_batch_dim: 形狀是否包含批次維度。如果引擎需要使用動態形狀執行,則必須提供批次維度。
(續 has_batch_dim 描述)
- class torch_tensorrt.fx.TRTInterpreter(module: GraphModule, input_specs: List[InputTensorSpec], explicit_batch_dimension: bool = False, explicit_precision: bool =False, logger_level=None)[原始碼]¶