直接從 PyTorch 使用 Torch-TensorRT TorchScript 前端¶
現在您將能夠直接從 PyTorch API 訪問 TensorRT。使用此功能的過程與在 Python 中使用 Torch-TensorRT中描述的編譯工作流程非常相似。
首先將 torch_tensorrt 載入到您的應用程式中。
import torch
import torch_tensorrt
然後,給定一個 TorchScript 模組,您可以使用 torch._C._jit_to_backend("tensorrt", ...) API 透過 TensorRT 編譯它。
import torchvision.models as models
model = models.mobilenet_v2(pretrained=True)
script_model = torch.jit.script(model)
與 Torch-TensorRT 中的 compile API 不同(該 API 假定您正在嘗試編譯模組的 forward 函式)或 convert_method_to_trt_engine(該 API 將指定函式轉換為 TensorRT 引擎),後端 API 將接受一個字典,該字典將要編譯的函式名稱對映到 Compilation Spec 物件,這些物件封裝了與您提供給 compile 的字典相同的內容。有關編譯規範字典的更多資訊,請參閱 Torch-TensorRT TensorRTCompileSpec API 的文件。
spec = {
"forward": torch_tensorrt.ts.TensorRTCompileSpec(
**{
"inputs": [torch_tensorrt.Input([1, 3, 300, 300])],
"enabled_precisions": {torch.float, torch.half},
"refit": False,
"debug": False,
"device": {
"device_type": torch_tensorrt.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": True,
},
"capability": torch_tensorrt.EngineCapability.default,
"num_avg_timing_iters": 1,
}
)
}
現在要使用 Torch-TensorRT 進行編譯,請將目標模組物件和規範字典提供給 torch._C._jit_to_backend("tensorrt", ...)
trt_model = torch._C._jit_to_backend("tensorrt", script_model, spec)
要執行,請顯式呼叫您要執行的方法函式(而不是像在標準 PyTorch 中那樣直接在模組本身上呼叫)
input = torch.randn((1, 3, 300, 300)).to("cuda").to(torch.half)
print(trt_model.forward(input))