Gradcheck 機制¶
本文件概述了 gradcheck() 和 gradgradcheck() 函式的工作原理。
它將涵蓋實值和復值函式的前向和後向 AD(自動微分),以及高階導數。本文件還涵蓋了 gradcheck 的預設行為以及傳遞 fast_mode=True 引數(下文稱為 fast gradcheck)的情況。
符號和背景資訊¶
在本文件中,我們將使用以下約定
, , , , , , 和 是實值向量,而 是一個復值向量,可以寫成兩個實值向量 的形式。
和 是我們將分別用於輸入和輸出空間維度的兩個整數。
是我們的基本實數到實數函式,滿足 。
是我們的基本複數到實數函式,滿足 。
對於簡單的實數到實數情況,我們將 表示與 相關的 Jacobian 矩陣,其大小為 。該矩陣包含所有偏導數,其中位置 的元素為 。後向模式 AD 計算的是給定大小為 的向量 的量 。另一方面,前向模式 AD 計算的是給定大小為 的向量 的量 。
對於包含複數值的函式,情況要複雜得多。此處僅提供要點,完整描述請參閱 複數的 Autograd。
滿足複數可微性(Cauchy-Riemann 方程)的約束對於所有實值損失函式來說都過於嚴格,因此我們轉而使用 Wirtinger 微積分。在 Wirtinger 微積分的基本設定中,鏈式法則需要訪問 Wirtinger 導數(下文稱為 )和共軛 Wirtinger 導數(下文稱為 )。 和 都需要傳播,因為通常情況下,儘管它們的名字如此,但一個並不是另一個的複共軛。
為了避免必須傳播這兩個值,對於後向模式 AD,我們始終假設正在計算導數的函式要麼是一個實值函式,要麼是更大的實值函式的一部分。這個假設意味著我們在反向傳播過程中計算的所有中間梯度也與實值函式相關聯。在實踐中,這個假設在進行最佳化時並不受限,因為這類問題需要實值目標函式(複數沒有自然排序)。
在此假設下,使用 和 的定義,我們可以證明 (這裡我們使用 表示複共軛),因此這兩個值中實際上只需要將其中一個“透過計算圖進行反向傳播”,因為另一個可以很容易地恢復。為了簡化內部計算,PyTorch 使用 作為它在使用者請求梯度時進行反向傳播並返回的值。與實數情況類似,當輸出實際上在 中時,反向模式自動微分不計算 ,而只計算給定向量 的 。
對於前向模式自動微分,我們使用類似的邏輯,在這種情況下,假設該函式是輸入在 中的更大函式的一部分。在此假設下,我們可以做出類似的斷言,即每個中間結果對應於一個輸入在 中的函式,在這種情況下,使用 和 的定義,我們可以證明中間函式的 。為了確保前向和反向模式在基本的一維函式情況下計算出相同的值,前向模式也計算 。與實數情況類似,當輸入實際上在 中時,前向模式自動微分不計算 ,而只計算給定向量 的 。
Default backward mode gradcheck behavior¶
Real-to-real functions¶
為了測試函式 ,我們以兩種方式重建大小為 的完整雅可比矩陣 :解析方法和數值方法。解析方法使用我們的反向模式自動微分,而數值方法使用有限差分。然後逐元素比較兩個重建的雅可比矩陣是否相等。
Default real input numerical evaluation¶
如果我們考慮一維函式的基本情況(),那麼我們可以使用來自 the wikipedia article 的基本有限差分公式。為了獲得更好的數值特性,我們使用“中心差分”。
這個公式可以輕鬆推廣到多個輸出()的情況,方法是讓 成為一個大小為 的列向量,例如 。在這種情況下,上述公式可以直接重用,並且僅透過兩次使用者函式評估(即 和 )來近似完整的雅可比矩陣。
處理多個輸入()的情況計算成本更高。在這種情況下,我們依次迴圈遍歷所有輸入,並依次對 的每個元素應用 擾動。這使得我們能夠逐列重建 矩陣。
Default real input analytical evaluation¶
對於解析評估,我們利用上述事實,即反向模式自動微分計算 。對於單個輸出的函式,我們只需使用 透過單次反向傳播來恢復完整的雅可比矩陣。
對於有多個輸出的函式,我們採用一個 for 迴圈遍歷輸出,其中每個 都是一個對應於各個輸出的獨熱向量(one-hot vector),依次進行。這使得我們能夠逐行重建 矩陣。
Complex-to-real functions¶
為了測試函式 ,其中 ,我們重建包含 的(復值)矩陣。
預設複數輸入數值評估¶
首先考慮簡單的情況,其中 。我們從這篇研究論文(第3章)瞭解到
注意,在上述方程中, 和 是 導數。為了數值評估這些,我們使用上面描述的實數到實數情況的方法。這使我們能夠計算出 矩陣,然後乘以 。
注意,在撰寫本文時,程式碼以一種稍微複雜的方式計算此值。
# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above
ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()
# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.
預設複數輸入解析評估¶
由於反向模式AD已經精確計算出 導數的兩倍,我們只需在此處使用與實數到實數情況相同的技巧,並在存在多個實數輸出時逐行重建矩陣。
快速反向模式gradcheck¶
雖然上述gradcheck的表述很好,既能確保正確性,又能方便除錯,但它非常慢,因為它重建了完整的雅可比矩陣。本節介紹了一種更快執行gradcheck的方法,而不影響其正確性。透過在檢測到錯誤時新增特殊邏輯可以恢復除錯能力。在這種情況下,我們可以執行重建完整矩陣的預設版本,以向用戶提供詳細資訊。
這裡的高層策略是找到一個標量量,它可以由數值方法和解析方法高效計算,並且能夠很好地代表慢速gradcheck計算出的完整矩陣,以確保它能捕獲雅可比矩陣中的任何差異。
實數到實數函式的快速gradcheck¶
我們想要在此處計算的標量量是 ,對於給定的隨機向量 和隨機單位範數向量 。
對於數值評估,我們可以高效地計算
然後,我們對該向量與 進行點積,以得到我們關注的標量值。
對於解析版本,我們可以直接使用反向模式AD計算 。然後,我們與 進行點積,以得到期望值。
複數到實數函式的快速gradcheck¶
與實到實情況類似,我們想對全矩陣進行約化。但 矩陣是複數值的,因此在這種情況下,我們將與複數標量進行比較。
由於數值情況下我們能高效計算的內容存在一些約束,併為了將數值計算次數降至最低,我們計算以下(儘管令人驚訝的)標量值
其中 , 和 。
快速複數輸入數值計算¶
我們首先考慮如何用數值方法計算 。為此,請記住我們考慮的是函式 ,其中 ,並且 ,我們將其重寫如下
在此公式中,我們可以看到 和 可以像實到實情況的快速版本一樣進行計算。一旦這些實數值量被計算出來,我們可以重構右側的復向量,並與實數值向量 進行點積。
快速複數輸入解析計算¶
對於解析情況,事情更簡單,我們將公式重寫為
因此,我們可以利用反向模式 AD 為我們提供了一種有效的方法來計算 ,然後將實部與 進行點積,虛部與 進行點積,最後重建最終的複數標量 。
為什麼不使用複數 ?¶
此時,您可能想知道為什麼我們沒有選擇一個複數 並直接執行歸約 。為了深入探討這一點,在本段中,我們將使用複數版本的 ,記作 。使用這樣的複數 的問題是,在進行數值評估時,我們需要計算:
這將需要進行四次實數到實數的有限差分評估(是上述方法的兩倍)。由於這種方法沒有更多的自由度(實值變數的數量相同),並且我們在這裡嘗試獲得最快的評估速度,因此我們使用了上面提到的另一種公式。
對具有複雜輸出的函式的快速 gradcheck¶
就像在慢速情況下一樣,我們考慮兩個實值函式,並對每個函式使用上述的適當規則。
Gradgradcheck 實現¶
PyTorch 還提供了用於驗證二階梯度的實用工具。這裡的目標是確保反向實現也具有適當的可微性並計算正確的結果。
此功能透過考慮函式 並使用上述定義的 gradcheck 對此函式進行檢查。請注意,本例中的 只是一個與 型別相同的隨機向量。
gradgradcheck 的快速版本是透過對同一函式 使用 gradcheck 的快速版本來實現的。