快捷方式

UnaryTransform

class torchrl.envs.transforms.UnaryTransform(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, *, fn: Callable[[Any], Tensor | TensorDictBase], inv_fn: Callable[[Any], Any] | None = None, use_raw_nontensor: bool = False)[source]

對指定的輸入應用一元操作。

引數:
  • in_keys (sequence of NestedKey) – 一元操作的輸入鍵。

  • out_keys (sequence of NestedKey) – 一元操作的輸出鍵。

  • in_keys_inv (sequence of NestedKey, optional) – 逆向呼叫期間一元操作的輸入鍵。

  • out_keys_inv (sequence of NestedKey, optional) – 逆向呼叫期間一元操作的輸出鍵。

關鍵字引數:
  • fn (Callable[[Any], Tensor | TensorDictBase]) – 用作一元操作的函式。如果它接受非張量輸入,它也必須接受 None

  • inv_fn (Callable[[Any], Any], optional) – 在逆向呼叫期間用作一元操作的函式。如果它接受非張量輸入,它也必須接受 None。可以省略,在這種情況下 fn 將用於逆向對映。

  • use_raw_nontensor (bool, optional) – 如果為 False,則在呼叫 fn 之前,從 NonTensorData/NonTensorStack 輸入中提取資料。如果為 True,則直接將原始 NonTensorData/NonTensorStack 輸入提供給 fn,它必須支援這些輸入。預設為 False

示例

>>> from torchrl.envs import GymEnv, UnaryTransform
>>> env = GymEnv("Pendulum-v1")
>>> env = env.append_transform(
...     UnaryTransform(
...         in_keys=["observation"],
...         out_keys=["observation_trsf"],
...             fn=lambda tensor: str(tensor.numpy().tobytes())))
>>> env.observation_spec
Composite(
    observation: BoundedContinuous(
        shape=torch.Size([3]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    observation_trsf: NonTensor(
        shape=torch.Size([]),
        space=None,
        device=cpu,
        dtype=None,
        domain=None),
    device=None,
    shape=torch.Size([]))
>>> env.rollout(3)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                observation_trsf: NonTensorStack(
                    ["b'\\xbe\\xbc\\x7f?8\\x859=/\\x81\\xbe;'", "b'\\x...,
                    batch_size=torch.Size([3]),
                    device=None),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        observation_trsf: NonTensorStack(
            ["b'\\x9a\\xbd\\x7f?\\xb8T8=8.c>'", "b'\\xbe\\xbc\...,
            batch_size=torch.Size([3]),
            device=None),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> env.check_env_specs()
[torchrl][INFO] check_env_specs succeeded!
transform_action_spec(action_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec[source]

轉換 action spec 以便生成的 spec 與 transform 對映匹配。

引數:

action_spec (TensorSpec) – transform 前的 spec

返回:

transform 後的預期 spec

transform_done_spec(done_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[source]

轉換 done spec 以便生成的 spec 與 transform 對映匹配。

引數:

done_spec (TensorSpec) – transform 前的 spec

返回:

transform 後的預期 spec

transform_input_spec(input_spec: Composite) Composite[source]

轉換 input spec 以便生成的 spec 與 transform 對映匹配。

引數:

input_spec (TensorSpec) – transform 前的 spec

返回:

transform 後的預期 spec

transform_observation_spec(observation_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[source]

轉換 observation spec 以便生成的 spec 與 transform 對映匹配。

引數:

observation_spec (TensorSpec) – transform 前的 spec

返回:

transform 後的預期 spec

transform_output_spec(output_spec: Composite) Composite[source]

轉換 output spec 以便生成的 spec 與 transform 對映匹配。

此方法通常不應修改。更改應使用 transform_observation_spec()transform_reward_spec()transform_full_done_spec() 實現。 :param output_spec: transform 前的 spec :type output_spec: TensorSpec

返回:

transform 後的預期 spec

transform_reward_spec(reward_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[source]

轉換 reward spec 以便生成的 spec 與 transform 對映匹配。

引數:

reward_spec (TensorSpec) – transform 前的 spec

返回:

transform 後的預期 spec

transform_state_spec(state_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec[source]

轉換 state spec 以便生成的 spec 與 transform 對映匹配。

引數:

state_spec (TensorSpec) – transform 前的 spec

返回:

transform 後的預期 spec

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源