在 ATen IR 上編寫圖形轉換¶
遍歷¶
由於 ATen IR 位於 FX Graph/GraphModule 層級,因此針對 FX Graphs 編寫的任何轉換都可以輕鬆套用到 ATen IR 上。如果您熟悉編寫 FX 圖形轉換,那麼這將是相同的。
編寫轉換最直接的方法是迴圈遍歷給定的圖形並直接操作圖形中的節點。
例如,假設我們要將 torch.ops.aten.add.Tensor() 呼叫替換為 torch.ops.aten.mul.Tensor() 呼叫
import torch
def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
            node.target = torch.ops.aten.mul.Tensor
我們也可以透過 FX 公用程式函式刪除和附加新節點,這些函式可以在 Graph 文件中找到。例如,如果我們要在 add 呼叫之後插入一個 torch.ops.aten.relu.default()
import torch
def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
            # Specifies the insertion point. Any nodes added to the graph within
            # this scope will be inserted after `node`
            with gm.graph.inserting_after(node):
                # Insert a new `call_function` node with op `torch.ops.aten.relu.default`
                new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,))
                # Replace all the places that use `node` to now use the `new_relu_node`
                node.replace_all_uses_with(new_relu_node)
一般而言,轉換可以大致分為幾個軸
軸 A:1. 建立一對多映射(例如分解)2. 建立多對一映射(例如融合)
軸 B:1. 進行正向迭代(例如形狀傳播)2. 進行反向迭代(例如無用程式碼消除)
軸 C:1. 依賴於本地節點資訊(例如輸出變體轉換)2. 依賴於全域圖形資訊(例如記憶體規劃)
我們預計這些用例的頻率為:1. A.1、B.1、C.1 2. A.2 3. B.2、C.2
雖然我們可以透過直接操作圖形來進行所有圖形轉換,但我們也提供了一些輔助公用程式,以便於使用層級 1 和 2 的用例。
轉換器¶
對於層級 1 的用例(建立一對多映射、進行正向迭代和查看本地節點資訊),我們可以使用 Transformer 類別來執行每個節點並重新建立圖形,但使用指定的轉換。
一對一傳遞¶
一對一映射的範例,如果我們想將運算子 A 替換為另一個運算子 B,我們可以執行 GraphModule,並且每次看到運算子 A 時,都返回運算子 B。
範例如下
class ReplaceAddWithMul(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target != torch.ops.aten.add.Tensor:
            return super().call_function(target, args, kwargs)
        return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)
transformed_graph_module = ReplaceAddWithMul(graph_module).transform()
super().call_function(target, args, kwargs, meta) 呼叫會建立一個 call_function FX 節點,並返回使用給定參數執行運算子的結果。
一對多傳遞¶
如果我們想進行一對多映射,例如將運算子 A 替換為另外兩個運算子 B 和 C,那麼我們會呼叫兩次 super().call_function 來建立兩個 FX 節點,一個使用運算子 B,另一個使用運算子 C,並返回執行運算子 C 的結果。
例如
class ReplaceAddWithMulSub(torch.fx.Transformer):
    """
    Original:
        def f(x, y):
            return x + y
    After pass:
        def f(x, y):
            z = x * y
            return z - y
    """
    def call_function(self, target, args, kwargs):
        if target != torch.ops.aten.add.Tensor:
            return super().call_function(target, args, kwargs)
        x, y = args
        mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
        return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})
transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()
一對無傳遞¶
如果我們想移除一個運算子,我們可以直接返回傳遞給函式的值
class RemoveDetachPass(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target not in (
            torch.ops.aten.detach.default,
            torch.ops.aten.detach_copy.default,
        ):
            return super().call_function(target, args, kwargs, meta)
        assert len(args) == 1
        return args[0]
transformed_graph_module = RemoveDetachPass(graph_module).transform()
利用本地資訊¶
利用本地節點資訊的一個例子是,如果我們想將圖形中的所有純量轉換為張量,我們可以執行給定的 fx.GraphModule,並且對於每個包含純量的參數,我們將其轉換為張量。它可能看起來像這樣
def args_map(target, fn, args, kwargs):
    assert isinstance(args, tuple)
    assert isinstance(kwargs, dict)
    args = list(args)
    kwargs = kwargs.copy()
    # Update the argument based on the function passed
    def update(key, args, schema):
        args[key] = fn(args[key], schema)
    # Update each argument in the schema
    for i, schema in enumerate(target._schema.arguments):
        if schema.name in kwargs:
            update(schema.name, kwargs, schema)
        elif not schema.kwarg_only and i < len(args):
            update(i, args, schema)
    return tuple(args), kwargs
class ScalarToTensorPass(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        breakpoint()
        def try_coerce(value, arg):
            return (
                torch.tensor(value)
                if isinstance(value, (float, int, bool))
                and type(arg.type) == torch.TensorType
                else value
            )
        args, kwargs = args_map(target, try_coerce, args, kwargs)
        return super().call_function(target, args, kwargs)
transformed_graph_module = ScalarToTensorPass(graph_module).transform()
子圖形重寫器¶
為了建立多對一映射,我們可以使用 FX 的 子圖形重寫器。給定一個 pattern,它會建立一個與模式匹配的運算子子圖形,然後將每個匹配的子圖形替換為 replacement。
注意
This is an inplace operation.
pattern 和 replacement 輸入必須是可呼叫的函式或包含圖形中使用的相同運算子(ATen 運算子)的 GraphModules,以便子圖形重寫器可以在圖形中找到正確的模式。在匹配時,模式/替換可呼叫物件的輸入將被視為萬用字元。
一個例子
from torch.fx import subgraph_rewriter
def replace_patterns(graph_module):
    def pattern(x, y):
        x = torch.ops.aten.add.Tensor(x, y)
        x = torch.ops.aten.mul.Tensor(x, y)
        return x
    def replacement(x, y):
        return torch.ops.aten.sub.Tensor(x, y)
replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
    traced_module, pattern, replacement
)
子圖形重寫器返回一個 ReplacedPatterns 列表
@dataclass
class ReplacedPatterns:
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]
    # List of nodes that were added into the graph
    replacements: List[Node]
注意
The nodes created by the subgraph rewriter will not have the metadata that
is populated in the matched nodes, but you can use
`ReplacedPatterns.nodes_map` to find the nodes in the original graph that
were matched, and `ReplacedPatterns.replacements` to find the nodes that
were replaced in the transformed graph.
傳遞管理器¶
`PassManager` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/pass_manager.py>`__ 是一個用於在給定的圖形模組上運行多個遍歷的類別。在初始化 PassManager 實例時,我們傳入要運行的遍歷列表並設置一些標誌。若要在一組圖形模組上運行遍歷集合,我們可以直接將圖形模組傳遞給 PassManager 實例。
一個例子
from torch.fx.passes.infra.pass_manager import PassManager
pm = PassManager(
    passes=[replace_add_with_div, replace_div_with_mul],
    run_checks_after_each_pass=True,
    suppress_check_failures=False,
)
graph_module_out = pm(graph_module)
若要添加在每次遍歷後運行的常見檢查集,我們可以呼叫 set_checks(check: Callable) 函數,該函數將可呼叫函數作為輸入。如果設置了 run_checks_after_each_pass 標誌,則在每次遍歷在圖形模組上運行後,都會呼叫 check。
一個例子
pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])
def check_div_target(graph_module):
    for node in graph_module.graph.nodes:
        if node.op == "call_function" and node.target != torch.div:
            raise ValueError("Target should be div!")
pm.add_checks(check_div_target)
pm(graph_module)    # raises ValueError after replace_div_with_mul pass
分割器¶
我們可以使用幾個常見的基於 FX 圖形的分割器來分割圖形。
子圖形匹配器¶
為了在圖形中查找與特定模式匹配的子圖形,我們可以使用 FX 的 `SubgraphMatcher` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/matcher_utils.py>`__。
類別屬性
- pattern (Graph):目標匹配模式。圖形中的佔位符節點在匹配時將被視為萬用字元。
- match_output (bool):如果為 True,則模式圖形中的輸出節點將被視為目標模式的一部分。如果為 False,則在匹配期間會忽略輸出節點。
- match_placeholder (bool):如果為 True,則模式圖形中的佔位符節點將被視為目標模式的一部分。如果為 False,則佔位符節點將被用作萬用字元。
- remove_overlapping_matches (bool):如果為 True,則在重疊匹配的情況下,只會返回第一個匹配。
- ignore_literals (bool):如果為 True,則不會檢查字面值是否相等,而是將其視為萬用字元。
一個例子
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
class LargeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight = torch.nn.Parameter(torch.ones(3, 3))
        self._bias = torch.nn.Parameter(torch.ones(3, 3))
    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias, x, self._weight)
large_model_graph = torch.export(LargeModel(), inputs).graph
class PatternModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
        self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
pattern_graph = torch.export(PatternModel(), inputs).graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)
match 函數返回一個 InternalMatch 列表
@dataclass
class InternalMatch():
    # Nodes from which the match was found
    anchors: List[Node]
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node] = field(default_factory=dict)
    # Nodes in target graph that are matched placeholder in pattern
    placeholder_nodes: List[Node] = field(default_factory=list)
    # Nodes in matched subgraph returned by output
    returning_nodes: List[Node] = field(default_factory=list)
基於能力的分割器¶
為了找到支持特定不變量的節點的最大子圖形,我們可以使用 FX 的 `CapabilityBasedPartitioner` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/partitioner.py#L34>`__。
類別屬性
- graph_module (torch.fx.GraphModule):我們要分割的圖形模組。
- operator_support (OperatorSupportBase):用於確定分割中是否支持圖形中節點的物件。
- allows_single_node_partition (bool):如果為 True,則允許形成單節點分割。
- non_compute_ops (Optional[Sequence[str]]):一組被認為是「非計算」的操作(例如- torch.ops.aten.view和- _operator.getitem),因此分割器不會創建僅包含這些非計算操作的圖形
- allowed_single_node_partition_ops (Optional[Sequence[str]]):一組允許在單節點分割中的操作。
分割器使用 `OperatorSupportBase` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#LL28C1-L28C1>`__ 類別來確定圖形中的特定節點是否屬於該分割。這是通過覆蓋 is_node_supported 函數來完成的。您可以使用 `chain` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L150>`__(如果任何 OperatorSupportBase 返回 False,則返回 False)和 `any_chain` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L164>`__(如果任何 OperatorSupportBase 返回 True,則返回 True)來鏈接多個 OperatorSupportBase。
一個例子
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
class AddMulOperatorSupport(OperatorSupportBase):
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        return node.op == "call_function" and node.target in [
            torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
        ]
capability_partitioner = CapabilityBasedPartitioner(
    graph_module,
    op_support,
)
# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()
# Fuses the partitions into graph modules and inserts `call_module` nodes in the graph
fused_graph_module = capability_partitioner.fuse_partitions(partition_list)