虛擬張量¶
程式碼: fake_tensor.py
動機¶
在執行 Dynamo 符號評估和編譯器傳遞時,我們通常希望能夠執行張量運算以瞭解輸出大小/資料類型/裝置,而無需實際執行這些運算(或捨棄預先存在的張量),這會比較慢(如果您正在執行大量計算)並且佔用大量記憶體(如果您的編譯器在編譯程式時需要使用 GPU 記憶體,那就很糟糕了)。虛擬張量在所有方面都類似於真實張量,只是它實際上沒有任何資料。例如,當我們執行 Dynamo 追蹤時,我們需要追蹤使用者張量程式碼並回答有關中間結果的問題(例如,如果使用者對中間張量執行條件式)。如果沒有虛擬張量,我們將無法獲得這些查詢的準確資訊。
同樣地,假設您要儲存張量的中繼資料,例如,在 FX IR 節點上(meta['val'])。您可以改為直接在節點上儲存虛擬張量,這將為您提供張量所需的所有中繼資料,包括您可能不會處理的微妙內容(例如,別名關係)。
整體架構¶
所有虛擬張量都與 FakeTensorMode 相關聯。由於虛擬張量的主要用例是對真實張量執行分析,因此一般工作流程是您擁有一堆真實張量,配置一個 FakeTensorMode,然後使用 from_real_tensor 將所有這些真實張量轉換為虛擬張量,然後對虛擬張量執行操作。特別是,FakeTensorMode 會維護一個備忘錄表,將張量(和儲存體)持久地映射到相同的儲存體。如果您多次虛擬化同一個張量,您將獲得相同的虛擬張量;如果您虛擬化兩個相互別名的張量,您將獲得兩個別名相同的虛擬儲存體的虛擬張量。FakeTensors 是張量子類別,因此如果您對其執行運算,您將自動獲得一個虛擬張量,但通常您會希望在 FakeTensorMode 處於活動狀態時對虛擬張量執行運算(例如,如果您正在執行 FX 傳遞);張量運算的作用是自動開啟虛擬張量模式並重試。
虛擬張量表示為中繼張量的 __torch_dispatch__ 張量子類別。這意味著在底層,虛擬張量是中繼裝置張量;然後,它們使用額外的擴展性鉤子(特別是 dispatch_device)來謊報張量的實際裝置。這是早期虛擬張量更容易出錯的部分之一:有時,虛擬張量在謊報自己是 CPU/CUDA 等方面做得太好了,結果您會發現一個 CPU 核心被呼叫,而虛擬張量試圖取消引用資料指標,這顯然行不通。如果您在虛擬張量程式碼中遇到區段錯誤,這是您應該檢查的第一件事:C++ 回溯是在 CPU 核心(意外!)還是中繼核心(預期!)中?中繼核心類似於真實核心,但它所做的只是配置輸出,而不執行任何資料計算。
張量子類別必須定義如何實作各種運算。以下是通用的虛擬張量配方
- 在輸入虛擬張量上執行中繼核心,將其重新解釋為中繼張量。這是透過魔術上下文管理器 in_kernel_invocation_manager 完成的,該管理器指示所有 PyTorch 將虛擬張量視為其底層的中繼張量,而不是將虛擬張量「解包」為中繼張量(虛擬張量是一種中繼張量)。以這種方式表示虛擬張量是為了避免必須同步兩組中繼資料(中繼張量的中繼資料和虛擬張量的中繼資料);「是」關係確保只有一份規範的中繼資料副本。 
- 如果您使用的是工廠函數,則應改為呼叫裝置為「中繼」的底層工廠函數。 
- 將產生的中繼張量轉換為虛擬張量,計算張量的輸出裝置應該是什麼(這通常很簡單,但有時並非如此,例如,CPU 純量提升或裝置轉換運算。) 
API:重要部分¶
非 PT2 用法(查看 test/test_fake_tensor.py 以獲取更多範例)
# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
# Fakeify some real tensors
fake_x = fake_mode.from_real_tensor(x)
with fake_mode:
    # Do some operations on the fake tensors
    fake_y = fake_x * 2
    # Factory operations automatically get fakeified in the context manager
    fake_z = torch.empty(20)
問:為什麼要使用真實張量作為輸入?
答:在 PT2 環境中,這是因為您通常是在即時編譯,因此對於要編譯的圖形的所有輸入,您已經擁有「真實」的輸入,因為您是在執行程式時進行編譯。
PT2 pre-AOTAutograd 使用(這很不尋常,您可能不想這樣做)
# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
fake_args = [fake_mode.from_real_tensor(arg) for arg in args]
with fake_mode:
... do stuff with the fake args, if needed ...
detect_fake_mode 將搜尋多個位置,嘗試找到與生命週期關聯的「虛擬」張量模式。通常會從追蹤環境中提取。
PT2 post-AOTAutograd 使用
# 虛擬模式已啟用!example_inputs 通常已經是虛擬的 # TODO:我們可能想改變這個 # 仍然這樣做以存取虛擬模式 fake_mode = detect_fake_mode(example_inputs) # 但一般來說,您不需要開啟它
其他有用的東西
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
with maybe_disable_fake_tensor_mode():
    # fake mode is disabled here, you can do real tensor compute
您什麼時候可能想停用虛擬張量模式?通常您不會想這樣做。我們發現它有用的一個特殊情況是在虛擬張量上實現常數傳播:在這種情況下,即使我們處於虛擬張量模式中,也需要執行一些實際的張量計算。
FakeTensorProp
from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
詳細資訊¶
自動轉換與否?最初,如果您嘗試在 FakeTensorMode 區域內對真實張量進行計算,FakeTensorMode 不會自動將其虛擬化。這樣做的動機是為了防止以下陷阱
with FakeTensorMode():
real_tensor.t_()
這段程式碼應該做什麼?如果我們真的修改了真實張量的中繼資料,那將會令人驚訝。但與此同時,也沒有明顯的機會來建立虛擬張量。因此,我們保守地決定讓它引發錯誤:「在 FakeTensorMode 中使用非虛擬張量輸入呼叫運算子尚未受到支援。請先將所有張量轉換為虛擬張量。」
這個錯誤在實務中很煩人。例如,假設您有一個真實的 nn.Module,並且想將虛擬張量輸入其中。您需要以某種方式將 nn.Module 虛擬化。這促使了 FakeCopyMode 的誕生。
最終,我們放棄並添加了自動虛擬化功能。但是,在 FakeTensorMode 的許多用法中,這仍然不是預設啟用的。
虛擬張量的中繼資料變動 如果您有一個虛擬張量,並且您對其執行 t_(),則虛擬張量的中繼資料會發生變化。從表面上看,這是合理的,但有時您也想將虛擬張量作為中繼資料儲存在 FX 節點上;變動虛擬張量是不好的,因為這會使舊的中繼資料失效!
事實上,這裡存在一個根本性的矛盾,即虛擬張量會維護關於張量的極其準確的中繼資料,包括物件識別碼。如果物件中繼資料在 FX 圖表中隨著時間推移而發生變化,則實際上沒有任何方法可以表示這種隨時間推移的變化。大多數情況下,我們嚴謹的 FX 分析是在函式化的圖表上完成的,這些圖表沒有這個問題,但有時您需要在非函式化的圖表上進行分析。也許將虛擬張量放入 meta['val'] 中是一個錯誤
關於張量子類別¶
虛擬張量同時使用子類別和模式張量子類別模式,其中 FakeTensor.__torch_dispatch__ 啟用與虛擬張量關聯的 FakeTensorMode,然後重新分派(依賴 FakeTensorMode 進行繁重的工作)。如果虛擬張量運算收到一個它不認識的子類別參數,它將返回 NotImplemented,讓其他子類別有機會先執行(希望解構為普通的張量運算),然後再試一次。這可能會導致無限迴圈。
每個個別的運算子是如何實作的?¶
不幸的是,任何給定的運算子都可能在一組相當複雜的地方實作。一些需要注意的重要情況
- 如果元素數量非常少,張量子類別支援有限的常數傳播(這有助於處理一些我們立即對此類張量呼叫 item() 的情況。) 
- 出於效能原因,我們為某些運算子提供了一些快速路徑實作,這些實作完全在虛擬張量中完成。 
- 如果您使用 @custom_op 生成自訂張量,這些張量將直接向虛擬張量註冊 impl_abstract。 
- 虛擬張量本身對於裝置轉換操作有一些硬編碼的特殊情況。 
- 如果沒有中繼實作也沒有任何分解,我們將生成真實的零填充張量,並嘗試直接執行運算子以找出結果。如果運算子嘗試使用資料進行索引,這可能會導致區段錯誤,因此我們不會為自訂運算預設開啟此功能。 
轉換器是如何運作的?¶
由於虛擬張量用於對張量的確切屬性非常敏感的情況,因此虛擬張量會非常小心地進行轉換,保留葉節點、requires_grad、別名以及許多其他屬性。大部分的繁重工作都在 MetaConverter 中完成。
效能特點¶
您可能會認為虛擬張量速度很快,因為它們不做任何張量計算。但在張量大小較小的情況下,我們實際上完全受限於開銷,而且,虛擬張量是用 Python 編寫的,我們通常需要做很多工作才能完成單個張量運算(因為它們是作為分解實作的)。因此,虛擬張量在實務中實際上相當慢,尤其是在涉及符號形狀時。我們目前在虛擬張量中有兩個重要的快速路徑,它們在實務中產生了很大的差異
- 逐點運算不會經過 PrimTorch 分解,而是我們手動編寫了它們的傳播規則。 
- 如果可能,我們應該這樣做。 
虛擬張量的虛擬張量?¶
我們有興趣將虛擬張量作為使用者輸入發送到 PT2 堆疊,這意味著我們需要能夠建立虛擬張量的虛擬張量。這在目前還沒有真正受到支援,但也許實現起來不會太困難。
與動態形狀的互動¶
每個 FakeTensorMode 都包含一個 ShapeEnv,它追蹤所有符號形狀資訊。它們的生命週期通常是綁定的:它們同生共死。
由於 FakeTensorMode 有一個 ShapeEnv(但中繼實作沒有),因此依賴於資料且需要分配未支援的 SymInt 的中繼函式位於虛擬張量中。虛擬張量還負責記憶未支援的 SymInt,例如,如果您在同一個虛擬張量上呼叫兩次 nonzero(),您將獲得相同的符號大小。