注意
轉到文末下載完整示例程式碼
可變 Torch TensorRT 模組¶
我們將演示如何輕鬆使用可變 Torch TensorRT 模組來編譯、互動和修改 TensorRT 圖模組。
編譯 Torch-TensorRT 模組非常簡單,但修改編譯後的模組可能會面臨挑戰,尤其是在維護 PyTorch 模組與相應的 Torch-TensorRT 模組之間的狀態和連線時。在預編譯 (AoT) 場景中,將 Torch TensorRT 與複雜管道(例如 Hugging Face Stable Diffusion 管道)整合會更加困難。可變 Torch TensorRT 模組旨在解決這些挑戰,使與 Torch-TensorRT 模組的互動比以往任何時候都更容易。
- 在本教程中,我們將介紹以下內容:
使用 ResNet 18 的可變 Torch TensorRT 模組示例工作流程
儲存可變 Torch TensorRT 模組
LoRA 用例中與 Huggingface 管道的整合
在可變 Torch TensorRT 模組中使用動態形狀
import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
np.random.seed(5)
torch.manual_seed(5)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
使用設定初始化可變 Torch TensorRT 模組。¶
settings = {
"use_python": False,
"enabled_precisions": {torch.float32},
"immutable_weights": False,
}
model = models.resnet18(pretrained=True).eval().to("cuda")
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
mutable_module(*inputs)
修改可變模組。¶
修改可變模組可能會觸發重新擬合或重新編譯。例如,載入不同的 state_dict 並設定新的權重值將觸發重新擬合,而向模型新增模組將觸發重新編譯。
model2 = models.resnet18(pretrained=False).eval().to("cuda")
mutable_module.load_state_dict(model2.state_dict())
# Check the output
# The refit happens while you call the mutable module again.
expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assert torch.allclose(
expected_output, refitted_output, 1e-2, 1e-2
), "Refit Result is not correct. Refit failed"
print("Refit successfully!")
儲存可變 Torch TensorRT 模組¶
# Currently, saving is only enabled when "use_python_runtime" = False in settings
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
使用 Huggingface 的 Stable Diffusion¶
from diffusers import DiffusionPipeline
with torch.no_grad():
settings = {
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"immutable_weights": False,
}
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
device = "cuda:0"
prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.to(device)
# The only extra line you need
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
BATCH = torch.export.Dim("BATCH", min=2, max=24)
_HEIGHT = torch.export.Dim("_HEIGHT", min=16, max=32)
_WIDTH = torch.export.Dim("_WIDTH", min=16, max=32)
HEIGHT = 4 * _HEIGHT
WIDTH = 4 * _WIDTH
args_dynamic_shapes = ({0: BATCH, 2: HEIGHT, 3: WIDTH}, {})
kwargs_dynamic_shapes = {
"encoder_hidden_states": {0: BATCH},
"added_cond_kwargs": {
"text_embeds": {0: BATCH},
"time_ids": {0: BATCH},
},
}
pipe.unet.set_expected_dynamic_shape_range(
args_dynamic_shapes, kwargs_dynamic_shapes
)
image = pipe(
prompt,
negative_prompt=negative,
num_inference_steps=30,
height=1024,
width=768,
num_images_per_prompt=2,
).images[0]
image.save("./without_LoRA_mutable.jpg")
# Standard Huggingface LoRA loading procedure
pipe.load_lora_weights(
"stablediffusionapi/load_lora_embeddings",
weight_name="all-disney-princess-xl-lo.safetensors",
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()
# Refit triggered
image = pipe(
prompt,
negative_prompt=negative,
num_inference_steps=30,
height=1024,
width=1024,
num_images_per_prompt=1,
).images[0]
image.save("./with_LoRA_mutable.jpg")
將可變 Torch TensorRT 模組與動態形狀一起使用¶
向 MutableTorchTensorRTModule 新增動態形狀提示時,形狀提示應嚴格遵循傳遞給 forward 函式的 arg_inputs 和 kwarg_inputs 的語義,並且不應省略任何條目(kwarg_inputs 中的 None 除外)。如果輸入中存在巢狀的字典/列表,則對應條目的動態形狀也應為巢狀的字典/列表。如果輸入的動態形狀不需要,則應為此輸入提供一個空字典作為形狀提示。請注意,應排除值為 None 的關鍵字引數,因為它們將被過濾掉。
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c={}):
x = torch.matmul(a, b)
x = torch.matmul(c["a"], c["b"].T)
print(c["b"][0])
x = 2 * c["b"]
return x
device = "cuda:0"
model = Model().eval().to(device)
inputs = (torch.rand(10, 3).to(device), torch.rand(3, 30).to(device))
kwargs = {
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(10, 30).to(device)},
}
dim_0 = torch.export.Dim("dim", min=1, max=50)
dim_1 = torch.export.Dim("dim", min=1, max=50)
dim_2 = torch.export.Dim("dim2", min=1, max=50)
args_dynamic_shapes = ({1: dim_1}, {0: dim_0})
kwarg_dynamic_shapes = {
"c": {
"a": {},
"b": {0: dim_2},
}, # a's shape does not change so we give it an empty dict
}
# Export the model first with custom dynamic shape constraints
model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1)
model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
# Compile
model(*inputs, **kwargs)
# Change input shape
inputs_2 = (torch.rand(10, 5).to(device), torch.rand(10, 30).to(device))
kwargs_2 = {
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(5, 30).to(device)},
}
# Run without recompiling
model(*inputs_2, **kwargs_2)
將可變 Torch TensorRT 模組與持久快取一起使用¶
利用引擎快取,我們可以縮短引擎編譯時間並節省大量時間。
import os
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
model = models.resnet18(pretrained=True).eval().to("cuda")
times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
model = torch_trt.MutableTorchTensorRTModule(
model,
use_python_runtime=True,
enabled_precisions={torch.float},
debug=True,
min_block_size=1,
immutable_weights=False,
cache_built_engines=True,
reuse_cached_engines=True,
engine_cache_size=1 << 30, # 1GB
)
def remove_timing_cache(path=TIMING_CACHE_PATH):
if os.path.exists(path):
os.remove(path)
remove_timing_cache()
for i in range(4):
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
start.record()
model(*inputs) # Recompile
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
print("----------------dynamo_compile----------------")
print("Without engine caching, used:", times[0], "ms")
print("With engine caching used:", times[1], "ms")
print("With engine caching used:", times[2], "ms")
print("With engine caching used:", times[3], "ms")
指令碼總執行時間: ( 0 分鐘 0.000 秒)