get_dataloader¶
- class torchrl.data.get_dataloader(batch_size: int, block_size: int, tensorclass_type: Type, device: torch.device, dataset_name: str | None = None, infinite: bool = True, prefetch: int = 0, split: str = 'train', root_dir: str | None = None, from_disk: bool = False, num_workers: int | None =None)[source]¶
建立資料集並返回其資料載入器。
- 引數:
batch_size (int) – 資料載入器樣本的批大小。
block_size (int) – 資料載入器中序列的最大長度。
tensorclass_type (tensorclass 類) – 具有
from_dataset()方法的 tensorclass 類,該方法必須接受三個關鍵字引數:split(見下文)、max_length(用於訓練的塊大小)和dataset_name(指示資料集的字串)。還應支援root_dir和from_disk引數。device (torch.device 或等效型別) – 應將樣本投射到的裝置。
dataset_name (str, 可選) – 資料集名稱。如果未提供且 tensorclass 支援,則將為所使用的 tensorclass 收集預設資料集名稱。
infinite (bool, 可選) – 如果為
True,迭代將是無限的,以便next(iterator)始終返回值。預設為True。prefetch (int, 可選) – 如果使用多執行緒資料載入,要預取的專案數。
split (str, 可選) – 資料分割。可以是
"train"或"valid"`. 預設為"train"。root_dir (path, 可選) – 資料集儲存的路徑。預設為
"$HOME/.cache/torchrl/data"from_disk (bool, 可選) – 如果為
True,將使用datasets.load_from_disk()。否則,將使用datasets.load_dataset()。預設為False。num_workers (int, 可選) –
datasets.dataset.map()的工作程序數,該方法在 tokenization 期間被呼叫。預設為max(os.cpu_count() // 2, 1)。
示例
>>> from torchrl.data.rlhf.reward import PairwiseDataset >>> dataloader = get_dataloader( ... batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu") >>> for d in dataloader: ... print(d) ... break PairwiseDataset( chosen_data=RewardData( attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([256]), device=cpu, is_shared=False), rejected_data=RewardData( attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([256]), device=cpu, is_shared=False), batch_size=torch.Size([256]), device=cpu, is_shared=False)