SafeProbabilisticModule¶
- class torchrl.modules.tensordict_module.SafeProbabilisticModule(*args, **kwargs)[源]¶
是
tensordict.nn.ProbabilisticTensorDictModule的子類,接受一個TensorSpec引數來控制輸出域。SafeProbabilisticModule 是一個非引數模組,內嵌一個機率分佈構造器。它使用指定的 in_keys 從輸入 TensorDict 讀取分佈引數,並輸出該分佈的一個樣本(廣義上)。
輸出的“樣本”是根據某個規則生成的,該規則由輸入的
default_interaction_type引數和interaction_type()全域性函式指定。SafeProbabilisticModule 可用於構造分佈(透過
get_dist()方法)和/或從該分佈中取樣(透過對模組的常規__call__()呼叫)。一個 SafeProbabilisticModule 例項有兩個主要特性
它讀寫 TensorDict 物件;
它使用一個實數對映 R^n -> R^m 來建立 R^d 中的分佈,可以從中取樣或計算值。
當呼叫
__call__()和forward()方法時,會建立一個分佈並計算一個值(取決於interaction_type值,可以使用 ‘dist.mean’、‘dist.mode’、‘dist.median’ 屬性,以及 ‘dist.rsample’、‘dist.sample’ 方法)。如果提供的 TensorDict 中已包含所有期望的鍵值對,則跳過取樣步驟。預設情況下,SafeProbabilisticModule 的分佈類是一個
Delta分佈,這使得 SafeProbabilisticModule 成為一個確定性對映函式的簡單包裝器。此類與
tensordict.nn.ProbabilisticTensorDictModule不同之處在於它接受一個spec關鍵字引數,可用於控制樣本是否屬於該分佈。`safe` 關鍵字引數控制是否應根據 spec 檢查樣本值。- 引數:
in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]) – 將從輸入 TensorDict 中讀取並用於構建分佈的鍵。重要提示:如果它是 NestedKey 列表或 NestedKey,則這些鍵的葉子(最後一個元素)必須與感興趣的分佈類使用的關鍵字匹配,例如
"loc"和"scale"用於Normal分佈等。如果 in_keys 是字典,則鍵是分佈的鍵,值是 tensordict 中將與相應分佈鍵匹配的鍵。out_keys (NestedKey | List[NestedKey] | None) – 寫入取樣值的鍵。重要提示:如果在輸入 TensorDict 中找到了這些鍵,則會跳過取樣步驟。
spec (TensorSpec) – 第一個輸出張量的規範。在呼叫 td_module.random() 時用於在目標空間生成隨機值。
- 關鍵字引數:
safe (bool, optional) – 如果為
True,則會根據輸入 spec 檢查樣本的值。由於探索策略或數值下溢/上溢問題,可能會發生超出域的取樣。與spec引數一樣,此檢查僅針對分佈樣本進行,而不針對輸入模組返回的其他張量。如果樣本超出界限,則使用 TensorSpec.project 方法將其投影回期望的空間。預設值為False。default_interaction_type (InteractionType, optional) –
僅關鍵字引數。用於獲取輸出值的預設方法。應為 InteractionType 之一:MODE、MEDIAN、MEAN 或 RANDOM(在這種情況下,值是從分佈中隨機取樣的)。預設值為 MODE。
注
抽取樣本時,
ProbabilisticTensorDictModule例項將首先查詢由interaction_type()全域性函式指定的互動模式。如果返回 None(其預設值),則將使用 ProbabilisticTDModule 例項的 default_interaction_type。請注意,DataCollectorBase例項預設將 set_interaction_type 設定為tensordict.nn.InteractionType.RANDOM。注
在某些情況下,mode、median 或 mean 值可能無法透過相應的屬性直接獲得。為了彌補這一點,
ProbabilisticTensorDictModule將首先嚐試透過呼叫get_mode()、get_median()或get_mean()來獲取值,如果方法存在。distribution_class (Type or Callable[[Any], Distribution], optional) –
僅關鍵字引數。一個
torch.distributions.Distribution類,用於取樣。預設值為Delta。注
如果分佈類是
CompositeDistribution型別,則out_keys可以直接從透過此類distribution_kwargs關鍵字引數提供的"distribution_map"或"name_map"關鍵字引數推斷出來,在這種情況下out_keys是可選的。distribution_kwargs (dict, optional) –
僅關鍵字引數。要傳遞給分佈的關鍵字引數對。
注
如果您的 kwargs 包含您想隨模組一起轉移到裝置上的張量,或者在呼叫 module.to(dtype) 時應修改其 dtype 的張量,您可以將 kwargs 包裝在
TensorDictParams中以自動完成此操作。return_log_prob (bool, optional) – 僅關鍵字引數。如果為
True,則分佈樣本的對數機率將寫入 tensordict 中,鍵為 log_prob_key。預設值為False。log_prob_keys (List[NestedKey], optional) –
如果
return_log_prob=True,寫入 log_prob 的鍵。預設值為 ‘<sample_key_name>_log_prob’,其中 <sample_key_name> 是每個out_keys。注
僅當
composite_lp_aggregate()設定為False時可用。log_prob_key (NestedKey, optional) –
如果
return_log_prob=True,寫入 log_prob 的鍵。當composite_lp_aggregate()設定為 True 時,預設值為 ‘sample_log_prob’,否則為 ‘<sample_key_name>_log_prob’。注
當存在多個樣本時,僅當
composite_lp_aggregate()設定為True時可用。cache_dist (bool, optional) – 僅關鍵字引數。實驗性功能:如果為
True,則分佈的引數(即模組的輸出)將與樣本一起寫入 tensordict 中。這些引數可以在之後用於重新計算原始分佈(例如,在 PPO 中計算用於取樣動作的分佈與更新後的分佈之間的散度)。預設值為False。n_empirical_estimate (int, optional) – 僅關鍵字引數。計算經驗平均值時使用的樣本數量,當經驗平均值不可用時。預設值為 1000。
警告
執行檢查會花費時間!使用 safe=True 將保證樣本在
project()中編碼的啟發式方法給定的 spec 界限內,但這需要檢查值是否在 spec 空間內,這將引入一些開銷。另請參閱
:class:`tensordict 中的組合分佈 <~tensordict.nn.CompositeDistribution>` 可用於建立多頭策略。
示例
>>> from torchrl.modules import SafeProbabilisticModule >>> from torchrl.data import Bounded >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import InteractionType >>> mod = SafeProbabilisticModule( ... in_keys=["loc", "scale"], ... out_keys=["action"], ... distribution_class=torch.distributions.Normal, ... safe=True, ... spec=Bounded(low=-1, high=1, shape=()), ... default_interaction_type=InteractionType.RANDOM ... ) >>> _ = torch.manual_seed(0) >>> data = TensorDict( ... loc=torch.zeros(10, requires_grad=True), ... scale=torch.full((10,), 10.0), ... batch_size=(10,)) >>> data = mod(data) >>> print(data["action"]) # All actions are within bound tensor([ 1., -1., -1., 1., -1., -1., 1., 1., -1., -1.], grad_fn=<ClampBackward0>) >>> data["action"].mean().backward() >>> print(data["loc"].grad) # clamp anihilates gradients tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])