torch.export IR 規範¶
Export IR 是一種用於編譯器的中介表示 (IR),與 MLIR 和 TorchScript 類似。它專為表達 PyTorch 程式的語義而設計。Export IR 主要以簡化的操作列表來表示計算,對動態性(如控制流程)的支援有限。
若要建立 Export IR 圖形,可以使用前端透過追蹤特化機制完整擷取 PyTorch 程式。然後,產生的 Export IR 可以由後端進行最佳化和執行。這在今天可以透過 torch.export.export() 完成。
本文件將涵蓋的關鍵概念包括
- ExportedProgram:包含 Export IR 程式的資料結構 
- 圖形:由節點列表組成。 
- 節點:表示儲存在此節點上的操作、控制流程和中繼資料。 
- 值由節點產生和使用。 
- 類型與值和節點相關聯。 
- 值的記憶體配置和大小也有定義。 
什麼是 Export IR¶
Export IR 是一種基於圖形的 PyTorch 程式中介表示 (IR)。Export IR 是在 torch.fx.Graph 之上實現的。換句話說,**所有 Export IR 圖形也是有效的 FX 圖形**,如果使用標準 FX 語義解釋,則可以完整地解釋 Export IR。這意味著可以透過標準 FX 程式碼產生將匯出的圖形轉換為有效的 Python 程式。
本文件將主要重點說明 Export IR 在嚴謹性方面與 FX 的不同之處,而跳過與 FX 相似之處。
ExportedProgram¶
頂級 Export IR 建構是一個 torch.export.ExportedProgram 類別。它將 PyTorch 模型的計算圖(通常是 torch.nn.Module)與該模型使用的參數或權重捆綁在一起。
torch.export.ExportedProgram 類別的一些值得注意的屬性是
- graph_module(- torch.fx.GraphModule):包含 PyTorch 模型的扁平化計算圖的資料結構。可以直接透過 ExportedProgram.graph 訪問該圖形。
- graph_signature(- torch.export.ExportGraphSignature):圖形簽章,指定在圖形中使用和變更的參數和緩衝區名稱。它們不是將參數和緩衝區儲存為圖形的屬性,而是提升為圖形的輸入。graph_signature 用於追蹤有關這些參數和緩衝區的其他資訊。
- state_dict(- Dict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含參數和緩衝區的資料結構。
- range_constraints(- Dict[sympy.Symbol, RangeConstraint]):對於使用資料依賴行為匯出的程式,每個節點上的中繼資料將包含符號形狀(看起來像- s0、- i0)。此屬性將符號形狀對應到其上下範圍。
圖形¶
Export IR 圖形是以 DAG(有向無環圖)形式表示的 PyTorch 程式。此圖形中的每個節點表示一個特定計算或操作,而此圖形的邊緣則由節點之間的參考組成。
我們可以查看具有此結構描述的圖形
class Graph:
  nodes: List[Node]
實際上,匯出 IR 的圖形是透過 torch.fx.Graph Python 類別實現的。
匯出 IR 圖形包含下列節點(下一節將更詳細地描述節點)
- 0 個或多個 op 類型為 - placeholder的節點
- 0 個或多個 op 類型為 - call_function的節點
- 恰好 1 個 op 類型為 - output的節點
推論: 最小的有效圖形將只有一個節點。也就是說,節點永遠不會是空的。
定義: 圖形的 placeholder 節點集表示 GraphModule 圖形的**輸入**。圖形的 output 節點表示 GraphModule 圖形的**輸出**。
範例
from torch import nn
class MyModule(nn.Module):
    def forward(self, x, y):
      return x + y
mod = torch.export.export(MyModule())
print(mod.graph)
以上是圖形的文字表示形式,其中每一行都是一個節點。
節點¶
節點表示特定的計算或運算,並在 Python 中使用 torch.fx.Node 類別表示。節點之間的邊緣表示為透過 Node 類別的 args 屬性直接參考其他節點。使用相同的 FX 機制,我們可以表示計算圖形通常需要的下列運算,例如運算子呼叫、佔位符(又稱為輸入)、條件式和迴圈。
節點具有下列結構描述
class Node:
  name: str # name of node
  op_name: str  # type of operation
  # interpretation of the fields below depends on op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]
FX 文字格式
如上例所示,請注意每一行都具有此格式
%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})
此格式以緊湊的格式擷取 Node 類別中存在的所有內容,但 meta 除外。
具體而言
- <名稱> 是節點的名稱,如同其在 - node.name中顯示的名稱。
- <op_name> 是 - node.op欄位,其必須是下列其中之一:<call_function>、<placeholder>、<get_attr> 或 <output>。
- <目標> 是節點的目標,如同 - node.target。此欄位的含義取決於- op_name。
- args1、… args 4… 是 - node.args元組中列出的內容。如果清單中的值是- torch.fx.Node,則會特別以開頭的 % 表示。
例如,對 add 運算子的呼叫將顯示為
%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})
其中 %x、%y 是另外兩個名稱分別為 x 和 y 的節點。值得注意的是,字串 torch.op.aten.add.Tensor 表示實際儲存在目標欄位中的可呼叫物件,而不僅僅是其字串名稱。
此文字格式的最後一行是
return [add]
這是一個 op_name = output 的節點,表示我們正在傳回這個元素。
call_function¶
call_function 節點表示對運算子的呼叫。
定義
- 函數式: 如果可呼叫物件滿足下列所有需求,我們會說它是「函數式」 - 非突變:運算子不會改變其輸入的值(對於張量,這包括中資料和資料)。 
- 沒有副作用:運算子不會改變從外部可見的狀態,例如改變模組參數的值。 
 
- 運算子: 是一個具有預定義結構描述的函數式可呼叫物件。此類運算子的範例包括函數式 ATen 運算子。 
在 FX 中的表示形式
%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})
與標準 FX call_function 的差異
- 在 FX 圖形中,call_function 可以參考任何可呼叫物件,在匯出 IR 中,我們將其限制為僅限於 ATen 運算子、自訂運算子和控制流程運算子的選定子集。 
- 在匯出 IR 中,常數引數將嵌入圖形中。 
- 在 FX 圖形中,get_attr 節點可以表示讀取儲存在圖形模組中的任何屬性。但是,在匯出 IR 中,這僅限於讀取子模組,因為所有參數/緩衝區都將作為輸入傳遞給圖形模組。 
中繼資料¶
Node.meta 是一個附加到每個 FX 節點的 dict。但是,FX 規範沒有指定哪些中繼資料可以或將會在那裡。匯出 IR 提供了更強的合約,特別是所有 call_function 節點都將保證具有且僅具有下列中繼資料欄位
- node.meta["stack_trace"]是一個包含 Python 堆疊追蹤的字串,參考原始 Python 原始碼。堆疊追蹤的範例如下所示- File "my_module.py", line 19, in forward return x + dummy_helper(y) File "helper_utility.py", line 89, in dummy_helper return y + 1 
- node.meta["val"]描述執行運算的輸出。其類型可以是 <symint>、<FakeTensor>、- List[Union[FakeTensor, SymInt]]或- None。
- node.meta["nn_module_stack"]描述節點來源的- torch.nn.Module的「堆疊追蹤」,如果其來自- torch.nn.Module呼叫。例如,如果包含從- torch.nn.Sequential模組內部的- torch.nn.Linear模組呼叫的- addmmop 的節點,則- nn_module_stack將如下所示- {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
- node.meta["source_fn_stack"]包含在分解之前呼叫此節點的 torch 函數或葉子- torch.nn.Module類別。例如,包含來自- torch.nn.Linear模組呼叫的- addmmop 的節點將在其- source_fn中包含- torch.nn.Linear,而包含來自- torch.nn.functional.Linear模組呼叫的- addmmop 的節點將在其- source_fn中包含- torch.nn.functional.Linear。
placeholder¶
Placeholder 表示圖形的輸入。其語義與 FX 中的語義完全相同。Placeholder 節點必須是圖形節點清單中的前 N 個節點。N 可以是零。
在 FX 中的表示形式
%name = placeholder[target = name](args = ())
目標欄位是輸入的名稱字串。
如果 args 不是空的,則其大小應為 1,表示此輸入的預設值。
中繼資料
Placeholder 節點也具有 meta[‘val’],例如 call_function 節點。在這種情況下,val 欄位表示圖形預期會收到此輸入參數的輸入形狀/dtype。
output¶
輸出呼叫表示函數中的 return 陳述式;因此它會終止目前的圖形。只有一個輸出節點,而且它始終是圖形的最後一個節點。
在 FX 中的表示形式
output[](args = (%something, …))
這與 torch.fx 中的語義完全相同。args 表示要傳回的節點。
中繼資料
輸出節點與 call_function 節點具有相同的中繼資料。
get_attr¶
get_attr 節點表示從封裝的 torch.fx.GraphModule 讀取子模組。與來自 torch.fx.symbolic_trace() 的標準 FX 圖形不同,在標準 FX 圖形中,get_attr 節點用於從頂級 torch.fx.GraphModule 讀取屬性(例如參數和緩衝區),參數和緩衝區作為輸入傳遞給圖形模組,並儲存在頂級 torch.export.ExportedProgram 中。
在 FX 中的表示形式
%name = get_attr[target = name](args = ())
範例
請考慮下列模型
from functorch.experimental.control_flow import cond
def true_fn(x):
    return x.sin()
def false_fn(x):
    return x.cos()
def f(x, y):
    return cond(y, true_fn, false_fn, [x])
圖形
graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional
這一行 %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 讀取包含 sin 運算子的子模組 true_graph_0。
參考¶
SymInt¶
SymInt 是一個可以是字面整數或表示整數的符號的物件(在 Python 中由 sympy.Symbol 類別表示)。當 SymInt 是一個符號時,它描述了一個在編譯時圖形未知的整數類型變數,也就是說,它的值僅在運行時才知道。
FakeTensor¶
FakeTensor 是一個包含張量中繼資料的物件。可以將其視為具有以下中繼資料。
class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # This doesn't exist yet
FakeTensor 的 size 欄位是一個整數或 SymInt 的列表。如果存在 SymInt,則表示此張量具有動態形狀。如果存在整數,則假設張量將具有該確切的靜態形狀。TensorMeta 的秩從不动态。dtype 欄位表示該節點輸出的 dtype。Edge IR 中沒有隱式類型提升。FakeTensor 中沒有步幅。
換句話說
- 如果 node.target 中的運算子返回一個張量,則 - node.meta['val']是一個描述該張量的 FakeTensor。
- 如果 node.target 中的運算子返回一個 n 元組的張量,則 - node.meta['val']是一個描述每個張量的 n 元組的 FakeTensors。
- 如果 node.target 中的運算子返回一個在編譯時已知的整數/浮點數/純量,則 - node.meta['val']為 None。
- 如果 node.target 中的運算子返回一個在編譯時未知的整數/浮點數/純量,則 - node.meta['val']的類型為 SymInt。
舉例來說
- aten::add返回一個張量;因此它的規範將是一個 FakeTensor,其中包含此運算子返回的張量的 dtype 和大小。
- aten::sym_size返回一個整數;因此它的 val 將是一個 SymInt,因為它的值僅在運行時可用。
- max_pool2d_with_indexes返回一個 (Tensor, Tensor) 的元組;因此規範也將是一個 FakeTensor 物件的二元組,第一個 TensorMeta 描述返回值的第一個元素,依此類推。
Python 程式碼
def add_one(x):
  return torch.ops.aten(x, 1)
圖形
graph():
  %ph_0 : [#users=1] = placeholder[target=ph_0]
  %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]
FakeTensor
FakeTensor(dtype=torch.int, size=[2,], device=CPU)
Pytree 可用類型¶
我們定義了一種類型「Pytree 可用」,如果它是葉類型或包含其他 Pytree 可用類型的容器類型。
注意事項
pytree 的概念與 JAX 此處 文件中記錄的概念相同
以下類型定義為**葉類型**
| 類型 | 定義 | 
|---|---|
| 張量 | |
| 純量 | 來自 Python 的任何數值類型,包括整數類型、浮點數類型和零維張量。 | 
| int | Python int(在 C++ 中綁定為 int64_t) | 
| float | Python float(在 C++ 中綁定為 double) | 
| bool | Python bool | 
| str | Python 字串 | 
| ScalarType | |
| Layout | |
| MemoryFormat | |
| Device | 
以下類型定義為**容器類型**
| 類型 | 定義 | 
|---|---|
| Tuple | Python 元組 | 
| List | Python 列表 | 
| Dict | 具有純量鍵的 Python dict | 
| NamedTuple | Python namedtuple | 
| Dataclass | 必須透過 register_dataclass 註冊 | 
| 自訂類別 | 使用 _register_pytree_node 定義的任何自訂類別 |