快捷方式

雜湊

class torchrl.envs.transforms.Hash(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], *, hash_fn: Callable = None, seed: Any | None = None, use_raw_nontensor: bool = False)[源]

向 tensordict 中新增雜湊值。

引數:
  • in_keys (巢狀鍵序列) – 要進行雜湊處理的值對應的鍵。

  • out_keys (巢狀鍵序列) – 生成的雜湊值對應的鍵。

  • in_keys_inv (巢狀鍵序列, 可選) –

    在 inv 呼叫期間要進行雜湊處理的值對應的鍵。

    注意

    如果需要反向對映,應與鍵列表一起傳遞一個雜湊到值的對映集 Dict[Tuple[int], Any],以便讓 Hash transform 知道如何從給定的雜湊中恢復值。此對映集不會被複制,因此在 transform 例項化後可以在同一工作區中修改它,並且這些修改將反映在對映中。缺失的雜湊將被對映到 None

  • out_keys_inv (巢狀鍵序列, 可選) – 在 inv 呼叫期間生成的雜湊值對應的鍵。

關鍵字引數:
  • hash_fn (可呼叫物件, 可選) – 要使用的雜湊函式。如果提供了 seed,則雜湊函式必須接受它作為第二個引數。預設值為 Hash.reproducible_hash

  • seed (可選) – 雜湊函式要使用的種子,如果需要的話。

  • use_raw_nontensor (bool, 可選) – 如果為 False,則在對 fn 呼叫 NonTensorData/NonTensorStack 輸入之前,會從中提取資料。如果為 True,則直接將原始 NonTensorData/NonTensorStack 輸入提供給 fnfn 必須支援這些輸入。預設值為 False

  • Hash (>>> from torchrl.envs import GymEnv, UnaryTransform,) –

  • GymEnv (>>> env =) –

  • output (>>> # 處理字串) –

  • env.append_transform( (>>> env =) –

  • UnaryTransform( (...) –

  • in_keys=["observation"], (...) –

  • out_keys=["observation_str"], (...) –

  • tensor (... fn=lambda) – str(tensor.numpy().tobytes())))

  • output

  • env.append_transform(

  • Hash( (...) –

  • in_keys=["observation_str"], (...) –

  • out_keys=["observation_hash"],) (...) –

  • ) (...) –

  • env.observation_spec (>>>) –

  • Composite(

    observation: BoundedContinuous(

    shape=torch.Size([3]), space=ContinuousBox(

    low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),

    device=cpu, dtype=torch.float32, domain=continuous),

    observation_str: NonTensor(

    shape=torch.Size([]), space=None, device=cpu, dtype=None, domain=None),

    observation_hash: UnboundedDiscrete(

    shape=torch.Size([32]), space=ContinuousBox(

    low=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True), high=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True)),

    device=cpu, dtype=torch.uint8, domain=discrete),

    device=None, shape=torch.Size([]))

  • env.rollout (>>>) –

  • TensorDict(

    fields={

    action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict(

    fields={

    done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), observation_str: NonTensorStack(

    [“b’g\x08\x8b\xbexav\xbf\x00\xee(>’”, “b’\x…, batch_size=torch.Size([3]), device=None),

    reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},

    batch_size=torch.Size([3]), device=None, is_shared=False),

    observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), observation_str: NonTensorStack(

    [“b’\xb5\x17\x8f\xbe\x88\xccu\xbf\xc0Vr?’”…, batch_size=torch.Size([3]), device=None),

    terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},

    batch_size=torch.Size([3]), device=None, is_shared=False)

  • env.check_env_specs() (>>>) –

  • succeeded! ([torchrl][INFO] check_env_specs)

classmethod reproducible_hash(string, seed=None)[源]

使用種子從字串建立可重現的 256 位雜湊。

引數:
  • string (str or None) – 輸入字串。如果為 None,則使用空字串 ""

  • seed (str, 可選) – 種子值。預設值為 None

返回:

形狀為 (32,),dtype 為 torch.uint8 的張量。

返回型別:

Tensor

state_dict(*args, destination=None, prefix='', keep_vars=False)[源]

返回一個包含對模組整個狀態的引用的字典。

包含引數和永續性緩衝區(例如,執行平均值)。鍵是相應的引數和緩衝區名稱。設定為 None 的引數和緩衝區不包含在內。

注意

返回的物件是淺複製。它包含對模組引數和緩衝區的引用。

警告

目前 state_dict() 也按順序接受 destinationprefixkeep_vars 的位置引數。然而,這已被棄用,未來版本將強制使用關鍵字引數。

警告

請避免使用引數 destination,因為它不是為終端使用者設計的。

引數:
  • destination (dict, 可選) – 如果提供,模組的狀態將更新到此字典中並返回同一物件。否則,將建立一個並返回 OrderedDict。預設值:None

  • prefix (str, 可選) – 新增到引數和緩衝區名稱前的字串,用於構成 state_dict 中的鍵。預設值:''

  • keep_vars (bool, 可選) – 預設情況下,state dict 中返回的 Tensor 會從 autograd 中分離。如果設定為 True,則不會執行分離操作。預設值:False

返回:

包含模組整個狀態的字典

返回型別:

dict

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源