快捷方式

TokenizedDatasetLoader

class torchrl.data.TokenizedDatasetLoader(split, max_length, dataset_name, tokenizer_fn: Type[TensorDictTokenizer], pre_tokenization_hook=None, root_dir=None, from_disk=False, valid_size: int = 2000, num_workers: Optional[int] = None, tokenizer_class=None, tokenizer_model_name=None)[source]

載入一個分詞後的資料集,並快取一個記憶體對映的副本。

引數:
  • split (str) – 取值可以是 "train""valid"

  • max_length (int) – 最大序列長度。

  • dataset_name (str) – 資料集的名稱。

  • tokenizer_fn (callable) – 分詞方法的建構函式,例如 torchrl.data.rlhf.TensorDictTokenizer。呼叫時,應返回一個 tensordict.TensorDict 例項或一個包含分詞資料的字典狀結構。

  • pre_tokenization_hook (callable, optional) – 在分詞之前在資料集上呼叫。它應返回一個修改後的 Dataset 物件。其預期用途是執行需要修改整個資料集而不是修改單個數據點的任務,例如根據特定條件丟棄某些資料點。資料的分詞和其他“按元素”操作由對映到資料集上的處理函式執行。

  • root_dir (path, optional) – 資料集儲存的路徑。預設為 "$HOME/.cache/torchrl/data"

  • from_disk (bool, optional) – 如果為 True,則使用 datasets.load_from_disk()。否則,使用 datasets.load_dataset()。預設為 False

  • valid_size (int, optional) – 驗證資料集的大小(如果 split 以 "valid" 開頭)將被截斷為此值。預設為 2000 項。

  • num_workers (int, optional) – datasets.dataset.map() 的工作程序數,該方法在分詞期間呼叫。預設為 max(os.cpu_count() // 2, 1)

  • tokenizer_class (Type, optional) – 一個分詞器類,例如 AutoTokenizer(預設)。

  • tokenizer_model_name (str, optional) – 從中收集詞彙表的模型。預設為 "gpt2"

資料集將儲存在 <root_dir>/<split>/<max_length>/ 中。

示例

>>> from torchrl.data.rlhf import TensorDictTokenizer
>>> from torchrl.data.rlhf.reward import  pre_tokenization_hook
>>> split = "train"
>>> max_length = 550
>>> dataset_name = "CarperAI/openai_summarize_comparisons"
>>> loader = TokenizedDatasetLoader(
...     split,
...     max_length,
...     dataset_name,
...     TensorDictTokenizer,
...     pre_tokenization_hook=pre_tokenization_hook,
... )
>>> dataset = loader.load()
>>> print(dataset)
TensorDict(
    fields={
        attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([185068]),
    device=None,
    is_shared=False)
static dataset_to_tensordict(dataset: 'datasets.Dataset' | TensorDict, data_dir: Path, prefix: NestedKey = None, features: Sequence[str] =None, batch_dims=1, valid_mask_key=None)[source]

將資料集轉換為記憶體對映的 TensorDict。

如果資料集已經是 TensorDict 例項,則直接將其轉換為記憶體對映的 TensorDict。否則,資料集應具有一個 features 屬性,該屬性是一個字串序列,指示資料集中可以找到的特徵。如果沒有該屬性,則必須顯式地將 features 傳遞給此函式。

引數:
  • dataset (datasets.Dataset, TensorDict or equivalent) – 要轉換為記憶體對映 TensorDict 的資料集。如果 featuresNone,則必須具有一個 features 屬性,其中包含要寫入 tensordict 的鍵列表。

  • data_dir (Path or equivalent) – 應寫入資料的目錄。

  • prefix (NestedKey, optional) – 資料集位置的字首。這可用於區分經過不同預處理的同一資料集的多個副本。

  • features (sequence of str, optional) – 一個字串序列,指示可以在資料集中找到的特徵。

  • batch_dims (int, optional) – 資料的批處理維度數(即 tensordict 可以沿其索引的維度數)。預設為 1。

  • valid_mask_key (NestedKey, optional) – 如果提供,將嘗試收集此條目並用於過濾資料。預設為 None(即沒有過濾鍵)。

返回值: 一個包含資料集記憶體對映張量的 TensorDict。

示例

>>> from datasets import Dataset
>>> import tempfile
>>> data = Dataset.from_dict({"tokens": torch.randint(20, (10, 11)), "labels": torch.zeros(10, 11)})
>>> with tempfile.TemporaryDirectory() as tmpdir:
...     data_memmap = TokenizedDatasetLoader.dataset_to_tensordict(
...         data, data_dir=tmpdir, prefix=("some", "prefix"), features=["tokens", "labels"]
...     )
...     print(data_memmap)
TensorDict(
    fields={
        some: TensorDict(
            fields={
                prefix: TensorDict(
                    fields={
                        labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
                        tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([10]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
load()[source]

如果存在預處理的記憶體對映資料集,則載入它,否則建立它。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源