快捷方式

ProbabilisticTensorDictModule

class tensordict.nn.ProbabilisticTensorDictModule(*args, **kwargs)

一個機率性 TD 模組。

ProbabilisticTensorDictModule 是一個非引數模組,嵌入了一個機率分佈構造器。它使用指定的 in_keys 從輸入的 TensorDict 中讀取分佈引數,並輸出該分佈的一個樣本(非嚴格意義上)。

輸出的“樣本”是根據特定規則生成的,該規則由輸入的 default_interaction_type 引數和 interaction_type() 全域性函式指定。

ProbabilisticTensorDictModule 可用於構建分佈(透過 get_dist() 方法)和/或從該分佈中進行取樣(透過對模組進行常規的 __call__() 呼叫)。

一個 ProbabilisticTensorDictModule 例項具有兩個主要特性

  • 它可以從 TensorDict 物件讀取和寫入資料;

  • 它使用一個實值對映 R^n -> R^m 來在 R^d 中建立一個分佈,可以從中取樣或計算值。

當呼叫 __call__()forward() 方法時,會建立一個分佈並計算一個值(取決於 interaction_type 的值,可以使用 'dist.mean'、'dist.mode'、'dist.median' 屬性,以及 'dist.rsample'、'dist.sample' 方法)。如果提供的 TensorDict 已經包含所有期望的鍵值對,則會跳過取樣步驟。

預設情況下,ProbabilisticTensorDictModule 的分佈類是 Delta 分佈,這使得 ProbabilisticTensorDictModule 成為確定性對映函式的一個簡單包裝器。

引數:
  • in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]) – 將從輸入的 TensorDict 中讀取並用於構建分佈的鍵。重要的是,如果它是 NestedKey 列表或單個 NestedKey,這些鍵的葉子(最後一個元素)必須與感興趣的分佈類使用的關鍵字匹配,例如 "loc""scale" 對於 Normal 分佈等。如果 in_keys 是一個字典,則字典的鍵是分佈的鍵,值是 tensordict 中將與相應分佈鍵匹配的鍵。

  • out_keys (NestedKey | List[NestedKey] | None) – 將寫入取樣值的鍵。重要的是,如果在輸入的 TensorDict 中找到了這些鍵,則會跳過取樣步驟。

關鍵字引數:
  • default_interaction_type (InteractionType, optional) –

    僅限關鍵字引數。用於檢索輸出值的預設方法。應為 InteractionType 中的一個:MODE、MEDIAN、MEAN 或 RANDOM(在這種情況下,值從分佈中隨機取樣)。預設值是 MODE。

    注意

    當抽取樣本時,ProbabilisticTensorDictModule 例項將首先查詢由 interaction_type() 全域性函式指定的互動模式。如果此函式返回 None(其預設值),則將使用 ProbabilisticTDModule 例項的 default_interaction_type。請注意,DataCollectorBase 例項預設將 set_interaction_type 設定為 tensordict.nn.InteractionType.RANDOM

    注意

    在某些情況下,模式、中位數或均值可能無法透過相應的屬性直接獲得。為解決此問題,ProbabilisticTensorDictModule 會首先嚐試透過呼叫 get_mode()get_median()get_mean()(如果方法存在)來獲取值。

  • distribution_class (Type or Callable[[Any], Distribution], optional) –

    僅限關鍵字引數。用於取樣的 torch.distributions.Distribution 類。預設值是 Delta

    注意

    如果分佈類是 CompositeDistribution 型別,則可以直接從此類的 distribution_kwargs 關鍵字引數中提供的 "distribution_map""name_map" 關鍵字引數推斷出 out_keys,從而在這些情況下 out_keys 是可選的。

  • distribution_kwargs (dict, optional) –

    僅限關鍵字引數。要傳遞給分佈的關鍵字引數對。

    注意

    如果您的 kwargs 包含希望隨模組一起傳輸到裝置的張量,或者在呼叫 module.to(dtype) 時應修改其 dtype 的張量,您可以將 kwargs 包裝在 TensorDictParams 中以自動完成此操作。

  • return_log_prob (bool, optional) – 僅限關鍵字引數。如果為 True,則分佈樣本的對數機率將寫入 tensordict 中,使用鍵 log_prob_key。預設值為 False

  • log_prob_keys (List[NestedKey], optional) –

    如果 return_log_prob=True,則寫入 log_prob 的鍵。預設為 ‘<sample_key_name>_log_prob’,其中 <sample_key_name>out_keys 中的每一個。

    注意

    這僅在 composite_lp_aggregate() 設定為 False 時可用。

  • log_prob_key (NestedKey, optional) –

    如果 return_log_prob=True,則寫入 log_prob 的鍵。當 composite_lp_aggregate() 設定為 True 時預設為 ‘sample_log_prob’,否則預設為 ‘<sample_key_name>_log_prob’

    注意

    當有多個樣本時,這僅在 composite_lp_aggregate() 設定為 True 時可用。

  • cache_dist (bool, optional) – 僅限關鍵字引數。實驗性:如果為 True,則分佈的引數(即模組的輸出)將與樣本一起寫入 tensordict。這些引數可用於稍後重新計算原始分佈(例如,計算用於取樣動作的分佈與 PPO 中更新的分佈之間的散度)。預設值為 False

  • n_empirical_estimate (int, optional) – 僅限關鍵字引數。當經驗均值不可用時,用於計算經驗均值的樣本數量。預設為 1000。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import (
...     ProbabilisticTensorDictModule,
...     ProbabilisticTensorDictSequential,
...     TensorDictModule,
... )
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch.distributions import Normal, Independent
>>> td = TensorDict(
...     {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.GRUCell(4, 8)
>>> module = TensorDictModule(
...     net, in_keys=["input", "hidden"], out_keys=["params"]
... )
>>> normal_params = TensorDictModule(
...     NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"]
... )
>>> def IndepNormal(**kwargs):
...     return Independent(Normal(**kwargs), 1)
>>> prob_module = ProbabilisticTensorDictModule(
...     in_keys=["loc", "scale"],
...     out_keys=["action"],
...     distribution_class=IndepNormal,
...     return_log_prob=True,
... )
>>> td_module = ProbabilisticTensorDictSequential(
...     module, normal_params, prob_module
... )
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
...     _ = td_module(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> with params.to_module(td_module):
...     dist = td_module.get_dist(td)
>>> print(dist)
Independent(Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])), 1)
>>> # we can also apply the module to the TensorDict with vmap
>>> from torch import vmap
>>> params = params.expand(4)
>>> def func(td, params):
...     with params.to_module(td_module):
...         return td_module(td)
>>> td_vmap = vmap(func, (None, 0))(td, params)
>>> print(td_vmap)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)
build_dist_from_params(tensordict: TensorDictBase) Distribution

使用輸入的 tensordict 中提供的引數建立一個 torch.distribution.Distribution 例項。

引數:

tensordict (TensorDictBase) – 包含分佈引數的輸入 tensordict。

返回:

使用輸入的 tensordict 建立的 torch.distribution.Distribution 例項。

丟擲:

TypeError – 如果輸入的 tensordict 與分佈關鍵字不匹配。

property dist_params_keys: List[NestedKey]

返回指向分佈引數的所有鍵。

property dist_sample_keys: List[NestedKey]

返回指向分佈樣本的所有鍵。

forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, _requires_sample: bool = True) TensorDictBase

定義每次呼叫時執行的計算。

應被所有子類覆蓋。

注意

雖然 forward pass 的實現需要在函式內部定義,但之後應該呼叫 Module 例項而不是此函式本身,因為前者負責執行註冊的鉤子,而後者會默默地忽略它們。

get_dist(tensordict: TensorDictBase) Distribution

使用輸入的 tensordict 中提供的引數建立一個 torch.distribution.Distribution 例項。

引數:

tensordict (TensorDictBase) – 包含分佈引數的輸入 tensordict。

返回:

使用輸入的 tensordict 建立的 torch.distribution.Distribution 例項。

丟擲:

TypeError – 如果輸入的 tensordict 與分佈關鍵字不匹配。

log_prob(tensordict, *, dist: Optional[Distribution] = None)

計算分佈樣本的對數機率。

引數:
  • tensordict (TensorDictBase) – 包含分佈引數的輸入 tensordict。

  • dist (torch.distributions.Distribution, optional) – 分佈例項。預設為 None。如果為 None,則將使用 get_dist 方法計算分佈。

返回:

表示分佈樣本對數機率的張量。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源