• 文件 >
  • torch.utils.checkpoint
捷徑

torch.utils.checkpoint

注意事項

檢查點的實作方式是在反向傳播期間針對每個檢查點區段重新執行正向傳遞區段。這可能會導致持續性狀態(例如 RNG 狀態)比沒有檢查點的情況下更進階。根據預設,檢查點包含處理 RNG 狀態的邏輯,以便與非檢查點傳遞相比,使用 RNG(例如透過 dropout)的檢查點傳遞具有確定性輸出。儲存和還原 RNG 狀態的邏輯可能會產生適度的效能影響,具體取決於檢查點作業的執行時間。如果不需要與非檢查點傳遞相比的確定性輸出,請提供 preserve_rng_state=Falsecheckpointcheckpoint_sequential 以在每個檢查點期間省略儲存和還原 RNG 狀態。

儲存邏輯會將 CPU 和其他裝置類型(透過 _infer_device_type 從張量引數中排除 CPU 張量來推斷裝置類型)的 RNG 狀態儲存並還原至 run_fn。如果有多個裝置,則只會針對單一裝置類型的裝置儲存裝置狀態,其餘裝置將被忽略。因此,如果任何檢查點函式涉及隨機性,則可能會導致梯度不正確。(請注意,如果偵測到的裝置中有 CUDA 裝置,則會優先考慮它;否則,將會選擇遇到的第一個裝置。)如果沒有 CPU 張量,則會儲存並還原預設裝置類型狀態(預設值為 cuda,並且可以使用 DefaultDeviceType 將其設定為其他裝置)。但是,邏輯無法預測使用者是否會在 run_fn 本身內將張量移動到新裝置。因此,如果您在 run_fn 內將張量移動到新裝置(「新」表示不屬於 [目前裝置 + 張量引數的裝置] 集合),則永遠無法保證與非檢查點傳遞相比的確定性輸出。

torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[來源]

檢查點模型或模型的一部分。

啟動檢查點是一種以計算換取記憶體的技術。在檢查點區域中,正向計算不會將反向傳播所需的張量保持在活動狀態,直到在反向傳播期間的梯度計算中使用它們為止,而是省略儲存張量以進行反向傳播,並在反向傳遞期間重新計算它們。啟動檢查點可以應用於模型的任何部分。

目前有兩種檢查點實作可用,由 use_reentrant 參數決定。建議您使用 use_reentrant=False。有關它們差異的討論,請參閱下面的注意事項。

警告

如果在反向傳遞期間呼叫 function 與正向傳遞不同(例如,由於全域變數),則檢查點版本可能不相等,可能會導致引發錯誤或導致靜默不正確的梯度。

警告

應該明確傳遞 use_reentrant 參數。在 2.4 版中,如果未傳遞 use_reentrant,我們將會引發例外狀況。如果您使用的是 use_reentrant=True 變體,請參閱下面的注意事項以瞭解重要的注意事項和潛在的限制。

注意事項

檢查點的可重入變體 (use_reentrant=True) 和檢查點的不可重入變體 (use_reentrant=False) 在以下方面有所不同

  • 不可重入檢查點會在重新計算所有需要的中間啟動後立即停止重新計算。此功能預設為啟用,但可以使用 set_checkpoint_early_stop() 禁用。可重入檢查點始終在反向傳遞期間完整地重新計算 function

  • 可重入變體在前向傳遞期間不會記錄 autograd 圖,因為它在 torch.no_grad() 下運行前向傳遞。不可重入版本會記錄 autograd 圖,允許在檢查點區域內對圖執行反向傳播。

  • 可重入檢查點僅支持 torch.autograd.backward() API 進行反向傳遞,而不使用其 inputs 參數,而非可重入版本則支持所有執行反向傳遞的方法。

  • 對於可重入變體,至少有一個輸入和輸出必須具有 requires_grad=True。如果未滿足此條件,則模型的檢查點部分將沒有梯度。非可重入版本沒有此要求。

  • 可重入版本不會將嵌套結構(例如,自定義對象、列表、字典等)中的張量視為參與 autograd,而非可重入版本則會。

  • 可重入檢查點不支持具有與計算圖分離的張量的檢查點區域,而非可重入版本則支持。對於可重入變體,如果檢查點段包含使用 detach()torch.no_grad() 分離的張量,則反向傳遞將引發錯誤。這是因為 checkpoint 會使所有輸出都需要梯度,並且當張量在模型中定義為沒有梯度時,就會導致問題。為避免此問題,請在 checkpoint 函數之外分離張量。

參數
  • function – 描述在模型或模型部分的前向傳遞中要運行的內容。它還應該知道如何處理作為元組傳遞的輸入。例如,在 LSTM 中,如果用戶傳遞 (activation, hidden),則 function 應正確使用第一個輸入作為 activation,第二個輸入作為 hidden

  • preserve_rng_state (布林值, 可選) – 在每個檢查點期間省略隱藏和恢復 RNG 狀態。請注意,在 torch.compile 下,此標記不起作用,我們始終會保留 RNG 狀態。默認值:True

  • use_reentrant (布林值) – 指定是否使用需要可重入 autograd 的激活檢查點變體。應明確傳遞此參數。在 2.4 版中,如果未傳遞 use_reentrant,我們將引發異常。如果 use_reentrant=False,則 checkpoint 將使用不需要可重入 autograd 的實現。這允許 checkpoint 支持其他功能,例如按預期使用 torch.autograd.grad 以及對輸入到檢查點函數的關鍵字參數的支持。

  • context_fn (可調用對象, 可選) – 一個可調用對象,返回兩個上下文管理器的元組。函數及其重新計算將分別在第一個和第二個上下文管理器下運行。僅當 use_reentrant=False 時才支持此參數。

  • determinism_check (字串, 可選) – 指定要執行的確定性檢查的字串。默認情況下,它設置為 "default",它將重新計算的張量的形狀、數據類型和設備與保存的張量進行比較。要關閉此檢查,請指定 "none"。目前只有這兩個值受支持。如果您想查看更多確定性檢查,請提出問題。僅當 use_reentrant=False 時才支持此參數,如果 use_reentrant=True,則始終禁用確定性檢查。

  • debug (布林值, 可選) – 如果為 True,則錯誤消息還將包含在原始前向計算和重新計算期間運行的運算符的跟踪。僅當 use_reentrant=False 時才支持此參數。

  • args – 包含 function 輸入的元組

返回值

*args 上運行 function 的輸出

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source]

檢查點順序模型以節省內存。

順序模型按順序(順序)執行模塊/函數列表。因此,我們可以將這樣的模型分成不同的段,並為每個段設置檢查點。除最後一個段外,所有段都不會存儲中間激活。將保存每個檢查點段的輸入,以便在反向傳遞中重新運行該段。

警告

應明確傳遞 use_reentrant 參數。在 2.4 版中,如果未傳遞 use_reentrant,我們將引發異常。如果您使用的是 use_reentrant=True` 變體,請參閱 :func:`~torch.utils.checkpoint.checkpoint` 了解此變體的重要注意事項和限制。建議您使用 ``use_reentrant=False

參數
  • functions – 要按順序運行的 torch.nn.Sequential 或模塊或函數列表(組成模型)。

  • segments – 在模型中創建的塊數

  • input – 作為 functions 輸入的張量

  • preserve_rng_state (布林值, 可選) – 在每個檢查點期間省略隱藏和恢復 RNG 狀態。默認值:True

  • use_reentrant (布林值) – 指定是否使用需要可重入 autograd 的激活檢查點變體。應明確傳遞此參數。在 2.4 版中,如果未傳遞 use_reentrant,我們將引發異常。如果 use_reentrant=False,則 checkpoint 將使用不需要可重入 autograd 的實現。這允許 checkpoint 支持其他功能,例如按預期使用 torch.autograd.grad 以及對輸入到檢查點函數的關鍵字參數的支持。

返回值

*inputs 上按順序運行 functions 的輸出

示例

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]

設置檢查點在運行時是否應打印其他調試信息的上下文管理器。有關更多信息,請參閱 checkpoint()debug 標記。請注意,設置後,此上下文管理器將覆蓋傳遞給檢查點的 debug 的值。要延遲到本地設置,請將 None 傳遞給此上下文。

參數

enabled (布林值) – 檢查點是否應打印調試信息。默認為「無」。

文件

訪問 PyTorch 的全面開發人員文檔

查看文檔

教程

獲取針對初學者和高級開發人員的深入教程

查看教程

資源

查找開發資源並獲得問題解答

查看資源