快捷方式

TanhModule

class torchrl.modules.tensordict_module.TanhModule(*args, **kwargs)[source]

一個用於具有有限動作空間的確定性策略的 Tanh 模組。

此 transform 用作 TensorDictModule 層,將網路輸出對映到有限空間。

引數:
  • in_keys (str 列表str 元組) – 模組的輸入鍵。

  • out_keys (str 列表str 元組, 可選) – 模組的輸出鍵。如果未提供,則假定與 in_keys 相同的鍵。

關鍵字引數:
  • spec (TensorSpec, 可選) – 如果提供,則為輸出的 spec。如果提供 Composite,其鍵必須與 out_keys 中的鍵匹配。否則,假定使用 out_keys 中的鍵,並對所有輸出使用相同的 spec。

  • low (float, np.ndarray 或 torch.Tensor) – 空間的下界。如果未提供且未提供 spec,則假定為 -1。如果提供 spec,將檢索 spec 的最小值。

  • high (float, np.ndarray 或 torch.Tensor) – 空間的上界。如果未提供且未提供 spec,則假定為 1。如果提供 spec,將檢索 spec 的最大值。

  • clamp (bool, 可選) – 如果為 True,輸出將被限制在邊界內,但與邊界保持最小解析度。預設為 False

示例

>>> from tensordict import TensorDict
>>> # simplest use case: -1 - 1 boundaries
>>> torch.manual_seed(0)
>>> in_keys = ["action"]
>>> mod = TanhModule(
...     in_keys=in_keys,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([ 1.0000, -0.9944, -1.0000,  1.0000, -1.0000])
>>> # low and high can be customized
>>> low = -2
>>> high = 1
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([-2.0000,  0.9991,  1.0000, -2.0000, -1.9991])
>>> # A spec can be provided
>>> from torchrl.data import Bounded
>>> spec = Bounded(low, high, shape=())
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
...     spec=spec,
...     clamp=False,
... )
>>> # One can also work with multiple keys
>>> in_keys = ['a', 'b']
>>> spec = Composite(
...     a=Bounded(-3, 0, shape=()),
...     b=Bounded(0, 3, shape=()))
>>> mod = TanhModule(
...     in_keys=in_keys,
...     spec=spec,
... )
>>> data = TensorDict(
...     {'a': torch.randn(10), 'b': torch.randn(10)}, batch_size=[])
>>> data = mod(data)
>>> data['a']
tensor([-2.3020, -1.2299, -2.5418, -0.2989, -2.6849, -1.3169, -2.2690, -0.9649,
        -2.5686, -2.8602])
>>> data['b']
tensor([2.0315, 2.8455, 2.6027, 2.4746, 1.7843, 2.7782, 0.2111, 0.5115, 1.4687,
        0.5760])
forward(tensordict=None)[source]

定義每次呼叫時執行的計算。

應由所有子類覆蓋。

注意

雖然 forward pass 的實現需要在本函式中定義,但之後應呼叫 Module 例項而不是本函式,因為前者負責執行註冊的鉤子,而後者會靜默忽略它們。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源