捷徑

Gradcheck 機制

本說明介紹 gradcheck()gradgradcheck() 函數的運作方式。

它將涵蓋實數值和複數值函數的正向和反向模式 AD,以及高階導數。本說明還涵蓋 gradcheck 的預設行為,以及傳遞 fast_mode=True 引數的情況(以下稱為快速 gradcheck)。

符號和背景資訊

在整個說明中,我們將使用以下慣例

  1. xxyyaabbvvuuururuiui 是實數值向量,而 zz 是複數值向量,可以用兩個實數值向量表示為 z=a+ibz = a + i b

  2. NNMM 是我們將用於輸入和輸出空間維度的兩個整數。

  3. f:RNRMf: \mathcal{R}^N \to \mathcal{R}^M 是我們基本的實數到實數函數,例如 y=f(x)y = f(x)

  4. g:CNRMg: \mathcal{C}^N \to \mathcal{R}^M 是我們基本的複數到實數函數,例如 y=g(z)y = g(z)

對於簡單的實數到實數情況,我們將與 ff 相關聯的雅可比矩陣寫成 JfJ_f,大小為 M×NM \times N。此矩陣包含所有偏導數,使得位置 (i,j)(i, j) 的條目包含 yixj\frac{\partial y_i}{\partial x_j}。然後,對於給定的 MM 大小的向量 vv,反向模式 AD 會計算量 vTJfv^T J_f。另一方面,對於給定的 NN 大小的向量 uu,正向模式 AD 會計算量 JfuJ_f u

對於包含複數值的函數,情況要複雜得多。我們在這裡只提供要點,完整的描述可以在 複數的 Autograd 中找到。

滿足複變可微性(柯西-黎曼方程式)的約束對於所有實值損失函數來說過於嚴格,因此我們選擇使用 Wirtinger 微積分。在 Wirtinger 微積分的基礎設定中,鏈式法則需要同時使用 Wirtinger 導數(以下稱為 WW)和共軛 Wirtinger 導數(以下稱為 CWCW)。 WWCWCW 都需要被傳播,因為一般來說,儘管它們的名稱如此,但它們並非彼此的複共軛。

為了避免必須傳播兩個值,對於反向模式自動微分,我們始終假設正在計算其導數的函數是實值函數或更大實值函數的一部分。這個假設意味著我們在反向傳播過程中計算的所有中間梯度也與實值函數相關聯。在實務上,這個假設在進行優化時並不具有限制性,因為此類問題需要實值目標(因為複數沒有自然的順序)。

在這個假設下,使用 WWCWCW 的定義,我們可以證明 W=CWW = CW^*(我們在此處使用 * 表示複共軛),因此實際上只需要將兩個值中的一個“反向傳播通過圖”,因為另一個值可以很容易地恢復。為了簡化內部計算,PyTorch 使用 2CW2 * CW 作為它在使用者要求梯度時反向傳播並返回的值。與實數情況類似,當輸出實際上在 RM\mathcal{R}^M 中時,反向模式自動微分不會計算 2CW2 * CW 而只計算 vT(2CW)v^T (2 * CW),其中給定向量 vRMv \in \mathcal{R}^M

對於正向模式自動微分,我們使用類似的邏輯,在這種情況下,假設該函數是更大函數的一部分,其輸入在 R\mathcal{R} 中。在此假設下,我們可以做出類似的聲明,即每個中間結果都對應於一個函數,其輸入在 R\mathcal{R} 中,並且在這種情況下,使用 WWCWCW 定義,我們可以證明對於中間函數, W=CWW = CW。為了確保正向和反向模式在一維函數的基本情況下計算相同的量,正向模式也會計算 2CW2 * CW。與實數情況類似,當輸入實際上在 RN\mathcal{R}^N 中時,正向模式自動微分不會計算 2CW2 * CW,而只會計算給定向量 uRNu \in \mathcal{R}^N(2CW)u(2 * CW) u

預設反向模式 gradcheck 行為

實數到實數函數

為了測試函數 f:RNRM,xyf: \mathcal{R}^N \to \mathcal{R}^M, x \to y,我們以兩種方式重建大小為 M×NM \times N 的完整雅可比矩陣 JfJ_f:解析法和數值法。解析版本使用我們的反向模式自動微分,而數值版本使用有限差分法。然後,逐元素比較兩個重建的雅可比矩陣是否相等。

預設實數輸入數值評估

如果我們考慮一維函數 (N=M=1N = M = 1) 的基本情況,那麼我們可以使用 維基百科文章 中的基本有限差分公式。為了更好的數值特性,我們使用「中心差分」

yxf(x+eps)f(xeps)2eps\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}

此公式可以輕鬆地推廣到多個輸出(M>1M \gt 1),方法是將 yx\frac{\partial y}{\partial x} 設為大小為 M×1M \times 1 的列向量,就像 f(x+eps)f(x + eps) 一樣。在這種情況下,上述公式可以按原樣重複使用,並僅通過兩次使用者函數求值(即 f(x+eps)f(x + eps)f(xeps)f(x - eps))來近似完整的雅可比矩陣。

處理多個輸入(N>1N \gt 1)的情況在計算上更為昂貴。在這種情況下,我們會循環遍歷所有輸入,並將 epseps 擾動依次應用於 xx 的每個元素。這使我們能夠逐列重建 JfJ_f 矩陣。

預設實數輸入的解析求值

對於解析求值,我們使用上述事實,即反向模式 AD 計算 vTJfv^T J_f。對於單一輸出的函數,我們簡單地使用 v=1v = 1 來通過單次反向傳遞恢復完整的雅可比矩陣。

對於具有多個輸出的函數,我們使用一個 for 迴圈迭代輸出,其中每個 vv 都是對應於每個輸出的單一熱向量。這允許逐行重建 JfJ_f 矩陣。

複數到實數函數

為了測試函數 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y,其中 z=a+ibz = a + i b,我們重建包含 2CW2 * CW 的(複數值)矩陣。

預設複數輸入數值評估

首先考慮 N=M=1N = M = 1 的基本情況。我們從這篇研究論文(第 3 章)中知道

CW:=yz=12(ya+iyb)CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})

請注意,上式中的 ya\frac{\partial y}{\partial a}yb\frac{\partial y}{\partial b}RR\mathcal{R} \to \mathcal{R} 導數。為了對其進行數值評估,我們對實數到實數的情況使用上述方法。這使我們能夠計算 CWCW 矩陣,然後將其乘以 22

請注意,在撰寫本文時,代碼以稍微複雜的方式計算此值

# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above

ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()

# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.

預設複數輸入解析評估

由於反向模式 AD 已經計算了 CWCW 導數的兩倍,因此我們在此處對實數到實數的情況使用相同的技巧,並在有多個實數輸出時逐行重建矩陣。

具有複數輸出的函數

在這種情況下,使用者提供的函數不符合自動微分的假設,即我們計算反向自動微分的函數是實數值。這意味著直接在此函數上使用自動微分並未明確定義。為了解決這個問題,我們將使用兩個函數:hrhrhihi 來替換對函數 h:PNCMh: \mathcal{P}^N \to \mathcal{C}^M 的測試(其中 P\mathcal{P} 可以是 R\mathcal{R}C\mathcal{C}),使得

hr(q):=real(f(q))hi(q):=imag(f(q))\begin{aligned} hr(q) &:= real(f(q)) \\ hi(q) &:= imag(f(q)) \end{aligned}

其中 qPq \in \mathcal{P}。然後,我們根據 P\mathcal{P} 使用上述的實數到實數或複數到實數情況,對 hrhrhihi 進行基本的梯度檢查。

請注意,在撰寫本文時,程式碼並未明確建立這些函數,而是透過將 grad_out\text{grad\_out} 參數傳遞給不同的函數,手動執行帶有 realrealimagimag 函數的鏈式規則。當 grad_out=1\text{grad\_out} = 1 時,我們考慮的是 hrhr。當 grad_out=1j\text{grad\_out} = 1j 時,我們考慮的是 hihi

快速反向模式梯度檢查

雖然上述的梯度檢查公式非常出色,但為了確保正確性和可除錯性,它的速度非常慢,因為它需要重建完整的雅可比矩陣。本節將介紹一種在不影響正確性的情況下,以更快的方式執行梯度檢查的方法。可除錯性可以透過在偵測到錯誤時新增特殊邏輯來恢復。在這種情況下,我們可以執行重建完整矩陣的預設版本,以便向使用者提供完整的詳細資訊。

這裡的高階策略是找到一個可以透過數值和解析方法有效計算的純量,並且它能夠充分表示由慢速梯度檢查計算的完整矩陣,以確保它能夠捕捉雅可比矩陣中的任何差異。

實數到實數函數的快速梯度檢查

我們這裡要計算的純量是 vTJfuv^T J_f u,給定一個隨機向量 vRMv \in \mathcal{R}^M 和一個隨機單位範數向量 uRNu \in \mathcal{R}^N

對於數值評估,我們可以有效地計算

Jfuf(x+ueps)f(xueps)2eps.J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.

然後我們計算此向量和 vv 的點積,以獲得我們感興趣的純量值。

對於解析版本,我們可以使用反向模式自動微分直接計算 vTJfv^T J_f。然後我們與 uu 進行點積以獲得期望值。

複數到實數函數的快速梯度檢查

與實數到實數的情況類似,我們想要對完整矩陣進行降維。但是 2CW2 * CW 矩陣是複數值的,所以在這種情況下,我們將與複數純量進行比較。

由於我們可以在數值情況下有效計算的內容受到一些限制,並且為了將數值評估的次數保持在最低限度,我們計算以下(儘管令人驚訝)的純量值

s:=2vT(real(CW)ur+iimag(CW)ui)s := 2 * v^T (real(CW) ur + i * imag(CW) ui)

其中 vRMv \in \mathcal{R}^MurRNur \in \mathcal{R}^NuiRNui \in \mathcal{R}^N

快速複數輸入數值評估

我們首先考慮如何使用數值方法計算 ss。為此,請記住我們正在考慮 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y,其中 z=a+ibz = a + i b,並且 CW=12(ya+iyb)CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b}),我們將其改寫如下:

s=2vT(real(CW)ur+iimag(CW)ui)=2vT(12yaur+i12ybui)=vT(yaur+iybui)=vT((yaur)+i(ybui))\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui)) \end{aligned}

在這個公式中,我們可以看到 yaur\frac{\partial y}{\partial a} urybui\frac{\partial y}{\partial b} ui 可以像實數到實數情況的快速版本一樣進行評估。一旦計算出這些實數值,我們就可以重建右側的複數向量,並與實數值 vv 向量進行點積。

快速複數輸入解析評估

對於解析情況,事情更簡單,我們將公式改寫為:

s=2vT(real(CW)ur+iimag(CW)ui)=vTreal(2CW)ur+ivTimag(2CW)ui)=real(vT(2CW))ur+iimag(vT(2CW))ui\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui \end{aligned}

因此,我們可以利用反向模式自動微分法提供了一種有效計算 vT(2CW)v^T (2 * CW) 的方法,然後將其實部與 urur 進行點積,將其虛部與 uiui 進行點積,最後重建最終的複數純量 ss

為什麼不使用複數 uu

在這個時候,您可能會好奇為什麼我們不選擇一個複數的 uu 並直接進行 2vTCWu2 * v^T CW u' 的降維。為了深入探討,在本段中,我們將使用 uu 的複數形式,記為 u=ur+iuiu' = ur' + i ui'。使用這樣的複數 uu',問題在於當進行數值評估時,我們需要計算...

2CWu=(ya+iyb)(ur+iui)=yaur+iyaui+iyburybui\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui' \end{aligned}

這需要四次實數到實數的有限差分計算(與上述方法相比多了一倍)。由於此方法沒有更多自由度(實數變數數量相同),並且我們嘗試在此處獲得最快的計算速度,因此我們使用上述其他公式。

複數輸出函數的快速梯度檢查

就像在慢速情況下一樣,我們考慮兩個實值函數,並對每個函數使用上述適當的規則。

梯度梯度檢查實作

PyTorch 也提供了一個工具來驗證二階梯度。此處的目標是確保反向傳播實作也是可微分的,並且計算結果正確。

此功能是通過考慮函數 F:x,vvTJfF: x, v \to v^T J_f 並且對此函數使用上面定義的 gradcheck 來實現的。請注意,在這種情況下,vv 只是一個與 f(x)f(x) 類型相同的隨機向量。

gradgradcheck 的快速版本是通過在相同的函數 FF 上使用 gradcheck 的快速版本來實現的。

文件

獲取 PyTorch 的完整開發者文檔

查看文檔

教學

為初學者和進階開發者提供深入的教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源