RandomUnstructured¶
- class torch.nn.utils.prune.RandomUnstructured(amount)[source][source]¶
隨機修剪張量中(當前未修剪的)單元。
- 引數
- classmethod apply(module, name, amount)[source][source]¶
動態新增修剪和張量的重引數化。
添加了前向預鉤子(forward pre-hook),該鉤子支援動態修剪以及將張量重引數化為原始張量和修剪掩碼的形式。
- apply_mask(module)[source]¶
簡單地處理要修剪的引數與生成的掩碼之間的乘法。
從模組中獲取掩碼和原始張量,並返回張量的修剪版本。
- 引數
module (nn.Module) – 包含要修剪的張量的模組
- 返回
輸入張量的修剪版本
- 返回型別
pruned_tensor (torch.Tensor)
- prune(t, default_mask=None, importance_scores=None)[source]¶
計算並返回輸入張量
t的修剪版本。根據
compute_mask()中指定的修剪規則。- 引數
t (torch.Tensor) – 要修剪的張量(與
default_mask維度相同)。importance_scores (torch.Tensor) – 用於計算修剪
t掩碼的重要性分數張量(與t形狀相同)。此張量中的值表示t中相應元素的重要性,該t正在被修剪。如果未指定或為 None,將使用張量t代替。default_mask (torch.Tensor, 可選) – 先前修剪迭代(如果存在)的掩碼。在確定修剪應作用於張量的哪一部分時需要考慮。如果為 None,預設為全 1 掩碼。
- 返回
張量
t的修剪版本。