捷徑

UX 限制

torch.func 與 JAX 相似,對可轉換的內容有限制。一般來說,JAX 的限制是轉換僅適用於純函式:也就是輸出完全由輸入決定且不涉及副作用(如變動)的函式。

我們有一個類似的保證:我們的轉換適用於純函式。但是,我們確實支援某些就地運算。一方面,撰寫與函式轉換相容的程式碼可能需要改變您撰寫 PyTorch 程式碼的方式,另一方面,您可能會發現我們的轉換讓您能夠表達以前在 PyTorch 中難以表達的事物。

一般限制

所有 torch.func 轉換都有一個共同的限制,即函式不應指派給全域變數。相反的,函式的所有輸出都必須從函式中返回。這個限制來自於 torch.func 的實作方式:每個轉換都會將張量輸入包裝在特殊的 torch.func 張量子類別中,以促進轉換。

因此,請勿使用以下程式碼:

import torch
from torch.func import grad

# Don't do this
intermediate = None

def f(x):
  global intermediate
  intermediate = x.sin()
  z = intermediate.sin()
  return z

x = torch.randn([])
grad_x = grad(f)(x)

請改寫 f 以返回 intermediate

def f(x):
  intermediate = x.sin()
  z = intermediate.sin()
  return z, intermediate

grad_x, intermediate = grad(f, has_aux=True)(x)

torch.autograd API

如果您嘗試在由 vmap() 或其中一個 torch.func 的 AD 轉換(vjp()jvp()jacrev()jacfwd())轉換的函式內部使用 torch.autograd API(例如 torch.autograd.gradtorch.autograd.backward),則轉換可能無法對其進行轉換。如果無法進行轉換,您將會收到錯誤訊息。

這是 PyTorch 的 AD 支援實作方式的基本設計限制,也是我們設計 torch.func 函式庫的原因。請改用 torch.autograd API 的 torch.func 等效項:- torch.autograd.gradTensor.backward -> torch.func.vjptorch.func.grad - torch.autograd.functional.jvp -> torch.func.jvp - torch.autograd.functional.jacobian -> torch.func.jacrevtorch.func.jacfwd - torch.autograd.functional.hessian -> torch.func.hessian

vmap 限制

備註

vmap() 是我們限制最多的轉換。與梯度相關的轉換(grad()vjp()jvp())沒有這些限制。jacfwd()(以及 hessian(),它使用 jacfwd() 實作)是 vmap()jvp() 的組合,因此它也有這些限制。

vmap(func) 是一個轉換,它會返回一個函式,該函式會將 func 映射到每個輸入張量的新維度上。vmap 的心智模型就像是在運行一個 for 迴圈:對於純函式(即在沒有副作用的情況下),vmap(f)(x) 等效於

torch.stack([f(x_i) for x_i in x.unbind(0)])

變動:任意變動 Python 資料結構

當存在副作用時,vmap() 的行為將不再像是在執行 for 迴圈。例如,以下函式

def f(x, list):
  list.pop()
  print("hello!")
  return x.sum(0)

x = torch.randn(3, 1)
lst = [0, 1, 2, 3]

result = vmap(f, in_dims=(0, None))(x, lst)

將只會列印一次「hello!」,並且只會從 lst 中彈出一個元素。

vmap() 只會執行 f 一次,因此所有副作用也只會發生一次。

這是 vmap 實作方式所導致的結果。torch.func 有一個特殊的內部 BatchedTensor 類別。vmap(f)(*inputs) 會將所有 Tensor 輸入轉換為 BatchedTensors,然後呼叫 f(*batched_tensor_inputs)。BatchedTensor 會覆寫 PyTorch API,以便為每個 PyTorch 運算子產生批次處理(亦即向量化)行為。

變異:就地 PyTorch 運算

您可能是因為收到有關 vmap 不相容的就地運算錯誤訊息而來到這裡。vmap() 如果遇到不支援的 PyTorch 就地運算,就會引發錯誤,否則就會成功。不支援的運算是指那些會導致元素較多的 Tensor 寫入元素較少的 Tensor 的運算。以下是一個可能發生的範例

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(1)
y = torch.randn(3, 1)  # When vmapped over, looks like it has shape [1]

# Raises an error because `x` has fewer elements than `y`.
vmap(f, in_dims=(None, 0))(x, y)

x 是一個只有一個元素的 Tensor,y 是一個有三個元素的 Tensor。x + y 有三個元素(由於廣播),但嘗試將三個元素寫回只有一個元素的 x 中,會因為嘗試將三個元素寫入只有一個元素的 Tensor 中而引發錯誤。

如果正在寫入的 Tensor 是在 vmap() 下批次處理的(亦即正在對其進行 vmap 運算),則不會有任何問題。

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(3, 1)
y = torch.randn(3, 1)
expected = x + y

# Does not raise an error because x is being vmapped over.
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)

解決此問題的一個常見方法是將對工廠函式的呼叫替換為其「new_*」等效項。例如

以下說明為什麼這樣做會有幫助。

def diag_embed(vec):
  assert vec.dim() == 1
  result = torch.zeros(vec.shape[0], vec.shape[0])
  result.diagonal().copy_(vec)
  return result

vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])

# RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ...
vmap(diag_embed)(vecs)

vmap() 內部,result 是一個形狀為 [3, 3] 的 Tensor。然而,雖然 vec 看起來像是形狀為 [3],但 vec 實際上的底層形狀是 [2, 3]。無法將 vec 複製到形狀為 [3] 的 result.diagonal() 中,因為它的元素太多了。

def diag_embed(vec):
  assert vec.dim() == 1
  result = vec.new_zeros(vec.shape[0], vec.shape[0])
  result.diagonal().copy_(vec)
  return result

vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
vmap(diag_embed)(vecs)

torch.zeros() 替換為 Tensor.new_zeros() 可以讓 result 的底層 Tensor 形狀變為 [2, 3, 3],因此現在可以將底層形狀為 [2, 3] 的 vec 複製到 result.diagonal() 中。

變異:out= PyTorch 運算

vmap() 不支援 PyTorch 運算中的 out= 關鍵字引數。如果在程式碼中遇到這個問題,它會引發錯誤。

這不是一個根本的限制;理論上我們可以在未來支援這個功能,但我們目前選擇不支援。

資料相依的 Python 控制流程

我們尚不支援對資料相依的控制流程使用 vmap。資料相依的控制流程是指 if-statement、while-loop 或 for-loop 的條件是一個正在進行 vmap 運算的 Tensor。例如,以下程式碼將會引發錯誤訊息

def relu(x):
  if x > 0:
    return x
  return 0

x = torch.randn(3)
vmap(relu)(x)

然而,任何不依賴於 vmap 的 Tensor 中的值的控制流程都可以運作

def custom_dot(x):
  if x.dim() == 1:
    return torch.dot(x, x)
  return (x * x).sum()

x = torch.randn(3)
vmap(custom_dot)(x)

JAX 支援使用特殊的控制流程運算子(例如 jax.lax.condjax.lax.while_loop)對 資料相依的控制流程 進行轉換。我們正在研究將這些功能的等效項添加到 PyTorch 中。

資料相依的運算(.item())

我們不支援(並且將來也不會支援)對呼叫 Tensor 的 .item() 的使用者自訂函式使用 vmap。例如,以下程式碼將會引發錯誤訊息

def f(x):
  return x.item()

x = torch.randn(3)
vmap(f)(x)

請嘗試修改您的程式碼,不要使用 .item() 呼叫。

您也可能會遇到有關使用 .item() 的錯誤訊息,但您可能沒有使用它。在這些情況下,可能是 PyTorch 內部正在呼叫 .item() - 請在 GitHub 上提交問題,我們會修復 PyTorch 內部程式碼。

動態形狀運算(nonzero 等)

vmap(f) 要求 f 套用於輸入中的每個「範例」時,都必須傳回形狀相同的 Tensor。不支援 torch.nonzerotorch.is_nonzero 等運算,因此會產生錯誤。

以下範例說明原因

xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)

torch.nonzero(xs[0]) 傳回形狀為 2 的 Tensor;但 torch.nonzero(xs[1]) 傳回形狀為 1 的 Tensor。我們無法建構單一 Tensor 作為輸出;輸出必須是不規則 Tensor(而 PyTorch 尚未支援不規則 Tensor 的概念)。

隨機性

使用者呼叫隨機運算時的意圖可能不明確。具體來說,有些使用者可能希望隨機行為在批次間保持一致,而有些使用者則希望在批次間有所不同。為了處理這個問題,vmap 採用了一個隨機性旗標。

該旗標只能傳遞給 vmap,並且可以採用 3 個值:「error」、「different」或「same」,預設值為 error。在「error」模式下,任何對隨機函式的呼叫都會產生錯誤,要求使用者根據其使用情況使用其他兩個旗標之一。

在「different」隨機性下,批次中的元素會產生不同的隨機值。例如:

def add_noise(x):
  y = torch.randn(())  # y will be different across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="different")(x)  # we get 3 different values

在「same」隨機性下,批次中的元素會產生相同的隨機值。例如:

def add_noise(x):
  y = torch.randn(())  # y will be the same across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="same")(x)  # we get the same value, repeated 3 times

警告

我們的系統只會決定 PyTorch 運算子的隨機性行為,無法控制 numpy 等其他程式庫的行為。這與 JAX 解決方案的限制類似

備註

使用任何一種支援的隨機性進行多次 vmap 呼叫不會產生相同的結果。與標準 PyTorch 一樣,使用者可以透過在 vmap 之外使用 torch.manual_seed() 或使用產生器來獲得隨機性重現性。

備註

最後,我們的隨機性與 JAX 不同,因為我們沒有使用無狀態 PRNG,部分原因是 PyTorch 不完全支援無狀態 PRNG。相反地,我們引入了一個旗標系統,以允許我們看到最常見的隨機性形式。如果您的使用情況不符合這些隨機性形式,請提交問題。

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得問題解答

檢視資源