快捷方式

torchrl.trainers 包

trainer 包提供了編寫可重用訓練指令碼的工具。核心思想是使用一個實現巢狀迴圈的訓練器,其中外層迴圈執行資料收集步驟,內層迴圈執行最佳化步驟。我們認為這適用於多種強化學習訓練方案,例如 on-policy、off-policy、基於模型和無模型的解決方案、離線 RL 等。更特殊的案例,例如 meta-RL 演算法,其訓練方案可能存在顯著差異。

trainer.train() 方法的示意圖如下

訓練器迴圈
        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

     >>> for batch in collector:
     ...     batch = self._process_batch_hook(batch)  # "batch_process"
     ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
     ...     self._pre_optim_hook()  # "pre_optim_steps"
     ...     for j in range(self.optim_steps_per_batch):
     ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
     ...         losses = self.loss_module(sub_batch)
     ...         self._post_loss_hook(sub_batch)  # "post_loss"
     ...         self.optimizer.step()
     ...         self.optimizer.zero_grad()
     ...         self._post_optim_hook()  # "post_optim"
     ...         self._post_optim_log(sub_batch)  # "post_optim_log"
     ...     self._post_steps_hook()  # "post_steps"
     ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

訓練器迴圈中可以使用 10 個鉤子:"batch_process", "pre_optim_steps", "process_optim_batch", "post_loss", "post_steps", "post_optim", "pre_steps_log", "post_steps_log", "post_optim_log""optimizer"。它們在應用的地方已在註釋中註明。鉤子可分為 3 類:資料處理 ("batch_process""process_optim_batch"),日誌記錄 ("pre_steps_log", "post_optim_log""post_steps_log") 和 操作 鉤子 ("pre_optim_steps", "post_loss", "post_optim""post_steps")。

  • 資料處理 鉤子用於更新資料的 TensorDict。鉤子的 __call__ 方法應接受一個 TensorDict 物件作為輸入,並根據某種策略對其進行更新。這類鉤子的示例包括回放緩衝區擴充套件 (ReplayBufferTrainer.extend)、資料歸一化(包括歸一化常數更新)、資料子取樣(:class:~torchrl.trainers.BatchSubSampler)等。

  • 日誌記錄 鉤子接收表示為 TensorDict 的資料批次,並在記錄器中寫入從該資料中檢索到的資訊。示例包括 LogValidationReward 鉤子、獎勵記錄器 (LogScalar) 等。鉤子應返回一個包含要記錄的資料的字典(或 None 值)。鍵 "log_pbar" 保留用於布林值,指示記錄的值是否應顯示在訓練日誌中列印的進度條上。

  • 操作 鉤子是在模型、資料收集器、目標網路更新等方面執行特定操作的鉤子。例如,使用 UpdateWeights 同步收集器的權重,或使用 ReplayBufferTrainer.update_priority 更新回放緩衝區的優先順序,都是操作鉤子的示例。它們是資料獨立的(不需要 TensorDict 輸入),只需在每次迭代(或每 N 次迭代)執行一次即可。

TorchRL 提供的鉤子通常繼承自一個共同的抽象基類 TrainerHookBase,並且都實現了三個基本方法:用於檢查點儲存的 state_dictload_state_dict 方法,以及在訓練器中以預設值註冊鉤子的 register 方法。此方法接收訓練器和模組名稱作為輸入。例如,以下日誌記錄鉤子在每次呼叫 "post_optim_log" 10 次後執行

>>> class LoggingHook(TrainerHookBase):
...     def __init__(self):
...         self.counter = 0
...
...     def register(self, trainer, name):
...         trainer.register_module(self, "logging_hook")
...         trainer.register_op("post_optim_log", self)
...
...     def save_dict(self):
...         return {"counter": self.counter}
...
...     def load_state_dict(self, state_dict):
...         self.counter = state_dict["counter"]
...
...     def __call__(self, batch):
...         if self.counter % 10 == 0:
...             self.counter += 1
...             out = {"some_value": batch["some_value"].item(), "log_pbar": False}
...         else:
...             out = None
...         self.counter += 1
...         return out

檢查點儲存

訓練器類和鉤子支援檢查點儲存,可以透過使用 torchsnapshot 後端或常規的 torch 後端來實現。這可以透過全域性變數 CKPT_BACKEND 控制

$ CKPT_BACKEND=torchsnapshot python script.py

CKPT_BACKEND 預設為 torch。torchsnapshot 相對於 pytorch 的優勢在於它是一個更靈活的 API,支援分散式檢查點儲存,並且允許使用者將儲存在磁碟檔案中的張量載入到具有物理儲存的張量中(pytorch 目前不支援此功能)。例如,這使得可以將張量從和載入到原本無法容納在記憶體中的回放緩衝區中。

構建訓練器時,可以提供檢查點儲存路徑。對於 torchsnapshot 後端,期望的是目錄路徑,而 torch 後端期望的是檔案路徑(通常是 .pt 檔案)。

>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
...     collector=collector,
...     total_frames=total_frames,
...     frame_skip=frame_skip,
...     loss_module=loss_module,
...     optimizer=optimizer,
...     save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)

Trainer.train() 方法可用於執行上述包含所有鉤子的迴圈,儘管僅將 Trainer 類用於其檢查點儲存功能也是完全有效的用法。

訓練器和鉤子

BatchSubSampler(batch_size[, sub_traj_len, ...])

用於線上 RL 最新實現的資料子取樣器。

ClearCudaCache(interval)

按給定間隔清除 cuda 快取。

CountFramesLog(*args, **kwargs)

一個幀計數鉤子。

LogScalar([logname, log_pbar, reward_key])

獎勵記錄器鉤子。

OptimizerHook(optimizer[, loss_components])

為一個或多個損失元件新增最佳化器。

LogValidationReward(*, record_interval, ...)

Trainer 的記錄器鉤子。

ReplayBufferTrainer(replay_buffer[, ...])

回放緩衝區鉤子提供者。

RewardNormalizer([decay, scale, eps, ...])

獎勵歸一化器鉤子。

SelectKeys(keys)

選擇 TensorDict 批次中的鍵。

Trainer(*args, **kwargs)

一個通用 Trainer 類。

TrainerHookBase()

用於 torchrl Trainer 類的抽象鉤子類。

UpdateWeights(collector, update_weights_interval)

一個收集器權重更新鉤子類。

構建器

make_collector_offpolicy(make_env, ...[, ...])

返回用於 off-policy 最新實現的資料收集器。

make_collector_onpolicy(make_env, ...[, ...])

在 on-policy 設定中建立一個收集器。

make_dqn_loss(model, cfg)

構建 DQN 損失模組。

make_replay_buffer(device, cfg)

使用從 ReplayArgsConfig 構建的配置構建回放緩衝區。

make_target_updater(cfg, loss_module)

構建一個目標網路權重更新物件。

make_trainer(collector, loss_module[, ...])

根據其組成部分建立一個 Trainer 例項。

parallel_env_constructor(cfg, **kwargs)

從使用適當的解析器構建器構建的 argparse.Namespace 返回一個並行環境。

sync_async_collector(env_fns, env_kwargs[, ...])

執行非同步收集器,每個收集器運行同步環境。

sync_sync_collector(env_fns, env_kwargs[, ...])

運行同步收集器,每個收集器運行同步環境。

transformed_env_constructor(cfg[, ...])

從使用適當的解析器構建器構建的 argparse.Namespace 返回一個環境建立器。

工具函式

correct_for_frame_skip(cfg)

根據輸入的 frame_skip 調整引數,將所有反映幀計數目的的引數除以 frame_skip。

get_stats_random_rollout(cfg[, ...])

使用隨機 rollout 從環境中收集統計資料(loc 和 scale)。

記錄器

Logger(exp_name, log_dir)

記錄器的模板。

csv.CSVLogger(exp_name[, log_dir, ...])

一個最小依賴的 CSV 記錄器。

mlflow.MLFlowLogger(exp_name, tracking_uri)

mlflow 記錄器的包裝器。

tensorboard.TensorboardLogger(exp_name[, ...])

Tensorboard 記錄器的包裝器。

wandb.WandbLogger(*args, **kwargs)

wandb 記錄器的包裝器。

get_logger(logger_type, logger_name, ...)

獲取指定 logger_type 的記錄器例項。

generate_exp_name(model_name, experiment_name)

使用 UUID 和當前日期為描述的實驗生成一個 ID (str)。

錄製工具函式

錄製工具函式在此處詳細介紹。

VideoRecorder(logger, tag[, in_keys, skip, ...])

影片錄製器 transform。

TensorDictRecorder(out_file_base[, ...])

TensorDict 記錄器。

PixelRenderTransform([out_keys, preproc, ...])

一個 transform,用於在父環境中呼叫 render,並在 tensordict 中註冊畫素觀察結果。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源