• 文件 >
  • 使用 Torch-TensorRT 的動態形狀
快捷方式

使用 Torch-TensorRT 的動態形狀

預設情況下,您可以使用不同的輸入形狀執行 PyTorch 模型,輸出形狀會即時確定。然而,Torch-TensorRT 是一款 AOT(提前)編譯器,它需要關於輸入形狀的一些預先資訊才能編譯和最佳化模型。

使用 torch.export 的動態形狀 (AOT)

對於動態輸入形狀,我們必須提供 (min_shape, opt_shape, max_shape) 引數,以便模型能夠針對此輸入形狀範圍進行最佳化。靜態和動態形狀的使用示例如下。

注意:以下程式碼使用 Dynamo 前端。如果使用 Torchscript 前端,請將 ir=dynamo 替換為 ir=ts,行為完全相同。

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
# Compile with static shapes
inputs = torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32)
# or compile with dynamic shapes
inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224],
                              opt_shape=[4, 3, 224, 224],
                              max_shape=[8, 3, 224, 224],
                              dtype=torch.float32)
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)

幕後工作原理

當我們使用 torch_tensorrt.compile API 並指定 ir=dynamo (預設) 時,編譯過程分為兩個階段。

  • torch_tensorrt.dynamo.trace (使用 torch.export 根據給定的輸入跟蹤圖)

我們使用 torch.export.export() API 來跟蹤 PyTorch 模組並將其匯出為 torch.export.ExportedProgram。對於動態形狀的輸入,透過 torch_tensorrt.Input API 提供的 (min_shape, opt_shape, max_shape) 範圍用於構造 torch.export.Dim 物件,這些物件在 export API 的 dynamic_shapes 引數中使用。請檢視 _tracer.py 檔案以瞭解其幕後工作原理。

  • torch_tensorrt.dynamo.compile (使用 TensorRT 編譯一個 torch.export.ExportedProgram 物件)

在轉換為 TensorRT 時,圖已經在節點的元資料中包含了動態形狀資訊,這些資訊將在 engine 構建階段使用。

自定義動態形狀約束

給定輸入 x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype),Torch-TensorRT 嘗試在 torch.export 跟蹤期間透過相應地使用提供的動態維度構造 torch.export.Dim 物件來自動設定約束。有時,我們可能需要設定額外的約束,如果我們不指定它們,Torchdynamo 會報錯。如果您必須為模型設定任何自定義約束(透過使用 torch.export.Dim),我們建議您先匯出程式,然後再使用 Torch-TensorRT 進行編譯。請參閱此文件,瞭解如何匯出具有動態形狀的 PyTorch 模組。這裡是一個簡單的示例,演示如何匯出一個對動態維度施加了一些限制的 matmul 層。

import torch
import torch_tensorrt

class MatMul(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query, key):
        attn_weight = torch.matmul(query, key.transpose(-1, -2))
        return attn_weight

model = MatMul().eval().cuda()
inputs = [torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()]
seq_len = torch.export.Dim("seq_len", min=1, max=10)
dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
exp_program = torch.export.export(model, tuple(inputs), dynamic_shapes=dynamic_shapes)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
# Run inference
trt_gm(*inputs)

使用 torch.compile 的動態形狀 (JIT)

torch_tensorrt.compile(model, inputs, ir="torch_compile") 返回一個配置了 TensorRT 後端的 torch.compile 封裝函式。在使用 ir=torch_compile 的情況下,使用者可以使用 torch._dynamo.mark_dynamic API (https://pytorch.com.tw/docs/stable/torch.compiler_dynamic_shapes.html) 為輸入提供動態形狀資訊,以避免 TensorRT engine 的重新編譯。

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224), dtype=float32)
# This indicates the dimension 0 is dynamic and the range is [1, 8]
torch._dynamo.mark_dynamic(inputs, 0, min=1, max=8)
trt_gm = torch.compile(model, backend="tensorrt")
# Compilation happens when you call the model
trt_gm(inputs)

# No recompilation of TRT engines with modified batch size
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32)
trt_gm(inputs_bs2)

文件

獲取 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源