• 文件 >
  • 編寫 Dynamo ATen Lowering Passes
快捷方式

編寫 Dynamo ATen Lowering Passes

Lowering Pass 的基礎知識

ATen Lowering Pass 是 Python 函式,它接受 ATen 運算元圖作為輸入,應用一些所需的修改,例如運算元合併/融合、運算元替換、子圖重寫、自定義運算元插入,或對 torch.fx.GraphModule 進行其他操作,然後將修改後的圖返回給呼叫者。這些 Lowering Pass 通常會就地修改圖並返回相同的輸入物件。

Lowering Pass 要求

Torch-TRT 中的 ATen Lowering Pass 函式必須滿足兩個要求: - 函式必須接受 torch.fx.GraphModule 和 torch Tensor 序列 Sequence[torch.Tensor] 作為輸入,並返回 Lowering 後的 torch.fx.GraphModule - 函式必須使圖處於有效且可呼叫的狀態,包括執行任何必要的 linting 和重新編譯

有關 FX 中 圖操作 (Graph Manipulations) 的資訊,請參閱此連結。下面是一個 Lowering Pass 示例,它修復輸入同時也是輸出的圖,這種情況是 TRT Engines 不允許的配置。

Lowering Pass 示例

def repair_input_as_output(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
    """Repair scenarios where inputs are also outputs of the graph

    TRT does not allow such cases, so we insert a clone (identity) layer
    """
    modified_graph = False

    # Extract graph placeholder Tensors
    placeholders = [
        node
        for node in gm.graph.nodes
        if (
            node.op == "placeholder"
            and isinstance(node.type, type)
            and issubclass(node.type, torch.Tensor)
        )
    ]

    for placeholder in placeholders:
        # If any placeholder has any users which are direct graph outputs
        if len(placeholder.users) >= 1 and any(
            user.op == "output" for user in placeholder.users
        ):
            modified_graph = True

            # Get direct graph outputs which are direct uses of placeholders
            direct_outputs = [user for user in placeholder.users if user.op == "output"]

            # Insert clone node for placeholder to ensure
            # placeholder is not a direct output
            with gm.graph.inserting_after(placeholder):
                cloned_placeholder = gm.graph.call_function(
                    torch.ops.aten.clone.default,
                    args=(placeholder,),
                )

            # Replace placeholder as output with cloned version
            for output in direct_outputs:
                output.replace_input_with(placeholder, cloned_placeholder)

    # If the graph was modified, clean up the graph and ensure it is up-to-date
    if modified_graph:
        gm.graph.eliminate_dead_code()
        gm.graph.lint()
        gm.recompile()
        logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")

    return gm

註冊 Lowering Pass

Lowering Pass 目前註冊在 py/torch_tensorrt/dynamo/lowering/passes/__init__.py 中,使用 torch.fx.passes.pass_manager.PassManager 工具按期望順序組裝 pass 列表。直接新增到該列表的新 pass 將應用於 Torch-TensorRT torch.compile 後端中的圖。目前,我們提供一個 ATen Lowering Pass 註冊裝飾器以方便使用,可以直接呼叫,也可以使用可選的 index 關鍵字引數來控制 Lowering Pass 在 pass 列表中的插入位置。

例如,要在預設位置(列表末尾)插入 pass,可以使用以下程式碼

@_aten_lowering_pass
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
    ...

或者,要在 pass 列表中的自定義索引(例如列表開頭)插入 pass,可以使用以下程式碼

@_aten_lowering_pass(index=0)
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
    ...

torch_tensorrt.dynamo.lowering.passes 中還提供了實用工具,用於顯示當前可用的 Lowering Pass 列表,將這些 pass 應用於任意 torch.fx.GraphModule,以及刪除特定索引處的 Lowering Pass。

# Print all lowering passes in the list
print(dump_lowering_passes())

# Apply lowering passes to a GraphModule
apply_lowering_passes(graph_module, sample_inputs)

# Remove the lowering pass at index 1
_remove_lowering_pass(index=1)

注意: 上述 API 可能會發生變化,因為 Lowering Pass 系統正在發展中。

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源