快捷方式

Gradcheck 機制

本文件概述了 gradcheck()gradgradcheck() 函式的工作原理。

它將涵蓋實值和復值函式的前向和後向 AD(自動微分),以及高階導數。本文件還涵蓋了 gradcheck 的預設行為以及傳遞 fast_mode=True 引數(下文稱為 fast gradcheck)的情況。

符號和背景資訊

在本文件中,我們將使用以下約定

  1. xx, yy, aa, bb, vv, uu, ururuiui 是實值向量,而 zz 是一個復值向量,可以寫成兩個實值向量 z=a+ibz = a + i b 的形式。

  2. NNMM 是我們將分別用於輸入和輸出空間維度的兩個整數。

  3. f:RNRMf: \mathcal{R}^N \to \mathcal{R}^M 是我們的基本實數到實數函式,滿足 y=f(x)y = f(x)

  4. g:CNRMg: \mathcal{C}^N \to \mathcal{R}^M 是我們的基本複數到實數函式,滿足 y=g(z)y = g(z)

對於簡單的實數到實數情況,我們將 JfJ_f 表示與 ff 相關的 Jacobian 矩陣,其大小為 M×NM \times N。該矩陣包含所有偏導數,其中位置 (i,j)(i, j) 的元素為 yixj\frac{\partial y_i}{\partial x_j}。後向模式 AD 計算的是給定大小為 MM 的向量 vv 的量 vTJfv^T J_f。另一方面,前向模式 AD 計算的是給定大小為 NN 的向量 uu 的量 JfuJ_f u

對於包含複數值的函式,情況要複雜得多。此處僅提供要點,完整描述請參閱 複數的 Autograd

滿足複數可微性(Cauchy-Riemann 方程)的約束對於所有實值損失函式來說都過於嚴格,因此我們轉而使用 Wirtinger 微積分。在 Wirtinger 微積分的基本設定中,鏈式法則需要訪問 Wirtinger 導數(下文稱為 WW)和共軛 Wirtinger 導數(下文稱為 CWCW)。 WWCWCW 都需要傳播,因為通常情況下,儘管它們的名字如此,但一個並不是另一個的複共軛。

為了避免必須傳播這兩個值,對於後向模式 AD,我們始終假設正在計算導數的函式要麼是一個實值函式,要麼是更大的實值函式的一部分。這個假設意味著我們在反向傳播過程中計算的所有中間梯度也與實值函式相關聯。在實踐中,這個假設在進行最佳化時並不受限,因為這類問題需要實值目標函式(複數沒有自然排序)。

在此假設下,使用 WWCWCW 的定義,我們可以證明 W=CWW = CW^* (這裡我們使用 * 表示複共軛),因此這兩個值中實際上只需要將其中一個“透過計算圖進行反向傳播”,因為另一個可以很容易地恢復。為了簡化內部計算,PyTorch 使用 2CW2 * CW 作為它在使用者請求梯度時進行反向傳播並返回的值。與實數情況類似,當輸出實際上在 RM\mathcal{R}^M 中時,反向模式自動微分不計算 2CW2 * CW,而只計算給定向量 vRMv \in \mathcal{R}^MvT(2CW)v^T (2 * CW)

對於前向模式自動微分,我們使用類似的邏輯,在這種情況下,假設該函式是輸入在 R\mathcal{R} 中的更大函式的一部分。在此假設下,我們可以做出類似的斷言,即每個中間結果對應於一個輸入在 R\mathcal{R} 中的函式,在這種情況下,使用 WWCWCW 的定義,我們可以證明中間函式的 W=CWW = CW。為了確保前向和反向模式在基本的一維函式情況下計算出相同的值,前向模式也計算 2CW2 * CW。與實數情況類似,當輸入實際上在 RN\mathcal{R}^N 中時,前向模式自動微分不計算 2CW2 * CW,而只計算給定向量 uRNu \in \mathcal{R}^N(2CW)u(2 * CW) u

Default backward mode gradcheck behavior

Real-to-real functions

為了測試函式 f:RNRM,xyf: \mathcal{R}^N \to \mathcal{R}^M, x \to y,我們以兩種方式重建大小為 M×NM \times N 的完整雅可比矩陣 JfJ_f:解析方法和數值方法。解析方法使用我們的反向模式自動微分,而數值方法使用有限差分。然後逐元素比較兩個重建的雅可比矩陣是否相等。

Default real input numerical evaluation

如果我們考慮一維函式的基本情況(N=M=1N = M = 1),那麼我們可以使用來自 the wikipedia article 的基本有限差分公式。為了獲得更好的數值特性,我們使用“中心差分”。

yxf(x+eps)f(xeps)2eps\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}

這個公式可以輕鬆推廣到多個輸出(M>1M \gt 1)的情況,方法是讓 yx\frac{\partial y}{\partial x} 成為一個大小為 M×1M \times 1 的列向量,例如 f(x+eps)f(x + eps)。在這種情況下,上述公式可以直接重用,並且僅透過兩次使用者函式評估(即 f(x+eps)f(x + eps)f(xeps)f(x - eps))來近似完整的雅可比矩陣。

處理多個輸入(N>1N \gt 1)的情況計算成本更高。在這種情況下,我們依次迴圈遍歷所有輸入,並依次對 xx 的每個元素應用 epseps 擾動。這使得我們能夠逐列重建 JfJ_f 矩陣。

Default real input analytical evaluation

對於解析評估,我們利用上述事實,即反向模式自動微分計算 vTJfv^T J_f。對於單個輸出的函式,我們只需使用 v=1v = 1 透過單次反向傳播來恢復完整的雅可比矩陣。

對於有多個輸出的函式,我們採用一個 for 迴圈遍歷輸出,其中每個 vv 都是一個對應於各個輸出的獨熱向量(one-hot vector),依次進行。這使得我們能夠逐行重建 JfJ_f 矩陣。

Complex-to-real functions

為了測試函式 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y,其中 z=a+ibz = a + i b,我們重建包含 2CW2 * CW 的(復值)矩陣。

預設複數輸入數值評估

首先考慮簡單的情況,其中 N=M=1N = M = 1。我們從這篇研究論文(第3章)瞭解到

CW:=yz=12(ya+iyb)CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})

注意,在上述方程中,ya\frac{\partial y}{\partial a}yb\frac{\partial y}{\partial b}RR\mathcal{R} \to \mathcal{R} 導數。為了數值評估這些,我們使用上面描述的實數到實數情況的方法。這使我們能夠計算出 CWCW 矩陣,然後乘以 22

注意,在撰寫本文時,程式碼以一種稍微複雜的方式計算此值。

# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above

ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()

# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.

預設複數輸入解析評估

由於反向模式AD已經精確計算出 CWCW 導數的兩倍,我們只需在此處使用與實數到實數情況相同的技巧,並在存在多個實數輸出時逐行重建矩陣。

具有複數輸出的函式

在這種情況下,使用者提供的函式不符合autograd的假設,即我們進行反向AD計算的函式是實數值的。這意味著直接在這種函式上使用autograd是未明確定義的。為了解決這個問題,我們將替換對函式 h:PNCMh: \mathcal{P}^N \to \mathcal{C}^M 的測試(其中 P\mathcal{P} 可以是 R\mathcal{R}C\mathcal{C}),用兩個函式替換:hrhrhihi,使得

hr(q):=real(f(q))hi(q):=imag(f(q))\begin{aligned} hr(q) &:= real(f(q)) \\ hi(q) &:= imag(f(q)) \end{aligned}

其中 qPq \in \mathcal{P}。然後,我們根據 P\mathcal{P},使用上面描述的實數到實數情況或複數到實數情況,對 hrhrhihi 進行基本gradcheck。

注意,在撰寫本文時,程式碼不顯式建立這些函式,而是透過將 grad_out\text{grad\_out} 引數傳遞給不同的函式,手動使用 realrealimagimag 函式執行鏈式法則。當 grad_out=1\text{grad\_out} = 1 時,我們考慮的是 hrhr。當 grad_out=1j\text{grad\_out} = 1j 時,我們考慮的是 hihi

快速反向模式gradcheck

雖然上述gradcheck的表述很好,既能確保正確性,又能方便除錯,但它非常慢,因為它重建了完整的雅可比矩陣。本節介紹了一種更快執行gradcheck的方法,而不影響其正確性。透過在檢測到錯誤時新增特殊邏輯可以恢復除錯能力。在這種情況下,我們可以執行重建完整矩陣的預設版本,以向用戶提供詳細資訊。

這裡的高層策略是找到一個標量量,它可以由數值方法和解析方法高效計算,並且能夠很好地代表慢速gradcheck計算出的完整矩陣,以確保它能捕獲雅可比矩陣中的任何差異。

實數到實數函式的快速gradcheck

我們想要在此處計算的標量量是 vTJfuv^T J_f u,對於給定的隨機向量 vRMv \in \mathcal{R}^M 和隨機單位範數向量 uRNu \in \mathcal{R}^N

對於數值評估,我們可以高效地計算

Jfuf(x+ueps)f(xueps)2eps.J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.

然後,我們對該向量與 vv 進行點積,以得到我們關注的標量值。

對於解析版本,我們可以直接使用反向模式AD計算 vTJfv^T J_f。然後,我們與 uu 進行點積,以得到期望值。

複數到實數函式的快速gradcheck

與實到實情況類似,我們想對全矩陣進行約化。但 2CW2 * CW 矩陣是複數值的,因此在這種情況下,我們將與複數標量進行比較。

由於數值情況下我們能高效計算的內容存在一些約束,併為了將數值計算次數降至最低,我們計算以下(儘管令人驚訝的)標量值

s:=2vT(real(CW)ur+iimag(CW)ui)s := 2 * v^T (real(CW) ur + i * imag(CW) ui)

其中 vRMv \in \mathcal{R}^M, urRNur \in \mathcal{R}^NuiRNui \in \mathcal{R}^N

快速複數輸入數值計算

我們首先考慮如何用數值方法計算 ss。為此,請記住我們考慮的是函式 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y,其中 z=a+ibz = a + i b,並且 CW=12(ya+iyb)CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b}),我們將其重寫如下

s=2vT(real(CW)ur+iimag(CW)ui)=2vT(12yaur+i12ybui)=vT(yaur+iybui)=vT((yaur)+i(ybui))\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui)) \end{aligned}

在此公式中,我們可以看到 yaur\frac{\partial y}{\partial a} urybui\frac{\partial y}{\partial b} ui 可以像實到實情況的快速版本一樣進行計算。一旦這些實數值量被計算出來,我們可以重構右側的復向量,並與實數值向量 vv 進行點積。

快速複數輸入解析計算

對於解析情況,事情更簡單,我們將公式重寫為

s=2vT(real(CW)ur+iimag(CW)ui)=vTreal(2CW)ur+ivTimag(2CW)ui)=real(vT(2CW))ur+iimag(vT(2CW))ui\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui \end{aligned}

因此,我們可以利用反向模式 AD 為我們提供了一種有效的方法來計算 vT(2CW)v^T (2 * CW),然後將實部與 urur 進行點積,虛部與 uiui 進行點積,最後重建最終的複數標量 ss

為什麼不使用複數 uu

此時,您可能想知道為什麼我們沒有選擇一個複數 uu 並直接執行歸約 2vTCWu2 * v^T CW u'。為了深入探討這一點,在本段中,我們將使用複數版本的 uu,記作 u=ur+iuiu' = ur' + i ui'。使用這樣的複數 uu' 的問題是,在進行數值評估時,我們需要計算:

2CWu=(ya+iyb)(ur+iui)=yaur+iyaui+iyburybui\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui' \end{aligned}

這將需要進行四次實數到實數的有限差分評估(是上述方法的兩倍)。由於這種方法沒有更多的自由度(實值變數的數量相同),並且我們在這裡嘗試獲得最快的評估速度,因此我們使用了上面提到的另一種公式。

對具有複雜輸出的函式的快速 gradcheck

就像在慢速情況下一樣,我們考慮兩個實值函式,並對每個函式使用上述的適當規則。

Gradgradcheck 實現

PyTorch 還提供了用於驗證二階梯度的實用工具。這裡的目標是確保反向實現也具有適當的可微性並計算正確的結果。

此功能透過考慮函式 F:x,vvTJfF: x, v \to v^T J_f 並使用上述定義的 gradcheck 對此函式進行檢查。請注意,本例中的 vv 只是一個與 f(x)f(x) 型別相同的隨機向量。

gradgradcheck 的快速版本是透過對同一函式 FF 使用 gradcheck 的快速版本來實現的。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源