快捷方式

CustomFromMask

class torch.nn.utils.prune.CustomFromMask(mask)[來源][來源]
classmethod apply(module, name, mask)[來源][來源]

新增動態剪枝和張量的重新引數化。

新增前向預鉤子,該鉤子可實現動態剪枝以及根據原始張量和剪枝掩碼對張量進行重新引數化。

引數
  • module (nn.Module) – 包含要剪枝張量的模組

  • name (str) – 模組中將進行剪枝操作的引數名稱。

apply_mask(module)[來源]

簡單處理被剪枝引數與生成的掩碼之間的乘法運算。

從模組中獲取掩碼和原始張量,並返回張量的剪枝版本。

引數

module (nn.Module) – 包含要剪枝張量的模組

返回值

輸入張量的剪枝版本

返回型別

pruned_tensor (torch.Tensor)

prune(t, default_mask=None, importance_scores=None)[來源]

計算並返回輸入張量 t 的剪枝版本。

根據 compute_mask() 中指定的剪枝規則。

引數
  • t (torch.Tensor) – 要剪枝的張量(與 default_mask 維度相同)。

  • importance_scores (torch.Tensor) – 用於計算張量 t 剪枝掩碼的重要性得分張量(與 t 形狀相同)。此張量中的值表示要剪枝的張量 t 中對應元素的重要性。如果未指定或為 None,則使用張量 t 代替。

  • default_mask (torch.Tensor, optional) – 如果有,則來自上一次剪枝迭代的掩碼。在確定剪枝應作用於張量的哪一部分時予以考慮。如果 None,預設為全一掩碼。

返回值

張量 t 的剪枝版本。

remove(module)[來源]

從模組中移除剪枝重新引數化。

名為 name 的剪枝引數保持永久剪枝狀態,名為 name+'_orig' 的引數將從引數列表中移除。類似地,名為 name+'_mask' 的緩衝區將從緩衝區中移除。

注意

剪枝本身不可撤銷或逆轉!

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源