torch.nn.utils.parametrize.register_parametrization¶
- torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization, *, unsafe=False)[原始碼][原始碼]¶
在一個模組中為張量註冊引數化(parametrization)。
為簡單起見,假設
tensor_name="weight"。當訪問module.weight時,模組將返回引數化版本parametrization(module.weight)。如果原始張量需要梯度,反向傳播將透過parametrization進行微分,並且最佳化器將相應地更新張量。模組首次註冊引數化時,此函式會向模組新增一個型別為
ParametrizationList的屬性parametrizations。張量
weight上的引數化列表可透過module.parametrizations.weight訪問。原始張量可透過
module.parametrizations.weight.original訪問。透過在同一屬性上註冊多個引數化,可以將引數化串聯起來。
註冊的引數化的訓練模式在註冊時會更新,以匹配宿主模組的訓練模式。
引數化的引數和緩衝區有一個內建的快取系統,可以使用上下文管理器
cached()啟用。parametrization可以選擇實現一個方法,簽名如下:def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]
當註冊第一個引數化時,此方法在未引數化的張量上呼叫,以計算原始張量的初始值。如果未實現此方法,則原始張量即為未引數化的張量。
如果在張量上註冊的所有引數化都實現了 right_inverse,則可以透過賦值來初始化引數化張量,如下例所示。
第一個引數化可能依賴於多個輸入。這可以透過從
right_inverse返回一個張量元組來實現(參見下面RankOne引數化的示例實現)。在這種情況下,無約束張量也位於
module.parametrizations.weight下,名稱分別為original0、original1等。注意
如果 unsafe=False(預設值),則會分別呼叫 forward 和 right_inverse 方法一次,以執行多項一致性檢查。如果 unsafe=True,則在張量未引數化時會呼叫 right_inverse,否則不會呼叫任何方法。
注意
在大多數情況下,
right_inverse將是一個函式,使得forward(right_inverse(X)) == X(參見右逆)。有時,當引數化不是滿射時,放寬此要求可能是合理的。警告
如果引數化依賴於多個輸入,
register_parametrization()將註冊一些新的引數。如果在建立最佳化器後註冊此類引數化,則需要手動將這些新引數新增到最佳化器中。參見torch.Optimizer.add_param_group()。- 引數
- 關鍵字引數
unsafe (bool) – 一個布林標誌,表示引數化是否可以改變張量的資料型別和形狀。預設值:False 警告:註冊時不會檢查引數化的一致性。啟用此標誌的風險自負。
- 丟擲異常
ValueError – 如果模組沒有名為
tensor_name的引數或緩衝區- 返回型別
示例
>>> import torch >>> import torch.nn as nn >>> import torch.nn.utils.parametrize as P >>> >>> class Symmetric(nn.Module): >>> def forward(self, X): >>> return X.triu() + X.triu(1).T # Return a symmetric matrix >>> >>> def right_inverse(self, A): >>> return A.triu() >>> >>> m = nn.Linear(5, 5) >>> P.register_parametrization(m, "weight", Symmetric()) >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric True >>> A = torch.rand(5, 5) >>> A = A + A.T # A is now symmetric >>> m.weight = A # Initialize the weight to be the symmetric matrix A >>> print(torch.allclose(m.weight, A)) True
>>> class RankOne(nn.Module): >>> def forward(self, x, y): >>> # Form a rank 1 matrix multiplying two vectors >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) >>> >>> def right_inverse(self, Z): >>> # Project Z onto the rank 1 matrices >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) >>> # Return rescaled singular vectors >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt >>> >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) 1