捷徑

Autograd 機制

本說明將概述 autograd 的運作方式以及如何記錄運算。並非一定要完全理解這些內容,但我們建議您熟悉這些內容,因為這將有助於您撰寫更高效、更乾淨的程式,並協助您進行除錯。

Autograd 如何編碼歷史記錄

Autograd 是一個反向自動微分系統。從概念上講,autograd 會記錄一個圖表,在您執行運算時記錄建立資料的所有運算,讓您得到一個有向無環圖,其葉子是輸入張量,根是輸出張量。透過從根到葉子追蹤此圖表,您可以使用鏈式法則自動計算梯度。

在內部,autograd 將此圖表表示為 Function 物件(實際上是運算式)的圖表,這些物件可以透過 apply() 計算評估圖表的結果。在計算正向傳遞時,autograd 會同時執行請求的計算,並建立一個圖表,表示計算梯度的函數(每個 torch.Tensor.grad_fn 屬性是進入此圖表的入口點)。正向傳遞完成後,我們會在反向傳遞中評估此圖表以計算梯度。

需要注意的是,圖表會在每次迭代時從頭開始重新建立,這正是允許使用任意 Python 控制流程語句的原因,這些語句可以在每次迭代時更改圖表的整體形狀和大小。您不必在開始訓練之前編碼所有可能的路徑 - 您執行的內容就是您要微分的內容。

已儲存的張量

某些運算需要在正向傳遞期間儲存中間結果,以便執行反向傳遞。例如,函數 xx2x\mapsto x^2 會儲存輸入 xx 以計算梯度。

在定義自訂 Python Function 時,您可以使用 save_for_backward() 在正向傳遞期間儲存張量,並使用 saved_tensors 在反向傳遞期間擷取這些張量。如需更多資訊,請參閱擴展 PyTorch

對於 PyTorch 定義的運算(例如 torch.pow()),會根據需要自動儲存張量。您可以透過尋找以 _saved 為首碼的屬性來探索(出於教育或除錯目的)哪些張量是由特定的 grad_fn 儲存的。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self))  # True
print(x is y.grad_fn._saved_self)  # True

在先前的程式碼中,y.grad_fn._saved_self 指的是與 x 相同的張量物件。但情況並非總是如此。例如:

x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result))  # True
print(y is y.grad_fn._saved_result)  # False

在幕後,為了防止循環引用,PyTorch 在儲存張量時會將其「打包」,並在讀取時將其「解包」成不同的張量。在這裡,您從存取 y.grad_fn._saved_result 取得的張量與 y 是不同的張量物件(但它們仍然共用相同的儲存空間)。

張量是否會被打包成不同的張量物件,取決於它是否為其自身 grad_fn 的輸出,這是一個可能會變更的實作細節,使用者不應依賴它。

您可以使用已儲存張量的鉤子來控制 PyTorch 如何進行打包/解包。

不可微分函數的梯度

只有在使用的每個基本函數都是可微分的情況下,使用自動微分進行的梯度計算才有效。遺憾的是,我們在實務中使用的許多函數並不具備此屬性(例如 relusqrt0 處)。為了嘗試減少不可微分函數的影響,我們透過依序套用以下規則來定義基本運算的梯度

  1. 如果函數是可微分的,因此在當前點存在梯度,則使用它。

  2. 如果函數是凸函數(至少在局部),則使用最小範數的次梯度(它是下降最快的方向)。

  3. 如果函數是凹函數(至少在局部),則使用最小範數的超梯度(考慮 -f(x) 並應用上一個觀點)。

  4. 如果函數有定義,則透過連續性在當前點定義梯度(請注意,此處 inf 是可能的,例如 sqrt(0))。如果可能有多個值,請任意選擇一個。

  5. 如果函數未定義(例如,sqrt(-1)log(-1) 或大多數函數在輸入為 NaN 時),則用作梯度的值是任意的(我們也可能引發錯誤,但這不保證)。大多數函數將使用 NaN 作為梯度,但出於效能原因,某些函數將使用其他值(例如,log(-1))。

  6. 如果函數不是確定性映射(即,它不是 數學函數),則它將被標記為不可微分。如果在需要梯度的張量上使用它,則會在反向過程中發生錯誤,除非在 no_grad 環境中使用。

局部禁用梯度計算

Python 提供了幾種機制來局部禁用梯度計算

要禁用整個程式碼塊的梯度,可以使用上下文管理器,例如無梯度模式和推論模式。若要更精細地從梯度計算中排除子圖,可以設定張量的 requires_grad 欄位。

除了討論上述機制之外,我們還將說明評估模式(nn.Module.eval()),這是一種不用於禁用梯度計算的方法,但由於其名稱,經常與這三種方法混淆。

設定 requires_grad

requires_grad 是一個旗標,預設為 false,除非包裝在 nn.Parameter ,它允許從梯度計算中精細地排除子圖。它在正向和反向傳遞中都有效

在正向傳遞期間,只有當至少一個輸入張量需要梯度時,才會在反向圖中記錄運算。在反向傳遞(.backward())期間,只有 requires_grad=True 的葉張量才會將梯度累積到其 .grad 欄位中。

重要的是要注意,即使每個張量都有這個旗標,但設定它只對葉張量有意義(沒有 grad_fn 的張量,例如 nn.Module 的參數)。非葉張量(具有 grad_fn 的張量)是具有與其關聯的反向圖的張量。因此,需要將其梯度作為中間結果來計算需要梯度的葉張量的梯度。從這個定義可以清楚地看出,所有非葉張量都將自動具有 require_grad=True

設定 requires_grad 應該是控制模型哪些部分參與梯度計算的主要方法,例如,如果您需要在模型微調期間凍結預先訓練的模型的某些部分。

若要凍結模型的某些部分,只需將 .requires_grad_(False) 應用於您不想更新的參數。如上所述,由於使用這些參數作為輸入的計算不會記錄在正向傳遞中,因此它們的 .grad 欄位不會在反向傳遞中更新,因為它們一開始就不屬於反向圖,如您所願。

因為這是一種非常常見的模式,所以也可以使用 nn.Module.requires_grad_() 在模組級別設定 requires_grad。當應用於模組時,.requires_grad_() 對模組的所有參數都有效(預設情況下,這些參數的 requires_grad=True)。

梯度模式

除了設定 requires_grad 之外,還有三種可以從 Python 中選擇的梯度模式,它們會影響 PyTorch 中的計算如何由 autograd 在內部處理:預設模式(梯度模式)、無梯度模式和推論模式,所有這些模式都可以透過上下文管理器和裝飾器進行切換。

模式

將運算排除在反向圖記錄之外

跳過額外的 autograd 追蹤開銷

在啟用模式時建立的張量可以在稍後的梯度模式中使用

範例

預設

正向傳遞

無梯度

優化器更新

推論

數據處理、模型評估

預設模式(梯度模式)

「預設模式」是指當沒有啟用其他模式(如無梯度模式和推論模式)時,我們隱式處於的模式。為了與「無梯度模式」形成對比,預設模式有時也稱為「梯度模式」。

關於預設模式,最重要的一點是,它是唯一一種 requires_grad 生效的模式。在其他兩種模式下,requires_grad 總是被覆蓋為 False

無梯度模式

無梯度模式下的計算行為就像沒有任何輸入需要梯度一樣。換句話說,即使存在 require_grad=True 的輸入,無梯度模式下的計算也不會記錄在反向圖中。

當您需要執行不應由 autograd 記錄的運算,但您仍然希望稍後在梯度模式下使用這些計算的輸出時,請啟用無梯度模式。此上下文管理器可以方便地禁用程式碼塊或函數的梯度,而無需臨時將張量設定為 requires_grad=False,然後再設定回 True

例如,在編寫優化器時,無梯度模式可能很有用:在執行訓練更新時,您希望就地更新參數,而無需由 autograd 記錄更新。您還打算在下一次正向傳遞中將更新後的參數用於梯度模式下的計算。

torch.nn.init 中的實現也依賴於無梯度模式來初始化參數,以避免在就地更新初始化參數時進行 autograd 追蹤。

推論模式

推論模式是無梯度模式的極端版本。就像在無梯度模式下一樣,推論模式下的計算不會記錄在反向圖中,但啟用推論模式將允許 PyTorch 更多地加快模型的速度。這種更好的執行時伴隨著一個缺點:在退出推論模式後,在推論模式下建立的張量將無法用於由 autograd 記錄的計算中。

當您執行的計算不需要記錄在反向圖中,並且您不打算在稍後由 autograd 記錄的任何計算中使用在推論模式下建立的張量時,請啟用推論模式。

建議您在不需要 autograd 追蹤的程式碼部分(例如,數據處理和模型評估)中嘗試推論模式。如果它在您的用例中開箱即用,那就是免費的效能提升。如果在啟用推論模式後遇到錯誤,請檢查您是否在退出推論模式後,在由 autograd 記錄的計算中使用了在推論模式下建立的張量。如果在您的情況下無法避免這種使用,您可以隨時切換回無梯度模式。

有關推論模式的詳細資訊,請參閱 推論模式

有關推論模式的實現細節,請參閱 RFC-0011-InferenceMode

評估模式(nn.Module.eval()

評估模式不是局部禁用梯度計算的機制。儘管如此,它還是包含在這裡,因為它有時會被誤認為是這樣一種機制。

在功能上,module.eval()(或等效的 module.train(False))與無梯度模式和推論模式完全正交。model.eval() 如何影響您的模型完全取決於模型中使用的特定模組,以及它們是否定義了任何特定於訓練模式的行為。

如果您需要呼叫 model.eval()model.train(),則您有責任這樣做,如果您的模型依賴於 torch.nn.Dropouttorch.nn.BatchNorm2d 等模組,這些模組的行為可能會因訓練模式而異,例如,為了避免在驗證數據上更新 BatchNorm 執行統計數據。

建議您在訓練時始終使用 model.train(),在評估模型(驗證/測試)時始終使用 model.eval(),即使您不確定您的模型是否具有特定於訓練模式的行為,因為您正在使用的模組可能會更新為在訓練和評估模式下表現不同。

使用 autograd 進行就地運算

在 autograd 中支援就地運算是一件困難的事情,我們不鼓勵在大多数情況下使用它們。Autograd 積極的緩衝區釋放和重用使其非常高效,而且很少有情況下,就地運算會顯著降低記憶體使用量。除非您在嚴重的記憶體壓力下運作,否則您可能永遠不需要使用它們。

有兩個主要原因限制了就地運算的適用性

  1. 就地運算可能會覆蓋計算梯度所需的值。

  2. 每個就地運算都需要實現來重寫計算圖。非就地版本只是分配新物件並保留對舊圖的引用,而就地運算則需要將所有輸入的建立者更改為表示此運算的 Function。這可能會很棘手,尤其是在有許多張量引用同一個儲存體的情況下(例如,透過索引或轉置建立),如果修改後的輸入的儲存體被任何其他 Tensor 引用,則就地函數將引發錯誤。

就地正確性檢查

每個張量都有一個版本計數器,每次在任何操作中被標記為髒污時都會遞增。當函數為反向傳播保存任何張量時,也會保存其包含張量的版本計數器。一旦你訪問 self.saved_tensors 時,就會檢查它,如果它大於保存的值,則會引發錯誤。這確保了如果你正在使用就地函數並且沒有看到任何錯誤,你可以確定計算出的梯度是正確的。

多執行緒 Autograd

Autograd 引擎負責運行所有必要的反向操作來計算反向傳播。本節將描述所有可以幫助你在多執行緒環境中充分利用它的細節。(這僅與 PyTorch 1.6+ 相關,因為先前版本的行為有所不同。)

使用者可以使用多執行緒程式碼訓練他們的模型(例如 Hogwild 訓練),並且不會阻塞併發的反向計算,範例程式碼如下:

# Define a train function to be used in different threads
def train_fn():
    x = torch.ones(5, 5, requires_grad=True)
    # forward
    y = (x + 3) * (x + 4) * 0.5
    # backward
    y.sum().backward()
    # potential optimizer update


# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
    p = threading.Thread(target=train_fn, args=())
    p.start()
    threads.append(p)

for p in threads:
    p.join()

請注意,使用者應該注意一些行為:

CPU 上的併發性

當你在 CPU 上的多個執行緒中通過 Python 或 C++ API 執行 backward()grad() 時,你期望看到額外的併發性,而不是在執行期間以特定順序序列化所有反向調用(PyTorch 1.6 之前的行為)。

非確定性

如果你從多個執行緒中同時調用 backward() 並且具有共享輸入(即 Hogwild CPU 訓練),那麼應該預期非確定性。這可能會發生,因為參數會自動在執行緒之間共享,因此,多個執行緒可能會在梯度累積期間訪問並嘗試累積相同的 .grad 屬性。從技術上講,這是不安全的,並且可能會導致競爭條件,並且結果可能無效。

開發具有共享參數的多執行緒模型的使用者應該牢記執行緒模型,並且應該了解上述問題。

可以使用函數式 API torch.autograd.grad() 來計算梯度,而不是使用 backward() 來避免非確定性。

圖形保留

如果 Autograd 圖的一部分在執行緒之間共享,即單執行緒執行第一部分正向傳播,然後在多個執行緒中執行第二部分,則第一部分圖形是共享的。在這種情況下,不同的執行緒在同一個圖形上執行 grad()backward() 可能會出現一個執行緒正在動態銷毀圖形的問題,而另一個執行緒在這種情況下會崩潰。Autograd 將向使用者發出類似於在沒有 retain_graph=True 的情況下調用兩次 backward() 的錯誤,並讓使用者知道他們應該使用 retain_graph=True

Autograd 節點上的執行緒安全

由於 Autograd 允許調用者執行緒驅動其反向執行以實現潛在的並行性,因此確保 CPU 上具有共享部分/全部 GraphTask 的並行 backward() 調用的執行緒安全非常重要。

自訂 Python autograd.Function 由於 GIL 的原因自動是執行緒安全的。對於內建的 C++ Autograd 節點(例如 AccumulateGrad、CopySlices)和自訂的 autograd::Function,Autograd 引擎使用執行緒互斥鎖來確保可能具有狀態寫入/讀取的 Autograd 節點上的執行緒安全。

C++ 鉤子沒有執行緒安全

Autograd 依賴使用者編寫執行緒安全的 C++ 鉤子。如果你希望鉤子在多執行緒環境中正確應用,則需要編寫適當的執行緒鎖定程式碼,以確保鉤子是執行緒安全的。

複數的 Autograd

簡短版本:

  • 當你使用 PyTorch 對任何具有複數域或對應域的函數 f(z)f(z) 進行微分時,梯度是在假設該函數是更大實值損失函數 g(input)=Lg(input)=L 的一部分的情況下計算的。計算出的梯度是 Lz\frac{\partial L}{\partial z^*}(注意 z 的共軛),其負數正是梯度下降演算法中使用的最速下降方向。因此,有一種可行的方法可以使現有的優化器在使用複數參數的情況下開箱即用。

  • 此約定與 TensorFlow 的複數微分約定相符,但與 JAX 不同(JAX 計算 Lz\frac{\partial L}{\partial z})。

  • 如果你有一個實數到實數的函數,它在內部使用複數運算,那麼這裡的約定並不重要:你將始終獲得與僅使用實數運算實現時相同的結果。

如果你對數學細節感到好奇,或者想知道如何在 PyTorch 中定義複數導數,請繼續閱讀。

什麼是複數導數?

複數可微性的數學定義採用導數的極限定義,並將其推廣到對複數進行運算。考慮一個函數 f:CCf: ℂ → ℂ

f(z=x+yj)=u(x,y)+v(x,y)jf(z=x+yj) = u(x, y) + v(x, y)j

其中 uuvv 是兩個實值變數函數,而 jj 是虛數單位。

使用導數定義,我們可以寫成

f(z)=limh0,hCf(z+h)f(z)hf'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h}

為了讓這個極限存在,不僅 uuvv 必須是實可微的,而且 ff 也必須滿足柯西-黎曼方程式。換句話說:針對實數和虛數步長(hh)計算的極限必須相等。這是一個更嚴格的條件。

複可微函數通常稱為全純函數。它們表現良好,具有您在實可微函數中看到的所有良好特性,但在優化領域實際上毫無用處。對於優化問題,研究界只使用實值目標函數,因為複數不屬於任何有序體,因此具有複值損失沒有多大意義。

事實還證明,沒有一個有趣的實值目標函數能滿足柯西-黎曼方程式。因此,全純函數理論不能用於優化,因此大多數人使用 Wirtinger 微積分。

Wirtinger 微積分應運而生…

所以,我們擁有這個關於複可微性和全純函數的偉大理論,但我們根本無法使用它,因為許多常用的函數都不是全純的。一位可憐的數學家該怎麼辦?嗯,Wirtinger 觀察到,即使 f(z)f(z) 不是全純的,也可以將其重寫為雙變數函數 f(z,z)f(z, z*),它始終是全純的。這是因為 zz 的實部和虛部可以用 zzzz^* 表示為

Re(z)=z+z2Im(z)=zz2j\begin{aligned} \mathrm{Re}(z) &= \frac {z + z^*}{2} \\ \mathrm{Im}(z) &= \frac {z - z^*}{2j} \end{aligned}

維丁格微積分建議改為研究 f(z,z)f(z, z^*),如果 ff 是實微分的,則保證它是全純的(另一種思考方式是將其視為坐標系的變化,從 f(x,y)f(x, y)f(z,z)f(z, z^*))。 這個函數具有偏導數 z\frac{\partial }{\partial z}z\frac{\partial}{\partial z^{*}}。我們可以使用鏈式法則來建立這些偏導數與 zz 的實部和虛部偏導數之間的關係。

x=zxz+zxz=z+zy=zyz+zyz=1j(zz)\begin{aligned} \frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\ &= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\ \\ \frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\ &= 1j * \left(\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*}\right) \end{aligned}

從上述等式,我們得到

z=1/2(x1jy)z=1/2(x+1jy)\begin{aligned} \frac{\partial }{\partial z} &= 1/2 * \left(\frac{\partial }{\partial x} - 1j * \frac{\partial }{\partial y}\right) \\ \frac{\partial }{\partial z^*} &= 1/2 * \left(\frac{\partial }{\partial x} + 1j * \frac{\partial }{\partial y}\right) \end{aligned}

這就是您可以在維基百科上找到的維丁格微積分的經典定義。

這種變化會產生許多優美的結果。

  • 首先,Cauchy-Riemann 方程式可以簡單地翻譯成 fz=0\frac{\partial f}{\partial z^*} = 0(也就是說,函數 ff 可以完全用 zz 來表示,而無需參考 zz^*)。

  • 另一個重要的(並且有點違背直覺的)結果是,當我們對實值損失函數進行優化時,我們在更新變數時應該採取的步驟是由 Lossz\frac{\partial Loss}{\partial z^*} 給出的(而不是 Lossz\frac{\partial Loss}{\partial z}),我們將在後面看到。

如需更多資訊,請查看:https://arxiv.org/pdf/0906.4835.pdf

Wirtinger 微積分在優化中如何發揮作用?

音訊和其他領域的研究人員更常使用梯度下降法,利用複變數來優化實值損失函數。通常,這些人將實部和虛部視為可以更新的獨立通道。對於步長 α/2\alpha/2 和損失 LL,我們可以在 R2ℝ^2 中寫入以下方程式

xn+1=xn(α/2)Lxyn+1=yn(α/2)Ly\begin{aligned} x_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} \\ y_{n+1} &= y_n - (\alpha/2) * \frac{\partial L}{\partial y} \end{aligned}

這些方程式如何在複數空間 C 中轉換?

zn+1=xn(α/2)Lx+1j(yn(α/2)Ly)=znα1/2(Lx+jLy)=znαLz\begin{aligned} z_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (\alpha/2) * \frac{\partial L}{\partial y}) \\ &= z_n - \alpha * 1/2 * \left(\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}\right) \\ &= z_n - \alpha * \frac{\partial L}{\partial z^*} \end{aligned}

一件非常有趣的事情發生了:Wirtinger 微積分告訴我們,我們可以簡化上述複變量更新公式,使其僅引用共軛 Wirtinger 導數 Lz\frac{\partial L}{\partial z^*},這給了我們在優化中採取的確切步驟。

由於共軛 Wirtinger 導數為實值損失函數提供了正確的步長,因此當您對具有實值損失的函數進行微分時,PyTorch 會為您提供此導數。

PyTorch 如何計算共軛 Wirtinger 導數?

通常,我們的導數公式會將 grad_output 作為輸入,表示我們已經計算出的輸入向量-雅可比積,又稱為 Ls\frac{\partial L}{\partial s^*},其中 LL 是整個計算的損失(產生實數損失),而 ss 是我們函數的輸出。我們的目標是計算 Lz\frac{\partial L}{\partial z^*},其中 zz 是函數的輸入。事實證明,在實數損失的情況下,我們可以*只*計算 Ls\frac{\partial L}{\partial s^*},即使鏈式法則意味著我們也需要訪問 Ls\frac{\partial L}{\partial s}。如果您想跳過此推導,請查看本節中的最後一個等式,然後跳到下一節。

讓我們繼續使用定義為 f(z)=f(x+yj)=u(x,y)+v(x,y)jf(z) = f(x+yj) = u(x, y) + v(x, y)jf:CCf: ℂ → ℂ。如上所述,自動梯度的梯度慣例以實值損失函數的優化為中心,所以讓我們假設 ff 是較大的實值損失函數 gg 的一部分。使用鏈式法則,我們可以寫成

(1)Lz=Luuz+Lvvz\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*}

現在使用 Wirtinger 導數定義,我們可以寫成

Ls=1/2(LuLvj)Ls=1/2(Lu+Lvj)\begin{aligned} \frac{\partial L}{\partial s} = 1/2 * \left(\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j\right) \\ \frac{\partial L}{\partial s^*} = 1/2 * \left(\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j\right) \end{aligned}

這裡需要注意的是,因為 uuvv 是實函數,而根據我們對 ff 是實值函數的一部分的假設,LL 也是實數,我們有

(2)(Ls)=Ls\left( \frac{\partial L}{\partial s} \right)^* = \frac{\partial L}{\partial s^*}

也就是說,Ls\frac{\partial L}{\partial s} 等於 grad_outputgrad\_output^*

解上述關於 Lu\frac{\partial L}{\partial u}Lv\frac{\partial L}{\partial v} 的方程式,我們得到

(3)Lu=Ls+LsLv=1j(LsLs)\begin{aligned} \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ \frac{\partial L}{\partial v} = 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) \end{aligned}

(3)代入(1),我們得到

Lz=(Ls+Ls)uz+1j(LsLs)vz=Ls(uz+vzj)+Ls(uzvzj)=Ls(u+vj)z+Ls(u+vj)z=Lssz+Lssz\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} + 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \left(\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j\right) + \frac{\partial L}{\partial s^*} * \left(\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j\right) \\ &= \frac{\partial L}{\partial s} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)^*}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ \end{aligned}

使用(2),我們得到

(4)Lz=(Ls)sz+Ls(sz)=(grad_output)sz+grad_output(sz)\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s^*}\right)^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \left(\frac{\partial s}{\partial z}\right)^* \\ &= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * \left(\frac{\partial s}{\partial z}\right)^* } \\ \end{aligned}

最後這個方程式是編寫您自己的梯度的重點,因為它將我們的導數公式分解成一個更容易用手計算的簡單公式。

我如何為複函數寫出我自己的導數公式?

上面的方框公式給出了複函數所有導數的一般公式。但是,我們仍然需要計算 sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*}。您可以通過兩種方式做到這一點

  • 第一種方法是直接使用 Wirtinger 導數的定義,並使用 sx\frac{\partial s}{\partial x}sy\frac{\partial s}{\partial y}(您可以用一般方法計算)來計算 sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*}

  • 第二種方法是使用變數變換技巧,將 f(z)f(z) 重寫為二元函數 f(z,z)f(z, z^*),並通過將 zzzz^* 視為獨立變量來計算共軛 Wirtinger 導數。這通常更容易;例如,如果所討論的函數是全純的,則只會使用 zz(而 sz\frac{\partial s}{\partial z^*} 將為零)。

讓我們以函數 f(z=x+yj)=cz=c(x+yj)f(z = x + yj) = c * z = c * (x+yj) 為例,其中 cRc \in ℝ

使用第一種方法計算 Wirtinger 導數,我們有:

sz=1/2(sxsyj)=1/2(c(c1j)1j)=csz=1/2(sx+syj)=1/2(c+(c1j)1j)=0\begin{aligned} \frac{\partial s}{\partial z} &= 1/2 * \left(\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c - (c * 1j) * 1j) \\ &= c \\ \\ \\ \frac{\partial s}{\partial z^*} &= 1/2 * \left(\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c + (c * 1j) * 1j) \\ &= 0 \\ \end{aligned}

使用(4)grad_output = 1.0(這是 PyTorch 中對純量輸出呼叫 backward() 時使用的預設梯度輸出值),我們得到

Lz=10+1c=c\frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c

使用第二種計算 Wirtinger 導數的方法,我們可以直接得到

sz=(cz)z=csz=(cz)z=0\begin{aligned} \frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\ &= c \\ \frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\ &= 0 \end{aligned}

再次使用 (4),我們得到 Lz=c\frac{\partial L}{\partial z^*} = c。如您所見,第二種方法涉及較少的計算,並且更便於快速計算。

跨域函數呢?

有些函數從複數輸入映射到實數輸出,反之亦然。這些函數形成 (4) 的一個特殊情況,我們可以使用鏈式規則推導出來

  • 對於 f:CRf: ℂ → ℝ,我們得到

    Lz=2grad_outputsz\frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}}
  • 對於 f:RCf: ℝ → ℂ,我們得到

    Lz=2Re(grad_outputsz)\frac{\partial L}{\partial z^*} = 2 * \mathrm{Re}(grad\_output^* * \frac{\partial s}{\partial z^{*}})

已保存張量的鉤子

您可以透過定義一對 pack_hook / unpack_hook 鉤子來控制 已保存張量的打包 / 解包方式pack_hook 函數應將張量作為其單一參數,但可以返回任何 Python 物件(例如另一個張量、一個元組,甚至是一個包含檔名的字串)。unpack_hook 函數將 pack_hook 的輸出作為其單一參數,並應返回要在反向傳遞中使用的張量。unpack_hook 返回的張量只需要與作為輸入傳遞給 pack_hook 的張量具有相同的內容。特別是,任何與自動梯度相關的元資料都可以忽略,因為它們會在解包過程中被覆蓋。

以下是一個這樣的配對範例:

class SelfDeletingTempFile():
    def __init__(self):
        self.name = os.path.join(tmp_dir, str(uuid.uuid4()))

    def __del__(self):
        os.remove(self.name)

def pack_hook(tensor):
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(temp_file):
    return torch.load(temp_file.name)

請注意,unpack_hook 不應刪除臨時檔案,因為它可能會被多次呼叫:只要返回的 SelfDeletingTempFile 物件存在,臨時檔案就應該存在。在上面的範例中,我們透過在不再需要臨時檔案時關閉它(在刪除 SelfDeletingTempFile 物件時)來防止洩漏臨時檔案。

注意

我們保證 pack_hook 只會被呼叫一次,但 unpack_hook 可以被呼叫多次,具體取決於反向傳遞的需求,我們希望它每次都返回相同的資料。

警告

禁止對任何函數的輸入執行就地操作,因為它們可能會導致意外的副作用。如果 pack_hook 的輸入被就地修改,PyTorch 將會拋出錯誤,但如果 unpack_hook 的輸入被就地修改,則不會捕捉到這種情況。

為已保存的張量註冊鉤子

您可以透過在 SavedTensor 物件上呼叫 register_hooks() 方法,在已保存的張量上註冊一對鉤子。這些物件作為 grad_fn 的屬性公開,並以 _raw_saved_ 字首開頭。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)

一旦註冊了這對鉤子,就會立即呼叫 pack_hook 方法。每次需要訪問已保存的張量時,無論是透過 y.grad_fn._saved_self 還是反向傳遞過程中,都會呼叫 unpack_hook 方法。

警告

如果您在釋放已保存的張量之後(即在呼叫反向傳遞之後)仍然保持對 SavedTensor 的引用,則禁止呼叫其 register_hooks() 方法。PyTorch 在大多數情況下會拋出錯誤,但在某些情況下可能會失敗,並可能出現未定義的行為。

為已保存的張量註冊預設鉤子

或者,您可以使用上下文管理器 saved_tensors_hooks 來註冊一對鉤子,這些鉤子將應用於在該上下文中建立的*所有*已保存的張量。

範例

# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000

def pack_hook(x):
    if x.numel() < SAVE_ON_DISK_THRESHOLD:
        return x
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(tensor_or_sctf):
    if isinstance(tensor_or_sctf, torch.Tensor):
        return tensor_or_sctf
    return torch.load(tensor_or_sctf.name)

class Model(nn.Module):
    def forward(self, x):
        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
          # ... compute output
          output = x
        return output

model = Model()
net = nn.DataParallel(model)

使用此上下文管理器定義的鉤子是執行緒本地的。因此,以下程式碼不會產生預期的效果,因為鉤子不會通過 DataParallel

# Example what NOT to do

net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    output = net(input)

請注意,使用這些鉤子會禁用所有用於減少張量物件建立的優化。例如

with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
    x = torch.randn(5, requires_grad=True)
    y = x * x

在沒有鉤子的情況下,xy.grad_fn._saved_selfy.grad_fn._saved_other 都指向同一個張量物件。在有鉤子的情況下,PyTorch 會將 x 打包和解包到兩個新的張量物件中,這些物件與原始的 x 共用相同的儲存空間(不執行複製)。

反向鉤子執行

本節將討論不同鉤子觸發或不觸發的時機。然後它將討論它們觸發的順序。將涵蓋的鉤子是:透過 torch.Tensor.register_hook() 註冊到張量的反向鉤子、透過 torch.Tensor.register_post_accumulate_grad_hook() 註冊到張量的累積後梯度鉤子、透過 torch.autograd.graph.Node.register_hook() 註冊到節點的後置鉤子,以及透過 torch.autograd.graph.Node.register_prehook() 註冊到節點的前置鉤子。

特定鉤子是否會被觸發

當正在計算張量的梯度時,會執行透過 torch.Tensor.register_hook() 註冊到張量的鉤子。(請注意,這並不需要執行張量的 grad_fn。例如,如果張量作為 inputs 參數的一部分傳遞給 torch.autograd.grad(),則張量的 grad_fn 可能不會被執行,但註冊到該張量的鉤子將始終被執行。)

在累積張量的梯度之後,也就是設定張量的 grad 欄位之後,會執行透過 torch.Tensor.register_post_accumulate_grad_hook() 註冊到張量的鉤子。透過 torch.Tensor.register_hook() 註冊的鉤子是在計算梯度時執行的,而透過 torch.Tensor.register_post_accumulate_grad_hook() 註冊的鉤子只有在反向傳遞結束時,自動梯度更新張量的 grad 欄位時才會被觸發。因此,累積後梯度鉤子只能註冊到葉張量。在非葉張量上透過 torch.Tensor.register_post_accumulate_grad_hook() 註冊鉤子會出錯,即使您呼叫 backward(retain_graph=True) 也會出錯。

使用 torch.autograd.graph.Node.register_hook()torch.autograd.graph.Node.register_prehook() 註冊到 torch.autograd.graph.Node 的鉤子只有在註冊到的節點被執行時才會被觸發。

特定節點是否被執行可能取決於反向傳遞是使用 torch.autograd.grad() 還是 torch.autograd.backward() 呼叫的。具體來說,當您在對應於您要傳遞給 torch.autograd.grad()torch.autograd.backward() 作為 inputs 參數一部分的張量的節點上註冊鉤子時,您應該注意這些差異。

如果您使用的是 torch.autograd.backward(),則無論您是否指定了 inputs 參數,所有上述鉤子都將被執行。這是因為 .backward() 會執行所有節點,即使它們對應於指定為輸入的張量。(請注意,執行這個對應於作為 inputs 傳遞的張量的額外節點通常是不必要的,但還是會執行。這種行為可能會改變;您不應該依賴它。)

另一方面,如果您使用的是 torch.autograd.grad(),則註冊到對應於傳遞給 input 的張量的節點的反向鉤子可能不會被執行,因為除非有另一個輸入依賴於此節點的梯度結果,否則這些節點將不會被執行。

不同鉤子觸發的順序

事情發生的順序是

  1. 執行註冊到張量的鉤子

  2. 執行註冊到節點的前置鉤子(如果節點被執行)。

  3. 更新 retain_grad 張量的 .grad 欄位

  4. 執行節點(根據上述規則)

  5. 對於已累積 .grad 的葉張量,執行累積後梯度鉤子

  6. 執行註冊到節點的後置鉤子(如果節點被執行)。

如果在同一個張量或節點上註冊了多個相同類型的鉤子,則它們將按照註冊的順序執行。稍後執行的鉤子可以觀察到先前鉤子對梯度所做的修改。

特殊鉤子

torch.autograd.graph.register_multi_grad_hook() 是使用註冊到張量的鉤子實現的。每個單獨的張量鉤子都按照上面定義的張量鉤子順序觸發,並且在計算最後一個張量梯度時呼叫註冊的多重梯度鉤子。

torch.nn.modules.module.register_module_full_backward_hook() 是使用註冊到節點的鉤子實現的。在計算正向傳遞時,鉤子會註冊到對應於模組輸入和輸出的 grad_fn。由於模組可能需要多個輸入並返回多個輸出,因此在正向傳遞之前,首先將虛擬的自定義自動梯度函數應用於模組的輸入,並在正向傳遞的輸出返回之前應用於模組的輸出,以確保這些張量共享一個單一的 grad_fn,然後我們可以將鉤子附加到該 grad_fn。

張量在原地修改時張量鉤子的行為

通常,註冊到張量的鉤子會接收輸出相對於該張量的梯度,其中張量的值取為計算反向傳遞時的值。

但是,如果您將鉤子註冊到一個張量,然後原地修改該張量,則在原地修改之前註冊的鉤子同樣會接收輸出相對於該張量的梯度,但張量的值取為原地修改之前的值。

如果您希望使用前一種情況的行為,則應該在對張量進行所有原地修改之後再將其註冊到該張量。例如

t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()

此外,了解在幕後,當鉤子註冊到張量時,它們實際上會永久綁定到該張量的 grad_fn,這一點會有所幫助,因此如果該張量隨後被原地修改,即使該張量現在有一個新的 grad_fn,在它被原地修改之前註冊的鉤子將繼續與舊的 grad_fn 相關聯,例如,當自動梯度引擎在圖中到達該張量的舊 grad_fn 時,它們將被觸發。

文件

取得 PyTorch 的完整開發人員文件

查看文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並取得您的問題解答

查看資源