雜湊¶
- class torchrl.envs.transforms.Hash(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], *, hash_fn: Callable = None, seed: Any | None = None, use_raw_nontensor: bool = False)[源]¶
向 tensordict 中新增雜湊值。
- 引數:
in_keys (巢狀鍵序列) – 要進行雜湊處理的值對應的鍵。
out_keys (巢狀鍵序列) – 生成的雜湊值對應的鍵。
in_keys_inv (巢狀鍵序列, 可選) –
在 inv 呼叫期間要進行雜湊處理的值對應的鍵。
注意
如果需要反向對映,應與鍵列表一起傳遞一個雜湊到值的對映集
Dict[Tuple[int], Any],以便讓Hashtransform 知道如何從給定的雜湊中恢復值。此對映集不會被複制,因此在 transform 例項化後可以在同一工作區中修改它,並且這些修改將反映在對映中。缺失的雜湊將被對映到None。out_keys_inv (巢狀鍵序列, 可選) – 在 inv 呼叫期間生成的雜湊值對應的鍵。
- 關鍵字引數:
hash_fn (可呼叫物件, 可選) – 要使用的雜湊函式。如果提供了
seed,則雜湊函式必須接受它作為第二個引數。預設值為Hash.reproducible_hash。seed (可選) – 雜湊函式要使用的種子,如果需要的話。
use_raw_nontensor (bool, 可選) – 如果為
False,則在對fn呼叫NonTensorData/NonTensorStack輸入之前,會從中提取資料。如果為True,則直接將原始NonTensorData/NonTensorStack輸入提供給fn,fn必須支援這些輸入。預設值為False。Hash (>>> from torchrl.envs import GymEnv, UnaryTransform,) –
GymEnv (>>> env =) –
output (>>> # 處理字串) –
env.append_transform( (>>> env =) –
UnaryTransform( (...) –
in_keys=["observation"], (...) –
out_keys=["observation_str"], (...) –
tensor (... fn=lambda) – str(tensor.numpy().tobytes())))
output –
env.append_transform( –
Hash( (...) –
in_keys=["observation_str"], (...) –
out_keys=["observation_hash"],) (...) –
) (...) –
env.observation_spec (>>>) –
Composite( –
- observation: BoundedContinuous(
shape=torch.Size([3]), space=ContinuousBox(
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu, dtype=torch.float32, domain=continuous),
- observation_str: NonTensor(
shape=torch.Size([]), space=None, device=cpu, dtype=None, domain=None),
- observation_hash: UnboundedDiscrete(
shape=torch.Size([32]), space=ContinuousBox(
low=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True), high=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True)),
device=cpu, dtype=torch.uint8, domain=discrete),
device=None, shape=torch.Size([]))
env.rollout (>>>) –
TensorDict( –
- fields={
action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict(
- fields={
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), observation_str: NonTensorStack(
[“b’g\x08\x8b\xbexav\xbf\x00\xee(>’”, “b’\x…, batch_size=torch.Size([3]), device=None),
reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]), device=None, is_shared=False),
observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), observation_str: NonTensorStack(
[“b’\xb5\x17\x8f\xbe\x88\xccu\xbf\xc0Vr?’”…, batch_size=torch.Size([3]), device=None),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]), device=None, is_shared=False)
env.check_env_specs() (>>>) –
succeeded! ([torchrl][INFO] check_env_specs)
- classmethod reproducible_hash(string, seed=None)[源]¶
使用種子從字串建立可重現的 256 位雜湊。
- 引數:
string (str or None) – 輸入字串。如果為
None,則使用空字串""。seed (str, 可選) – 種子值。預設值為
None。
- 返回:
形狀為
(32,),dtype 為torch.uint8的張量。- 返回型別:
Tensor
- state_dict(*args, destination=None, prefix='', keep_vars=False)[源]¶
返回一個包含對模組整個狀態的引用的字典。
包含引數和永續性緩衝區(例如,執行平均值)。鍵是相應的引數和緩衝區名稱。設定為
None的引數和緩衝區不包含在內。注意
返回的物件是淺複製。它包含對模組引數和緩衝區的引用。
警告
目前
state_dict()也按順序接受destination、prefix和keep_vars的位置引數。然而,這已被棄用,未來版本將強制使用關鍵字引數。警告
請避免使用引數
destination,因為它不是為終端使用者設計的。- 引數:
destination (dict, 可選) – 如果提供,模組的狀態將更新到此字典中並返回同一物件。否則,將建立一個並返回
OrderedDict。預設值:None。prefix (str, 可選) – 新增到引數和緩衝區名稱前的字串,用於構成 state_dict 中的鍵。預設值:
''。keep_vars (bool, 可選) – 預設情況下,state dict 中返回的
Tensor會從 autograd 中分離。如果設定為True,則不會執行分離操作。預設值:False。
- 返回:
包含模組整個狀態的字典
- 返回型別:
dict
示例
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']