PruningContainer¶
- class torch.nn.utils.prune.PruningContainer(*args)[source][source]¶
包含一系列用於迭代剪枝的剪枝方法的容器。
跟蹤剪枝方法的應用順序,並處理連續剪枝呼叫的合併。
接受 BasePruningMethod 的例項或其可迭代物件作為引數。
- add_pruning_method(method)[source][source]¶
向容器新增子剪枝
method。- 引數
method (subclass of BasePruningMethod) – 要新增到容器的子剪枝方法。
- classmethod apply(module, name, *args, importance_scores=None, **kwargs)[source]¶
新增即時剪枝和張量的重新引數化。
新增啟用即時剪枝的前向預鉤子,以及根據原始張量和剪枝掩碼對張量進行重新引數化。
- 引數
module (nn.Module) – 包含要剪枝的張量的模組
name (str) –
module中剪枝將作用的引數名稱。args – 傳遞給
BasePruningMethod子類的引數importance_scores (torch.Tensor) – 重要性分數張量(與模組引數形狀相同),用於計算剪枝掩碼。此張量中的值表示被剪枝引數中對應元素的重要性。如果未指定或為 None,則將使用引數本身。
kwargs – 傳遞給
BasePruningMethod子類的關鍵字引數
- apply_mask(module)[source]¶
僅處理被剪枝引數與生成的掩碼之間的乘法。
從模組中獲取掩碼和原始張量,並返回張量的剪枝版本。
- 引數
module (nn.Module) – 包含要剪枝的張量的模組
- 返回值
輸入張量的剪枝版本
- 返回值型別
pruned_tensor (torch.Tensor)
- compute_mask(t, default_mask)[source][source]¶
應用最新的
method,透過計算新的部分掩碼並返回其與default_mask的組合。新的部分掩碼應在未被
default_mask置零的條目或通道上計算。新的掩碼將從張量t的哪些部分計算取決於PRUNING_TYPE(由型別處理器處理)對於 ‘unstructured’,掩碼將從非掩碼條目的展平列表計算;
對於 ‘structured’,掩碼將從張量中非掩碼的通道計算;
對於 ‘global’,掩碼將在所有條目上計算。
- 引數
t (torch.Tensor) – 表示要剪枝引數的張量(與
default_mask尺寸相同)。default_mask (torch.Tensor) – 來自前一次剪枝迭代的掩碼。
- 返回值
新的掩碼,結合了
default_mask和當前剪枝method生成的新掩碼的效果(與default_mask和t尺寸相同)。- 返回值型別
mask (torch.Tensor)
- prune(t, default_mask=None, importance_scores=None)[source]¶
計算並返回輸入張量
t的剪枝版本。根據
compute_mask()中指定的剪枝規則。- 引數
t (torch.Tensor) – 要剪枝的張量(與
default_mask尺寸相同)。importance_scores (torch.Tensor) – 重要性分數張量(與
t形狀相同),用於計算對t進行剪枝的掩碼。此張量中的值表示被剪枝t中對應元素的重要性。如果未指定或為 None,則將使用張量t本身。default_mask (torch.Tensor, 可選) – 來自前一次剪枝迭代的掩碼(如果有)。在確定剪枝應作用於張量的哪個部分時需要考慮。如果為 None,則預設為全一掩碼。
- 返回值
張量
t的剪枝版本。