快捷方式

TensorDictTokenizer

class torchrl.data.TensorDictTokenizer(tokenizer, max_length, key='text', padding='max_length', truncation=True, return_tensordict=True, device=None)[原始碼]

一個將分詞器應用於文字示例的處理函式工廠。

引數:
  • tokenizer (來自 transformers 庫的分詞器) – 要使用的分詞器。

  • max_length (int) – 序列的最大長度。

  • key (str, 可選) – 查詢文字的鍵。預設為 "text"

  • padding (str, 可選) – 填充型別。預設為 "max_length"

  • truncation (bool, 可選) – 序列是否應截斷到 max_length。

  • return_tensordict (bool, 可選) – 如果為 True,則返回 TensoDict。否則,將返回原始資料。

  • device (torch.device, 可選) – 儲存資料的裝置。如果 return_tensordict=False,則忽略此選項。

有關分詞器的更多資訊,請參閱 transformers 庫

填充和截斷:https://huggingface.tw/docs/transformers/pad_truncation

返回:一個與輸入資料具有相同批大小的 tensordict.TensorDict 例項。

示例

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> tokenizer.pad_token = 100
>>> process = TensorDictTokenizer(tokenizer, max_length=10)
>>> # example with a single input
>>> example = {"text": "I am a little worried"}
>>> process(example)
TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> # example with a multiple inputs
>>> example = {"text": ["Let me reassure you", "It will be ok"]}
>>> process(example)
TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源