快捷方式

Trainer

class torchrl.trainers.Trainer(*args, **kwargs)[]

一個通用的 Trainer 類。

一個 trainer 負責收集資料和訓練模型。為了使此類儘可能通用,Trainer 不構建任何其特定的操作:所有操作都必須在訓練迴圈中的特定點進行鉤入 (hooked)。

要構建一個 Trainer,需要一個可迭代的資料來源(一個 collector)、一個損失模組和一個最佳化器。

引數:
  • collector (Sequence[TensorDictBase]) – 一個可迭代物件,返回形狀為 [batch x time steps] 的 TensorDict 形式的資料批次。

  • total_frames (int) – 訓練期間收集的總幀數。

  • loss_module (LossModule) – 一個模組,讀取 TensorDict 批次(可能從回放緩衝區中取樣),並返回一個損失 TensorDict,其中每個鍵指向不同的損失分量。

  • optimizer (optim.Optimizer) – 一個訓練模型引數的最佳化器。

  • logger (Logger, optional) – 一個將處理日誌記錄的 Logger。

  • optim_steps_per_batch (int) – 每次資料收集的最佳化步數。一個 trainer 工作方式如下:一個主迴圈收集資料批次(epoch loop),一個子迴圈(training loop)在兩次資料收集之間執行模型更新。

  • clip_grad_norm (bool, optional) – 如果為 True,梯度將根據模型引數的總範數進行裁剪。如果為 False,所有偏導數將被限制在 (-clip_norm, clip_norm) 之間。預設為 True

  • clip_norm (Number, optional) – 用於裁剪梯度的值。預設為 None(不進行範數裁剪)。

  • progress_bar (bool, optional) – 如果為 True,將使用 tqdm 顯示進度條。如果未安裝 tqdm,此選項將不起作用。預設為 True

  • seed (int, optional) – 用於 collector、pytorch 和 numpy 的種子。預設為 None

  • save_trainer_interval (int, optional) – trainer 應多久儲存到磁碟一次,以幀數計。預設為 10000。

  • log_interval (int, optional) – 應多久記錄一次值,以幀數計。預設為 10000。

  • save_trainer_file (path, optional) – 儲存 trainer 的路徑。預設為 None(不儲存)。

load_from_file(file: Union[str, Path], **kwargs) Trainer[]

載入檔案及其 state-dict 到 trainer 中。

關鍵字引數傳遞給 load() 函式。


© 版權所有 2022,Meta。

使用 Sphinx 構建,主題由 Read the Docs 提供。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源