快捷方式

torch.jit.interface

torch.jit.interface(obj)[source][source]

使用此裝飾器來標註不同型別的類或模組。

此裝飾器可用於定義一個介面,該介面可用於標註不同型別的類或模組。這可以用來標註子模組或屬性類,這些子模組或屬性類可能具有實現同一介面的不同型別,或者可在執行時進行交換;或者用來儲存不同型別的模組或類的列表。

有時用於實現“可呼叫物件”(Callables)——即實現某個介面但實現方式不同且可被替換的函式或模組。

示例:.. testcode

import torch
from typing import List

@torch.jit.interface
class InterfaceType:
    def run(self, x: torch.Tensor) -> torch.Tensor:
        pass

# implements InterfaceType
@torch.jit.script
class Impl1:
    def run(self, x: torch.Tensor) -> torch.Tensor:
        return x.relu()

class Impl2(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.val = torch.rand(())

    @torch.jit.export
    def run(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.val

def user_fn(impls: List[InterfaceType], idx: int, val: torch.Tensor) -> torch.Tensor:
    return impls[idx].run(val)

user_fn_jit = torch.jit.script(user_fn)

impls = [Impl1(), torch.jit.script(Impl2())]
val = torch.rand(4, 4)
user_fn_jit(impls, 0, val)
user_fn_jit(impls, 1, val)

文件

查閱全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源