快捷方式

Ordinal

class torchrl.modules.Ordinal(scores: Tensor)[source]

一種用於學習從有限有序集合中取樣的離散分佈。

與 Categorical 分佈不同,Categorical 分佈對其支援元素的原子不施加任何鄰近性或順序概念。Ordinal 分佈明確編碼了這些概念,這對於從連續集合中學習離散取樣非常有用。詳情請參閱 `Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`_ 的 §5。

注意

當您想學習基於一個透過對連續集合離散化獲得的有限集合上的分佈時,此類別特別有用。

引數:

scores (torch.Tensor) – 形狀為 […, N] 的張量,其中 N 是支援分佈的集合大小。通常是引數化分佈的神經網路的輸出。

示例

>>> num_atoms, num_samples = 5, 20
>>> mean = (num_atoms - 1) / 2  # Target mean for samples, centered around the middle atom
>>> torch.manual_seed(42)
>>> logits = torch.ones((num_atoms), requires_grad=True)
>>> optimizer = torch.optim.Adam([logits], lr=0.1)
>>>
>>> # Perform optimisation loop to minimise deviation from `mean`
>>> for _ in range(20):
>>>     sampler = Ordinal(scores=logits)
>>>     samples = sampler.sample((num_samples,))
>>>     # Define loss to encourage samples around the mean by penalising deviation from mean
>>>     loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples))
>>>     loss.backward()
>>>     optimizer.step()
>>>     optimizer.zero_grad()
>>>
>>> sampler.probs
tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...)
>>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4)
>>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms)
torch.return_types.histogram(
    hist=tensor([ 24., 158., 478., 228., 112.]),
    bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000]))

© Copyright 2022, Meta.

使用 Sphinx 構建,主題由 Read the Docs 提供。

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答您的問題

檢視資源