快捷方式

PruningContainer

class torch.nn.utils.prune.PruningContainer(*args)[source][source]

包含一系列用於迭代剪枝的剪枝方法的容器。

跟蹤剪枝方法的應用順序,並處理連續剪枝呼叫的合併。

接受 BasePruningMethod 的例項或其可迭代物件作為引數。

add_pruning_method(method)[source][source]

向容器新增子剪枝 method

引數

method (subclass of BasePruningMethod) – 要新增到容器的子剪枝方法。

classmethod apply(module, name, *args, importance_scores=None, **kwargs)[source]

新增即時剪枝和張量的重新引數化。

新增啟用即時剪枝的前向預鉤子,以及根據原始張量和剪枝掩碼對張量進行重新引數化。

引數
  • module (nn.Module) – 包含要剪枝的張量的模組

  • name (str) – module 中剪枝將作用的引數名稱。

  • args – 傳遞給 BasePruningMethod 子類的引數

  • importance_scores (torch.Tensor) – 重要性分數張量(與模組引數形狀相同),用於計算剪枝掩碼。此張量中的值表示被剪枝引數中對應元素的重要性。如果未指定或為 None,則將使用引數本身。

  • kwargs – 傳遞給 BasePruningMethod 子類的關鍵字引數

apply_mask(module)[source]

僅處理被剪枝引數與生成的掩碼之間的乘法。

從模組中獲取掩碼和原始張量,並返回張量的剪枝版本。

引數

module (nn.Module) – 包含要剪枝的張量的模組

返回值

輸入張量的剪枝版本

返回值型別

pruned_tensor (torch.Tensor)

compute_mask(t, default_mask)[source][source]

應用最新的 method,透過計算新的部分掩碼並返回其與 default_mask 的組合。

新的部分掩碼應在未被 default_mask 置零的條目或通道上計算。新的掩碼將從張量 t 的哪些部分計算取決於 PRUNING_TYPE(由型別處理器處理)

  • 對於 ‘unstructured’,掩碼將從非掩碼條目的展平列表計算;

  • 對於 ‘structured’,掩碼將從張量中非掩碼的通道計算;

  • 對於 ‘global’,掩碼將在所有條目上計算。

引數
  • t (torch.Tensor) – 表示要剪枝引數的張量(與 default_mask 尺寸相同)。

  • default_mask (torch.Tensor) – 來自前一次剪枝迭代的掩碼。

返回值

新的掩碼,結合了 default_mask 和當前剪枝 method 生成的新掩碼的效果(與 default_maskt 尺寸相同)。

返回值型別

mask (torch.Tensor)

prune(t, default_mask=None, importance_scores=None)[source]

計算並返回輸入張量 t 的剪枝版本。

根據 compute_mask() 中指定的剪枝規則。

引數
  • t (torch.Tensor) – 要剪枝的張量(與 default_mask 尺寸相同)。

  • importance_scores (torch.Tensor) – 重要性分數張量(與 t 形狀相同),用於計算對 t 進行剪枝的掩碼。此張量中的值表示被剪枝 t 中對應元素的重要性。如果未指定或為 None,則將使用張量 t 本身。

  • default_mask (torch.Tensor, 可選) – 來自前一次剪枝迭代的掩碼(如果有)。在確定剪枝應作用於張量的哪個部分時需要考慮。如果為 None,則預設為全一掩碼。

返回值

張量 t 的剪枝版本。

remove(module)[source]

從模組中移除剪枝重新引數化。

名為 name 的剪枝引數將永久保持剪枝狀態,名為 name+'_orig' 的引數將從引數列表中移除。類似地,名為 name+'_mask' 的緩衝區將從緩衝區中移除。

注意

剪枝本身不會被撤銷或反轉!

文件

獲取 PyTorch 全面的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源