torch.nn.utils.parametrize.cached¶
- torch.nn.utils.parametrize.cached()[原始碼][原始碼]¶
上下文管理器,用於在透過
register_parametrization()註冊的引數化中啟用快取系統。當此上下文管理器處於活動狀態時,引數化物件的值在首次需要時計算並快取。離開上下文管理器時,快取的值將被丟棄。
這在使用引數化引數在正向傳播中多次出現時非常有用。例如,當對 RNN 的迴圈核進行引數化或共享權重時。
啟用快取的最簡單方法是包裹神經網路的正向傳播
import torch.nn.utils.parametrize as P ... with P.cached(): output = model(inputs)
在訓練和評估中。也可以包裹多次使用引數化張量的模組部分。例如,帶有引數化迴圈核的 RNN 的迴圈部分
with P.cached(): for x in xs: out_rnn = self.rnn_cell(x, out_rnn)