torch.export IR 規範¶
Export IR 是用於編譯器的中間表示 (IR),與 MLIR 和 TorchScript 類似。它專門設計用於表達 PyTorch 程式的語義。Export IR 主要透過精簡的運算列表表示計算,對控制流等動態特性的支援有限。
為了建立 Export IR 圖,可以使用前端透過跟蹤特化機制可靠地捕獲 PyTorch 程式。生成的 Export IR 可以由後端進行最佳化和執行。目前可以透過 torch.export.export() 實現這一點。
本文件將涵蓋的關鍵概念包括
ExportedProgram:包含 Export IR 程式的資料結構
Graph:由節點列表組成。
Nodes:表示運算、控制流以及儲存在此節點上的元資料。
值由節點產生和消費。
型別與值和節點相關聯。
還定義了值的大小和記憶體佈局。
什麼是 Export IR¶
Export IR 是 PyTorch 程式的基於圖的中間表示 IR。Export IR 是在 torch.fx.Graph 之上實現的。換句話說,所有 Export IR 圖都是有效的 FX 圖,如果使用標準的 FX 語義解釋,Export IR 可以被可靠地解釋。一個推論是,匯出的圖可以透過標準的 FX 程式碼生成轉換為有效的 Python 程式。
本文件將主要重點介紹 Export IR 在嚴格性方面與 FX 的不同之處,同時跳過與 FX 相似的部分。
ExportedProgram¶
頂層 Export IR 構造是 torch.export.ExportedProgram 類。它將 PyTorch 模型的計算圖(通常是 torch.nn.Module)與該模型使用的引數或權重捆綁在一起。
torch.export.ExportedProgram 類的一些值得注意的屬性包括
graph_module(torch.fx.GraphModule):包含 PyTorch 模型展平計算圖的資料結構。可以透過 ExportedProgram.graph 直接訪問該圖。graph_signature(torch.export.ExportGraphSignature):圖簽名,指定圖中使用和修改的引數和緩衝區名稱。引數和緩衝區不作為圖的屬性儲存,而是作為圖的輸入被提升。graph_signature 用於跟蹤這些引數和緩衝區的附加資訊。state_dict(Dict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含引數和緩衝區的資料結構。range_constraints(Dict[sympy.Symbol, RangeConstraint]):對於匯出時具有資料依賴行為的程式,每個節點上的元資料將包含符號形狀(例如s0、i0)。此屬性將符號形狀對映到其下限/上限範圍。
Graph¶
Export IR Graph 是以 DAG(有向無環圖)形式表示的 PyTorch 程式。此圖中的每個節點表示特定的計算或運算,圖的邊由節點之間的引用組成。
我們可以將 Graph 視為具有以下模式
class Graph:
nodes: List[Node]
實際上,Export IR 的圖透過 torch.fx.Graph Python 類實現。
Export IR 圖包含以下節點(節點將在下一節中更詳細地描述)
0 個或多個 op 型別為
placeholder的節點0 個或多個 op 型別為
call_function的節點恰好 1 個 op 型別為
output的節點
推論:最小的有效 Graph 將包含一個節點。即,nodes 永遠不為空。
定義: Graph 的 placeholder 節點集合表示 GraphModule 的 Graph 的輸入。output 節點表示 GraphModule 的 Graph 的輸出。
示例
import torch
from torch import nn
class MyModule(nn.Module):
def forward(self, x, y):
return x + y
example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
return (add,)
上圖是 Graph 的文字表示,每行是一個節點。
節點¶
Node 表示特定的計算或運算,在 Python 中使用 torch.fx.Node 類表示。節點之間的邊透過 Node 類的 args 屬性表示為對其他節點的直接引用。使用相同的 FX 機制,我們可以表示計算圖通常需要的以下運算,例如運算元呼叫、placeholder(即輸入)、條件語句和迴圈。
Node 具有以下模式
class Node:
name: str # name of node
op_name: str # type of operation
# interpretation of the fields below depends on op_name
target: [str|Callable]
args: List[object]
kwargs: Dict[str, object]
meta: Dict[str, object]
FX 文字格式
如上例所示,請注意每行都具有此格式
%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})
此格式以緊湊形式捕獲 Node 類中的所有內容,但 meta 除外。
具體來說
<name> 是節點的名稱,如它在
node.name中顯示的那樣。<op_name> 是
node.op欄位,必須是以下之一:<call_function>、<placeholder>、<get_attr> 或 <output>。<target> 是節點的 target,如
node.target。此欄位的含義取決於op_name。args1, … args 4… 是
node.args元組中列出的內容。如果列表中的值是torch.fx.Node,則會特別用前導 %. 指示。
例如,對 add 運算元的呼叫將顯示為
%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})
其中 %x、%y 是另外兩個名稱為 x 和 y 的 Node。值得注意的是,字串 torch.op.aten.add.Tensor 表示實際儲存在 target 欄位中的可呼叫物件,而不僅僅是其字串名稱。
此文字格式的最後一行是
return [add]
它是一個 op_name = output 的 Node,表示我們正在返回此元素。
call_function¶
一個 call_function 節點表示對運算元的呼叫。
定義
函式式:如果一個可呼叫物件滿足以下所有要求,我們稱之為“函式式”
非變異:運算元不修改其輸入的值(對於張量,這包括元資料和資料)。
無副作用:運算元不修改外部可見的狀態,例如更改模組引數的值。
運算元:是具有預定義模式的函式式可呼叫物件。此類運算元示例包括函式式 ATen 運算元。
在 FX 中的表示
%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})
與普通 FX call_function 的區別
在 FX 圖中,一個 call_function 可以引用任何可呼叫物件,而在 Export IR 中,我們將其限制為僅一部分選定的 ATen 運算元、自定義運算元和控制流運算元。
在 Export IR 中,常量引數將嵌入到圖中。
在 FX 圖中,get_attr 節點可以表示讀取圖模組中儲存的任何屬性。然而,在 Export IR 中,這被限制為僅讀取子模組,因為所有引數/緩衝區都將作為輸入傳遞給圖模組。
元資料¶
Node.meta 是附加到每個 FX 節點的字典。然而,FX 規範並未指定其中可以或將包含哪些元資料。Export IR 提供了更強的契約,特別是所有 call_function 節點都保證具有且僅具有以下元資料欄位
node.meta["stack_trace"]是包含引用原始 Python 原始碼的 Python 堆疊跟蹤的字串。示例堆疊跟蹤如下所示File "my_module.py", line 19, in forward return x + dummy_helper(y) File "helper_utility.py", line 89, in dummy_helper return y + 1
node.meta["val"]描述了執行運算的輸出。它可以是 <symint>、<FakeTensor>、List[Union[FakeTensor, SymInt]]或None型別。node.meta["nn_module_stack"]描述了節點來源的torch.nn.Module的“堆疊跟蹤”,如果它來自torch.nn.Module呼叫的話。例如,如果包含來自torch.nn.Sequential模組內部的torch.nn.Linear模組呼叫的addmm運算元的節點,則nn_module_stack將類似於{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}node.meta["source_fn_stack"]包含此節點在分解之前從哪個 torch 函式或葉子torch.nn.Module類呼叫。例如,包含來自torch.nn.Linear模組呼叫的addmm運算元的節點將在其source_fn中包含torch.nn.Linear,而包含來自torch.nn.functional.Linear模組呼叫的addmm運算元的節點將在其source_fn中包含torch.nn.functional.Linear。
placeholder¶
Placeholder 表示圖的輸入。其語義與 FX 中完全相同。Placeholder 節點必須是圖節點列表中的前 N 個節點。N 可以為零。
在 FX 中的表示
%name = placeholder[target = name](args = ())
target 欄位是一個字串,它是輸入的名稱。
args,如果非空,其大小應為 1,表示此輸入的預設值。
元資料
Placeholder 節點也具有 meta[‘val’],就像 call_function 節點一樣。在這種情況下,val 欄位表示圖期望為此輸入引數接收的輸入形狀/dtype。
output¶
一個 output 呼叫表示函式中的返回語句;因此它終止當前圖。只有一個 output 節點,並且它始終是圖的最後一個節點。
在 FX 中的表示
output[](args = (%something, …))
這與 torch.fx 中的語義完全相同。args 表示要返回的節點。
元資料
Output 節點具有與 call_function 節點相同的元資料。
get_attr¶
get_attr 節點表示從封裝的 torch.fx.GraphModule 讀取子模組。與透過 torch.fx.symbolic_trace() 生成的普通 FX 圖不同,在普通 FX 圖中,get_attr 節點用於從頂層 torch.fx.GraphModule 讀取引數和緩衝區等屬性,而在 Export IR 中,引數和緩衝區作為輸入傳遞給圖模組,並存儲在頂層 torch.export.ExportedProgram 中。
在 FX 中的表示
%name = get_attr[target = name](args = ())
示例
考慮以下模型
from functorch.experimental.control_flow import cond
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
Graph
graph():
%x_1 : [num_users=1] = placeholder[target=x_1]
%y_1 : [num_users=1] = placeholder[target=y_1]
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
return conditional
行 %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 讀取子模組 true_graph_0,其中包含 sin 運算元。
參考¶
SymInt¶
SymInt 是一個可以表示字面整數或代表整數(在 Python 中由 sympy.Symbol 類表示)的符號的物件。當 SymInt 是一個符號時,它描述了一個在編譯時未知、其值僅在執行時已知的整數型別變數。
FakeTensor¶
FakeTensor 是包含張量元資料的物件。可以將其視為具有以下元資料。
class FakeTensor:
size: List[SymInt]
dtype: torch.dtype
device: torch.device
dim_order: List[int] # This doesn't exist yet
FakeTensor 的 size 欄位是整數或 SymInt 的列表。如果存在 SymInt,則表示此張量具有動態形狀。如果存在整數,則假定該張量將具有該精確的靜態形狀。TensorMeta 的秩永遠不是動態的。dtype 欄位表示該節點輸出的 dtype。Edge IR 中沒有隱式型別提升。FakeTensor 中沒有 strides。
換句話說
如果 node.target 中的運算元返回 Tensor,則
node.meta['val']是描述該張量的 FakeTensor。如果 node.target 中的運算元返回一個由 Tensor 組成的 n 元組,則
node.meta['val']是描述每個張量的 FakeTensor 組成的 n 元組。如果 node.target 中的運算元返回在編譯時已知的 int/float/標量,則
node.meta['val']為 None。如果 node.target 中的運算元返回在編譯時未知的 int/float/標量,則
node.meta['val']的型別為 SymInt。
例如
aten::add返回一個 Tensor;因此其 spec 將是 FakeTensor,包含此運算元返回的張量的 dtype 和 size。aten::sym_size返回一個整數;因此其 val 將是 SymInt,因為其值僅在執行時可用。max_pool2d_with_indexes返回一個 (Tensor, Tensor) 元組;因此 spec 也將是由 FakeTensor 物件組成的 2 元組,第一個 TensorMeta 描述返回值的第一個元素,依此類推。
Python 程式碼
def add_one(x):
return torch.ops.aten(x, 1)
Graph
graph():
%ph_0 : [#users=1] = placeholder[target=ph_0]
%add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
return [add_tensor]
FakeTensor
FakeTensor(dtype=torch.int, size=[2,], device=CPU)
Pytree-able 型別¶
我們定義一種型別為“Pytree-able”,如果它要麼是葉子型別,要麼是包含其他 Pytree-able 型別的容器型別。
注意
pytree 的概念與此處為 JAX 記錄的概念相同
以下型別定義為葉子型別
型別 |
定義 |
|---|---|
Tensor |
|
標量 |
Python 中的任何數值型別,包括整數型別、浮點型別和零維張量。 |
int |
Python int(在 C++ 中繫結為 int64_t) |
float |
Python float(在 C++ 中繫結為 double) |
bool |
Python bool |
str |
Python 字串 |
ScalarType |
|
Layout |
|
MemoryFormat |
|
裝置 |
以下型別定義為容器型別
型別 |
定義 |
|---|---|
元組 |
Python 元組 |
列表 |
Python 列表 |
字典 |
鍵為 Scalar 的 Python 字典 |
命名元組 |
Python namedtuple |
Dataclass |
必須透過 register_dataclass 註冊 |
自定義類 |
使用 _register_pytree_node 定義的任何自定義類 |