快捷方式

LossModule

torchrl.objectives.LossModule(*args, **kwargs)[原始碼]

RL 損失函式的父類。

LossModule 繼承自 nn.Module。它旨在讀取輸入的 TensorDict 並返回另一個 tensordict,其中包含名為 "loss_*" 的損失鍵。

將損失分解為其組成部分,然後由訓練器用於記錄訓練過程中的各種損失值。輸出 tensordict 中存在的其他標量也會被記錄。

變數:

default_value_estimator – 類的預設值型別。需要值估計的損失函式配備了預設值指標。此類屬性指示在未指定其他值估計器時將使用哪個值估計器。可以使用 make_value_estimator() 方法更改值估計器。

預設情況下,forward 方法始終使用 gh torchrl.envs.ExplorationType.MEAN 進行修飾

要利用透過 set_keys() 配置 tensordict 鍵的能力,子類必須定義一個 _AcceptedKeys 資料類。此資料類應包含所有預期可配置的鍵。此外,子類必須實現 :meth:._forward_value_estimator_keys() 方法。此函式對於將任何修改後的 tensordict 鍵轉發到底層 value_estimator 至關重要。

示例

>>> class MyLoss(LossModule):
>>>     @dataclass
>>>     class _AcceptedKeys:
>>>         action = "action"
>>>
>>>     def _forward_value_estimator_keys(self, **kwargs) -> None:
>>>         pass
>>>
>>> loss = MyLoss()
>>> loss.set_keys(action="action2")

注意

當將用探索模組包裝或增強的策略傳遞給損失函式時,我們希望透過 set_exploration_type(<exploration>) 來停用探索,其中 <exploration> 可以是 ExplorationType.MEANExplorationType.MODEExplorationType.DETERMINISTIC。預設值為 DETERMINISTIC,透過 deterministic_sampling_mode 損失屬性設定。如果需要其他探索模式(或如果 DETERMINISTIC 不可用),可以更改此屬性的值,這將改變模式。

convert_to_functional(module: TensorDictModule, module_name: str, expand_dim: Optional[int] = None, create_target_params: bool = False, compare_against: Optional[List[Parameter]] = None, **kwargs) None[原始碼]

將模組轉換為函式式以便在損失函式中使用。

引數:
  • module (TensorDictModule相容型別) – 一個有狀態的 tensordict 模組。此模組的引數將被隔離在 <module_name>_params 屬性中,並且此模組的無狀態版本將註冊在 module_name 屬性下。

  • module_name (str) – 查詢模組的名稱。模組的引數將在 loss_module.<module_name>_params 下找到,而模組將在 loss_module.<module_name> 下找到。

  • expand_dim (int, 可選) –

    如果提供,模組的引數

    將沿第一個維度擴充套件 N 次,其中 N = expand_dim。當需要使用具有多個配置的目標網路時,應使用此選項。

    注意

    如果提供了 compare_against 值列表,則結果引數將僅是原始引數的獨立(detached)擴充套件。如果未提供 compare_against,則引數的值將在引數內容的最小值和最大值之間均勻重取樣。

    create_target_params (bool, 可選): 如果為 True,則會提供一個獨立的(detached)

    引數副本,用於提供給目標網路,名稱為 loss_module.<module_name>_target_params。如果為 False(預設),此屬性仍然可用,但它將是引數的一個獨立(detached)例項,而非副本。換句話說,引數值的任何修改將直接反映在目標引數中。

  • compare_against (引數可迭代物件, 可選) – 如果提供,此引數列表將用作模組引數的比較集。如果引數被擴充套件(expand_dim > 0),則模組的結果引數將是原始引數的簡單擴充套件。否則,結果引數將是原始引數的獨立(detached)版本。如果為 None,則結果引數將按預期帶有梯度。

forward(tensordict: TensorDictBase) TensorDictBase[原始碼]

它旨在讀取輸入的 TensorDict 並返回另一個 tensordict,其中包含名為“loss*”的損失鍵。

將損失分解為其組成部分,然後由訓練器用於記錄訓練過程中的各種損失值。輸出 tensordict 中存在的其他標量也會被記錄。

引數:

tensordict – 包含計算損失所需值的輸入 tensordict。

返回值:

一個不帶批次維度的新的 tensordict,包含各種損失標量,這些標量將命名為“loss*”。損失以這種名稱返回至關重要,因為它們將在反向傳播之前由訓練器讀取。

from_stateful_net(network_name: str, stateful_net: Module)[原始碼]

根據網路的有狀態版本填充模型的引數。

有關如何獲取網路的有狀態版本的詳細資訊,請參閱 get_stateful_net()

引數:
  • network_name (str) – 要重置的網路名稱。

  • stateful_net (nn.Module) – 應從中獲取引數的有狀態網路。

屬性 functional

模組是否為函式式的。

除非特別設計為非函式式,否則所有損失函式都是函式式的。

get_stateful_net(network_name: str, copy: bool | None = None)[原始碼]

返回網路的有狀態版本。

這可用於初始化引數。

此類網路通常無法直接呼叫,需要透過 vmap 呼叫才能執行。

引數:
  • network_name (str) – 要獲取的網路名稱。

  • copy (bool, 可選) –

    如果為 True,則對網路進行深複製。預設為 True

    注意

    如果模組不是函式式的,則不進行複製。

make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[原始碼]

值函式構造器。

如果需要非預設值函式,則必須使用此方法構建。

引數:
  • value_type (值估計器) – 一個 ValueEstimators 列舉型別,指示要使用的值函式。如果未提供,將使用儲存在 default_value_estimator 屬性中的預設值。結果值估計器類將註冊在 self.value_type 中,以便將來進行細化。

  • **hyperparams – 用於值函式的超引數。如果未提供,將使用 default_value_kwargs() 指示的值。

示例

>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> # updating the parameters of the default value estimator
>>> dqn_loss.make_value_estimator(gamma=0.9)
>>> dqn_loss.make_value_estimator(
...     ValueEstimators.TD1,
...     gamma=0.9)
>>> # if we want to change the gamma value
>>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Parameter]][原始碼]

返回一個模組引數的迭代器,同時生成引數的名稱和引數本身。

引數:
  • prefix (str) – 要前置到所有引數名稱的字首。

  • recurse (bool) – 如果為 True,則生成此模組和所有子模組的引數。否則,僅生成屬於此模組直接成員的引數。

  • remove_duplicate (bool, 可選) – 是否移除結果中的重複引數。預設為 True。

生成內容:

(str, Parameter) – 包含名稱和引數的元組

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
parameters(recurse: bool = True) Iterator[Parameter][原始碼]

返回一個模組引數的迭代器。

這通常傳遞給最佳化器。

引數:

recurse (bool) – 如果為 True,則生成此模組和所有子模組的引數。否則,僅生成屬於此模組直接成員的引數。

生成內容:

Parameter – 模組引數

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
reset_parameters_recursive()[原始碼]

重置模組的引數。

set_keys(**kwargs) None[原始碼]

設定 tensordict 鍵的名稱。

示例

>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
屬性 value_estimator: ValueEstimatorBase

值函式將來自後續狀態/狀態-動作對的獎勵和值估計融合到值網路的目標值估計中。

屬性 vmap_randomness

Vmap 隨機模式。

vmap 隨機模式控制 vmap() 在處理具有隨機結果的函式(如 randn()rand())時應執行的操作。如果為 “error”,任何隨機函式都將引發異常,表明 vmap 不知道如何處理隨機呼叫。

如果為 “different”,則沿其呼叫 vmap 的批處理的每個元素將表現不同。如果為 “same”,vmap 將在所有元素中複製相同的結果。

如果未檢測到隨機模組,則 vmap_randomness 預設為 “error”,否則預設為 “different”。預設情況下,只有有限數量的模組被列為隨機模組,但可以使用 add_random_module() 函式擴充套件此列表。

此屬性支援設定其值。

文件

獲取 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源