跳過模組引數初始化¶
建立日期:2021年6月17日 | 最後更新:2021年6月17日 | 最後驗證:未驗證
引言¶
建立模組時,其可學習引數會根據與模組型別關聯的預設初始化方案進行初始化。例如,torch.nn.Linear 模組的 weight 引數是從 uniform(-1/sqrt(in_features), 1/sqrt(in_features)) 分佈初始化的。如果需要其他初始化方案,傳統上需要在模組例項化後重新初始化引數。
from torch import nn
# Initializes weight from the default distribution: uniform(-1/sqrt(10), 1/sqrt(10)).
m = nn.Linear(10, 5)
# Re-initialize weight from a different distribution.
nn.init.orthogonal_(m.weight)
在這種情況下,構造過程中進行的初始化是浪費的計算,如果 weight 引數很大,這可能不是一件微不足道的事情。
跳過初始化¶
現在可以在模組構造期間跳過引數初始化,從而避免浪費計算。使用 torch.nn.utils.skip_init() 函式可以輕鬆實現這一點。
from torch import nn
from torch.nn.utils import skip_init
m = skip_init(nn.Linear, 10, 5)
# Example: Do custom, non-default parameter initialization.
nn.init.orthogonal_(m.weight)
這可以應用於滿足下方 更新模組以支援跳過初始化 部分所述條件的任何模組。請注意,torch.nn 提供的所有模組都滿足這些條件,因此支援跳過初始化。
更新模組以支援跳過初始化¶
由於 torch.nn.utils.skip_init() 的實現方式(參見 實現細節),模組必須滿足兩個要求才能與該函式相容。透過遵守這些要求,您就可以為您的自定義模組啟用引數初始化跳過功能。
1. 模組的建構函式必須接受 device 關鍵字引數,並將其傳遞給構造過程中建立的任何引數或緩衝區。
2. 模組在建構函式中不能對引數或緩衝區執行除初始化之外的任何計算(即 torch.nn.init 中的函式)。
以下示例演示了一個模組,它透過將 device 關鍵字引數傳遞給任何建立的引數、緩衝區或子模組來支援該功能。
import torch
from torch import nn
class MyModule(torch.nn.Module):
def __init__(self, foo, bar, device=None):
super().__init__()
# ==== Case 1: Module creates parameters directly. ====
# Pass device along to any created parameters.
self.param1 = nn.Parameter(torch.empty((foo, bar), device=device))
self.register_parameter('param2', nn.Parameter(torch.empty(bar, device=device)))
# To ensure support for the meta device, avoid using ops except those in
# torch.nn.init on parameters in your module's constructor.
with torch.no_grad():
nn.init.kaiming_uniform_(self.param1)
nn.init.uniform_(self.param2)
# ==== Case 2: Module creates submodules. ====
# Pass device along recursively. All submodules will need to support
# them as well; this is the case for all torch.nn provided modules.
self.fc = nn.Linear(bar, 5, device=device)
# This also works with containers.
self.linears = nn.Sequential(
nn.Linear(5, 5, device=device),
nn.Linear(5, 1, device=device)
)
# ==== Case 3: Module creates buffers. ====
# Pass device along during buffer tensor creation.
self.register_buffer('some_buffer', torch.ones(7, device=device))
...
實現細節¶
在幕後,torch.nn.utils.skip_init() 函式是按照兩步模式實現的。
# 1. Initialize module on the meta device; all torch.nn.init ops have
# no-op behavior on the meta device.
m = nn.Linear(10, 5, device='meta')
# 2. Materialize an uninitialized (empty) form of the module on the CPU device.
# The result of this is a module instance with uninitialized parameters.
m.to_empty(device='cpu')
它的工作原理是將模組例項化到一個“meta”裝置上,該裝置具有張量形狀資訊但不分配任何儲存空間。torch.nn.init 操作是專門為這個 meta 裝置實現的,以便它們具有無操作行為。這導致引數初始化邏輯基本被跳過。
請注意,此模式僅適用於在構造期間正確支援 device 關鍵字引數的模組,如 更新模組以支援跳過初始化 中所述。