UX 限制¶
torch.func,與 JAX 類似,對可轉換的內容有限制。一般來說,JAX 的限制是轉換隻適用於純函式:也就是說,其輸出完全由輸入決定且不涉及副作用(如修改)的函式。
我們也有類似的保證:我們的轉換與純函式配合得很好。然而,我們也支援某些原地操作。一方面,編寫與函式轉換相容的程式碼可能需要改變你編寫 PyTorch 程式碼的方式;另一方面,你可能會發現我們的轉換可以讓你表達以前在 PyTorch 中難以表達的內容。
一般限制¶
所有的 torch.func 轉換都有一個共同的限制,即函式不應向全域性變數賦值。相反,函式的所有輸出必須從函式中返回。這一限制源於 torch.func 的實現方式:每個轉換都會將 Tensor 輸入包裝在特殊的 torch.func Tensor 子類中,以促進轉換。
因此,不要像下面這樣
import torch
from torch.func import grad
# Don't do this
intermediate = None
def f(x):
global intermediate
intermediate = x.sin()
z = intermediate.sin()
return z
x = torch.randn([])
grad_x = grad(f)(x)
請重寫 f 使其返回 intermediate
def f(x):
intermediate = x.sin()
z = intermediate.sin()
return z, intermediate
grad_x, intermediate = grad(f, has_aux=True)(x)
torch.autograd API¶
如果你試圖在被 vmap() 或 torch.func 的 AD 轉換之一(vjp(), jvp(), jacrev(), jacfwd())轉換的函式內部使用 torch.autograd API(例如 torch.autograd.grad 或 torch.autograd.backward),則該轉換可能無法對其進行轉換。如果無法做到,你將收到錯誤訊息。
這是 PyTorch 的 AD 支援實現方式上的一個基本設計限制,也是我們設計 torch.func 庫的原因。請轉而使用 torch.autograd API 的 torch.func 等價物: - torch.autograd.grad, Tensor.backward -> torch.func.vjp 或 torch.func.grad - torch.autograd.functional.jvp -> torch.func.jvp - torch.autograd.functional.jacobian -> torch.func.jacrev 或 torch.func.jacfwd - torch.autograd.functional.hessian -> torch.func.hessian
vmap 限制¶
注意
vmap() 是我們限制最多的轉換。與梯度相關的轉換(grad(), vjp(), jvp())沒有這些限制。jacfwd()(以及 hessian(),它透過 jacfwd() 實現)是 vmap() 和 jvp() 的組合,因此它也具有這些限制。
vmap(func) 是一種轉換,它返回一個函式,該函式將 func 對映到每個輸入 Tensor 的某個新維度上。vmap 的思維模型就像執行一個 for 迴圈:對於純函式(即沒有副作用的函式),vmap(f)(x) 等價於
torch.stack([f(x_i) for x_i in x.unbind(0)])
修改:任意修改 Python 資料結構¶
在存在副作用的情況下,vmap() 不再像執行 for 迴圈那樣。例如,以下函式
def f(x, list):
list.pop()
print("hello!")
return x.sum(0)
x = torch.randn(3, 1)
lst = [0, 1, 2, 3]
result = vmap(f, in_dims=(0, None))(x, lst)
將只打印一次“hello!” 並只從 lst 中彈出一個元素。
vmap() 只執行 f 一次,因此所有副作用只發生一次。
這是 vmap 實現方式的結果。torch.func 有一個特殊的內部 BatchedTensor 類。vmap(f)(*inputs) 接受所有 Tensor 輸入,將它們轉換為 BatchedTensor,然後呼叫 f(*batched_tensor_inputs)。BatchedTensor 重寫了 PyTorch API,以便為每個 PyTorch 運算元產生批處理(即向量化)行為。
修改:PyTorch 原地操作¶
你看到這裡可能是因為收到了關於 vmap 不相容原地操作的錯誤。vmap() 如果遇到不支援的 PyTorch 原地操作就會引發錯誤,否則會成功。不支援的操作是指那些會導致元素數量更多的 Tensor 被寫入元素數量更少的 Tensor 的操作。下面是一個示例說明這種情況如何發生
def f(x, y):
x.add_(y)
return x
x = torch.randn(1)
y = torch.randn(3, 1) # When vmapped over, looks like it has shape [1]
# Raises an error because `x` has fewer elements than `y`.
vmap(f, in_dims=(None, 0))(x, y)
x 是一個只有一個元素的 Tensor,y 是一個有三個元素的 Tensor。x + y 有三個元素(由於廣播),但試圖將三個元素寫回只有一個元素的 x 中會引發錯誤,因為試圖將三個元素寫入只有一個元素的 Tensor。
如果被寫入的 Tensor 在 vmap() 下進行批處理(即正在對其進行 vmap 操作),則沒有問題。
def f(x, y):
x.add_(y)
return x
x = torch.randn(3, 1)
y = torch.randn(3, 1)
expected = x + y
# Does not raise an error because x is being vmapped over.
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)
一個常見的解決方法是將對工廠函式的呼叫替換為其“new_*”等價物。例如
將
torch.zeros()替換為Tensor.new_zeros()將
torch.empty()替換為Tensor.new_empty()
要了解為何這有幫助,請考慮以下情況。
def diag_embed(vec):
assert vec.dim() == 1
result = torch.zeros(vec.shape[0], vec.shape[0])
result.diagonal().copy_(vec)
return result
vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
# RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ...
vmap(diag_embed)(vecs)
在 vmap() 內部,result 是一個形狀為 [3, 3] 的 Tensor。然而,儘管 vec 看似形狀為 [3],vec 的實際底層形狀是 [2, 3]。無法將 vec 複製到形狀為 [3] 的 result.diagonal() 中,因為它包含過多元素。
def diag_embed(vec):
assert vec.dim() == 1
result = vec.new_zeros(vec.shape[0], vec.shape[0])
result.diagonal().copy_(vec)
return result
vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
vmap(diag_embed)(vecs)
將 torch.zeros() 替換為 Tensor.new_zeros() 使 result 擁有一個底層形狀為 [2, 3, 3] 的 Tensor,因此現在可以將底層形狀為 [2, 3] 的 vec 複製到 result.diagonal() 中。
修改:out= PyTorch 操作¶
vmap() 不支援 PyTorch 操作中的 out= 關鍵字引數。如果在你的程式碼中遇到這種情況,它會優雅地報錯。
這不是一個根本性限制;理論上我們未來可以支援這一點,但目前我們選擇不這樣做。
資料依賴的 Python 控制流¶
我們尚不支援對資料依賴的控制流進行 vmap 操作。資料依賴的控制流是指 if 語句、while 迴圈或 for 迴圈的條件是一個正在被 vmap 處理的 Tensor。例如,以下程式碼將引發錯誤訊息
def relu(x):
if x > 0:
return x
return 0
x = torch.randn(3)
vmap(relu)(x)
但是,任何不依賴於 vmap 處理的 Tensor 中值的控制流都將正常工作
def custom_dot(x):
if x.dim() == 1:
return torch.dot(x, x)
return (x * x).sum()
x = torch.randn(3)
vmap(custom_dot)(x)
JAX 支援使用特殊的控制流運算元(例如 jax.lax.cond, jax.lax.while_loop)對資料依賴的控制流進行轉換。我們正在研究向 PyTorch 新增與這些等價的功能。
資料依賴操作 (.item())¶
我們不支援(將來也不會支援)對在 Tensor 上呼叫 .item() 的使用者定義函式進行 vmap 操作。例如,以下程式碼將引發錯誤訊息
def f(x):
return x.item()
x = torch.randn(3)
vmap(f)(x)
請嘗試重寫你的程式碼,避免使用 .item() 呼叫。
你可能還會遇到關於使用 .item() 的錯誤訊息,但你可能並沒有使用它。在這些情況下,PyTorch 內部可能正在呼叫 .item() – 請在 GitHub 上提交問題,我們將修復 PyTorch 的內部實現。
動態形狀操作 (nonzero 等)¶
vmap(f) 要求 f 應用於輸入中的每個“示例”時返回具有相同形狀的 Tensor。因此,諸如 torch.nonzero、torch.is_nonzero 之類的操作不受支援,並將導致錯誤。
要了解原因,請考慮以下示例
xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)
torch.nonzero(xs[0]) 返回形狀為 2 的 Tensor;但 torch.nonzero(xs[1]) 返回形狀為 1 的 Tensor。我們無法構造一個單個 Tensor 作為輸出;輸出需要是一個不規則(ragged)Tensor(而 PyTorch 尚未有不規則 Tensor 的概念)。
隨機性¶
使用者呼叫隨機操作時的意圖可能不明確。具體來說,有些使用者希望隨機行為在批次之間保持一致,而另一些使用者可能希望它在批次之間有所不同。為了解決這個問題,vmap 接受一個隨機性標誌。
該標誌只能傳遞給 vmap,可以取三個值:“error”、“different”或“same”,預設為“error”。在“error”模式下,任何呼叫隨機函式的行為都會產生錯誤,要求使用者根據其用例使用其他兩個標誌之一。
在“different”隨機性下,批次中的元素產生不同的隨機值。例如,
def add_noise(x):
y = torch.randn(()) # y will be different across the batch
return x + y
x = torch.ones(3)
result = vmap(add_noise, randomness="different")(x) # we get 3 different values
在“same”隨機性下,批次中的元素產生相同的隨機值。例如,
def add_noise(x):
y = torch.randn(()) # y will be the same across the batch
return x + y
x = torch.ones(3)
result = vmap(add_noise, randomness="same")(x) # we get the same value, repeated 3 times
警告
我們的系統只決定 PyTorch 運算元的隨機性行為,無法控制其他庫(如 numpy)的行為。這類似於 JAX 及其解決方案的限制。
注意
使用任一型別受支援隨機性的多次 vmap 呼叫不會產生相同的結果。與標準的 PyTorch 一樣,使用者可以透過在 vmap 之外使用 torch.manual_seed() 或使用生成器來實現隨機性的可復現性。
注意
最後,我們的隨機性與 JAX 不同,因為我們沒有使用無狀態的 PRNG(偽隨機數生成器),部分原因在於 PyTorch 尚未完全支援無狀態 PRNG。相反,我們引入了一個標誌系統,以支援我們看到的最常見的隨機性形式。如果你的用例不符合這些隨機性形式,請提交問題。