自定義編譯器 Pass 和 分割槽器¶
Passes¶
Passes 大致可分為幾個維度
維度 A
建立一對多對映(例如,分解)
建立多對一對映(例如,融合)
維度 B
執行前向迭代(例如,形狀傳播)
執行後向迭代(例如,死程式碼消除)
維度 C
依賴於區域性節點資訊(例如,out 變體轉換)
依賴於全域性圖資訊(例如,記憶體規劃)
我們對這些用例發生頻率的預測是
A.1, B.1, C.1
A.2
B.2, C.2
級別 1¶
對於級別 1 的用例(建立一對多對映、執行前向迭代以及檢視區域性節點資訊),我們可以使用一個名為 ExportPass 的輔助類。這是一種基於直譯器的方式,我們執行每個節點並根據指定的轉換重新建立圖。這使我們能夠透過確保 Pass 期間建立的所有節點都符合 IR 規範來保留 IR 規範,包括確保堆疊跟蹤、FakeTensor 值和 torch.nn.Module 層次結構等元資料得到保留並根據所做的轉換進行更新。
要實現這個 Pass,我們可以建立一個 ExportPass 的子類並實現公開的函式。當使用一個圖模組呼叫時,它會執行該圖模組並建立一個包含 Pass 指定更改的新圖。這意味著傳入的圖模組必須能夠在 CPU 上執行,並且在 Pass 執行後會保持這個不變性。
一對一 Pass¶
一對一對映的一個示例是,如果我們想用另一個 op B 替換 op A,我們可以執行給定的 fx.GraphModule,並且每次看到 op A 時,返回 op B。
考慮以下示例
class ReplaceInPlaceReluWithOutOfPlaceReluPass(ExportPass):
"""
relu_ is the in-place version. Replace it with relu, which is the
out-of-place version
"""
def call_operator(self, op, args, kwargs, meta):
if op != torch.ops.aten.relu_.default:
return super().call_operator(op, args, kwargs, meta)
return super().call_operator(Op(torch.ops.aten.relu.default), args, kwargs, meta)
# To create a pass
replace_pass = ReplaceInPlaceReluWithOutOfPlaceReluPass()
# To run a pass
new_graph_module = replace_pass(graph_module).graph_module
`super().call_operator(op, args, kwargs, meta)` 呼叫會建立一個 call_function FX 節點,並返回使用給定引數執行該運算元的結果。
一對多 Pass¶
如果我們想進行一對多對映,例如用另外兩個 op B 和 C 替換 op A,那麼我們將對 super().call_operator 進行兩次呼叫,建立兩個 FX 節點,一個使用 op B,另一個使用 op C,並返回執行 op C 的結果。
例如
class ReplaceAddWithMulSub(ExportPass):
"""
Original:
def f(x, y):
return x + y
After pass:
def f(x, y):
z = x * y
return z - y
"""
def call_operator(self, op, args, kwargs, meta):
if op != torch.ops.aten.add.default:
return super().call_operator(op, args, kwargs, meta)
x, y = args
mul_res = super().call_operator(
torch.ops.aten.mul.default,
args,
{},
meta
)
return super().call_operator(
torch.ops.aten.sub.default,
(mul_res, y),
{},
meta
)
一對無 Pass¶
如果我們想移除一個 op,我們可以直接返回傳遞給函式的值
class RemoveDetachPass(ExportPass):
def call_operator(self, op, args, kwargs, meta):
if op not in (
torch.ops.aten.detach.default,
torch.ops.aten.detach_copy.default,
):
return super().call_operator(op, args, kwargs, meta)
assert len(args) == 1
return args[0]
利用區域性資訊¶
利用區域性節點資訊的一個示例是,如果我們想將圖中的所有標量轉換為張量,我們可以執行給定的 fx.GraphModule,對於包含標量的每個引數,我們將其轉換為張量。它可能看起來像這樣
def args_map(op, fn, args, kwargs):
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
args = list(args)
kwargs = kwargs.copy()
# Update the argument based on the function passed
def update(key, args, schema):
args[key] = fn(args[key], schema)
# Update each argument in the schema
for i, schema in enumerate(self.op._schema.arguments):
if schema.name in kwargs:
update(schema.name, kwargs, schema)
elif not schema.kwarg_only and i < len(args):
update(i, args, schema)
class ScalarToTensorPass(ExportPass):
def call_operator(self, op, args, kwargs):
def try_coerce(value, arg):
return (
torch.tensor(value)
if isinstance(value, (float, int, bool))
and type(arg.type) == torch.TensorType
else value
)
args, kwargs = args_map(op, try_coerce, args, kwargs)
return super().call_operator(op, args, kwargs)
級別 2¶
為了建立多對一對映,我們可以利用 FX 的子圖重寫器。給定一個 pattern,它會建立一個與該模式匹配的運算元子圖,然後將每個匹配的子圖替換為 replacement。
注意
This is an inplace operation.
`pattern` 和 replacement 輸入必須是可呼叫函式,它們使用與您要匹配的 EXIR 圖中使用的相同運算元 (ATen 運算元) 編寫,以便子圖重寫器可以在圖中找到正確的模式。pattern/replacement 可呼叫函式的輸入將被視為萬用字元。
考慮以下示例
from torch.fx import subgraph_rewriter
def replace_patterns(graph_module):
def pattern(x, y):
x = torch.ops.aten.add.Tensor(x, y)
x = torch.ops.aten.mul.Tensor(x, y)
return x
def replacement(x, y):
return torch.ops.aten.sub.Tensor(x, y)
replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
traced_module, pattern, replacement
)
子圖重寫器返回一個 ReplacedPatterns 列表
@dataclass
class ReplacedPatterns:
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
# List of nodes that were added into the graph
replacements: List[Node]
注意
The nodes created by the subgraph rewriter will not have the metadata that
is normally in EXIR nodes (`stack_trace`, `val`, `nn_module_stack`).
級別 3¶
建立 Pass 的第三種方式是利用最基本的 PassBase。要建立一個 Pass,我們可以子類化它並實現包含 Pass 內容的 call 函式。此外,我們可以實現 requires 和 ensures 函式,它們將在 call 函式之前和之後被呼叫。請注意,這些函式也可以在 ExportPass 中被覆蓋。要在圖模組上執行 Pass,我們可以將圖模組直接傳遞給該類的一個例項。
考慮以下示例
class ReplaceAddPass(PassBase):
def __init__(self, replace_op):
self.replace_op = replace_op
def call(self, graph_module):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.add:
node.target = self.replace_op
# Optional to implement, will be called before call()
def requires(self, graph_module) -> None:
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == torch.add:
return
raise ValueError("No torch.add ops!")
# Optional to implement, will be called after call()
def ensures(self, graph_module: torch.fx.GraphModule) -> None:
pass
# To create a pass
replace_add_with_div = ReplaceAddPass(torch.div)
# To run a pass
replace_add_with_div(graph_module)
Pass Manager¶
PassManager 是一個用於在給定圖模組上執行多個 Pass 的類。初始化 PassManager 例項時,我們傳入要執行的 Pass 列表並設定一些標誌。要在圖模組上執行這一系列 Pass,我們可以將圖模組直接傳遞給 PassManager 例項。
一個示例
from executorch.exir.pass_manager import PassManager
pm = PassManager(
passes=[replace_add_with_div, replace_div_with_mul],
run_checks_after_each_pass=True,
suppress_check_failures=False,
)
graph_module_out = pm(graph_module)
要新增一組在每個 Pass 執行後執行的常見檢查,我們可以呼叫接受可呼叫函式作為輸入的 set_checks(check: Callable) 函式。如果設定了 run_checks_after_each_pass 標誌,則在每個 Pass 在圖模組上執行後,將呼叫 check。
一個示例
pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])
def check_div_target(graph_module):
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target != torch.div:
raise ValueError("Target should be div!")
pm.add_checks(check_div_target)
pm(graph_module) # raises ValueError after replace_div_with_mul pass
分割槽器¶
有一些常見的基於 FX 圖的分割槽器可用於對圖進行分割槽。但是,這些分割槽器不一定能生成符合 IR 規範的圖,因此在使用時請小心。
子圖匹配器¶
為了在圖中找到匹配特定模式的子圖,我們可以利用 FX 的SubgraphMatcher。
類屬性
pattern (Graph):目標匹配模式。圖中的佔位符節點在匹配時將被視為萬用字元。match_output (bool):如果為 True,模式圖中的輸出節點將被視為目標模式的一部分。如果為 False,匹配時將忽略輸出節點。match_placeholder (bool):如果為 True,模式圖中的佔位符節點將被視為目標模式的一部分。如果為 False,佔位符節點將用作萬用字元。remove_overlapping_matches (bool):如果為 True,在匹配重疊的情況下,只返回第一個匹配項。ignore_literals (bool):如果為 True,將不檢查字面量是否相等,而是將其視為萬用字元。
考慮以下示例
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
class LargeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight = torch.nn.Parameter(torch.ones(3, 3))
self._bias = torch.nn.Parameter(torch.ones(3, 3))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias, x, self._weight)
large_model_graph = to_edge(export(LargeModel(), large_inputs)).exported_program().graph_module.graph
class PatternModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
pattern_graph = to_edge(export(PatternModel(), pattern_inputs)).exported_program().graph_module.graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)
match 函式返回一個 InternalMatch 列表
@dataclass
class InternalMatch():
# Nodes from which the match was found
anchors: List[Node]
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node] = field(default_factory=dict)
# Nodes in target graph that are matched placeholder in pattern
placeholder_nodes: List[Node] = field(default_factory=list)
# Nodes in matched subgraph returned by output
returning_nodes: List[Node] = field(default_factory=list)
基於能力的分割槽器¶
為了找到支援特定不變性的最大節點子圖,我們可以利用 FX 的CapabilityBasedPartitioner。
類屬性
graph_module (torch.fx.GraphModule):我們要對其進行分割槽的圖模組。operator_support (OperatorSupportBase):用於確定圖中某個節點是否在分割槽中受支援的物件。allows_single_node_partition (bool):如果為 True,則允許形成單節點分割槽。non_compute_ops (Optional[Sequence[str]]):一組被認為是“非計算”的運算元(例如torch.ops.aten.view和_operator.getitem),這樣分割槽器就不會建立僅包含這些非計算運算元的圖allowed_single_node_partition_ops (Optional[Sequence[str]]):允許出現在單節點分割槽中的一組運算元。
OperatorSupportBase 類由分割槽器使用,以確定圖中某個特定節點是否屬於該分割槽。這是透過覆蓋 is_node_supported 函式來實現的。您可以使用 chain(如果任何 OperatorSupportBase 返回 False,則返回 False)和 any_chain(如果任何 OperatorSupportBase 返回 True,則返回 True)來鏈式組合多個 OperatorSuppportBase。
考慮以下示例
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
class AddMulOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
]
capability_partitioner = CapabilityBasedPartitioner(
graph_module,
op_support,
)
# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()
如果您檢視基於能力的分割槽器,您可能還會發現一個 fuse_partition 函式,它將返回一個修改後的圖,其中分割槽作為子模組,並透過 call_module 節點在頂層圖中呼叫這些子模組。然而,這不符合 IR 規範,因為我們不允許 call_module 節點。
組合¶
我們還提供了一個組合輔助函式:generate_pattern_op_partitions
引數
graph_module (fx.GraphModule):我們要分割槽的模組patterns (List[torch.fx.Graph]):以 torch.fx.Graph 形式表示的模式列表。這些圖可以透過 exir.capture 獲取的 GraphModule 的graph欄位獲得(推薦),或者透過符號跟蹤獲得(可能無法產生準確的 edge dialect 圖),或者透過手動建立圖模組獲得。op_support (OperatorSupportBase):可以透過以下方式建立的 OperatorSupportBase直接子類化並實現
is_node_supported()獲取
create_op_support()的結果獲取
create_pattern_support()的結果使用
chain()或any_chain()鏈式組合多個 OperatorSupportBase 類
返回
包含由給定 OperatorSupportBase 物件和給定模式圖的並集支援的節點的(最大可能子圖)分割槽列表。
源分割槽器¶
對於更復雜的用例,使用者希望基於更高級別的模組(torch.nn.Linear 或 torch.nn.functional.Linear)進行分割槽,這些模組現已被分解為其運算元(aten.permute, aten.addmm),我們有以下輔助函式
get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]
引數
graph:我們要分割槽的圖wanted_sources:從該源分解而來的節點源列表。這可以是函式(例如torch.nn.functional.linear)或葉模組型別(例如torch.nn.Linear)
返回
將源(例如
torch.nn.modules.linear.Linear)對映到與從該型別模組展平的節點列表相對應的SourcePartitions列表的字典。
@dataclass
class SourcePartition():
# Nodes in a particular partition
nodes: List[Node]
# Module type
module_type: Type
# Nodes in the graph that are needed as inputs to the partition
input_nodes: List[Node] = field(default_factory=list)
# Nodes in the partition that are being used by nodes outside of the partition
output_nodes: List[Node] = field(default_factory=list)
# Parameters that are being used
params: List[str] = field(default_factory=list)
一個示例
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(3, 3)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(3, 5)
def forward(self, x):
x = self.linear1(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
inputs = (torch.randn(3, 3),)
edge_graph = to_edge(export(M(), inputs)).exported_program().graph_module.graph
print(edge_graph)
"""
graph():
%arg0 : [#users=1] = placeholder[target=arg0]
%_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
%permute_default : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0,), kwargs = {})
%_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
%_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
%permute_default_1 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0_1,), kwargs = {})
%_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
%relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
%_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
%permute_default_2 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant2,), kwargs = {})
%_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
%addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
return [addmm_default_2]
"""
module_partitions = get_source_partitions(edge_graph, [torch.nn.Linear, torch.nn.ReLU])
print(module_partitions)
"""
{<class 'torch.nn.modules.linear.Linear'>: [
ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],
<class 'torch.nn.modules.activation.ReLU'>: [
ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
"""