torch.func.functionalize¶
- torch.func.functionalize(func, *, remove='mutations')[源]¶
functionalize 是一個變換,可用於從函式中移除(中間)修改和別名,同時保留函式的語義。
functionalize(func)返回一個新函式,該函式與func具有相同的語義,但移除了所有中間修改。在中間張量上執行的每一個原地操作:intermediate.foo_()都將被其非原地等效操作替換:intermediate_updated = intermediate.foo()。functionalize 對於將 PyTorch 程式傳送到無法輕鬆表示修改或別名運算子的後端或編譯器非常有用。
- 引數
func (Callable) – 一個接受一個或多個引數的 Python 函式。
remove (str) – 一個可選的字串引數,其值可以是 ‘mutations’ 或 ‘mutations_and_views’。如果傳入 ‘mutations’,則所有修改型運算子都將替換為其非修改型等效項。如果傳入 ‘mutations_and_views’,則此外,所有別名運算子都將替換為其非別名等效項。預設值:‘mutations’。
- 返回值
返回一個已“函式化”的新函式。它接受與
func相同的輸入,並具有相同的行為,但函式中對中間張量執行的任何修改(以及可選的別名)都將被移除。- 返回型別
functionalize 也會移除對函式輸入進行的修改(和檢視)。但是為了保留語義,functionalize 會在變換執行完成後“修補”修改,方法是檢測是否有任何張量輸入“應該”被修改,並在必要時將新資料複製回輸入。
示例
>>> import torch >>> from torch.fx.experimental.proxy_tensor import make_fx >>> from torch.func import functionalize >>> >>> # A function that uses mutations and views, but only on intermediate tensors. >>> def f(a): ... b = a + 1 ... c = b.view(-1) ... c.add_(1) ... return b ... >>> inpt = torch.randn(2) >>> >>> out1 = f(inpt) >>> out2 = functionalize(f)(inpt) >>> >>> # semantics are the same (outputs are equivalent) >>> print(torch.allclose(out1, out2)) True >>> >>> f_traced = make_fx(f)(inpt) >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt) >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) >>> >>> print(f_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view = torch.ops.aten.view(add, [-1]) add_ = torch.ops.aten.add_(view, 1); view = None return add >>> print(f_no_mutations_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view = torch.ops.aten.view(add, [-1]); add = None add_1 = torch.ops.aten.add(view, 1); view = None view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None return view_1 >>> print(f_no_mutations_and_views_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view_copy = torch.ops.aten.view_copy(add, [-1]); add = None add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None return view_copy_1 >>> # A function that mutates its input tensor >>> def f(a): ... b = a.view(-1) ... b.add_(1) ... return a ... >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) >>> # >>> # All mutations and views have been removed, >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input >>> # after the function has completed. >>> print(f_no_mutations_and_views_traced.code) def forward(self, a_1): view_copy = torch.ops.aten.view_copy(a_1, [-1]) add = torch.ops.aten.add(view_copy, 1); view_copy = None view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None return view_copy_1
- functionalize 有幾個值得指出的“失敗模式”
與其他 torch.func 變換一樣,functionalize() 不適用於直接使用 .backward() 的函式。對於 torch.autograd.grad 也是如此。如果想使用 autograd,可以直接使用 functionalize(grad(f)) 計算梯度。
與其他 torch.func 變換一樣,functionalize() 不適用於全域性狀態。如果在對非區域性狀態進行檢視/修改的函式上呼叫 functionalize(f),函式化將只是空操作(no-op),並將檢視/修改呼叫直接傳遞給後端。解決此問題的一種方法是確保任何非區域性狀態的建立都被封裝到一個更大的函式中,然後在該函式上呼叫 functionalize。
resize_() 有一些限制:functionalize 僅適用於使用 resize_()` 的程式,前提是正在 resize 的張量不是檢視。
as_strided() 有一些限制:functionalize 不適用於導致張量記憶體重疊的 as_strided() 呼叫。
最後,理解函式化(functionalization)的一個有用思維模型是,大多數使用者 PyTorch 程式都是使用公共 torch API 編寫的。執行時,torch 運算子通常被分解為我們的內部 C++“ATen” API。函式化的邏輯完全發生在 ATen 層面。函式化知道如何將 ATen 中的每個別名運算子對映到其非別名等效項(例如
tensor.view({-1})->at::view_copy(tensor, {-1})),以及如何將 ATen 中的每個修改型運算子對映到其非修改型等效項(例如tensor.add_(1)->at::add(tensor, -1)),同時在旁路跟蹤別名和修改,以便知道何時進行修補。關於哪些 ATen 運算子是別名或修改型的資訊都來自 https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml。