​PathAI 是领先的病理学(疾病研究)AI 驱动技术工具和服务提供商。我们的平台旨在通过利用图像分割、图神经网络和多示例学习等现代机器学习方法,大幅提高复杂疾病的诊断准确性和疗效评估水平。

传统的病理学手动操作容易受到主观性和观察者差异的影响,这可能会对诊断和药物开发试验产生负面影响。在深入探讨我们如何使用 PyTorch 改进诊断工作流程之前,让我们先介绍一下没有机器学习支持的传统模拟病理学工作流程。

传统生物制药是如何运作的

生物制药公司发现新型疗法或诊断方法有很多途径。其中一条途径在很大程度上依赖于对病理切片的分析,以回答各种问题:特定的细胞通信通路是如何工作的?某种疾病状态是否与特定蛋白质的存在或缺失有关?为什么临床试验中的某种特定药物对某些患者有效,而对另一些患者无效?患者预后与新型生物标志物之间是否存在关联?

为了帮助回答这些问题,生物制药公司依靠专业病理学家来分析切片并协助评估他们可能遇到的疑问。

正如你所想象的,需要一位获得专业委员会认证的病理学家才能做出准确的解读和诊断。在一项研究中,将同一个活检结果交给 36 位不同的病理学家,结果得出了 18 种不同的诊断结论,严重程度从无需治疗到必须进行激进治疗不等。在处理困难的边缘病例时,病理学家也经常征求同事的反馈。考虑到问题的复杂性,即使有专家培训和协作,病理学家在做出正确诊断时仍可能感到困难。这种潜在的偏差可能就是药物获批与临床试验失败之间的区别。

PathAI 如何利用机器学习助力药物开发

PathAI 开发机器学习模型,为药物研发提供见解,助力临床试验并辅助诊断。为此,PathAI 利用 PyTorch 进行切片级别的推断,采用了包括图神经网络 (GNN) 以及多示例学习在内的多种方法。在此背景下,“切片”是指玻璃切片的全尺寸扫描图像,即两片玻璃之间夹着一层薄薄的组织切片,经过染色以显示各种细胞结构。PyTorch 使我们的团队能够利用这些不同的方法共享一个通用框架,该框架足够稳健,可以在我们需要的所有条件下工作。PyTorch 高级、命令式且符合 Python 习惯的语法使我们能够快速对模型进行原型设计,并在获得我们想要的结果后将这些模型扩展到生产规模。

针对千兆像素图像的多示例学习

将机器学习应用于病理学面临的一个独特挑战是图像极其巨大。这些数字切片的分辨率通常可达 100,000 x 100,000 像素甚至更高,且大小可达千兆字节。将完整图像加载到 GPU 内存中并在其上应用传统的计算机视觉算法几乎是不可能的任务。此外,对整张切片图像 (100k x 100k) 进行标注需要相当多的时间和资源,尤其是当标注者必须是领域专家(委员会认证的病理学家)时。我们通常构建模型来预测患者切片上的图像级别标签(如是否存在癌症),这在整张图像中仅覆盖几千个像素。癌变区域有时仅占整张切片的一小部分,这使得机器学习问题类似于“大海捞针”。另一方面,某些问题(如预测特定的组织学生物标志物)需要聚合整张切片的信息,由于图像尺寸巨大,这同样困难。所有这些因素在将机器学习技术应用于病理学问题时,增加了显著的算法、计算和后勤复杂性。

将图像分解为较小的补丁 (patch),学习补丁表示,然后聚合这些表示以预测图像级别标签是解决此问题的一种方法,如下图所示。一种流行的方法称为多示例学习 (MIL)。每个补丁被视为一个“示例”,一组补丁构成一个“包”。将各个补丁的表示聚合在一起以预测最终的包级别标签。从算法上讲,包中的单个补丁示例不需要标签,因此允许我们以弱监督的方式学习包级别标签。它们还使用置换不变池化函数,使得预测与补丁的顺序无关,并允许高效的信息聚合。通常使用基于注意力的池化函数,这不仅可以实现高效聚合,还可以为包中的每个补丁提供注意力值。这些值指示了相应补丁在预测中的重要性,并且可以进行可视化以更好地理解模型预测。这种可解释性元素对于推动这些模型在现实世界中的应用非常重要,我们使用诸如加性 MIL 模型等变体来实现这种空间可解释性。在计算上,MIL 模型避开了在大型图像尺寸上应用神经网络的问题,因为补丁表示是独立于图像尺寸获得的。

在 PathAI,我们使用基于深度网络的自定义 MIL 模型来预测图像级别标签。该过程的概述如下:

  1. 使用不同的采样方法从切片中选择补丁。
  2. 基于随机采样或启发式规则构建一个补丁包。
  3. 基于预训练模型或大规模表示学习模型为每个示例生成补丁表示。
  4. 应用置换不变池化函数以获得最终的切片级别分数。

现在我们已经了解了 PyTorch 中 MIL 的一些高层细节,让我们看看一些代码,了解如何使用 PyTorch 从构思到生产代码是多么简单。我们首先定义一个采样器、转换和我们的 MIL 数据集。

# Create a bag sampler which randomly samples patches from a slide
bag_sampler = RandomBagSampler(bag_size=12)

# Setup the transformations
crop_transform = FlipRotateCenterCrop(use_flips=True)

# Create the dataset which loads patches for each bag
train_dataset = MILDataset(
  bag_sampler=bag_sampler,
  samples_loader=sample_loader,
  transform=crop_transform,
)

定义好采样器和数据集后,我们需要定义将使用所述数据集进行训练的模型。PyTorch 熟悉的模型定义语法使这变得容易,同时也允许我们同时创建定制模型。

classifier = DefaultPooledClassifier(hidden_dims=[256, 256], input_dims=1024, output_dims=1)

pooling = DefaultAttentionModule(
  input_dims=1024,
  hidden_dims=[256, 256],
  output_activation=StableSoftmax()
)

# Define the model which is a composition of the featurizer, pooling module and a classifier
model = DefaultMILGraph(featurizer=ShuffleNetV2(), classifier=classifier, pooling = pooling)

由于这些模型是端到端训练的,它们提供了一种直接从千兆像素全切片图像得出单个标签的强大方法。鉴于它们对不同生物学问题的广泛适用性,其实现和部署的两个方面至关重要:

  1. 对管道的每个部分进行可配置的控制,包括数据加载器、模型的模块化部分及其相互作用。
  2. 能够快速迭代“构思-实现-实验-生产”循环。

PyTorch 在 MIL 建模方面具有多项优势。它提供了一种直观的方法来创建具有灵活控制流的动态计算图,这非常适合快速研究实验。地图样式数据集、可配置的采样器和批量采样器允许我们定制构建补丁包的方式,从而加快实验速度。由于 MIL 模型 IO 密集,数据并行和 Pythonic 数据加载器使任务非常高效且用户友好。最后,PyTorch 的面向对象特性支持构建可重用模块,这有助于快速实验、可维护的实现以及轻松构建管道的组合组件。

使用 PyTorch 中的 GNN 探索空间组织结构

在健康和患病组织中,细胞的空间排列和结构往往与细胞本身一样重要。例如,在评估肺癌时,病理学家会尝试观察肿瘤细胞的整体分组和结构(它们是形成实心片状?还是以较小的局部簇形式出现?),以确定癌症是否属于特定的亚型,而这可能具有截然不同的预后。细胞与其他组织结构之间的这种空间关系可以使用图进行建模,以同时捕捉组织拓扑结构和细胞组成。图神经网络 (GNN) 允许学习这些图中的空间模式,这些模式与其他临床变量相关联,例如某些癌症中基因的过度表达。

2020 年底,当 PathAI 开始在组织样本上使用 GNN 时,PyTorch 通过 PyG 包对 GNN 功能提供了最好且最成熟的支持。鉴于 GNN 模型是我们知道想要探索的重要机器学习概念,这使得 PyTorch 成为我们团队的自然选择。

在组织样本的背景下,GNN 的主要价值在于图本身可以揭示仅通过目视检查很难发现的空间关系。在我们最近的 AACR 出版物中,我们展示了通过使用 GNN,我们可以更好地了解肿瘤微环境中免疫细胞聚集体(特别是三级淋巴结构,或称 TLS)的存在如何影响患者预后。在这种情况下,GNN 方法被用于预测与 TLS 存在相关的基因表达,并识别出 TLS 区域本身之外与 TLS 相关的组织学特征。在没有机器学习模型辅助的情况下,从组织样本图像中识别此类基因表达见解是非常困难的。

我们成功尝试过的最有希望的 GNN 变体之一是 自注意力图池化 (Self Attention Graph Pooling)。让我们看看如何使用 PyTorch 和 PyG 定义我们的自注意力图池化 (SAGPool) 模型。

class SAGPool(torch.nn.Module):
  def __init__(self, ...):
    super().__init__()
    self.conv1 = GraphConv(in_features, hidden_features, aggr='mean')
    self.convs = torch.nn.ModuleList()
    self.pools = torch.nn.ModuleList()
    self.convs.extend([GraphConv(hidden_features, hidden_features, aggr='mean') for i in range(num_layers - 1)])
    self.pools.extend([SAGPooling(hidden_features, ratio, GNN=GraphConv, min_score=min_score) for i in range((num_layers) // 2)])
    self.jump = JumpingKnowledge(mode='cat')
    self.lin1 = Linear(num_layers * hidden_features, hidden_features)
    self.lin2 = Linear(hidden_features, out_features)
    self.out_activation = out_activation
    self.dropout = dropout

在上面的代码中,我们首先定义一个卷积图层,然后添加两个 module list 层,允许我们传入可变数量的层。然后,我们获取空的 module list 并扩展可变数量的 GraphConv 层,随后是可变数量的 SAGPooling 层。我们通过添加 JumpingKnowledge 层、两个线性层、激活函数和 dropout 值来完成 SAGPool 的定义。PyTorch 直观的语法使我们能够在抽象掉 SAG Pooling 等前沿方法复杂性的同时,保持我们熟悉的模型开发通用方法。

像上面描述的 SAG Pool 这样的模型只是 GNN 与 PyTorch 如何让我们探索新想法的一个例子。我们最近还探索了多模态 CNN – GNN 混合模型,其准确率比传统的病理学家共识分数高出 20%。这些创新以及传统 CNN 和 GNN 之间的相互作用,再次通过简短的研究到生产的模型开发循环得以实现。

改善患者预后

为了实现我们通过 AI 驱动的病理学改善患者预后的使命,PathAI 需要依赖一个机器学习开发框架,该框架能够:(1) 在开发和探索的初始阶段促进快速迭代和轻松扩展(即作为代码的模型配置);(2) 将模型训练和推理扩展到海量图像;(3) 轻松且稳健地为我们的产品提供生产用途(在临床试验及其他领域)。正如我们所展示的,PyTorch 为我们提供了所有这些功能,甚至更多。我们对 PyTorch 的未来感到非常兴奋,迫不及待地想看看我们还能利用该框架解决哪些具有影响力的挑战。