作者:Logan Kilpatrick - 高级技术倡导者,Harshith Padigela - 机器学习工程师,Syed Ashar Javed - 机器学习技术负责人,Robert Egger - 生物医学数据科学家

PathAI 是领先的 AI 病理学(疾病研究)技术工具和服务提供商。我们的平台旨在通过利用机器学习中的现代方法(如图像分割、图神经网络和多实例学习)大幅提高复杂疾病的诊断准确性和治疗效果的衡量水平。

传统的纯人工病理学容易受到 主观性和观察者变异性 的影响,这可能对诊断和药物开发试验产生负面影响。在我们深入探讨如何使用 PyTorch 改进诊断流程之前,让我们先阐述一下不使用机器学习的传统模拟病理学工作流程。

传统生物制药的工作方式

生物制药公司可以采取多种途径来发现新型疗法或诊断方法。其中一种途径严重依赖对病理学载玻片的分析,以回答各种问题:特定的细胞通讯途径如何工作?特定疾病状态是否与特定蛋白质的存在或缺失有关?为什么临床试验中的某种特定药物对某些患者有效而对其他患者无效?患者预后是否与新型生物标志物存在关联?

为了帮助回答这些问题,生物制药公司依赖于专业病理学家来分析载玻片并帮助评估他们可能提出的问题。 

正如您可能想象的,做出准确解读和诊断需要经过专业委员会认证的病理学家。在一项研究中,将单一活检结果交给 36 位不同的病理学家,结果得出 18 种不同的诊断,严重程度从无需治疗到需要积极治疗不等。病理学家在处理疑难边缘病例时也经常征求同事的反馈。考虑到问题的复杂性,即使经过专家培训和协作,病理学家仍然可能难以做出正确诊断。这种潜在的差异可能导致药物获得批准或在临床试验中失败的结局。

PathAI 如何利用机器学习推动药物开发

PathAI 开发机器学习模型,为药物开发研发、临床试验以及诊断提供洞见。为此,PathAI 利用 PyTorch 通过多种方法进行载玻片层面的推理,包括图神经网络 (GNN) 和多实例学习。在这里,“载玻片”指的是扫描的全尺寸玻璃载玻片图像,即夹在两块玻璃之间的一薄片组织,经过染色以显示各种细胞结构。PyTorch 使我们使用这些不同方法学的团队能够共享一个通用的框架,该框架足够健壮,可以在我们所需的所有条件下工作。PyTorch 的高级、命令式和 Pythonic 语法使我们能够快速原型化模型,并在获得所需结果后将这些模型扩展到生产规模。 

在千兆字节图像上进行多实例学习

将机器学习应用于病理学的一个独特挑战是图像的巨大尺寸。这些数字载玻片的分辨率通常高达 100,000 x 100,000 像素或更高,尺寸达到千兆字节。将整个图像加载到 GPU 内存中并对其应用传统计算机视觉算法几乎是一项不可能完成的任务。对整个载玻片图像(100k x 100k)进行标注也需要大量时间和资源,尤其当标注者需要是领域专家(经委员会认证的病理学家)时。我们通常构建模型来预测图像级别的标签,例如患者载玻片上是否存在癌症,这可能只占整个图像的几千个像素。癌变区域有时只占整个载玻片的极小一部分,这使得机器学习问题类似于大海捞针。另一方面,某些问题,如预测某些组织学生物标志物,需要汇总来自整个载玻片的信息,由于图像尺寸巨大,这同样困难。所有这些因素都为将机器学习技术应用于病理学问题带来了显著的算法、计算和后勤复杂性。

将图像分解成更小的块(patch),学习块表示,然后汇集(pool)这些表示来预测图像级别的标签是解决这个问题的一种方法,如下图所示。一种常用的方法叫做多实例学习 (MIL)。每个块被视为一个“实例”,一组块构成一个“包”(bag)。单个块的表示被汇集在一起以预测最终的包级别标签。从算法上讲,包中的单个块实例不需要标签,因此允许我们以弱监督的方式学习包级别标签。它们还使用置换不变的汇集函数,这使得预测与块的顺序无关,并允许高效地聚合信息。通常使用基于注意力(attention)的汇集函数,这不仅允许高效聚合,还为包中的每个块提供注意力值。这些值表示对应块在预测中的重要性,可以可视化以便更好地理解模型预测。这种可解释性元素对于推动这些模型在现实世界中的应用非常重要,我们使用诸如加性 MIL 模型等变体来实现这种空间可解释性。在计算上,MIL 模型规避了将神经网络应用于大型图像尺寸的问题,因为块表示的获取独立于图像尺寸。

在 PathAI,我们使用基于深度网络的定制 MIL 模型来预测图像级别标签。此过程概述如下:

  1. 使用不同的采样方法从载玻片中选择块。
  2. 根据随机采样或启发式规则构建一个块包。
  3. 基于预训练模型或大规模表示学习模型为每个实例生成块表示。
  4. 应用置换不变汇集函数以获得最终的载玻片级别分数。

既然我们已经了解了 PyTorch 中 MIL 的一些高级细节,下面来看一些代码,看看从构思到生产代码是多么简单。我们首先定义一个采样器、转换以及我们的 MIL 数据集

# Create a bag sampler which randomly samples patches from a slide
bag_sampler = RandomBagSampler(bag_size=12)

# Setup the transformations
crop_transform = FlipRotateCenterCrop(use_flips=True)

# Create the dataset which loads patches for each bag
train_dataset = MILDataset(
  bag_sampler=bag_sampler,
  samples_loader=sample_loader,
  transform=crop_transform,
)

定义好采样器和数据集后,我们需要定义将用该数据集实际训练的模型。PyTorch 熟悉的模型定义语法使得这很容易实现,同时还允许我们创建定制模型。

classifier = DefaultPooledClassifier(hidden_dims=[256, 256], input_dims=1024, output_dims=1)

pooling = DefaultAttentionModule(
  input_dims=1024,
  hidden_dims=[256, 256],
  output_activation=StableSoftmax()
)

# Define the model which is a composition of the featurizer, pooling module and a classifier
model = DefaultMILGraph(featurizer=ShuffleNetV2(), classifier=classifier, pooling = pooling)

由于这些模型是端到端训练的,它们提供了一种强大的方式,可以直接从千兆像素的全载玻片图像得到一个单一标签。鉴于它们对不同生物问题的广泛适用性,它们的实现和部署有两个重要方面:

  1. 对管道的每个部分(包括数据加载器、模型的模块化部分以及它们之间的交互)进行可配置控制。
  2. 能够快速迭代“构思-实现-实验-生产化”循环。

PyTorch 在 MIL 建模方面具有各种优势。它提供了一种直观的方式来创建具有灵活控制流的动态计算图,这对于快速研究实验非常有用。map-style 数据集、可配置的采样器和 batch-samplers 使我们能够定制构建块包的方式,从而加快实验。由于 MIL 模型是 I/O 密集型的,数据并行性和 Pythonic 数据加载器使得任务非常高效且用户友好。最后,PyTorch 的面向对象特性使得构建可重用模块成为可能,这有助于快速实验、可维护的实现以及易于构建管道的组合组件。

使用 PyTorch 中的 GNN 探索空间组织结构

在健康和患病组织中,细胞的空间排列和结构通常与其自身一样重要。例如,在评估肺癌时,病理学家会尝试查看肿瘤细胞的整体分组和结构(它们形成实心片状吗?还是以更小、局限的簇形式出现?)以确定癌症是否属于特定亚型,因为不同亚型的预后可能大相径庭。细胞与其他组织结构之间的这种空间关系可以使用图来建模,以同时捕捉组织拓扑结构和细胞组成。图神经网络 (GNN) 允许学习这些图中的空间模式,这些模式与其他临床变量相关,例如某些癌症中基因的过度表达。

在 2020 年底,当 PathAI 开始在组织样本上使用 GNN 时,PyTorch 通过 PyG 包提供了最好、最成熟的 GNN 功能支持。这使得 PyTorch 成为我们团队的自然选择,因为 GNN 模型是我们知道要探索的重要机器学习概念。

GNN 在组织样本方面的主要附加价值之一是图本身可以揭示仅凭肉眼检查难以发现的空间关系。在我们最近的 AACR 出版物中,我们展示了通过使用 GNN,我们可以更好地理解肿瘤微环境中免疫细胞聚集体(特别是三级淋巴结构,或 TLS)的存在如何影响患者预后。在这种情况下,GNN 方法用于预测与 TLS 存在相关的基因表达,并识别 TLS 区域本身之外与 TLS 相关的组织学特征。在没有机器学习模型辅助的情况下,很难从组织样本图像中识别出此类基因表达的洞见。

我们取得成功的最有前景的 GNN 变体之一是自注意力图汇集。让我们看看如何使用 PyTorch 和 PyG 定义我们的自注意力图汇集 (SAGPool) 模型

class SAGPool(torch.nn.Module):
  def __init__(self, ...):
    super().__init__()
    self.conv1 = GraphConv(in_features, hidden_features, aggr='mean')
    self.convs = torch.nn.ModuleList()
    self.pools = torch.nn.ModuleList()
    self.convs.extend([GraphConv(hidden_features, hidden_features, aggr='mean') for i in range(num_layers - 1)])
    self.pools.extend([SAGPooling(hidden_features, ratio, GNN=GraphConv, min_score=min_score) for i in range((num_layers) // 2)])
    self.jump = JumpingKnowledge(mode='cat')
    self.lin1 = Linear(num_layers * hidden_features, hidden_features)
    self.lin2 = Linear(hidden_features, out_features)
    self.out_activation = out_activation
    self.dropout = dropout

在上面的代码中,我们首先定义一个卷积图层,然后添加两个模块列表层,这允许我们传入可变数量的层。然后我们取出空的模块列表并添加可变数量的 GraphConv 层,然后是可变数量的 SAGPooling 层。我们通过添加跳跃知识层 (JumpingKnowledge Layer)、两个线性层、我们的激活函数和 dropout 值来完成 SAGPool 的定义。PyTorch 直观的语法使我们能够抽象出处理 SAG Pooling 等最先进方法的复杂性,同时保留我们熟悉的一般模型开发方法。

像我们上面描述的 SAG Pooling 模型只是 PyTorch 的 GNN 如何让我们探索新的和新颖想法的一个例子。我们最近还探索了多模态 CNN - GNN 混合模型,其准确率比传统病理学家共识评分高出 20%。这些创新以及传统 CNN 与 GNN 之间的相互作用再次得益于从研究到生产的模型开发短循环。

改善患者预后

为了实现我们通过 AI 病理学改善患者预后的使命,PathAI 需要依赖一个机器学习开发框架,该框架 (1) 在开发和探索的初始阶段促进快速迭代和轻松扩展(即,模型配置即代码);(2) 将模型训练和推理扩展到海量图像;(3) 轻松可靠地为我们产品的生产用途(临床试验及其他)提供模型服务。正如我们所示,PyTorch 为我们提供了所有这些功能及更多。我们对 PyTorch 的未来感到无比兴奋,迫不及待地想看到我们能用该框架解决哪些其他有影响力的挑战。