Fake tensor¶
程式碼:fake_tensor.py
動機¶
在進行 Dynamo 符號求值和編譯器遍時,我們通常希望能夠執行張量操作,以瞭解輸出的尺寸/資料型別/裝置是什麼,而無需實際執行這些操作(或破壞現有張量),因為這樣做會更慢(如果計算量很大)且佔用大量記憶體(如果編譯器在編譯程式時需要使用 GPU 記憶體,那就不好了)。一個 fake tensor 在所有方面都像一個真實的張量,不同之處在於它實際上沒有任何資料。例如,當我們進行 Dynamo 追蹤時,我們需要追蹤使用者的 Tensor 程式碼,並回答關於中間結果的問題(例如,如果使用者對中間張量進行條件判斷)。沒有 fake tensor,我們將無法獲得這些查詢的準確資訊。
同樣,假設您想儲存張量的元資料,例如在 FX IR 節點 (meta['val']) 上。您可以直接在節點上儲存一個 fake tensor,它將為您提供張量所需的所有元資料,包括一些您可能無法處理的細微之處(例如,別名關係)。
整體架構¶
所有 fake tensor 都與 FakeTensorMode 相關聯。由於 fake tensor 的主要用例是對真實張量進行分析,因此一般的工作流程是:您有一堆真實張量,分配一個 FakeTensorMode,然後使用 from_real_tensor 將所有真實張量轉換為 fake tensor,接著對這些 fake tensor 進行操作。特別地,FakeTensorMode 維護一個持久的備忘錄表,將張量(和儲存)對映到相同的儲存。如果您多次 fakeify 同一個張量,您將得到同一個 fake tensor;如果您 fakeify 兩個相互別名的張量,您將得到兩個別名同一 fake storage 的 fake tensor。Fake tensor 是張量子類,因此如果您對它們進行操作,您會自動獲得一個 fake tensor,但通常您會希望在 FakeTensorMode 啟用的情況下對 fake tensor 進行操作(例如,如果您正在執行 FX pass);張量操作會自動開啟 fake tensor 模式並重試。
Fake tensor 被表示為 meta tensor 的一個 __torch_dispatch__ 張量子類。這意味著在底層,fake tensor 是 meta device 張量;然後它們利用額外的可擴充套件性鉤子,特別是 dispatch_device,來偽稱張量的實際裝置。這是 fake tensor 早期階段更容易出錯的部分之一:有時,fake tensor 太擅長偽裝成 CPU/CUDA 或其他裝置,結果會導致 CPU kernel 被呼叫,而 fake tensor 試圖解引用資料指標,這顯然行不通。如果在 fake tensor 程式碼中發生段錯誤,這是您首先應該檢查的事情:C++ 回溯是在 CPU kernel 中(意外!)還是在 meta kernel 中(預期!)。meta kernel 類似於真實 kernel,但它只負責分配輸出,不進行任何資料計算。
張量子類必須定義如何實現各種操作。以下是一般的 fake tensor 實現方法
在輸入 fake tensor 上執行 meta kernel,並將它們重新解釋為 meta tensor。這是透過一個神奇的上下文管理器 `in_kernel_invocation_manager` 完成的,該管理器指示 PyTorch 將 fake tensor 視為其底層 meta tensor,而不是將 fake tensor “展開”為 meta tensor(因為 fake tensor 就是 meta tensor)。fake tensor 以這種方式表示,以避免必須同步兩組元資料(meta tensor 的元資料和 fake tensor 的元資料);“is a” 關係確保只有一份規範的元資料副本。
如果您是一個工廠函式,您將轉而呼叫底層工廠函式,裝置設定為 `device='meta'`。
將生成的 meta tensor 轉換為 fake tensor,計算張量的輸出裝置應該是什麼(這通常很簡單,但有時並非如此,例如 CPU 標量提升或裝置轉換操作)。
API:重要部分¶
非 PT2 用法(更多示例請檢視 test/test_fake_tensor.py)
# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
# Do some operations on the fake tensors
fake_y = fake_x * 2
# Factory operations automatically get fakeified in the context manager
fake_z = torch.empty(20)
問:為什麼輸入是真實張量?
答:在 PT2 環境中,這是因為您通常是即時編譯,因此對於您正在編譯的圖的所有輸入,您已經有了“真實”輸入,因為您在執行程式時進行編譯。
PT2 pre-AOTAutograd 用法(這不常見,您可能不希望這樣做)
# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
... # do stuff with the fake args, if needed ...
`detect_fake_mode` 會搜尋多個位置,嘗試找到與生命週期相關的“那個”fake tensor 模式。通常它會從追蹤上下文中獲取。
PT2 post-AOTAutograd 用法
# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
fake_mode = detect_fake_mode(example_inputs)
# But in general you don't have to turn it on
其他有用內容
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
... # fake mode is disabled here, you can do real tensor compute
您何時可能想要停用 fake tensor 模式?通常您不希望這樣做。一個我們發現有用的特殊情況是在 fake tensor 上實現常量傳播:在這種情況下,即使我們處於 fake tensor 模式,也需要進行一些實際的張量計算。
import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
詳細資訊¶
是否自動轉換?
with FakeTensorMode():
real_tensor.t_()
最初,如果您嘗試在 FakeTensorMode 區域內對真實張量進行計算,FakeTensorMode 不會自動 fakeify 這些真實張量。這樣做是為了防止以下陷阱
這段程式碼應該做什麼?如果我們真的修改了真實張量的元資料,那會很令人驚訝。但與此同時,也沒有任何明顯的機會去建立一個 FakeTensor。因此,我們保守地決定讓它丟擲一個錯誤:“在 FakeTensorMode 中使用非 Fake Tensor 輸入呼叫運算元尚不受支援。請先將所有 Tensor 轉換為 FakeTensor。”
這個錯誤在實踐中相當煩人。例如,假設您有一個真實的 nn.Module,並且想讓 fake tensor 透過它。您需要以某種方式 fakeify 這個 nn.Module。這促使了 FakeCopyMode 的出現。
最終,我們放棄了限制,並添加了自動 fakeification 功能。然而,在許多 FakeTensorMode 的使用場景中,此功能預設尚未啟用。
fake tensor 上的元資料修改
如果您有一個 fake tensor 並對其呼叫 `t_()`,則該 fake tensor 上的元資料會改變。這表面上看來合理,但有時您也希望將 fake tensor 作為元資料儲存在 FX 節點上;修改 fake tensor 是不好的,因為它會使舊的元資料失效!
事實上,這裡存在一個根本性的矛盾,即 fake tensor 維護著極其準確的張量元資料,包括物件標識。如果 FX 圖中的物件元資料隨時間變化,實際上沒有任何方法可以表示這種隨時間的變化。大多數時候,我們的重要 FX 分析是在 函式化圖 上進行的,這些圖沒有這個問題,但偶爾您需要在 非函式化圖 上進行分析。也許將 fake tensor 放在 meta['val'] 中是一個錯誤。
關於張量子類¶
Fake tensor 同時使用了子類和 mode 張量子類模式,其中 FakeTensor.__torch_dispatch__ 啟用與 fake tensor 關聯的 FakeTensorMode,然後重新排程(依賴 FakeTensorMode 完成繁重工作)。如果 fake tensor 操作收到一個它不認識的子類引數,它將返回 NotImplemented,讓其他子類有機會先執行(希望能去糖化為普通張量操作),然後再嘗試。這可能導致無限迴圈。
每個運算元是如何實現的?¶
不幸的是,任何給定運算元的實現位置都相當複雜。一些需要了解的重要情況包括:
如果元素數量非常小,張量子類支援有限的常量傳播(這有助於處理一些我們立即對此類張量呼叫 `item()` 的情況)。
出於效能考慮,我們對某些運算元有一些快速路徑實現,這些實現完全在 fake tensor 中完成。
如果您使用 `@custom_op` 生成自定義張量,這些將直接向 fake tensor 註冊 `impl_abstract`。
Fake tensor 本身對裝置轉換操作有一些硬編碼的特殊情況。
如果沒有 meta 實現或任何分解,我們將生成真實的零填充張量,並嘗試直接執行運算元以確定結果。如果運算元嘗試使用資料進行索引,這可能導致段錯誤,因此我們預設不為自定義運算元啟用此功能。
轉換器是如何工作的?¶
由於 fake tensor 用於對張量精確屬性非常敏感的情況,因此 fake tensor 會非常仔細地進行轉換,保留 leaf 屬性、requires_grad 屬性、別名關係以及許多其他屬性。大部分繁重工作由 MetaConverter 完成。
效能特徵¶
您可能會認為 fake tensor 速度很快,因為它不進行任何張量計算。但在張量尺寸較小時,我們實際上完全受開銷限制,而且 fake tensor 是用 Python 實現的,我們通常需要做很多工作才能完成一個張量操作(因為它們被實現為分解)。因此,fake tensor 在實踐中實際上相當慢,尤其是在涉及符號形狀時。目前我們在 fake tensor 中有兩個重要的快速路徑,它們在實踐中起著重要作用
逐點運算元不經過 PrimTorch 分解,而是我們手動編碼了它們的傳播規則。
如果可能,我們應該這樣做。
Fake tensor 的 fake tensor?¶
有人有興趣將 fake tensor 作為使用者輸入傳送到 PT2 堆疊中,這意味著我們需要能夠建立 fake tensor 的 fake tensor。這目前尚不支援,但也許做起來不會太難。
與動態形狀的互動¶