快捷方式

torchrl.objectives 包

TorchRL 提供了一系列可在訓練指令碼中使用的損失函式。其目標是提供易於複用/替換且簽名簡單的損失函式。

TorchRL 損失函式的主要特點如下:

  • 它們是有狀態物件:它們包含可訓練引數的副本,因此 loss_module.parameters() 提供了訓練演算法所需的一切。

  • 它們遵循 tensordict 約定:torch.nn.Module.forward() 方法將接收一個 tensordict 作為輸入,其中包含返回損失值所需的所有資訊。

  • 它們輸出一個 tensordict.TensorDict 例項,損失值以 "loss_<smth>" 為鍵寫入,其中 smth 是描述損失的字串。tensordict 中的其他鍵可能是訓練期間有用的指標。

注意

我們返回獨立損失的原因是為了讓使用者可以對不同的引數集使用不同的最佳化器。簡單地將損失相加可以透過以下方式完成:

>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))

注意

損失函式中的引數初始化可以透過呼叫 get_stateful_net() 完成,該方法將返回網路的有狀態版本,可以像其他任何模組一樣初始化。如果修改是就地進行的 (in-place),它將下游傳遞到使用相同引數集的任何其他模組(無論在損失函式內部還是外部):例如,從損失函式修改 actor_network 引數也會修改收集器中的 actor。如果引數是非就地修改的 (out-of-place),則可以使用 from_stateful_net() 將損失函式中的引數重置為新值。

torch.vmap 與隨機性

TorchRL 損失模組大量呼叫 vmap(),以分攤在迴圈中呼叫多個相似模型的成本,並將其向量化。當需要在呼叫中生成隨機數時,需要明確告訴 vmap 如何處理。為此,需要設定一個隨機性模式,該模式必須是 “error”(預設,處理偽隨機函式時出錯)、“same”(跨批次複製結果)或 “different”(批次中的每個元素單獨處理)之一。依賴預設設定通常會導致如下錯誤:

>>> RuntimeError: vmap: called random operation while in randomness error mode.

由於對 vmap 的呼叫隱藏在損失模組內部,TorchRL 提供了一個外部介面來設定 vmap 模式,透過 loss.vmap_randomness = str_value 完成,更多資訊請參見 vmap_randomness()

如果未檢測到隨機模組,LossModule.vmap_randomness 預設為 “error”,在其他情況下預設為 “different”。預設情況下,只有有限數量的模組被列為隨機模組,但可以使用 add_random_module() 函式擴充套件該列表。

訓練價值函式

TorchRL 提供了一系列價值估計器,例如 TD(0)、TD(1)、TD(\(\lambda\)) 和 GAE。簡單來說,價值估計器是資料(主要是獎勵和完成狀態)和狀態值(即,用於估計狀態值的函式返回的值)的函式。要了解更多關於價值估計器的資訊,請查閱 Sutton 和 Barto 的 RL 入門,特別是關於價值迭代和 TD 學習的章節。它根據資料和代理對映,對遵循某個狀態或狀態-動作對的折扣回報給出了一個有偏的估計。這些估計器用於兩種場景:

  • 為了訓練價值網路來學習“真實”的狀態值(或狀態-動作值)對映,需要一個目標值來擬合。估計器越好(偏差越小,方差越小),價值網路就越好,這反過來可以顯著加快策略訓練。通常,價值網路的損失函式如下所示:

    >>> value = value_network(states)
    >>> target_value = value_estimator(rewards, done, value_network(next_state))
    >>> value_net_loss = (value - target_value).pow(2).mean()
    
  • 計算策略最佳化的“優勢”訊號。優勢是價值估計值(來自估計器,即來自“真實”資料)與價值網路輸出(即該值的代理)之間的差值。正優勢可以看作是策略實際表現優於預期的訊號,因此如果以該軌跡為例,則表示有改進的空間。相反,負優勢表示策略表現不如預期。

事情並非總是像上面的例子那樣簡單,計算價值估計器或優勢的公式可能比這稍微複雜一些。為了幫助使用者靈活地使用不同的價值估計器,我們提供了一個簡單的 API 來動態更改它。這裡以 DQN 為例,但所有模組都遵循類似的結構:

>>> from torchrl.objectives import DQNLoss, ValueEstimators
>>> loss_module = DQNLoss(actor)
>>> kwargs = {"gamma": 0.9, "lmbda": 0.9}
>>> loss_module.make_value_estimator(ValueEstimators.TDLambda, **kwargs)

ValueEstimators 類枚舉了可供選擇的價值估計器。這使得使用者可以輕鬆地依靠自動補全來做出選擇。

LossModule(*args, **kwargs)

RL 損失函式的父類。

DQN

DQNLoss(*args, **kwargs)

DQN 損失類。

DistributionalDQNLoss(*args, **kwargs)

分散式 DQN 損失類。

DDPG

DDPGLoss(*args, **kwargs)

DDPG 損失類。

SAC

SACLoss(*args, **kwargs)

TorchRL 實現的 SAC 損失函式。

DiscreteSACLoss(*args, **kwargs)

離散 SAC 損失模組。

REDQ

REDQLoss(*args, **kwargs)

REDQ 損失模組。

CrossQ

CrossQLoss(*args, **kwargs)

TorchRL 實現的 CrossQ 損失函式。

IQL

IQLLoss(*args, **kwargs)

TorchRL 實現的 IQL 損失函式。

DiscreteIQLLoss(*args, **kwargs)

TorchRL 實現的離散 IQL 損失函式。

CQL

CQLLoss(*args, **kwargs)

TorchRL 實現的連續 CQL 損失函式。

DiscreteCQLLoss(*args, **kwargs)

TorchRL 實現的離散 CQL 損失函式。

GAIL

GAILLoss(*args, **kwargs)

TorchRL 實現的生成對抗模仿學習 (GAIL) 損失函式。

DT

DTLoss(*args, **kwargs)

TorchRL 實現的線上決策 Transformer 損失函式。

OnlineDTLoss(*args, **kwargs)

TorchRL 實現的線上決策 Transformer 損失函式。

TD3

TD3Loss(*args, **kwargs)

TD3 損失模組。

TD3+BC

TD3BCLoss(*args, **kwargs)

TD3+BC 損失模組。

PPO

PPOLoss(*args, **kwargs)

PPO 損失父類。

ClipPPOLoss(*args, **kwargs)

裁剪 PPO 損失。

KLPENPPOLoss(*args, **kwargs)

KL 懲罰 PPO 損失。

將 PPO 與多頭動作策略一起使用

注意

構建多頭策略時要考慮的主要工具有:CompositeDistributionProbabilisticTensorDictModuleProbabilisticTensorDictSequential。處理這些模組時,建議在指令碼開頭呼叫 tensordict.nn.set_composite_lp_aggregate(False).set(),以指示 CompositeDistribution 不應聚合對數機率,而應將其作為葉節點寫入 tensordict 中。

在某些情況下,我們有一個單一的優勢值,但採取了多個動作。每個動作都有自己的對數機率和形狀。例如,動作空間可以結構化如下:

>>> action_td = TensorDict(
...     agents=TensorDict(
...         action0=Tensor(batch, n_agents, f0),
...         action1=Tensor(batch, n_agents, f1, f2),
...         batch_size=torch.Size((batch, n_agents))
...     ),
...     batch_size=torch.Size((batch,))
... )

其中 f0f1f2 是任意整數。

請注意,在 TorchRL 中,根 tensordict 的形狀與環境的形狀相同(如果環境是批次鎖定的,否則其形狀與正在執行的批次環境數相同)。如果 tensordict 是從緩衝區取樣的,其形狀將與回放緩衝區 batch_size 的形狀相同。儘管 n_agent 維度對每個動作都常見,但它通常不會出現在根 tensordict 的批次大小中(儘管根據 MARL API,它會出現在包含智慧體特定資料的子 tensordict 中)。

這種情況是有合理原因的:智慧體數量可能會影響環境的某些規範,但不是全部。例如,有些環境在所有智慧體之間共享一個完成狀態。在這種情況下,一個更完整的 tensordict 可能看起來像這樣:

>>> action_td = TensorDict(
...     agents=TensorDict(
...         action0=Tensor(batch, n_agents, f0),
...         action1=Tensor(batch, n_agents, f1, f2),
...         observation=Tensor(batch, n_agents, f3),
...         batch_size=torch.Size((batch, n_agents))
...     ),
...     done=Tensor(batch, 1),
...     [...] # etc
...     batch_size=torch.Size((batch,))
... )

請注意,done 狀態和 reward 通常在最右側伴隨一個單例維度。請參閱文件的這部分,瞭解更多關於此限制的資訊。

在給定其各自分佈的情況下,我們的動作的對數機率可能看起來像這樣:

>>> action_td = TensorDict(
...     agents=TensorDict(
...         action0_log_prob=Tensor(batch, n_agents),
...         action1_log_prob=Tensor(batch, n_agents, f1),
...         batch_size=torch.Size((batch, n_agents))
...     ),
...     batch_size=torch.Size((batch,))
... )

>>> action_td = TensorDict(
...     agents=TensorDict(
...         action0_log_prob=Tensor(batch, n_agents),
...         action1_log_prob=Tensor(batch, n_agents),
...         batch_size=torch.Size((batch, n_agents))
...     ),
...     batch_size=torch.Size((batch,))
... )

即,分佈對數機率的維數通常從樣本的維數到任何小於該維數的值不等,例如,如果分佈是多元的(例如 Dirichlet)或是一個 Independent 例項。相反,tensordict 的維數仍然與環境/回放緩衝區的批次大小匹配。

在呼叫 PPO 損失時,損失模組將 схематически 執行以下一系列操作:

>>> def ppo(tensordict):
...     prev_log_prob = tensordict.select(*log_prob_keys)
...     action = tensordict.select(*action_keys)
...     new_log_prob = dist.log_prob(action)
...     log_weight = new_log_prob - prev_log_prob
...     advantage = tensordict.get("advantage") # computed by GAE earlier
...     # attempt to map shape
...     log_weight.batch_size = advantage.batch_size[:-1]
...     log_weight = sum(log_weight.sum(dim="feature").values(True, True)) # get a single tensor of log_weights
...     return minimum(log_weight.exp() * advantage, log_weight.exp().clamp(1-eps, 1+eps) * advantage)

要了解多頭策略下的 PPO 流水線是什麼樣的,可以在庫的示例目錄中找到一個示例。

A2C

A2CLoss(*args, **kwargs)

TorchRL 實現的 A2C 損失函式。

Reinforce

ReinforceLoss(*args, **kwargs)

Reinforce 損失模組。

Dreamer

DreamerActorLoss(*args, **kwargs)

Dreamer Actor 損失。

DreamerModelLoss(*args, **kwargs)

Dreamer Model 損失。

DreamerValueLoss(*args, **kwargs)

Dreamer Value 損失。

多智慧體目標

這些目標是多智慧體演算法特有的。

QMixer

QMixerLoss(*args, **kwargs)

QMixer 損失類。

返回值

ValueEstimatorBase(*args, **kwargs)

價值函式模組的抽象父類。

TD0Estimator(*args, **kwargs)

時序差分 (TD(0)) 優勢函式估計器。

TD1Estimator(*args, **kwargs)

\(\infty\)-時序差分 (TD(1)) 優勢函式估計器。

TDLambdaEstimator(*args, **kwargs)

TD(\(\lambda\)) 優勢函式估計器。

GAE(*args, **kwargs)

廣義優勢估計函式的類封裝器。

functional.td0_return_estimate(gamma, ...[, ...])

軌跡的 TD(0) 折扣回報估計。

functional.td0_advantage_estimate(gamma, ...)

軌跡的 TD(0) 優勢估計。

functional.td1_return_estimate(gamma, ...[, ...])

TD(1) 回報估計。

functional.vec_td1_return_estimate(gamma, ...)

向量化 TD(1) 回報估計。

functional.td1_advantage_estimate(gamma, ...)

TD(1) 優勢估計。

functional.vec_td1_advantage_estimate(gamma, ...)

向量化 TD(1) 優勢估計。

functional.td_lambda_return_estimate(gamma, ...)

TD(\(\lambda\)) 回報估計。

functional.vec_td_lambda_return_estimate(...)

向量化 TD(\(\lambda\)) 回報估計。

functional.td_lambda_advantage_estimate(...)

TD(\(\lambda\)) 優勢估計。

functional.vec_td_lambda_advantage_estimate(...)

向量化 TD(\(\lambda\)) 優勢估計。

functional.generalized_advantage_estimate(...)

軌跡的廣義優勢估計。

functional.vec_generalized_advantage_estimate(...)

軌跡的向量化廣義優勢估計。

functional.reward2go(reward, done, gamma, *)

計算給定多條軌跡和情節結束的折扣累積獎勵和。

工具函式

HardUpdate(loss_module, *[, ...])

用於 Double DQN/DDPG 中目標網路硬更新的類(與軟更新相對)。

SoftUpdate(loss_module, *[, eps, tau])

用於 Double DQN/DDPG 中目標網路軟更新的類。

ValueEstimators(value)

用於自定義估計器的價值函式列舉器。

default_value_kwargs(value_type)

預設價值函式關鍵字引數生成器。

distance_loss(v1, v2, loss_function[, ...])

計算兩個張量之間的距離損失。

group_optimizers(*optimizers)

將多個最佳化器組合成一個。

hold_out_net(network)

將網路排除在計算圖之外的上下文管理器。

hold_out_params(params)

將引數列表排除在計算圖之外的上下文管理器。

next_state_value(tensordict[, operator, ...])

計算下一狀態值(無梯度)以計算目標值。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源