torch.jit.interface¶
- torch.jit.interface(obj)[来源][来源]¶
装饰器,用于注释不同类型的类或模块。
此装饰器可用于定义接口,该接口可用于注释不同类型的类或模块。这可以用于注释子模块或属性类,这些子模块或属性类可能具有实现相同接口的不同类型,或者可以在运行时交换;或者用于存储不同类型的模块或类的列表。
它有时用于实现“可调用对象” - 实现接口但实现方式不同且可以换出的函数或模块。
示例: .. testcode
import torch from typing import List @torch.jit.interface class InterfaceType: def run(self, x: torch.Tensor) -> torch.Tensor: pass # implements InterfaceType @torch.jit.script class Impl1: def run(self, x: torch.Tensor) -> torch.Tensor: return x.relu() class Impl2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.val = torch.rand(()) @torch.jit.export def run(self, x: torch.Tensor) -> torch.Tensor: return x + self.val def user_fn(impls: List[InterfaceType], idx: int, val: torch.Tensor) -> torch.Tensor: return impls[idx].run(val) user_fn_jit = torch.jit.script(user_fn) impls = [Impl1(), torch.jit.script(Impl2())] val = torch.rand(4, 4) user_fn_jit(impls, 0, val) user_fn_jit(impls, 1, val)