快捷方式

TorchScript 简介

作者:James Reed (jamesreed@fb.com),Michael Suo (suo@fb.com),rev2

本教程介绍了 TorchScript,它是 PyTorch 模型(nn.Module 的子类)的中间表示,可以随后在 C++ 等高性能环境中运行。

在本教程中,我们将涵盖

  1. PyTorch 模型创作的基础知识,包括

  • 模块

  • 定义 forward 函数

  • 将模块组合成模块层次结构

  1. 将 PyTorch 模块转换为 TorchScript 的特定方法,我们的高性能部署运行时

  • 跟踪现有模块

  • 使用脚本直接编译模块

  • 如何组合这两种方法

  • 保存和加载 TorchScript 模块

我们希望您完成本教程后,可以继续学习后续教程,该教程将指导您完成一个从 C++ 实际调用 TorchScript 模型的示例。

import torch  # This is all you need to use both PyTorch and TorchScript!
print(torch.__version__)
torch.manual_seed(191009)  # set the seed for reproducibility
2.5.0+cu124

<torch._C.Generator object at 0x7f85676e3950>

PyTorch 模型创作的基础知识

让我们从定义一个简单的 Module 开始。 Module 是 PyTorch 中的基本组成单元。它包含

  1. 一个构造函数,它为调用模块做好准备

  2. 一组 Parameters 和子 Modules。这些由构造函数初始化,可以在模块调用期间使用。

  3. 一个 forward 函数。 这是模块被调用时运行的代码。

让我们看看一个简单的例子。

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()

    def forward(self, x, h):
        new_h = torch.tanh(x + h)
        return new_h, new_h

my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))
(tensor([[0.8219, 0.8990, 0.6670, 0.8277],
        [0.5176, 0.4017, 0.8545, 0.7336],
        [0.6013, 0.6992, 0.2618, 0.6668]]), tensor([[0.8219, 0.8990, 0.6670, 0.8277],
        [0.5176, 0.4017, 0.8545, 0.7336],
        [0.6013, 0.6992, 0.2618, 0.6668]]))

所以我们已经

  1. 创建了一个继承自 torch.nn.Module 的类。

  2. 定义了一个构造函数。 构造函数没有做太多事情,只是调用了 super 的构造函数。

  3. 定义了一个 forward 函数,它接受两个输入并返回两个输出。 forward 函数的实际内容并不重要,但它有点像一个假的 RNN 单元——也就是说,它是一个在循环中应用的函数。

我们实例化了模块,并创建了 xh,它们只是 3x4 的随机值矩阵。 然后我们用 my_cell(x, h) 调用了单元。 这反过来调用了我们的 forward 函数。

让我们做一些更有趣的事情。

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))
MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.8573,  0.6190,  0.5774,  0.7869],
        [ 0.3326,  0.0530,  0.0702,  0.8114],
        [ 0.7818, -0.0506,  0.4039,  0.7967]], grad_fn=<TanhBackward0>), tensor([[ 0.8573,  0.6190,  0.5774,  0.7869],
        [ 0.3326,  0.0530,  0.0702,  0.8114],
        [ 0.7818, -0.0506,  0.4039,  0.7967]], grad_fn=<TanhBackward0>))

我们重新定义了模块 MyCell,但这次我们添加了一个 self.linear 属性,并在 forward 函数中调用了 self.linear

这里究竟发生了什么? torch.nn.Linear 是来自 PyTorch 标准库的 Module。 就像 MyCell 一样,它可以使用调用语法来调用。 我们正在构建一个 Module 的层次结构。

Module 使用 print 将给出 Module 子类层次结构的可视化表示。 在我们的例子中,我们可以看到我们的 Linear 子类及其参数。

通过以这种方式组合 Module,我们可以简洁明了地编写具有可重用组件的模型。

您可能已经注意到输出上的 grad_fn。 这是 PyTorch 的自动微分方法的细节,称为 autograd。 简而言之,该系统允许我们计算可能很复杂的程序的导数。 该设计允许在模型编写中拥有大量的灵活性。

现在让我们看看这种灵活性。

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))
MyCell(
  (dg): MyDecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.8346,  0.5931,  0.2097,  0.8232],
        [ 0.2340, -0.1254,  0.2679,  0.8064],
        [ 0.6231,  0.1494, -0.3110,  0.7865]], grad_fn=<TanhBackward0>), tensor([[ 0.8346,  0.5931,  0.2097,  0.8232],
        [ 0.2340, -0.1254,  0.2679,  0.8064],
        [ 0.6231,  0.1494, -0.3110,  0.7865]], grad_fn=<TanhBackward0>))

我们再次重新定义了 MyCell 类,但这里我们定义了 MyDecisionGate。 该模块利用了控制流。 控制流包括循环和 if 语句。

许多框架采用在给定完整程序表示的情况下计算符号导数的方法。 但是,在 PyTorch 中,我们使用了一个梯度带。 我们记录发生的运算,并在计算导数时反向回放。 通过这种方式,框架不必为语言中的所有结构显式定义导数。

How autograd works

autograd 如何工作

TorchScript 基础知识

现在让我们以我们正在运行的示例为例,看看如何应用 TorchScript。

简而言之,TorchScript 提供了工具来捕获模型的定义,即使考虑到 PyTorch 的灵活性和动态性。 让我们从检查我们所说的跟踪开始。

跟踪 Modules

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)

(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))

我们稍微回溯了一下,并采用了 MyCell 类的第二个版本。 和以前一样,我们已经实例化了它,但这次,我们调用了 torch.jit.trace,传入 Module,并传入网络可能看到的示例输入

这究竟做了什么? 它调用了 Module,记录了 Module 运行时发生的运算,并创建了一个 torch.jit.ScriptModule 实例(TracedModule 是其中一个实例)。

TorchScript 在中间表示(或 IR)中记录其定义,在深度学习中通常被称为。 我们可以使用 .graph 属性检查图。

graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)

但是,这是一个非常底层的表示,图中包含的大多数信息对最终用户没有用。 相反,我们可以使用 .code 属性给出代码的 Python 语法解释。

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)

那么为什么我们要做这一切? 有几个原因。

  1. TorchScript 代码可以在自己的解释器中调用,该解释器基本上是一个受限的 Python 解释器。 该解释器不会获取全局解释器锁,因此许多请求可以同时在同一个实例上处理。

  2. 这种格式允许我们将整个模型保存到磁盘,并将其加载到另一个环境中,例如在用除 Python 之外的语言编写的服务器中。

  3. TorchScript 为我们提供了一个表示,我们可以在其中对代码进行编译器优化,以提供更有效的执行。

  4. TorchScript 允许我们与许多后端/设备运行时交互,这些运行时需要比单个运算符更广泛的程序视图。

我们可以看到调用 traced_cell 会产生与 Python 模块相同的结果。

print(my_cell(x, h))
print(traced_cell(x, h))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))

使用脚本将模块转换为

我们使用模块的第二个版本,而不是带有控制流子模块的那个版本是有原因的。 现在让我们检查一下。

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell.dg.code)
print(traced_cell.code)
/var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:261: TracerWarning:

Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

def forward(self,
    argument_1: Tensor) -> NoneType:
  return None

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = (linear).forward(x, )
  _1 = (dg).forward(_0, )
  _2 = torch.tanh(torch.add(_0, h))
  return (_2, _2)

查看 .code 输出,我们可以看到 if-else 分支无处可寻! 为什么? 跟踪正是我们所说的那样:运行代码,记录发生的运算,并构建一个ScriptModule,它正是这样做的。 不幸的是,像控制流这样的东西会被擦除。

我们如何忠实地将此模块表示在 TorchScript 中? 我们提供了一个脚本编译器,它对您的 Python 源代码进行直接分析,将其转换为 TorchScript。 让我们使用脚本编译器转换 MyDecisionGate

scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)

print(scripted_gate.code)
print(scripted_cell.code)
def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)

好极了! 我们现在已经忠实地捕获了程序在 TorchScript 中的行为。 现在让我们尝试运行程序。

# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
print(scripted_cell(x, h))
(tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],
        [ 0.5228,  0.7122,  0.6985, -0.0656],
        [ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>), tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],
        [ 0.5228,  0.7122,  0.6985, -0.0656],
        [ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>))

混合使用脚本和跟踪

某些情况需要使用跟踪而不是脚本(例如,一个模块具有许多基于常量 Python 值的架构决策,我们希望这些值不会出现在 TorchScript 中)。 在这种情况下,脚本可以与跟踪组合使用:torch.jit.script 将内联跟踪模块的代码,而跟踪将内联脚本模块的代码。

第一种情况的示例。

class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    cell = self.cell
    _0 = (cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)

第二种情况的示例。

class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)

通过这种方式,脚本和跟踪可以在需要使用它们的情况下使用,也可以一起使用。

保存和加载模型

我们提供了 API 来将 TorchScript 模块保存到磁盘或从磁盘加载到磁盘,并使用存档格式。 此格式包括代码、参数、属性和调试信息,这意味着存档是模型的独立表示,可以加载到完全独立的进程中。 让我们保存和加载我们包装的 RNN 模块。

traced.save('wrapped_rnn.pt')

loaded = torch.jit.load('wrapped_rnn.pt')

print(loaded)
print(loaded.code)
RecursiveScriptModule(
  original_name=WrapRNN
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCell
      (dg): RecursiveScriptModule(original_name=MyDecisionGate)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)

正如您所见,序列化保留了模块层次结构和我们一直在检查的代码。 模型也可以加载,例如,加载到 C++ 中,以进行无 Python 执行。

进一步阅读

我们完成了教程! 为了更详细的演示,请查看使用 TorchScript 转换机器翻译模型的 NeurIPS 演示:https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ

脚本的总运行时间:(0 分钟 0.221 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源