快捷方式

torch.nn.utils.parametrize.cached

torch.nn.utils.parametrize.cached()[原始碼][原始碼]

上下文管理器,用於在透過 register_parametrization() 註冊的引數化中啟用快取系統。

當此上下文管理器處於活動狀態時,引數化物件的值在首次需要時計算並快取。離開上下文管理器時,快取的值將被丟棄。

這在使用引數化引數在正向傳播中多次出現時非常有用。例如,當對 RNN 的迴圈核進行引數化或共享權重時。

啟用快取的最簡單方法是包裹神經網路的正向傳播

import torch.nn.utils.parametrize as P
...
with P.cached():
    output = model(inputs)

在訓練和評估中。也可以包裹多次使用引數化張量的模組部分。例如,帶有引數化迴圈核的 RNN 的迴圈部分

with P.cached():
    for x in xs:
        out_rnn = self.rnn_cell(x, out_rnn)

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源