• 文件 >
  • 使用 Torch-TensorRT dynamo 後端編譯 FLUX.1-dev 模型
快捷方式

使用 Torch-TensorRT dynamo 後端編譯 FLUX.1-dev 模型

本示例展示瞭如何使用 Torch-TensorRT 最佳化最新的模型 FLUX.1-dev

FLUX.1 [dev] 是一個擁有 120 億引數的 rectified flow Transformer,能夠從文字描述生成影像。這是一個用於非商業應用的開放權重、指導蒸餾模型。

要執行此演示,您需要獲得 Flux 模型訪問許可權(如果尚未獲得,請在 FLUX.1-dev 頁面申請)並安裝以下依賴項

pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3"

FLUX.1-dev 流水線包含不同的元件,例如 transformervaetext_encodertokenizerscheduler。在本示例中,我們演示瞭如何最佳化模型的 transformer 元件(該元件通常佔整個端到端擴散延遲的 >95%)

匯入以下庫

import torch
import torch_tensorrt
from diffusers import FluxPipeline
from torch.export._trace import _export

定義 FLUX-1.dev 模型

使用 FluxPipeline 類載入 FLUX-1.dev 預訓練流水線。FluxPipeline 包含生成影像所需的各種元件,例如 transformervaetext_encodertokenizerscheduler。我們使用 torch_dtype 引數以 FP16 精度載入權重

DEVICE = "cuda:0"
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.float16,
)

# Store the config and transformer backbone
config = pipe.transformer.config
backbone = pipe.transformer.to(DEVICE)

使用 torch.export 匯出主幹網路

定義 dummy 輸入及其對應的動態形狀。由於 0/1 特化,我們以 batch_size=2 的動態形狀匯出 transformer 主幹網路

batch_size = 2
BATCH = torch.export.Dim("batch", min=1, max=2)
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
# To see this recommendation, you can try exporting using min=1, max=4096
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
dynamic_shapes = {
    "hidden_states": {0: BATCH},
    "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
    "pooled_projections": {0: BATCH},
    "timestep": {0: BATCH},
    "txt_ids": {0: SEQ_LEN},
    "img_ids": {0: IMG_ID},
    "guidance": {0: BATCH},
    "joint_attention_kwargs": {},
    "return_dict": None,
}
# The guidance factor is of type torch.float32
dummy_inputs = {
    "hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
        DEVICE
    ),
    "encoder_hidden_states": torch.randn(
        (batch_size, 512, 4096), dtype=torch.float16
    ).to(DEVICE),
    "pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
        DEVICE
    ),
    "timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(DEVICE),
    "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
    "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
    "guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
    "joint_attention_kwargs": {},
    "return_dict": False,
}
# This will create an exported program which is going to be compiled with Torch-TensorRT
ep = _export(
    backbone,
    args=(),
    kwargs=dummy_inputs,
    dynamic_shapes=dynamic_shapes,
    strict=False,
    allow_complex_guards_as_runtime_asserts=True,
)

Torch-TensorRT 編譯

注意

編譯需要具有大記憶體(> 80GB)的 GPU,因為 TensorRT 以 FP32 精度儲存權重。這是一個已知問題,將在未來版本中解決。

我們透過設定 use_fp32_acc=True 啟用 FP32 矩陣乘法累加,透過引入轉換為 FP32 的節點來確保精度得以保留。我們還啟用了顯式型別指定,以確保 TensorRT 遵守使用者設定的資料型別,這是 FP32 矩陣乘法累加的要求。由於這是一個 120 億引數的模型,在 H100 GPU 上編譯大約需要 20-30 分鐘。該模型完全可轉換,並生成一個單一的 TensorRT 引擎。

trt_gm = torch_tensorrt.dynamo.compile(
    ep,
    inputs=dummy_inputs,
    enabled_precisions={torch.float32},
    truncate_double=True,
    min_block_size=1,
    use_fp32_acc=True,
    use_explicit_typing=True,
)

後處理

釋放由匯出程式和 pipe.transformer 佔用的 GPU 記憶體 將 Flux 流水線中的 transformer 設定為 Torch-TRT 編譯後的模型

del ep
backbone.to("cpu")
pipe.to(DEVICE)
torch.cuda.empty_cache()
pipe.transformer = trt_gm
pipe.transformer.config = config

使用提示詞生成影像

提供提示詞和要生成影像的檔名。這裡我們使用提示詞 A golden retriever holding a sign to code

# Function which generates images from the flux pipeline
def generate_image(pipe, prompt, image_name):
    seed = 42
    image = pipe(
        prompt,
        output_type="pil",
        num_inference_steps=20,
        generator=torch.Generator("cuda").manual_seed(seed),
    ).images[0]
    image.save(f"{image_name}.png")
    print(f"Image generated using {image_name} model saved as {image_name}.png")


generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")

生成的影像如下所示

tutorials/_rendered_examples/dynamo/dog_code.png

指令碼總執行時間:( 0 分鐘 0.000 秒)

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源