捷徑

torch.export IR 規範

Export IR 是一種用於編譯器的中介表示 (IR),與 MLIR 和 TorchScript 類似。它專為表達 PyTorch 程式的語義而設計。Export IR 主要以簡化的操作列表來表示計算,對動態性(如控制流程)的支援有限。

若要建立 Export IR 圖形,可以使用前端透過追蹤特化機制完整擷取 PyTorch 程式。然後,產生的 Export IR 可以由後端進行最佳化和執行。這在今天可以透過 torch.export.export() 完成。

本文件將涵蓋的關鍵概念包括

  • ExportedProgram:包含 Export IR 程式的資料結構

  • 圖形:由節點列表組成。

  • 節點:表示儲存在此節點上的操作、控制流程和中繼資料。

  • 值由節點產生和使用。

  • 類型與值和節點相關聯。

  • 值的記憶體配置和大小也有定義。

假設

本文件假設讀者充分熟悉 PyTorch,特別是 torch.fx 及其相關工具。因此,它將不再描述 torch.fx 文件和文件中出現的內容。

什麼是 Export IR

Export IR 是一種基於圖形的 PyTorch 程式中介表示 (IR)。Export IR 是在 torch.fx.Graph 之上實現的。換句話說,**所有 Export IR 圖形也是有效的 FX 圖形**,如果使用標準 FX 語義解釋,則可以完整地解釋 Export IR。這意味著可以透過標準 FX 程式碼產生將匯出的圖形轉換為有效的 Python 程式。

本文件將主要重點說明 Export IR 在嚴謹性方面與 FX 的不同之處,而跳過與 FX 相似之處。

ExportedProgram

頂級 Export IR 建構是一個 torch.export.ExportedProgram 類別。它將 PyTorch 模型的計算圖(通常是 torch.nn.Module)與該模型使用的參數或權重捆綁在一起。

torch.export.ExportedProgram 類別的一些值得注意的屬性是

  • graph_module (torch.fx.GraphModule):包含 PyTorch 模型的扁平化計算圖的資料結構。可以直接透過 ExportedProgram.graph 訪問該圖形。

  • graph_signature (torch.export.ExportGraphSignature):圖形簽章,指定在圖形中使用和變更的參數和緩衝區名稱。它們不是將參數和緩衝區儲存為圖形的屬性,而是提升為圖形的輸入。graph_signature 用於追蹤有關這些參數和緩衝區的其他資訊。

  • state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含參數和緩衝區的資料結構。

  • range_constraints (Dict[sympy.Symbol, RangeConstraint]):對於使用資料依賴行為匯出的程式,每個節點上的中繼資料將包含符號形狀(看起來像 s0i0)。此屬性將符號形狀對應到其上下範圍。

圖形

Export IR 圖形是以 DAG(有向無環圖)形式表示的 PyTorch 程式。此圖形中的每個節點表示一個特定計算或操作,而此圖形的邊緣則由節點之間的參考組成。

我們可以查看具有此結構描述的圖形

class Graph:
  nodes: List[Node]

實際上,匯出 IR 的圖形是透過 torch.fx.Graph Python 類別實現的。

匯出 IR 圖形包含下列節點(下一節將更詳細地描述節點)

  • 0 個或多個 op 類型為 placeholder 的節點

  • 0 個或多個 op 類型為 call_function 的節點

  • 恰好 1 個 op 類型為 output 的節點

推論: 最小的有效圖形將只有一個節點。也就是說,節點永遠不會是空的。

定義: 圖形的 placeholder 節點集表示 GraphModule 圖形的**輸入**。圖形的 output 節點表示 GraphModule 圖形的**輸出**。

範例

from torch import nn

class MyModule(nn.Module):

    def forward(self, x, y):
      return x + y

mod = torch.export.export(MyModule())
print(mod.graph)

以上是圖形的文字表示形式,其中每一行都是一個節點。

節點

節點表示特定的計算或運算,並在 Python 中使用 torch.fx.Node 類別表示。節點之間的邊緣表示為透過 Node 類別的 args 屬性直接參考其他節點。使用相同的 FX 機制,我們可以表示計算圖形通常需要的下列運算,例如運算子呼叫、佔位符(又稱為輸入)、條件式和迴圈。

節點具有下列結構描述

class Node:
  name: str # name of node
  op_name: str  # type of operation

  # interpretation of the fields below depends on op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]

FX 文字格式

如上例所示,請注意每一行都具有此格式

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

此格式以緊湊的格式擷取 Node 類別中存在的所有內容,但 meta 除外。

具體而言

  • <名稱> 是節點的名稱,如同其在 node.name 中顯示的名稱。

  • <op_name>node.op 欄位,其必須是下列其中之一:<call_function><placeholder><get_attr><output>

  • <目標> 是節點的目標,如同 node.target。此欄位的含義取決於 op_name

  • args1、… args 4…node.args 元組中列出的內容。如果清單中的值是 torch.fx.Node,則會特別以開頭的 % 表示。

例如,對 add 運算子的呼叫將顯示為

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中 %x%y 是另外兩個名稱分別為 x 和 y 的節點。值得注意的是,字串 torch.op.aten.add.Tensor 表示實際儲存在目標欄位中的可呼叫物件,而不僅僅是其字串名稱。

此文字格式的最後一行是

return [add]

這是一個 op_name = output 的節點,表示我們正在傳回這個元素。

call_function

call_function 節點表示對運算子的呼叫。

定義

  • 函數式: 如果可呼叫物件滿足下列所有需求,我們會說它是「函數式」

    • 非突變:運算子不會改變其輸入的值(對於張量,這包括中資料和資料)。

    • 沒有副作用:運算子不會改變從外部可見的狀態,例如改變模組參數的值。

  • 運算子: 是一個具有預定義結構描述的函數式可呼叫物件。此類運算子的範例包括函數式 ATen 運算子。

在 FX 中的表示形式

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

與標準 FX call_function 的差異

  1. 在 FX 圖形中,call_function 可以參考任何可呼叫物件,在匯出 IR 中,我們將其限制為僅限於 ATen 運算子、自訂運算子和控制流程運算子的選定子集。

  2. 在匯出 IR 中,常數引數將嵌入圖形中。

  3. 在 FX 圖形中,get_attr 節點可以表示讀取儲存在圖形模組中的任何屬性。但是,在匯出 IR 中,這僅限於讀取子模組,因為所有參數/緩衝區都將作為輸入傳遞給圖形模組。

中繼資料

Node.meta 是一個附加到每個 FX 節點的 dict。但是,FX 規範沒有指定哪些中繼資料可以或將會在那裡。匯出 IR 提供了更強的合約,特別是所有 call_function 節點都將保證具有且僅具有下列中繼資料欄位

  • node.meta["stack_trace"] 是一個包含 Python 堆疊追蹤的字串,參考原始 Python 原始碼。堆疊追蹤的範例如下所示

    File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
    File "helper_utility.py", line 89, in dummy_helper
    return y + 1
    
  • node.meta["val"] 描述執行運算的輸出。其類型可以是 <symint><FakeTensor>List[Union[FakeTensor, SymInt]]None

  • node.meta["nn_module_stack"] 描述節點來源的 torch.nn.Module 的「堆疊追蹤」,如果其來自 torch.nn.Module 呼叫。例如,如果包含從 torch.nn.Sequential 模組內部的 torch.nn.Linear 模組呼叫的 addmm op 的節點,則 nn_module_stack 將如下所示

    {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
    
  • node.meta["source_fn_stack"] 包含在分解之前呼叫此節點的 torch 函數或葉子 torch.nn.Module 類別。例如,包含來自 torch.nn.Linear 模組呼叫的 addmm op 的節點將在其 source_fn 中包含 torch.nn.Linear,而包含來自 torch.nn.functional.Linear 模組呼叫的 addmm op 的節點將在其 source_fn 中包含 torch.nn.functional.Linear

placeholder

Placeholder 表示圖形的輸入。其語義與 FX 中的語義完全相同。Placeholder 節點必須是圖形節點清單中的前 N 個節點。N 可以是零。

在 FX 中的表示形式

%name = placeholder[target = name](args = ())

目標欄位是輸入的名稱字串。

如果 args 不是空的,則其大小應為 1,表示此輸入的預設值。

中繼資料

Placeholder 節點也具有 meta[‘val’],例如 call_function 節點。在這種情況下,val 欄位表示圖形預期會收到此輸入參數的輸入形狀/dtype。

output

輸出呼叫表示函數中的 return 陳述式;因此它會終止目前的圖形。只有一個輸出節點,而且它始終是圖形的最後一個節點。

在 FX 中的表示形式

output[](args = (%something, …))

這與 torch.fx 中的語義完全相同。args 表示要傳回的節點。

中繼資料

輸出節點與 call_function 節點具有相同的中繼資料。

get_attr

get_attr 節點表示從封裝的 torch.fx.GraphModule 讀取子模組。與來自 torch.fx.symbolic_trace() 的標準 FX 圖形不同,在標準 FX 圖形中,get_attr 節點用於從頂級 torch.fx.GraphModule 讀取屬性(例如參數和緩衝區),參數和緩衝區作為輸入傳遞給圖形模組,並儲存在頂級 torch.export.ExportedProgram 中。

在 FX 中的表示形式

%name = get_attr[target = name](args = ())

範例

請考慮下列模型

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

圖形

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

這一行 %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 讀取包含 sin 運算子的子模組 true_graph_0

參考

SymInt

SymInt 是一個可以是字面整數或表示整數的符號的物件(在 Python 中由 sympy.Symbol 類別表示)。當 SymInt 是一個符號時,它描述了一個在編譯時圖形未知的整數類型變數,也就是說,它的值僅在運行時才知道。

FakeTensor

FakeTensor 是一個包含張量中繼資料的物件。可以將其視為具有以下中繼資料。

class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # This doesn't exist yet

FakeTensor 的 size 欄位是一個整數或 SymInt 的列表。如果存在 SymInt,則表示此張量具有動態形狀。如果存在整數,則假設張量將具有該確切的靜態形狀。TensorMeta 的秩從不动态。dtype 欄位表示該節點輸出的 dtype。Edge IR 中沒有隱式類型提升。FakeTensor 中沒有步幅。

換句話說

  • 如果 node.target 中的運算子返回一個張量,則 node.meta['val'] 是一個描述該張量的 FakeTensor。

  • 如果 node.target 中的運算子返回一個 n 元組的張量,則 node.meta['val'] 是一個描述每個張量的 n 元組的 FakeTensors。

  • 如果 node.target 中的運算子返回一個在編譯時已知的整數/浮點數/純量,則 node.meta['val'] 為 None。

  • 如果 node.target 中的運算子返回一個在編譯時未知的整數/浮點數/純量,則 node.meta['val'] 的類型為 SymInt。

舉例來說

  • aten::add 返回一個張量;因此它的規範將是一個 FakeTensor,其中包含此運算子返回的張量的 dtype 和大小。

  • aten::sym_size 返回一個整數;因此它的 val 將是一個 SymInt,因為它的值僅在運行時可用。

  • max_pool2d_with_indexes 返回一個 (Tensor, Tensor) 的元組;因此規範也將是一個 FakeTensor 物件的二元組,第一個 TensorMeta 描述返回值的第一個元素,依此類推。

Python 程式碼

def add_one(x):
  return torch.ops.aten(x, 1)

圖形

graph():
  %ph_0 : [#users=1] = placeholder[target=ph_0]
  %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]

FakeTensor

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

Pytree 可用類型

我們定義了一種類型「Pytree 可用」,如果它是葉類型或包含其他 Pytree 可用類型的容器類型。

注意事項

pytree 的概念與 JAX 此處 文件中記錄的概念相同

以下類型定義為**葉類型**

類型

定義

張量

torch.Tensor

純量

來自 Python 的任何數值類型,包括整數類型、浮點數類型和零維張量。

int

Python int(在 C++ 中綁定為 int64_t)

float

Python float(在 C++ 中綁定為 double)

bool

Python bool

str

Python 字串

ScalarType

torch.dtype

Layout

torch.layout

MemoryFormat

torch.memory_format

Device

torch.device

以下類型定義為**容器類型**

類型

定義

Tuple

Python 元組

List

Python 列表

Dict

具有純量鍵的 Python dict

NamedTuple

Python namedtuple

Dataclass

必須透過 register_dataclass 註冊

自訂類別

使用 _register_pytree_node 定義的任何自訂類別

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源