作者:Logan Kilpatrick - 高级技术倡导者,Harshith Padigela - 机器学习工程师,Syed Ashar Javed - 机器学习技术主管,Robert Egger - 生物医学数据科学家

​PathAI 是 AI 驱动的病理学(疾病研究)技术工具和服务的领先提供商。我们的平台旨在通过利用机器学习中的现代方法(如图像分割、图神经网络和多实例学习),大幅提高复杂疾病诊断的准确性和治疗效果的衡量标准。

传统的手工病理学容易出现主观性和观察者变异性,这对诊断和药物开发试验可能产生负面影响。在我们深入探讨如何使用 PyTorch 改进我们的诊断工作流程之前,我们先来了解一下不使用机器学习的传统模拟病理学工作流程。

传统生物制药公司的工作方式

生物制药公司采取多种途径来发现新型治疗方法或诊断方法。其中一种途径严重依赖于病理切片的分析,以回答各种问题:特定的细胞通讯通路如何工作?特定的疾病状态是否与特定蛋白质的存在或缺失有关?为什么临床试验中的特定药物对某些患者有效,而对另一些患者无效?患者预后与新型生物标志物之间是否存在关联?

为了帮助回答这些问题,生物制药公司依靠专业的病理学家来分析切片,并帮助评估他们可能提出的问题。

您可以想象,需要一位专家级认证病理学家才能做出准确的解读和诊断。在一项研究中,将单个活检结果提供给 36 位不同的病理学家,结果出现了 18 种不同的诊断结果,严重程度从无需治疗到必须进行积极治疗不等。病理学家在遇到困难的边缘病例时,也经常征求同事的反馈。鉴于问题的复杂性,即使经过专业的培训和协作,病理学家仍然难以做出正确的诊断。这种潜在的差异可能是药物获得批准与临床试验失败之间的差别。

PathAI 如何利用机器学习为药物开发提供动力

PathAI 开发机器学习模型,为药物开发研发、临床试验和诊断提供见解。为此,PathAI 利用 PyTorch 进行切片级推理,使用了多种方法,包括图神经网络 (GNN) 和多实例学习。在此上下文中,“切片”指的是玻璃切片的完整扫描图像,玻璃切片是用玻璃制成的,中间夹着薄薄的组织切片,染色后可以显示各种细胞形态。PyTorch 使我们使用这些不同方法学的团队能够共享一个通用框架,该框架足够强大,可以在我们所需的所有条件下工作。PyTorch 的高级、命令式和 Pythonic 语法使我们能够快速构建模型原型,并在获得所需结果后将这些模型扩展。

千兆字节图像上的多实例学习

将 ML 应用于病理学的独特挑战之一是图像的巨大尺寸。这些数字切片的分辨率通常可以达到 100,000 x 100,000 像素或更高,大小可达千兆字节。将完整图像加载到 GPU 内存并在其上应用传统的计算机视觉算法几乎是不可能的任务。对完整切片图像 (100k x 100k) 进行注释也需要大量时间和资源,尤其是在注释者需要是领域专家(获得认证的病理学家)的情况下。我们经常构建模型来预测患者切片上的图像级标签(例如癌症的存在),这在整个图像中只覆盖了几千像素。癌变区域有时只是整个切片的一小部分,这使得 ML 问题类似于大海捞针。另一方面,某些问题(如某些组织学生物标志物的预测)需要来自整个切片的信息聚合,这再次因图像尺寸而变得困难。所有这些因素都增加了将 ML 技术应用于病理学问题时的算法、计算和后勤复杂性。

将图像分解成更小的图块,学习图块表示,然后汇集这些表示以预测图像级标签,这是一种解决此问题的方法,如下图所示。一种流行的做法是多实例学习 (MIL)。每个图块都被视为一个“实例”,一组图块构成一个“包”。单个图块表示被汇集在一起,以预测最终的包级标签。在算法上,包中的单个图块实例不需要标签,因此允许我们以弱监督的方式学习包级标签。它们还使用置换不变池化函数,这使得预测与图块的顺序无关,并允许有效地聚合信息。通常,使用基于注意力的池化函数,这不仅可以实现高效的聚合,还可以为包中的每个图块提供注意力值。这些值指示相应图块在预测中的重要性,并且可以可视化以更好地理解模型预测。这种可解释性元素对于推动这些模型在现实世界中的应用非常重要,我们使用加性 MIL 模型等变体来实现这种空间可解释性。在计算方面,MIL 模型规避了将神经网络应用于大型图像尺寸的问题,因为图块表示的获取与图像尺寸无关。

在 PathAI,我们使用基于深度网络的自定义 MIL 模型来预测图像级标签。此过程的概述如下

  1. 使用不同的采样方法从切片中选择图块。
  2. 基于随机采样或启发式规则构建图块包。
  3. 基于预训练模型或大规模表示学习模型,为每个实例生成图块表示。
  4. 应用置换不变池化函数以获得最终的切片级分数。

现在我们已经了解了 PyTorch 中 MIL 的一些高级细节,让我们看一些代码,了解使用 PyTorch 从构思到生产代码有多么简单。我们首先定义采样器、转换和我们的 MIL 数据集

# Create a bag sampler which randomly samples patches from a slide
bag_sampler = RandomBagSampler(bag_size=12)

# Setup the transformations
crop_transform = FlipRotateCenterCrop(use_flips=True)

# Create the dataset which loads patches for each bag
train_dataset = MILDataset(
  bag_sampler=bag_sampler,
  samples_loader=sample_loader,
  transform=crop_transform,
)

在我们定义了采样器和数据集之后,我们需要定义我们将使用所述数据集实际训练的模型。PyTorch 熟悉的模型定义语法使这变得容易,同时也允许我们同时创建定制模型。

classifier = DefaultPooledClassifier(hidden_dims=[256, 256], input_dims=1024, output_dims=1)

pooling = DefaultAttentionModule(
  input_dims=1024,
  hidden_dims=[256, 256],
  output_activation=StableSoftmax()
)

# Define the model which is a composition of the featurizer, pooling module and a classifier
model = DefaultMILGraph(featurizer=ShuffleNetV2(), classifier=classifier, pooling = pooling)

由于这些模型是端到端训练的,因此它们提供了一种强大的方法,可以直接从千兆像素的全切片图像转到单个标签。由于它们对不同生物学问题的广泛适用性,因此它们的实施和部署的两个方面非常重要

  1. 可配置地控制管道的每个部分,包括数据加载器、模型的模块化部分及其彼此之间的交互。
  2. 能够快速迭代构思-实施-实验-产品化循环。

在 MIL 建模方面,PyTorch 具有多种优势。它提供了一种直观的方式来创建具有灵活控制流的动态计算图,这非常适合快速研究实验。地图式数据集、可配置的采样器和批采样器允许我们自定义如何构建图块包,从而加快实验速度。由于 MIL 模型是 IO 密集型的,因此数据并行性和 Pythonic 数据加载器使任务非常高效且用户友好。最后,PyTorch 的面向对象性质支持构建可重用的模块,这有助于快速实验、可维护的实施以及轻松构建管道的组成部分。

使用 PyTorch 中的 GNN 探索空间组织组织

在健康和患病组织中,细胞的空间排列和结构通常与细胞本身一样重要。例如,在评估肺癌时,病理学家会尝试查看肿瘤细胞的整体分组和结构(它们是否形成实体片?还是以较小的局部簇形式出现?)以确定癌症是否属于特定的亚型,这些亚型可能具有非常不同的预后。细胞和其他组织结构之间的这种空间关系可以使用图来建模,以同时捕获组织拓扑和细胞组成。图神经网络 (GNN) 允许学习这些图中与其他临床变量相关的空间模式,例如某些癌症中基因的过度表达。

在 2020 年底,当 PathAI 开始在组织样本上使用 GNN 时,PyTorch 通过 PyG 包对 GNN 功能提供了最佳和最成熟的支持。鉴于 GNN 模型是我们知道将成为我们想要探索的重要 ML 概念的东西,这使得 PyTorch 成为我们团队的自然选择。

GNN 在组织样本背景下的主要附加值之一是,图本身可以揭示仅凭视觉检查很难发现的空间关系。在我们最近的AACR 出版物中,我们表明,通过使用 GNN,我们可以更好地了解肿瘤微环境中免疫细胞聚集体(特别是三级淋巴结构,或 TLS)的存在如何影响患者的预后。在本例中,GNN 方法用于预测与 TLS 存在相关的基因表达,并识别与 TLS 相关的 TLS 区域本身之外的组织学特征。如果不借助 ML 模型,很难从组织样本图像中识别出对基因表达的此类见解。

我们已经成功使用过的最有前途的 GNN 变体之一是自注意力图池化。让我们看看如何使用 PyTorch 和 PyG 定义我们的自注意力图池化 (SAGPool) 模型

class SAGPool(torch.nn.Module):
  def __init__(self, ...):
    super().__init__()
    self.conv1 = GraphConv(in_features, hidden_features, aggr='mean')
    self.convs = torch.nn.ModuleList()
    self.pools = torch.nn.ModuleList()
    self.convs.extend([GraphConv(hidden_features, hidden_features, aggr='mean') for i in range(num_layers - 1)])
    self.pools.extend([SAGPooling(hidden_features, ratio, GNN=GraphConv, min_score=min_score) for i in range((num_layers) // 2)])
    self.jump = JumpingKnowledge(mode='cat')
    self.lin1 = Linear(num_layers * hidden_features, hidden_features)
    self.lin2 = Linear(hidden_features, out_features)
    self.out_activation = out_activation
    self.dropout = dropout

在上面的代码中,我们首先定义一个卷积图层,然后添加两个模块列表层,这使我们可以传入可变数量的层。然后,我们采用我们的空模块列表并附加可变数量的 GraphConv 层,后跟可变数量的 SAGPooling 层。我们通过添加 JumpingKnowledge 层、两个线性层、我们的激活函数和我们的 dropout 值来完成我们的 SAGPool 定义。PyTorch 直观的语法使我们能够抽象出使用 SAG Pooling 等最先进方法的复杂性,同时保持我们熟悉的模型开发常用方法。

上面描述的 SAG Pool 模型只是 GNN 与 PyTorch 如何使我们能够探索新的和新颖想法的一个示例。我们最近还探索了多模态 CNN - GNN 混合模型,其准确度比传统病理学家共识评分高 20%。这些创新以及传统 CNN 和 GNN 之间的相互作用再次得益于从研究到生产的模型开发循环的缩短。

改善患者预后

为了实现我们通过 AI 驱动的病理学改善患者预后的使命,PathAI 需要依赖一个 ML 开发框架,该框架 (1) 在开发和探索的初始阶段促进快速迭代和轻松扩展(即模型配置即代码),(2) 将模型训练和推理扩展到海量图像,(3) 轻松而稳健地为我们产品的生产用途(在临床试验及其他方面)提供模型服务。正如我们所展示的,PyTorch 为我们提供了所有这些功能以及更多功能。我们对 PyTorch 的未来感到非常兴奋,并且迫不及待地想看看我们还可以使用该框架解决哪些有影响力的挑战。