LossModule¶
- 類 torchrl.objectives.LossModule(*args, **kwargs)[原始碼]¶
RL 損失函式的父類。
LossModule 繼承自 nn.Module。它旨在讀取輸入的 TensorDict 並返回另一個 tensordict,其中包含名為
"loss_*"的損失鍵。將損失分解為其組成部分,然後由訓練器用於記錄訓練過程中的各種損失值。輸出 tensordict 中存在的其他標量也會被記錄。
- 變數:
default_value_estimator – 類的預設值型別。需要值估計的損失函式配備了預設值指標。此類屬性指示在未指定其他值估計器時將使用哪個值估計器。可以使用
make_value_estimator()方法更改值估計器。
預設情況下,forward 方法始終使用 gh
torchrl.envs.ExplorationType.MEAN進行修飾要利用透過
set_keys()配置 tensordict 鍵的能力,子類必須定義一個 _AcceptedKeys 資料類。此資料類應包含所有預期可配置的鍵。此外,子類必須實現 :meth:._forward_value_estimator_keys() 方法。此函式對於將任何修改後的 tensordict 鍵轉發到底層 value_estimator 至關重要。示例
>>> class MyLoss(LossModule): >>> @dataclass >>> class _AcceptedKeys: >>> action = "action" >>> >>> def _forward_value_estimator_keys(self, **kwargs) -> None: >>> pass >>> >>> loss = MyLoss() >>> loss.set_keys(action="action2")
注意
當將用探索模組包裝或增強的策略傳遞給損失函式時,我們希望透過
set_exploration_type(<exploration>)來停用探索,其中<exploration>可以是ExplorationType.MEAN、ExplorationType.MODE或ExplorationType.DETERMINISTIC。預設值為DETERMINISTIC,透過deterministic_sampling_mode損失屬性設定。如果需要其他探索模式(或如果DETERMINISTIC不可用),可以更改此屬性的值,這將改變模式。- convert_to_functional(module: TensorDictModule, module_name: str, expand_dim: Optional[int] = None, create_target_params: bool = False, compare_against: Optional[List[Parameter]] = None, **kwargs) None[原始碼]¶
將模組轉換為函式式以便在損失函式中使用。
- 引數:
module (TensorDictModule 或 相容型別) – 一個有狀態的 tensordict 模組。此模組的引數將被隔離在 <module_name>_params 屬性中,並且此模組的無狀態版本將註冊在 module_name 屬性下。
module_name (str) – 查詢模組的名稱。模組的引數將在
loss_module.<module_name>_params下找到,而模組將在loss_module.<module_name>下找到。expand_dim (int, 可選) –
- 如果提供,模組的引數
將沿第一個維度擴充套件
N次,其中N = expand_dim。當需要使用具有多個配置的目標網路時,應使用此選項。注意
如果提供了
compare_against值列表,則結果引數將僅是原始引數的獨立(detached)擴充套件。如果未提供compare_against,則引數的值將在引數內容的最小值和最大值之間均勻重取樣。- create_target_params (bool, 可選): 如果為
True,則會提供一個獨立的(detached) 引數副本,用於提供給目標網路,名稱為
loss_module.<module_name>_target_params。如果為False(預設),此屬性仍然可用,但它將是引數的一個獨立(detached)例項,而非副本。換句話說,引數值的任何修改將直接反映在目標引數中。
compare_against (引數可迭代物件, 可選) – 如果提供,此引數列表將用作模組引數的比較集。如果引數被擴充套件(
expand_dim > 0),則模組的結果引數將是原始引數的簡單擴充套件。否則,結果引數將是原始引數的獨立(detached)版本。如果為None,則結果引數將按預期帶有梯度。
- forward(tensordict: TensorDictBase) TensorDictBase[原始碼]¶
它旨在讀取輸入的 TensorDict 並返回另一個 tensordict,其中包含名為“loss*”的損失鍵。
將損失分解為其組成部分,然後由訓練器用於記錄訓練過程中的各種損失值。輸出 tensordict 中存在的其他標量也會被記錄。
- 引數:
tensordict – 包含計算損失所需值的輸入 tensordict。
- 返回值:
一個不帶批次維度的新的 tensordict,包含各種損失標量,這些標量將命名為“loss*”。損失以這種名稱返回至關重要,因為它們將在反向傳播之前由訓練器讀取。
- from_stateful_net(network_name: str, stateful_net: Module)[原始碼]¶
根據網路的有狀態版本填充模型的引數。
有關如何獲取網路的有狀態版本的詳細資訊,請參閱
get_stateful_net()。- 引數:
network_name (str) – 要重置的網路名稱。
stateful_net (nn.Module) – 應從中獲取引數的有狀態網路。
- 屬性 functional¶
模組是否為函式式的。
除非特別設計為非函式式,否則所有損失函式都是函式式的。
- get_stateful_net(network_name: str, copy: bool | None = None)[原始碼]¶
返回網路的有狀態版本。
這可用於初始化引數。
此類網路通常無法直接呼叫,需要透過 vmap 呼叫才能執行。
- 引數:
network_name (str) – 要獲取的網路名稱。
copy (bool, 可選) –
如果為
True,則對網路進行深複製。預設為True。注意
如果模組不是函式式的,則不進行複製。
- make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[原始碼]¶
值函式構造器。
如果需要非預設值函式,則必須使用此方法構建。
- 引數:
value_type (值估計器) – 一個
ValueEstimators列舉型別,指示要使用的值函式。如果未提供,將使用儲存在default_value_estimator屬性中的預設值。結果值估計器類將註冊在self.value_type中,以便將來進行細化。**hyperparams – 用於值函式的超引數。如果未提供,將使用
default_value_kwargs()指示的值。
示例
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> # updating the parameters of the default value estimator >>> dqn_loss.make_value_estimator(gamma=0.9) >>> dqn_loss.make_value_estimator( ... ValueEstimators.TD1, ... gamma=0.9) >>> # if we want to change the gamma value >>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Parameter]][原始碼]¶
返回一個模組引數的迭代器,同時生成引數的名稱和引數本身。
- 引數:
prefix (str) – 要前置到所有引數名稱的字首。
recurse (bool) – 如果為 True,則生成此模組和所有子模組的引數。否則,僅生成屬於此模組直接成員的引數。
remove_duplicate (bool, 可選) – 是否移除結果中的重複引數。預設為 True。
- 生成內容:
(str, Parameter) – 包含名稱和引數的元組
示例
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- parameters(recurse: bool = True) Iterator[Parameter][原始碼]¶
返回一個模組引數的迭代器。
這通常傳遞給最佳化器。
- 引數:
recurse (bool) – 如果為 True,則生成此模組和所有子模組的引數。否則,僅生成屬於此模組直接成員的引數。
- 生成內容:
Parameter – 模組引數
示例
>>> # xdoctest: +SKIP("undefined vars") >>> for param in model.parameters(): >>> print(type(param), param.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- set_keys(**kwargs) None[原始碼]¶
設定 tensordict 鍵的名稱。
示例
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
- 屬性 value_estimator: ValueEstimatorBase¶
值函式將來自後續狀態/狀態-動作對的獎勵和值估計融合到值網路的目標值估計中。
- 屬性 vmap_randomness¶
Vmap 隨機模式。
vmap 隨機模式控制
vmap()在處理具有隨機結果的函式(如randn()和 rand())時應執行的操作。如果為 “error”,任何隨機函式都將引發異常,表明 vmap 不知道如何處理隨機呼叫。如果為 “different”,則沿其呼叫 vmap 的批處理的每個元素將表現不同。如果為 “same”,vmap 將在所有元素中複製相同的結果。
如果未檢測到隨機模組,則
vmap_randomness預設為 “error”,否則預設為 “different”。預設情況下,只有有限數量的模組被列為隨機模組,但可以使用add_random_module()函式擴充套件此列表。此屬性支援設定其值。