快捷方式

Identity

class torch.nn.utils.prune.Identity[source][source]

一個實用的剪枝方法,它不剪枝任何單元,但會生成一個包含全1掩碼的剪枝引數化。

classmethod apply(module, name)[source][source]

新增動態剪枝和張量重引數化。

新增啟用動態剪枝的前向預鉤子,以及根據原始張量和剪枝掩碼對張量進行重引數化。

引數
  • module (nn.Module) – 包含待剪枝張量的模組

  • name (str) – 在 module 中要進行剪枝的引數名稱。

apply_mask(module)[source]

簡單地處理被剪枝引數和生成的掩碼之間的乘法運算。

從模組中獲取掩碼和原始張量,並返回張量的剪枝版本。

引數

module (nn.Module) – 包含待剪枝張量的模組

返回值

輸入張量的剪枝版本

返回型別

pruned_tensor (torch.Tensor)

prune(t, default_mask=None, importance_scores=None)[source]

計算並返回輸入張量 t 的剪枝版本。

根據 compute_mask() 中指定的剪枝規則。

引數
  • t (torch.Tensor) – 要剪枝的張量(與 default_mask 維度相同)。

  • importance_scores (torch.Tensor) – 用於計算張量 t 剪枝掩碼的重要性分數張量(與 t 形狀相同)。此張量中的值表示被剪枝張量 t 中對應元素的重要性。如果未指定或為 None,則將使用張量 t 本身。

  • default_mask (torch.Tensor, optional) – 來自先前剪枝迭代的掩碼(如果有)。在確定剪枝應作用於張量的哪個部分時予以考慮。如果為 None,則預設使用全1掩碼。

返回值

張量 t 的剪枝版本。

remove(module)[source]

從模組中移除剪枝重引數化。

名為 name 的剪枝引數將永久保持剪枝狀態,名為 name+'_orig' 的引數將從引數列表中移除。類似地,名為 name+'_mask' 的緩衝區將從緩衝區中移除。

注意

剪枝本身是不可撤銷或逆轉的!

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源