from_modules¶
- class tensordict.from_modules(*modules, as_module: bool = False, lock: bool = True, use_state_dict: bool = False, lazy_stack: bool = False, expand_identical: bool = False)¶
檢索多個模組的引數,用於整合學習/透過 vmap 實現的特性應用。
- 引數:
modules (nn.Module 序列) – 需要從中獲取引數的模組。如果模組結構不同,則需要惰性堆疊(請參閱下面的
lazy_stack引數)。- 關鍵字引數:
as_module (bool, 可選) – 如果為
True,將返回一個TensorDictParams例項,該例項可用於在torch.nn.Module中儲存引數。預設為False。lock (bool, 可選) – 如果為
True,則結果 tensordict 將被鎖定。預設為True。use_state_dict (bool, 可選) –
如果為
True,將使用模組的 state-dict,並將其展開為具有模型樹狀結構的 TensorDict。預設為False。注意
這在必須使用 state-dict 鉤子時特別有用。
lazy_stack (bool, 可選) –
引數應該密集堆疊還是惰性堆疊。預設為
False(密集堆疊)。注意
lazy_stack和as_module是互斥特性。警告
惰性輸出和非惰性輸出之間存在關鍵差異:非惰性輸出將重新例項化具有所需批次大小的引數,而
lazy_stack僅將引數表示為惰性堆疊。這意味著當lazy_stack=True時,原始引數可以安全地傳遞給最佳化器,而當設定為True時,需要傳遞新引數。警告
雖然使用惰性堆疊來保留原始引數引用可能很誘人,但請記住,每次呼叫
get()時,惰性堆疊都會執行一次堆疊操作。這需要計算記憶體(引數大小的 N 倍,如果構建了計算圖則更多)和時間。這也意味著最佳化器將包含更多引數,並且step()或zero_grad()等操作將需要更長時間執行。通常,lazy_stack應該僅保留給極少數用例。expand_identical (bool, 可選) – 如果為
True且同一引數(相同標識)被堆疊到自身,則將轉而返回此引數的擴充套件版本。當lazy_stack=True時,此引數將被忽略。
示例
>>> from torch import nn >>> from tensordict import from_modules >>> torch.manual_seed(0) >>> empty_module = nn.Linear(3, 4, device="meta") >>> n_models = 2 >>> modules = [nn.Linear(3, 4) for _ in range(n_models)] >>> params = from_modules(*modules) >>> print(params) TensorDict( fields={ bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) >>> # example of batch execution >>> def exec_module(params, x): ... with params.to_module(empty_module): ... return empty_module(x) >>> x = torch.randn(3) >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> # since lazy_stack = False, backprop leaves the original params untouched >>> y.sum().backward() >>> assert params["weight"].grad.norm() > 0 >>> assert modules[0].weight.grad is None
當
lazy_stack=True時,情況略有不同>>> params = TensorDict.from_modules(*modules, lazy_stack=True) >>> print(params) LazyStackedTensorDict( fields={ bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=None, is_shared=False, stack_dim=0) >>> # example of batch execution >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> y.sum().backward() >>> assert modules[0].weight.grad is not None