快捷方式

merge_tensordicts

class tensordict.merge_tensordicts(*tensordicts: T, callback_exist: Optional[Union[Callable[[Any], Any], Dict[NestedKey, Callable[[Any], Any]]]] = None)

合併 tensordict。

引數:

*tensordicts (TensorDict 或等效物件的序列) – 要合併的 tensordict 列表。

關鍵字引數:

callback_exist (可呼叫物件Dict[str, 可呼叫物件], 可選) – 當每個 tensordict 中都存在某個條目時使用的可呼叫物件。如果條目存在於部分 tensordict 中但並非全部,或者如果 callback_exist 未傳遞,則使用 update 方法,並使用 tensordict 序列中的第一個非 None 值。如果傳遞了一個可呼叫物件字典,它將包含傳遞給函式的一些巢狀鍵的關聯回撥函式。

示例

>>> from tensordict import merge_tensordicts, TensorDict
>>> td0 = TensorDict({"a": {"b0": 0}, "c": {"d": {"e": 0}}, "common": 0})
>>> td1 = TensorDict({"a": {"b1": 1}, "f": {"g": {"h": 1}}, "common": 1})
>>> td2 = TensorDict({"a": {"b2": 2}, "f": {"g": {"h": 2}}, "common": 2})
>>> td = merge_tensordicts(td0, td1, td2, callback_exist=lambda *v: torch.stack(list(v)))
>>> print(td)
TensorDict(
    fields={
        a: TensorDict(
            fields={
                b0: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                b1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                b2: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        c: TensorDict(
            fields={
                d: TensorDict(
                    fields={
                        e: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        common: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.int64, is_shared=False),
        f: TensorDict(
            fields={
                g: TensorDict(
                    fields={
                        h: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(td["common"])
tensor([0, 1, 2])

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源