快捷方式

在 ATen IR 上編寫圖變換

Passes

由於 ATen IR 位於 FX Graph/GraphModule 層面,因此為 FX Graph 編寫的任何變換都可以輕鬆應用於 ATen IR。如果您熟悉如何編寫 FX 圖變換,那麼這裡也是一樣的。

最直接的編寫變換的方式是遍歷給定的圖並直接操作圖中的節點。

例如,假設我們想將 torch.ops.aten.add.Tensor() 呼叫替換為 torch.ops.aten.mul.Tensor() 呼叫

import torch

def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
            node.target = torch.ops.aten.mul.Tensor

我們還可以透過 FX 工具函式刪除和新增新節點,這些工具函式可以在 Graph 文件中找到。例如,如果我們想在 add 呼叫後插入一個 torch.ops.aten.relu.default()

import torch

def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:

            # Specifies the insertion point. Any nodes added to the graph within
            # this scope will be inserted after `node`
            with gm.graph.inserting_after(node):
                # Insert a new `call_function` node with op `torch.ops.aten.relu.default`
                new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,))
                # Replace all the places that use `node` to now use the `new_relu_node`
                node.replace_all_uses_with(new_relu_node)

總的來說,變換可以大致分為幾個維度

維度 A: 1. 建立一對多對映(例如,分解) 2. 建立多對一對映(例如,融合)

維度 B: 1. 進行前向迭代(例如,形狀傳播) 2. 進行後向迭代(例如,死程式碼消除)

維度 C: 1. 依賴於區域性節點資訊(例如, out-variant 轉換) 2. 依賴於全域性圖資訊(例如,記憶體規劃)

我們對這些用例出現頻率的預測是: 1. A.1, B.1, C.1 2. A.2 3. B.2, C.2

雖然我們可以透過直接操作圖來進行所有圖變換,但我們也提供了一些輔助工具,以便於處理級別 1 和 2 的用例。

Transformer

對於級別 1 的用例(建立一對多對映、進行前向迭代以及檢視區域性節點資訊),我們可以利用 Transformer 類來執行每個節點並重新建立圖,但會應用指定的變換。

一對一 Pass

一對一對映的一個示例是,如果我們想用另一個 op B 替換 op A,我們可以執行 GraphModule,並且每次看到 op A 時,返回 op B。

例如

class ReplaceAddWithMul(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target != torch.ops.aten.add.Tensor:
            return super().call_function(target, args, kwargs)
        return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)

transformed_graph_module = ReplaceAddWithMul(graph_module).transform()

super().call_function(target, args, kwargs, meta) 呼叫會建立一個 call_function FX 節點,並返回使用給定引數執行運算元的結果。

一對多 Pass

如果我們想進行一對多對映,例如用另外兩個 op B 和 C 替換 op A,那麼我們將呼叫 super().call_function 兩次,建立兩個 FX 節點,一個使用 op B,另一個使用 op C,並返回執行 op C 的結果。

例如

class ReplaceAddWithMulSub(torch.fx.Transformer):
    """
    Original:
        def f(x, y):
            return x + y

    After pass:
        def f(x, y):
            z = x * y
            return z - y
    """
    def call_function(self, target, args, kwargs):
        if target != torch.ops.aten.add.Tensor:
            return super().call_function(target, args, kwargs)

        x, y = args

        mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
        return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})

transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()

一對零 Pass

如果我們想刪除一個 op,我們可以直接返回傳入函式的值

class RemoveDetachPass(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target not in (
            torch.ops.aten.detach.default,
            torch.ops.aten.detach_copy.default,
        ):
            return super().call_function(target, args, kwargs, meta)

        assert len(args) == 1
        return args[0]

transformed_graph_module = RemoveDetachPass(graph_module).transform()

利用區域性資訊

利用區域性節點資訊的一個例子是,如果我們想將圖中的所有標量轉換為 tensor,我們可以執行給定的 fx.GraphModule,對於包含標量的每個引數,我們將其轉換為 tensor。這可能看起來像

def args_map(target, fn, args, kwargs):
    assert isinstance(args, tuple)
    assert isinstance(kwargs, dict)
    args = list(args)
    kwargs = kwargs.copy()

    # Update the argument based on the function passed
    def update(key, args, schema):
        args[key] = fn(args[key], schema)

    # Update each argument in the schema
    for i, schema in enumerate(target._schema.arguments):
        if schema.name in kwargs:
            update(schema.name, kwargs, schema)
        elif not schema.kwarg_only and i < len(args):
            update(i, args, schema)
    return tuple(args), kwargs

class ScalarToTensorPass(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        breakpoint()
        def try_coerce(value, arg):
            return (
                torch.tensor(value)
                if isinstance(value, (float, int, bool))
                and type(arg.type) == torch.TensorType
                else value
            )

        args, kwargs = args_map(target, try_coerce, args, kwargs)
        return super().call_function(target, args, kwargs)

transformed_graph_module = ScalarToTensorPass(graph_module).transform()

子圖重寫器

為了建立多對一對映,我們可以利用 FX 的 子圖重寫器。給定一個 pattern,它會建立與該模式匹配的運算元子圖,然後將每個匹配的子圖替換為 replacement

注意

This is an inplace operation.

patternreplacement 輸入必須是可呼叫函式或包含圖中使用的相同運算元(ATen ops)的 GraphModules,以便子圖重寫器能在圖中找到正確的模式。模式/替換可呼叫函式的輸入在匹配時將被視為萬用字元。

例如

from torch.fx import subgraph_rewriter

def replace_patterns(graph_module):
    def pattern(x, y):
        x = torch.ops.aten.add.Tensor(x, y)
        x = torch.ops.aten.mul.Tensor(x, y)
        return x

    def replacement(x, y):
        return torch.ops.aten.sub.Tensor(x, y)

replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
    traced_module, pattern, replacement
)

子圖重寫器返回一個 ReplacedPatterns 列表

@dataclass
class ReplacedPatterns:
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]
    # List of nodes that were added into the graph
    replacements: List[Node]

注意

The nodes created by the subgraph rewriter will not have the metadata that
is populated in the matched nodes, but you can use
`ReplacedPatterns.nodes_map` to find the nodes in the original graph that
were matched, and `ReplacedPatterns.replacements` to find the nodes that
were replaced in the transformed graph.

Pass 管理器

`PassManager <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/pass_manager.py>`__ 是一個用於在給定 graph module 上執行多個 passes 的類。初始化 PassManager 例項時,我們傳入一個要執行的 passes 列表並設定一些標誌。要在 graph module 上執行這組 passes,我們可以直接將 graph module 傳遞給 PassManager 例項。

例如

from torch.fx.passes.infra.pass_manager import PassManager

pm = PassManager(
    passes=[replace_add_with_div, replace_div_with_mul],
    run_checks_after_each_pass=True,
    suppress_check_failures=False,
)
graph_module_out = pm(graph_module)

為了新增在每個 pass 執行後執行的一組通用檢查,我們可以呼叫函式 set_checks(check: Callable),它接收一個可呼叫函式作為輸入。如果設定了 run_checks_after_each_pass 標誌,則在 graph module 上執行每個 pass 後,將呼叫 check

例如

pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])

def check_div_target(graph_module):
    for node in graph_module.graph.nodes:
        if node.op == "call_function" and node.target != torch.div:
            raise ValueError("Target should be div!")

pm.add_checks(check_div_target)

pm(graph_module)    # raises ValueError after replace_div_with_mul pass

Partitioner

有一些常用的基於 FX 圖的分割槽器可用於對圖進行分割槽。

子圖匹配器

為了在圖中找到與特定模式匹配的子圖,我們可以利用 FX 的 `SubgraphMatcher <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/matcher_utils.py>`__。

類屬性

  • pattern (Graph):目標匹配模式。圖中的佔位符節點在匹配時將被視為萬用字元。

  • match_output (bool):如果為 True,模式圖中的輸出節點將被視為目標模式的一部分。如果為 False,則在匹配時忽略輸出節點。

  • match_placeholder (bool):如果為 True,模式圖中的佔位符節點將被視為目標模式的一部分。如果為 False,佔位符節點將用作萬用字元。

  • remove_overlapping_matches (bool):如果為 True,在匹配重疊的情況下,只返回第一個匹配項。

  • ignore_literals (bool):如果為 True,將不檢查字面量是否相等,而是將其視為萬用字元。

例如

from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

class LargeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight = torch.nn.Parameter(torch.ones(3, 3))
        self._bias = torch.nn.Parameter(torch.ones(3, 3))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias, x, self._weight)

large_model_graph = torch.export(LargeModel(), inputs).graph

class PatternModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
        self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)

pattern_graph = torch.export(PatternModel(), inputs).graph

subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)

match 函式返回一個 InternalMatch 列表

@dataclass
class InternalMatch():
    # Nodes from which the match was found
    anchors: List[Node]
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node] = field(default_factory=dict)
    # Nodes in target graph that are matched placeholder in pattern
    placeholder_nodes: List[Node] = field(default_factory=list)
    # Nodes in matched subgraph returned by output
    returning_nodes: List[Node] = field(default_factory=list)

基於能力的 Partitioner

為了找到支援特定不變性的最大節點子圖,我們可以利用 FX 的 `CapabilityBasedPartitioner <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/partitioner.py#L34>`__。

類屬性

  • graph_module (torch.fx.GraphModule):我們進行分割槽的 graph module。

  • operator_support (OperatorSupportBase):用於確定圖中節點是否在分割槽中受支援的物件。

  • allows_single_node_partition (bool):如果為 True,允許形成單節點分割槽。

  • non_compute_ops (Optional[Sequence[str]]):一組被認為是“非計算”的 ops(例如 torch.ops.aten.view_operator.getitem),以便分割槽器不會建立僅包含這些非計算 ops 的圖

  • allowed_single_node_partition_ops (Optional[Sequence[str]]):一組允許存在於單節點分割槽中的 ops。

`OperatorSupportBase <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#LL28C1-L28C1>`__ 類被分割槽器用來確定圖中的特定節點是否屬於該分割槽。這是透過重寫 is_node_supported 函式來完成的。你可以使用 `chain <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L150>`__(如果任何 OperatorSupportBase 返回 False 則返回 False)和 `any_chain <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L164>`__(如果任何 OperatorSupportBase 返回 True 則返回 True)來鏈式組合多個 OperatorSupportBase

例如

from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

class AddMulOperatorSupport(OperatorSupportBase):
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        return node.op == "call_function" and node.target in [
            torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
        ]

capability_partitioner = CapabilityBasedPartitioner(
    graph_module,
    op_support,
)

# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()
# Fuses the partitions into graph modules and inserts `call_module` nodes in the graph
fused_graph_module = capability_partitioner.fuse_partitions(partition_list)

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源