• 文件 >
  • 自定義編譯器 Pass 和 分割槽器
快捷方式

自定義編譯器 Pass 和 分割槽器

Passes

Passes 大致可分為幾個維度

維度 A

  1. 建立一對多對映(例如,分解)

  2. 建立多對一對映(例如,融合)

維度 B

  1. 執行前向迭代(例如,形狀傳播)

  2. 執行後向迭代(例如,死程式碼消除)

維度 C

  1. 依賴於區域性節點資訊(例如,out 變體轉換)

  2. 依賴於全域性圖資訊(例如,記憶體規劃)

我們對這些用例發生頻率的預測是

  1. A.1, B.1, C.1

  2. A.2

  3. B.2, C.2

級別 1

對於級別 1 的用例(建立一對多對映、執行前向迭代以及檢視區域性節點資訊),我們可以使用一個名為 ExportPass 的輔助類。這是一種基於直譯器的方式,我們執行每個節點並根據指定的轉換重新建立圖。這使我們能夠透過確保 Pass 期間建立的所有節點都符合 IR 規範來保留 IR 規範,包括確保堆疊跟蹤、FakeTensor 值和 torch.nn.Module 層次結構等元資料得到保留並根據所做的轉換進行更新。

要實現這個 Pass,我們可以建立一個 ExportPass 的子類並實現公開的函式。當使用一個圖模組呼叫時,它會執行該圖模組並建立一個包含 Pass 指定更改的新圖。這意味著傳入的圖模組必須能夠在 CPU 上執行,並且在 Pass 執行後會保持這個不變性。

一對一 Pass

一對一對映的一個示例是,如果我們想用另一個 op B 替換 op A,我們可以執行給定的 fx.GraphModule,並且每次看到 op A 時,返回 op B。

考慮以下示例

class ReplaceInPlaceReluWithOutOfPlaceReluPass(ExportPass):
    """
    relu_ is the in-place version. Replace it with relu, which is the
    out-of-place version
    """

    def call_operator(self, op, args, kwargs, meta):
        if op != torch.ops.aten.relu_.default:
            return super().call_operator(op, args, kwargs, meta)
        return super().call_operator(Op(torch.ops.aten.relu.default), args, kwargs, meta)

# To create a pass
replace_pass = ReplaceInPlaceReluWithOutOfPlaceReluPass()
# To run a pass
new_graph_module = replace_pass(graph_module).graph_module

`super().call_operator(op, args, kwargs, meta)` 呼叫會建立一個 call_function FX 節點,並返回使用給定引數執行該運算元的結果。

一對多 Pass

如果我們想進行一對多對映,例如用另外兩個 op B 和 C 替換 op A,那麼我們將對 super().call_operator 進行兩次呼叫,建立兩個 FX 節點,一個使用 op B,另一個使用 op C,並返回執行 op C 的結果。

例如

class ReplaceAddWithMulSub(ExportPass):
    """
    Original:
        def f(x, y):
            return x + y

    After pass:
        def f(x, y):
            z = x * y
            return z - y
    """
    def call_operator(self, op, args, kwargs, meta):
        if op != torch.ops.aten.add.default:
            return super().call_operator(op, args, kwargs, meta)

        x, y = args

        mul_res = super().call_operator(
            torch.ops.aten.mul.default,
            args,
            {},
            meta
        )

        return super().call_operator(
            torch.ops.aten.sub.default,
            (mul_res, y),
            {},
            meta
        )

一對無 Pass

如果我們想移除一個 op,我們可以直接返回傳遞給函式的值

class RemoveDetachPass(ExportPass):
    def call_operator(self, op, args, kwargs, meta):
        if op not in (
            torch.ops.aten.detach.default,
            torch.ops.aten.detach_copy.default,
        ):
            return super().call_operator(op, args, kwargs, meta)

        assert len(args) == 1
        return args[0]

利用區域性資訊

利用區域性節點資訊的一個示例是,如果我們想將圖中的所有標量轉換為張量,我們可以執行給定的 fx.GraphModule,對於包含標量的每個引數,我們將其轉換為張量。它可能看起來像這樣

def args_map(op, fn, args, kwargs):
    assert isinstance(args, tuple)
    assert isinstance(kwargs, dict)
    args = list(args)
    kwargs = kwargs.copy()

    # Update the argument based on the function passed
    def update(key, args, schema):
        args[key] = fn(args[key], schema)

    # Update each argument in the schema
    for i, schema in enumerate(self.op._schema.arguments):
        if schema.name in kwargs:
            update(schema.name, kwargs, schema)
        elif not schema.kwarg_only and i < len(args):
            update(i, args, schema)

class ScalarToTensorPass(ExportPass):
    def call_operator(self, op, args, kwargs):
        def try_coerce(value, arg):
            return (
                torch.tensor(value)
                if isinstance(value, (float, int, bool))
                and type(arg.type) == torch.TensorType
                else value
            )

        args, kwargs = args_map(op, try_coerce, args, kwargs)
        return super().call_operator(op, args, kwargs)

級別 2

為了建立多對一對映,我們可以利用 FX 的子圖重寫器。給定一個 pattern,它會建立一個與該模式匹配的運算元子圖,然後將每個匹配的子圖替換為 replacement

注意

This is an inplace operation.

`pattern` 和 replacement 輸入必須是可呼叫函式,它們使用與您要匹配的 EXIR 圖中使用的相同運算元 (ATen 運算元) 編寫,以便子圖重寫器可以在圖中找到正確的模式。pattern/replacement 可呼叫函式的輸入將被視為萬用字元。

考慮以下示例

from torch.fx import subgraph_rewriter

def replace_patterns(graph_module):
    def pattern(x, y):
        x = torch.ops.aten.add.Tensor(x, y)
        x = torch.ops.aten.mul.Tensor(x, y)
        return x

    def replacement(x, y):
        return torch.ops.aten.sub.Tensor(x, y)

replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
    traced_module, pattern, replacement
)

子圖重寫器返回一個 ReplacedPatterns 列表

@dataclass
class ReplacedPatterns:
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]
    # List of nodes that were added into the graph
    replacements: List[Node]

注意

The nodes created by the subgraph rewriter will not have the metadata that
is normally in EXIR nodes (`stack_trace`, `val`, `nn_module_stack`).

級別 3

建立 Pass 的第三種方式是利用最基本的 PassBase。要建立一個 Pass,我們可以子類化它並實現包含 Pass 內容的 call 函式。此外,我們可以實現 requiresensures 函式,它們將在 call 函式之前和之後被呼叫。請注意,這些函式也可以在 ExportPass 中被覆蓋。要在圖模組上執行 Pass,我們可以將圖模組直接傳遞給該類的一個例項。

考慮以下示例

class ReplaceAddPass(PassBase):

    def __init__(self, replace_op):
        self.replace_op = replace_op

    def call(self, graph_module):
        for node in gm.graph.nodes:
            if node.op == "call_function" and node.target == torch.add:
                node.target = self.replace_op

    # Optional to implement, will be called before call()
    def requires(self, graph_module) -> None:
        for node in graph_module.graph.nodes:
            if node.op == "call_function" and node.target == torch.add:
                return
        raise ValueError("No torch.add ops!")

    # Optional to implement, will be called after call()
    def ensures(self, graph_module: torch.fx.GraphModule) -> None:
        pass

# To create a pass
replace_add_with_div = ReplaceAddPass(torch.div)
# To run a pass
replace_add_with_div(graph_module)

Pass Manager

PassManager 是一個用於在給定圖模組上執行多個 Pass 的類。初始化 PassManager 例項時,我們傳入要執行的 Pass 列表並設定一些標誌。要在圖模組上執行這一系列 Pass,我們可以將圖模組直接傳遞給 PassManager 例項。

一個示例

from executorch.exir.pass_manager import PassManager

pm = PassManager(
    passes=[replace_add_with_div, replace_div_with_mul],
    run_checks_after_each_pass=True,
    suppress_check_failures=False,
)
graph_module_out = pm(graph_module)

要新增一組在每個 Pass 執行後執行的常見檢查,我們可以呼叫接受可呼叫函式作為輸入的 set_checks(check: Callable) 函式。如果設定了 run_checks_after_each_pass 標誌,則在每個 Pass 在圖模組上執行後,將呼叫 check

一個示例

pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])

def check_div_target(graph_module):
    for node in graph_module.graph.nodes:
        if node.op == "call_function" and node.target != torch.div:
            raise ValueError("Target should be div!")

pm.add_checks(check_div_target)

pm(graph_module)    # raises ValueError after replace_div_with_mul pass

分割槽器

有一些常見的基於 FX 圖的分割槽器可用於對圖進行分割槽。但是,這些分割槽器不一定能生成符合 IR 規範的圖,因此在使用時請小心。

子圖匹配器

為了在圖中找到匹配特定模式的子圖,我們可以利用 FX 的SubgraphMatcher

類屬性

  • pattern (Graph):目標匹配模式。圖中的佔位符節點在匹配時將被視為萬用字元。

  • match_output (bool):如果為 True,模式圖中的輸出節點將被視為目標模式的一部分。如果為 False,匹配時將忽略輸出節點。

  • match_placeholder (bool):如果為 True,模式圖中的佔位符節點將被視為目標模式的一部分。如果為 False,佔位符節點將用作萬用字元。

  • remove_overlapping_matches (bool):如果為 True,在匹配重疊的情況下,只返回第一個匹配項。

  • ignore_literals (bool):如果為 True,將不檢查字面量是否相等,而是將其視為萬用字元。

考慮以下示例

from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

class LargeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight = torch.nn.Parameter(torch.ones(3, 3))
        self._bias = torch.nn.Parameter(torch.ones(3, 3))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias, x, self._weight)

large_model_graph = to_edge(export(LargeModel(), large_inputs)).exported_program().graph_module.graph

class PatternModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
        self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)

pattern_graph = to_edge(export(PatternModel(), pattern_inputs)).exported_program().graph_module.graph

subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)

match 函式返回一個 InternalMatch 列表

@dataclass
class InternalMatch():
    # Nodes from which the match was found
    anchors: List[Node]
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node] = field(default_factory=dict)
    # Nodes in target graph that are matched placeholder in pattern
    placeholder_nodes: List[Node] = field(default_factory=list)
    # Nodes in matched subgraph returned by output
    returning_nodes: List[Node] = field(default_factory=list)

基於能力的分割槽器

為了找到支援特定不變性的最大節點子圖,我們可以利用 FX 的CapabilityBasedPartitioner

類屬性

  • graph_module (torch.fx.GraphModule):我們要對其進行分割槽的圖模組。

  • operator_support (OperatorSupportBase):用於確定圖中某個節點是否在分割槽中受支援的物件。

  • allows_single_node_partition (bool):如果為 True,則允許形成單節點分割槽。

  • non_compute_ops (Optional[Sequence[str]]):一組被認為是“非計算”的運算元(例如 torch.ops.aten.view_operator.getitem),這樣分割槽器就不會建立僅包含這些非計算運算元的圖

  • allowed_single_node_partition_ops (Optional[Sequence[str]]):允許出現在單節點分割槽中的一組運算元。

OperatorSupportBase 類由分割槽器使用,以確定圖中某個特定節點是否屬於該分割槽。這是透過覆蓋 is_node_supported 函式來實現的。您可以使用 chain(如果任何 OperatorSupportBase 返回 False,則返回 False)和 any_chain(如果任何 OperatorSupportBase 返回 True,則返回 True)來鏈式組合多個 OperatorSuppportBase

考慮以下示例

from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

class AddMulOperatorSupport(OperatorSupportBase):
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        return node.op == "call_function" and node.target in [
            torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
        ]

capability_partitioner = CapabilityBasedPartitioner(
    graph_module,
    op_support,
)

# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()

如果您檢視基於能力的分割槽器,您可能還會發現一個 fuse_partition 函式,它將返回一個修改後的圖,其中分割槽作為子模組,並透過 call_module 節點在頂層圖中呼叫這些子模組。然而,這不符合 IR 規範,因為我們不允許 call_module 節點。

組合

我們還提供了一個組合輔助函式:generate_pattern_op_partitions

引數

  • graph_module (fx.GraphModule):我們要分割槽的模組

  • patterns (List[torch.fx.Graph]):以 torch.fx.Graph 形式表示的模式列表。這些圖可以透過 exir.capture 獲取的 GraphModule 的 graph 欄位獲得(推薦),或者透過符號跟蹤獲得(可能無法產生準確的 edge dialect 圖),或者透過手動建立圖模組獲得。

  • op_support (OperatorSupportBase):可以透過以下方式建立的 OperatorSupportBase

    • 直接子類化並實現 is_node_supported()

    • 獲取 create_op_support() 的結果

    • 獲取 create_pattern_support() 的結果

    • 使用 chain()any_chain() 鏈式組合多個 OperatorSupportBase 類

返回

  • 包含由給定 OperatorSupportBase 物件和給定模式圖的並集支援的節點的(最大可能子圖)分割槽列表。

源分割槽器

對於更復雜的用例,使用者希望基於更高級別的模組(torch.nn.Lineartorch.nn.functional.Linear)進行分割槽,這些模組現已被分解為其運算元(aten.permute, aten.addmm),我們有以下輔助函式

get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]

引數

  • graph:我們要分割槽的圖

  • wanted_sources:從該源分解而來的節點源列表。這可以是函式(例如 torch.nn.functional.linear)或葉模組型別(例如 torch.nn.Linear

返回

  • 將源(例如 torch.nn.modules.linear.Linear)對映到與從該型別模組展平的節點列表相對應的 SourcePartitions 列表的字典。

@dataclass
class SourcePartition():
    # Nodes in a particular partition
    nodes: List[Node]
    # Module type
    module_type: Type
    # Nodes in the graph that are needed as inputs to the partition
    input_nodes: List[Node] = field(default_factory=list)
    # Nodes in the partition that are being used by nodes outside of the partition
    output_nodes: List[Node] = field(default_factory=list)
    # Parameters that are being used
    params: List[str] = field(default_factory=list)

一個示例

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(3, 3)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(3, 5)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

inputs = (torch.randn(3, 3),)
edge_graph = to_edge(export(M(), inputs)).exported_program().graph_module.graph
print(edge_graph)
"""
graph():
    %arg0 : [#users=1] = placeholder[target=arg0]
    %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
    %permute_default : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0,), kwargs = {})
    %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
    %_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
    %permute_default_1 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0_1,), kwargs = {})
    %_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
    %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
    %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
    %permute_default_2 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant2,), kwargs = {})
    %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
    %addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
    return [addmm_default_2]
"""

module_partitions = get_source_partitions(edge_graph, [torch.nn.Linear, torch.nn.ReLU])
print(module_partitions)
"""
{<class 'torch.nn.modules.linear.Linear'>: [
    ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
    ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
    ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],

 <class 'torch.nn.modules.activation.ReLU'>: [
    ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
"""

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源