ProbabilisticTensorDictSequential¶
- class tensordict.nn.ProbabilisticTensorDictSequential(*args, **kwargs)¶
一個包含至少一個
ProbabilisticTensorDictModule的TensorDictModules序列。此類擴充套件了
TensorDictSequential,通常配置為一個模組序列,其中最後一個模組是ProbabilisticTensorDictModule的例項。然而,它也支援一個或多箇中間模組是ProbabilisticTensorDictModule的例項,而最後一個模組可能不是機率性的配置。在所有情況下,它都暴露了get_dist()方法,以從序列中的ProbabilisticTensorDictModule例項中恢復分佈物件。多個機率性模組可以共存於一個
ProbabilisticTensorDictSequential中。如果 return_composite 為False(預設),則只有最後一個模組會產生分佈,而其他模組將作為常規的TensorDictModule例項執行。然而,如果一個 ProbabilisticTensorDictModule 不是序列中的最後一個模組,並且 return_composite=False,則在嘗試查詢該模組時將引發 ValueError。如果 return_composite=True,所有中間的 ProbabilisticTensorDictModule 例項將共同組成一個單獨的CompositeDistribution例項。如果樣本相互依賴,則結果對數機率將是條件機率:當
\[Z = F(X, Y)\]則 Z 的對數機率將是
\[log(p(z | x, y))\]- 引數:
*modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule) – 一個有序的
TensorDictModule例項序列,通常以ProbabilisticTensorDictModule結束,用於順序執行。模組可以是 TensorDictModuleBase 的例項,也可以是任何符合此簽名的其他函式。請注意,如果使用了非 TensorDictModuleBase 的可呼叫物件,其輸入和輸出鍵將不會被跟蹤,因此不會影響 TensorDictSequential 的 in_keys 和 out_keys 屬性。- 關鍵字引數:
partial_tolerant (bool, optional) – 如果為
True,輸入 tensordict 可以缺少部分輸入鍵。在這種情況下,只有那些根據現有鍵可以執行的模組會被執行。此外,如果輸入的 tensordict 是 tensordicts 的惰性堆疊(lazy stack),並且 partial_tolerant 為True,並且該堆疊不包含所需的鍵,則 TensorDictSequential 將掃描子 tensordicts,查詢包含所需鍵(如果有的話)的 tensordicts。預設為False。return_composite (bool, optional) –
如果為 True,並且找到了多個
ProbabilisticTensorDictModule或ProbabilisticTensorDictSequential例項,則將使用一個CompositeDistribution例項。否則,將僅使用最後一個模組來構建分佈。預設為False。警告
`return_composite` 的行為將在 v0.9 中改變,並從那時起預設為 True。
- 丟擲:
ValueError – 如果輸入的模組序列為空。
TypeError – 如果最後一個模組不是
ProbabilisticTensorDictModule或ProbabilisticTensorDictSequential的例項。
示例
>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq >>> import torch >>> # Typical usage: a single distribution is computed last in the sequence >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq, ... TensorDictModule as Mod >>> torch.manual_seed(0) >>> >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... ) >>> input = TensorDict(x=torch.ones(3)) >>> td = module(input.copy()) >>> print(td) TensorDict( fields={ loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) Normal(loc: torch.Size([3]), scale: torch.Size([3])) >>> print(module.log_prob(td)) tensor([-0.9189, -0.9189, -0.9189]) >>> # Intermediate distributions are ignored when return_composite=False >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]), ... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... return_composite=False, ... ) >>> td = module(TensorDict(x=torch.ones(3))) >>> print(td) TensorDict( fields={ loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) Normal(loc: torch.Size([3]), scale: torch.Size([3])) >>> print(module.log_prob(td)) tensor([-0.9189, -0.9189, -0.9189]) >>> # Intermediate distributions produce a CompositeDistribution when return_composite=True >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]), ... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... return_composite=True, ... ) >>> input = TensorDict(x=torch.ones(3)) >>> td = module(input.copy()) >>> print(td) TensorDict( fields={ loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3])), 'sample1': Normal(loc: torch.Size([3]), scale: torch.Size([3]))}) >>> print(module.log_prob(td)) TensorDict( fields={ sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> # Even a single intermediate distribution is wrapped in a CompositeDistribution when >>> # return_composite=True >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["y"]), ... return_composite=True, ... ) >>> td = module(TensorDict(x=torch.ones(3))) >>> print(td) TensorDict( fields={ loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3]))}) >>> print(module.log_prob(td)) TensorDict( fields={ sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- build_dist_from_params(tensordict: TensorDictBase) Distribution¶
根據輸入引數構建分佈,而不評估序列中的其他模組。
此方法在序列中查詢最後一個
ProbabilisticTensorDictModule,並使用它來構建分佈。- 引數:
tensordict (TensorDictBase) – 包含分佈引數的輸入 tensordict。
- 返回:
構建的分佈物件。
- 返回型別:
D.Distribution
- 丟擲:
RuntimeError – 如果在序列中未找到
ProbabilisticTensorDictModule。
- property default_interaction_type¶
使用迭代啟發式方法返回模組的 default_interaction_type。
此屬性以反向順序迭代所有模組,嘗試從任何子模組中檢索 default_interaction_type 屬性。遇到的第一個非 None 值將被返回。如果未找到此類值,則返回預設的 interaction_type()。
- forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs) TensorDictBase¶
當未設定 tensordict 引數時,使用 kwargs 建立 TensorDict 例項。
- get_dist(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) Distribution¶
返回將輸入 tensordict 透過序列傳遞後得到的分佈。
如果 return_composite 為
False(預設),此方法將僅考慮序列中的最後一個機率性模組。否則,它將返回一個包含所有機率性模組分佈的
CompositeDistribution例項。- 引數:
tensordict (TensorDictBase) – 輸入 tensordict。
tensordict_out (TensorDictBase, optional) – 輸出 tensordict。如果為
None,將建立一個新的 tensordict。預設為None。
- 關鍵字引數:
**kwargs – 傳遞給底層模組的額外關鍵字引數。
- 返回:
結果分佈物件。
- 返回型別:
D.Distribution
- 丟擲:
RuntimeError – 如果在序列中未找到機率性模組。
注意
當 return_composite 為
True時,分佈是以前一個序列中的樣本為條件的。這意味著如果一個模組依賴於前一個機率性模組的輸出,其分佈將是條件分佈。
- get_dist_params(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) tuple[torch.distributions.distribution.Distribution, tensordict.base.TensorDictBase]¶
返回分佈引數和輸出 tensordict。
此方法執行 ProbabilisticTensorDictSequential 模組的確定性部分以獲取分佈引數。互動型別被設定為當前全域性互動型別(如果可用),否則預設為最後一個模組的互動型別。
- 引數:
tensordict (TensorDictBase) – 輸入 tensordict。
tensordict_out (TensorDictBase, optional) – 輸出 tensordict。如果為
None,將建立一個新的 tensordict。預設為None。
- 關鍵字引數:
**kwargs – 傳遞給模組確定性部分的額外關鍵字引數。
- 返回:
一個包含分佈物件和輸出 tensordict 的元組。
- 返回型別:
tuple[D.Distribution, TensorDictBase]
注意
在此方法的執行期間,互動型別被臨時設定為指定的值。
- log_prob(tensordict, tensordict_out: Optional[TensorDictBase] = None, *, dist: Optional[Distribution] = None, **kwargs)¶
返回輸入 tensordict 的對數機率。
如果 self.return_composite 為
True且分佈是一個CompositeDistribution,則此方法將返回整個複合分佈的對數機率。否則,它將僅考慮序列中的最後一個機率性模組。
- 引數:
tensordict (TensorDictBase) – 輸入 tensordict。
tensordict_out (TensorDictBase, optional) – 輸出 tensordict。如果為
None,將建立一個新的 tensordict。預設為None。
- 關鍵字引數:
dist (torch.distributions.Distribution, optional) – 分佈物件。如果為
None,將使用 get_dist 計算。預設為None。- 返回:
輸入 tensordict 的對數機率。
- 返回型別:
警告
在未來的版本(v0.9)中,aggregate_probabilities、inplace 和 include_sum 的預設值將發生變化。為避免警告,建議顯式地將這些引數傳遞給 log_prob 方法或在建構函式中設定它們。