快捷方式

BasePruningMethod

class torch.nn.utils.prune.BasePruningMethod[source][source]

建立新剪枝技術的抽象基類。

提供了一個用於定製的骨架,需要重寫諸如 compute_mask()apply() 等方法。

classmethod apply(module, name, *args, importance_scores=None, **kwargs)[source][source]

新增即時剪枝和張量重引數化。

新增前向預鉤子,實現即時剪枝,並將張量重引數化為原始張量和剪枝掩碼的形式。

引數
  • module (nn.Module) – 包含要剪枝的張量的模組

  • name (str) – module 中剪枝將作用的引數名稱。

  • args – 傳遞給 BasePruningMethod 子類的引數

  • importance_scores (torch.Tensor) – 用於計算剪枝掩碼的重要性分數張量(形狀與模組引數相同)。此張量中的值指示被剪枝引數中對應元素的重要性。如果未指定或為 None,將使用引數本身作為重要性分數。

  • kwargs – 傳遞給 BasePruningMethod 子類的關鍵字引數

apply_mask(module)[source][source]

簡單處理被剪枝引數與生成掩碼之間的乘法。

從模組中獲取掩碼和原始張量,並返回張量的剪枝版本。

引數

module (nn.Module) – 包含要剪枝的張量的模組

返回

輸入張量的剪枝版本

返回型別

pruned_tensor (torch.Tensor)

abstract compute_mask(t, default_mask)[source][source]

計算並返回輸入張量 t 的掩碼。

從基礎 default_mask 開始(如果張量尚未剪枝,則應為全一掩碼),根據具體的剪枝方法生成一個隨機掩碼應用於 default_mask 之上。

引數
  • t (torch.Tensor) – 表示要

  • 剪枝 (引數的) –

  • default_mask (torch.Tensor) – 來自先前剪枝的基礎掩碼

  • 迭代

  • (新掩碼應用後需要保留的部分。) –

  • t. (與 維度相同) –

返回

要應用於 t 的掩碼,維度與 t 相同

返回型別

mask (torch.Tensor)

prune(t, default_mask=None, importance_scores=None)[source][source]

計算並返回輸入張量 t 的剪枝版本。

根據 compute_mask() 中指定的剪枝規則進行。

引數
  • t (torch.Tensor) – 要剪枝的張量(與 default_mask 維度相同)。

  • importance_scores (torch.Tensor) – 用於計算剪枝 t 掩碼的重要性分數張量(形狀與 t 相同)。此張量中的值指示被剪枝的 t 中對應元素的重要性。如果未指定或為 None,將使用張量 t 本身作為重要性分數。

  • default_mask (torch.Tensor, 可選) – 如果存在,則為上一次剪枝迭代的掩碼。在確定剪枝應作用於張量的哪一部分時需要考慮。如果為 None,則預設為全一掩碼。

返回

張量 t 的剪枝版本。

remove(module)[source][source]

從模組中移除剪枝重引數化。

名為 name 的剪枝引數永久保留剪枝狀態,名為 name+'_orig' 的引數從引數列表中移除。類似地,名為 name+'_mask' 的緩衝區從緩衝區中移除。

注意

剪枝本身無法撤銷或反轉!

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源