• 文件 >
  • 模型最佳化入門
快捷方式

模型最佳化入門

作者: Vincent Moens

注意

要在 Notebook 中執行此教程,請在開頭新增一個安裝單元,其中包含

!pip install tensordict
!pip install torchrl

在 TorchRL 中,我們嘗試以 PyTorch 慣用的方式處理最佳化,使用專門的損失模組,這些模組的唯一目的就是最佳化模型。這種方法有效地將策略的執行與其訓練解耦開來,並允許我們設計與傳統監督學習示例中相似的訓練迴圈。

因此,典型的訓練迴圈如下所示

..code - block::Python

>>> for i in range(n_collections):
...     data = get_next_batch(env, policy)
...     for j in range(n_optim):
...         loss = loss_fn(data)
...         loss.backward()
...         optim.step()

在這個簡潔的教程中,您將簡要了解損失模組。由於 API 在基本用法上通常很直接,本教程將保持簡短。

RL 目標函式

在強化學習 (RL) 中,創新通常涉及探索最佳化策略的新方法(即新演算法),而不是像在其他領域那樣專注於新架構。在 TorchRL 中,這些演算法被封裝在損失模組中。一個損失模組協調您演算法的各個組成部分,併產生一組損失值,這些值可以透過反向傳播來訓練相應的組成部分。

在本教程中,我們將以一種流行的離策略演算法 DDPG 作為示例。

要構建損失模組,唯一需要的是一組定義為 :class:`~tensordict.nn.TensorDictModule` 的網路。大多數時候,其中一個模組將是策略。可能還需要其他輔助網路,例如 Q 值網路或某種評論家網路。讓我們看看這在實踐中是什麼樣子的:DDPG 需要一個從觀測空間到動作空間的確定性對映,以及一個預測狀態-動作對值的價值網路。DDPG 損失函式將嘗試找到能夠輸出在給定狀態下最大化值的動作的策略引數。

要構建損失函式,我們需要 Actor 網路和價值網路。如果它們是根據 DDPG 的期望構建的,那麼這就是我們獲得可訓練損失模組所需的全部內容。

from torchrl.envs import GymEnv

env = GymEnv("Pendulum-v1")

from torchrl.modules import Actor, MLP, ValueOperator
from torchrl.objectives import DDPGLoss

n_obs = env.observation_spec["observation"].shape[-1]
n_act = env.action_spec.shape[-1]
actor = Actor(MLP(in_features=n_obs, out_features=n_act, num_cells=[32, 32]))
value_net = ValueOperator(
    MLP(in_features=n_obs + n_act, out_features=1, num_cells=[32, 32]),
    in_keys=["observation", "action"],
)

ddpg_loss = DDPGLoss(actor_network=actor, value_network=value_net)

就是這樣!我們的損失模組現在可以使用來自環境的資料運行了(我們省略了探索、儲存和其他功能,以專注於損失函式的功能)

rollout = env.rollout(max_steps=100, policy=actor)
loss_vals = ddpg_loss(rollout)
print(loss_vals)

LossModule 的輸出

如您所見,我們從損失模組獲得的值不是一個單一的標量,而是一個包含多個損失的字典。

原因很簡單:因為可能同時訓練多個網路,並且由於一些使用者可能希望在不同步驟中分開最佳化每個模組,TorchRL 的目標函式將返回包含各種損失組成部分的字典。

這種格式還允許我們 همراه 損失值傳遞元資料。一般來說,我們確保只有損失值是可微分的,這樣您就可以簡單地對字典中的值求和以獲得總損失。如果您想確保完全控制正在發生的事情,您可以僅對鍵以 "loss_" 字首開頭的條目求和。

total_loss = 0
for key, val in loss_vals.items():
    if key.startswith("loss_"):
        total_loss += val

訓練 LossModule

鑑於這一切,訓練模組與在任何其他訓練迴圈中所做的沒有太大區別。因為它封裝了模組,獲取可訓練引數列表的最簡單方法是呼叫 parameters() 方法。

我們將需要一個最佳化器(如果您的選擇是每個模組一個最佳化器)。

from torch.optim import Adam

optim = Adam(ddpg_loss.parameters())
total_loss.backward()

以下專案通常會在您的訓練迴圈中找到

optim.step()
optim.zero_grad()

進一步考慮:目標引數

另一個重要的方面需要考慮的是離策略演算法(如 DDPG)中目標引數的存在。目標引數通常代表引數隨時間的延遲或平滑版本,它們在策略訓練期間的價值估計中起著至關重要的作用。與使用價值網路引數的當前配置相比,利用目標引數進行策略訓練通常會顯著提高效率。一般來說,目標引數的管理由損失模組處理,減輕了使用者的直接顧慮。但是,根據具體要求更新這些值仍然是使用者的責任。TorchRL 提供了一些更新器,即 HardUpdateSoftUpdate,它們可以輕鬆例項化,無需深入瞭解損失模組的底層機制。

from torchrl.objectives import SoftUpdate

updater = SoftUpdate(ddpg_loss, eps=0.99)

在您的訓練迴圈中,您需要在每個最佳化步驟或每個收集步驟中更新目標引數

updater.step()

這就是關於損失模組您入門所需瞭解的全部內容!

要進一步探索該主題,請檢視

  • 損失模組參考頁;

  • 編碼 DDPG 損失函式的教程;

  • PPODQN 中實際應用的損失函式。

由 Sphinx-Gallery 生成的畫廊

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源