RolloutFromModel¶
- class torchrl.data.RolloutFromModel(model, ref_model, reward_model, kl_coef=0.1, max_new_tokens=50, score_clip=10.0, kl_scheduler: KLControllerBase | None = None, num_steps: int | None = None)[原始碼]¶
用於對因果語言模型執行 rollout 的類。
假設此類封裝的模型接收分詞後的文字作為輸入,其任務是根據已讀取的前 n 個詞預測句子中的下一個詞。
- 引數:
model (transformers.Transformer) – 要使用的模型。應具有
generate()方法。ref_model (transformers.Transformer) –
model的凍結版本,其引數處於初始配置。這用於計算獎勵的 KL 懲罰,以防止模型在訓練期間偏離參考模型太遠。reward_model – (nn.Module, tensordict.nn.TensorDictModule): 一個模型,給定
input_ids和attention_mask,計算每個 token 的獎勵以及 end_scores(每個序列中最後一個 token 的獎勵)。kl_coef – (
float, 可選): 初始 kl 係數。max_new_tokens (int, 可選) – 序列的最大長度。預設為 50。
score_clip (
float, 可選) – 獎勵模型的分數將被裁剪到(-score_clip, score_clip)範圍內。預設為 10。kl_scheduler (KLControllerBase, 可選) – KL 係數排程器。
num_steps (int, 可選) – 兩次最佳化之間的步數。
示例
>>> from tensordict.nn import TensorDictModule >>> from torchrl.modules.models.rlhf import GPT2RewardModel >>> from torchrl.data.rlhf.utils import RolloutFromModel >>> from torchrl.data.rlhf.dataset import get_dataloader >>> from torchrl.data.rlhf.prompt import PromptData >>> from transformers import GPT2LMHeadModel >>> >>> dl = get_dataloader( ... batch_size=4, ... block_size=550, ... tensorclass_type=PromptData, ... device="cpu", ... dataset_name="CarperAI/openai_summarize_tldr", ... ) >>> model = GPT2LMHeadModel.from_pretrained("gpt2") >>> # we load ref_model with random weights so it differs from model >>> ref_model = GPT2LMHeadModel(GPT2LMHeadModel.config_class()) >>> reward_model = GPT2RewardModel(model_path="gpt2") >>> rollout_from_model = RolloutFromModel(model, ref_model, reward_model) >>> >>> batch = next(dl) >>> rollout = rollout_from_model.rollout_from_data(batch) >>> rollout TensorDict( fields={ action: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False), attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False), input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False), done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False), input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False), reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), reward_kl: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), reward_raw: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 50]), device=cpu, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 50]), device=cpu, is_shared=False)
- create_rollout_td(batch, generated, log_probs, log_ratio)[原始碼]¶
一個用於生成資料的 TensorDict 包裝器。
此函式接收一個 batch 以及生成的 tokens,並複製了從每次 timestep 取樣一個 token 的 TorchRL 環境 rollout 中獲得的 tensordict 結構。
- 引數:
batch (TensorDict) – 包含原始 prompt 以及指示 prompt 右側索引的欄位“rindex”的資料批次。
generated (torch.Tensor) – 分詞後的 prompt 後面跟著生成的 tokens。這可以透過呼叫
generate方法獲得。log_probs (torch.Tensor) – 生成 tokens 的對數機率。這可以透過呼叫
generate方法獲得。log_ratio (torch.Tensor) – 根據生成模型和凍結版本計算的生成 tokens 機率的對數比率。這可以透過呼叫
generate方法獲得。
- 返回:
"action": 動作序列(生成的 tokens)"input_ids": 在每個時間步傳遞給生成模型的 input_ids。"attention_mask": 在每個時間步傳遞給生成模型的 attention_mask"sample_log_prob": 生成過程中每個 token 的對數機率("next", "input_ids"): 生成後的 tokens 序列。構成用於生成下一個 token 的輸入的一部分。("next", "attention_mask"): token 生成後更新的 attention_mask。在下一個時間步傳遞給生成模型("next", "terminated"): 布林陣列,指示是否已達到終止狀態(是因為生成了 EOS token 還是因為達到了 token 限制)("next", "done"): 布林陣列,指示是否已達到最終狀態。當前是"terminated"的副本。("next", "reward"): 在每個時間步收到的獎勵("next", "reward_raw"): 來自獎勵模型的原始獎勵,不包含 KL 項。這主要用於除錯和日誌記錄,不用於訓練。("next", "reward_kl"): 來自獎勵的 KL 項。這主要用於除錯和日誌記錄,不用於訓練。
- 返回型別:
一個
TensorDict,包含以下鍵
- generate(batch: PromptData, generation_config=None)[原始碼]¶
從資料收集器取樣的批次資料中生成 tokens 序列。
- 引數:
batch (PromptData) – 要使用的資料。必須包含
input_ids和prompt_rindex欄位。generation_config (GenerationConfig, 可選) – 呼叫 generate 的配置。
- 返回:
- 一個 [B x (Ti +To)] 整數序列 (tokens),
其中 Ti 是輸入序列的長度,To 是生成序列的長度。
log_probs_gen: 生成 token 的對數機率。log_ratio: 生成
模型下機率與凍結版本之間機率的對數比率。
- 返回型別:
generated (torch.Tensor)