快捷方式

Fake tensor

程式碼:fake_tensor.py

動機

在進行 Dynamo 符號求值和編譯器遍時,我們通常希望能夠執行張量操作,以瞭解輸出的尺寸/資料型別/裝置是什麼,而無需實際執行這些操作(或破壞現有張量),因為這樣做會更慢(如果計算量很大)且佔用大量記憶體(如果編譯器在編譯程式時需要使用 GPU 記憶體,那就不好了)。一個 fake tensor 在所有方面都像一個真實的張量,不同之處在於它實際上沒有任何資料。例如,當我們進行 Dynamo 追蹤時,我們需要追蹤使用者的 Tensor 程式碼,並回答關於中間結果的問題(例如,如果使用者對中間張量進行條件判斷)。沒有 fake tensor,我們將無法獲得這些查詢的準確資訊。

同樣,假設您想儲存張量的元資料,例如在 FX IR 節點 (meta['val']) 上。您可以直接在節點上儲存一個 fake tensor,它將為您提供張量所需的所有元資料,包括一些您可能無法處理的細微之處(例如,別名關係)。

整體架構

所有 fake tensor 都與 FakeTensorMode 相關聯。由於 fake tensor 的主要用例是對真實張量進行分析,因此一般的工作流程是:您有一堆真實張量,分配一個 FakeTensorMode,然後使用 from_real_tensor 將所有真實張量轉換為 fake tensor,接著對這些 fake tensor 進行操作。特別地,FakeTensorMode 維護一個持久的備忘錄表,將張量(和儲存)對映到相同的儲存。如果您多次 fakeify 同一個張量,您將得到同一個 fake tensor;如果您 fakeify 兩個相互別名的張量,您將得到兩個別名同一 fake storage 的 fake tensor。Fake tensor 是張量子類,因此如果您對它們進行操作,您會自動獲得一個 fake tensor,但通常您會希望在 FakeTensorMode 啟用的情況下對 fake tensor 進行操作(例如,如果您正在執行 FX pass);張量操作會自動開啟 fake tensor 模式並重試。

Fake tensor 被表示為 meta tensor 的一個 __torch_dispatch__ 張量子類。這意味著在底層,fake tensor 是 meta device 張量;然後它們利用額外的可擴充套件性鉤子,特別是 dispatch_device,來偽稱張量的實際裝置。這是 fake tensor 早期階段更容易出錯的部分之一:有時,fake tensor 太擅長偽裝成 CPU/CUDA 或其他裝置,結果會導致 CPU kernel 被呼叫,而 fake tensor 試圖解引用資料指標,這顯然行不通。如果在 fake tensor 程式碼中發生段錯誤,這是您首先應該檢查的事情:C++ 回溯是在 CPU kernel 中(意外!)還是在 meta kernel 中(預期!)。meta kernel 類似於真實 kernel,但它只負責分配輸出,不進行任何資料計算。

張量子類必須定義如何實現各種操作。以下是一般的 fake tensor 實現方法

  • 在輸入 fake tensor 上執行 meta kernel,並將它們重新解釋為 meta tensor。這是透過一個神奇的上下文管理器 `in_kernel_invocation_manager` 完成的,該管理器指示 PyTorch 將 fake tensor 視為其底層 meta tensor,而不是將 fake tensor “展開”為 meta tensor(因為 fake tensor 就是 meta tensor)。fake tensor 以這種方式表示,以避免必須同步兩組元資料(meta tensor 的元資料和 fake tensor 的元資料);“is a” 關係確保只有一份規範的元資料副本。

  • 如果您是一個工廠函式,您將轉而呼叫底層工廠函式,裝置設定為 `device='meta'`。

  • 將生成的 meta tensor 轉換為 fake tensor,計算張量的輸出裝置應該是什麼(這通常很簡單,但有時並非如此,例如 CPU 標量提升或裝置轉換操作)。

API:重要部分

非 PT2 用法(更多示例請檢視 test/test_fake_tensor.py)

# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
    # Do some operations on the fake tensors
    fake_y = fake_x * 2
    # Factory operations automatically get fakeified in the context manager
    fake_z = torch.empty(20)

問:為什麼輸入是真實張量?

答:在 PT2 環境中,這是因為您通常是即時編譯,因此對於您正在編譯的圖的所有輸入,您已經有了“真實”輸入,因為您在執行程式時進行編譯。

PT2 pre-AOTAutograd 用法(這不常見,您可能不希望這樣做)

# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
    ... # do stuff with the fake args, if needed ...

`detect_fake_mode` 會搜尋多個位置,嘗試找到與生命週期相關的“那個”fake tensor 模式。通常它會從追蹤上下文中獲取。

PT2 post-AOTAutograd 用法

# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
fake_mode = detect_fake_mode(example_inputs)
# But in general you don't have to turn it on

其他有用內容

from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
    ... # fake mode is disabled here, you can do real tensor compute

您何時可能想要停用 fake tensor 模式?通常您不希望這樣做。一個我們發現有用的特殊情況是在 fake tensor 上實現常量傳播:在這種情況下,即使我們處於 fake tensor 模式,也需要進行一些實際的張量計算。

import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)

詳細資訊

是否自動轉換?

with FakeTensorMode():
    real_tensor.t_()

最初,如果您嘗試在 FakeTensorMode 區域內對真實張量進行計算,FakeTensorMode 不會自動 fakeify 這些真實張量。這樣做是為了防止以下陷阱

這段程式碼應該做什麼?如果我們真的修改了真實張量的元資料,那會很令人驚訝。但與此同時,也沒有任何明顯的機會去建立一個 FakeTensor。因此,我們保守地決定讓它丟擲一個錯誤:“在 FakeTensorMode 中使用非 Fake Tensor 輸入呼叫運算元尚不受支援。請先將所有 Tensor 轉換為 FakeTensor。”

這個錯誤在實踐中相當煩人。例如,假設您有一個真實的 nn.Module,並且想讓 fake tensor 透過它。您需要以某種方式 fakeify 這個 nn.Module。這促使了 FakeCopyMode 的出現。

最終,我們放棄了限制,並添加了自動 fakeification 功能。然而,在許多 FakeTensorMode 的使用場景中,此功能預設尚未啟用。

fake tensor 上的元資料修改

如果您有一個 fake tensor 並對其呼叫 `t_()`,則該 fake tensor 上的元資料會改變。這表面上看來合理,但有時您也希望將 fake tensor 作為元資料儲存在 FX 節點上;修改 fake tensor 是不好的,因為它會使舊的元資料失效!

事實上,這裡存在一個根本性的矛盾,即 fake tensor 維護著極其準確的張量元資料,包括物件標識。如果 FX 圖中的物件元資料隨時間變化,實際上沒有任何方法可以表示這種隨時間的變化。大多數時候,我們的重要 FX 分析是在 函式化圖 上進行的,這些圖沒有這個問題,但偶爾您需要在 非函式化圖 上進行分析。也許將 fake tensor 放在 meta['val'] 中是一個錯誤。

關於張量子類

Fake tensor 同時使用了子類和 mode 張量子類模式,其中 FakeTensor.__torch_dispatch__ 啟用與 fake tensor 關聯的 FakeTensorMode,然後重新排程(依賴 FakeTensorMode 完成繁重工作)。如果 fake tensor 操作收到一個它不認識的子類引數,它將返回 NotImplemented,讓其他子類有機會先執行(希望能去糖化為普通張量操作),然後再嘗試。這可能導致無限迴圈。

  • 每個運算元是如何實現的?

  • 不幸的是,任何給定運算元的實現位置都相當複雜。一些需要了解的重要情況包括:

  • 如果元素數量非常小,張量子類支援有限的常量傳播(這有助於處理一些我們立即對此類張量呼叫 `item()` 的情況)。

  • 出於效能考慮,我們對某些運算元有一些快速路徑實現,這些實現完全在 fake tensor 中完成。

  • 如果您使用 `@custom_op` 生成自定義張量,這些將直接向 fake tensor 註冊 `impl_abstract`。

Fake tensor 本身對裝置轉換操作有一些硬編碼的特殊情況。

如果沒有 meta 實現或任何分解,我們將生成真實的零填充張量,並嘗試直接執行運算元以確定結果。如果運算元嘗試使用資料進行索引,這可能導致段錯誤,因此我們預設不為自定義運算元啟用此功能。

轉換器是如何工作的?

由於 fake tensor 用於對張量精確屬性非常敏感的情況,因此 fake tensor 會非常仔細地進行轉換,保留 leaf 屬性、requires_grad 屬性、別名關係以及許多其他屬性。大部分繁重工作由 MetaConverter 完成。

  • 效能特徵

  • 您可能會認為 fake tensor 速度很快,因為它不進行任何張量計算。但在張量尺寸較小時,我們實際上完全受開銷限制,而且 fake tensor 是用 Python 實現的,我們通常需要做很多工作才能完成一個張量操作(因為它們被實現為分解)。因此,fake tensor 在實踐中實際上相當慢,尤其是在涉及符號形狀時。目前我們在 fake tensor 中有兩個重要的快速路徑,它們在實踐中起著重要作用

逐點運算元不經過 PrimTorch 分解,而是我們手動編碼了它們的傳播規則。

如果可能,我們應該這樣做。

Fake tensor 的 fake tensor?

有人有興趣將 fake tensor 作為使用者輸入傳送到 PT2 堆疊中,這意味著我們需要能夠建立 fake tensor 的 fake tensor。這目前尚不支援,但也許做起來不會太難。

與動態形狀的互動

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源