快捷方式

使用 Captum 进行模型可解释性

创建日期:2020年4月14日 | 最后更新:2023年9月26日 | 最后验证:未验证

Captum 帮助您理解数据特征如何影响您的模型预测或神经元激活,阐明模型的工作原理。

使用 Captum,您可以以统一的方式应用各种最先进的特征归因算法,例如 Guided GradCamIntegrated Gradients

在本攻略中,您将学习如何使用 Captum 来

  • 将图像分类器的预测归因于相应的图像特征。

  • 可视化归因结果。

开始之前

请确保 Captum 已安装在您当前的 Python 环境中。Captum 可从 GitHub、pip 包或 conda 包获取。详细说明请参阅安装指南:https://captum.ai/

对于模型,我们使用 PyTorch 内置的图像分类器。Captum 可以揭示样本图像的哪些部分支持模型做出的特定预测。

import torchvision
from torchvision import models, transforms
from PIL import Image
import requests
from io import BytesIO

model = torchvision.models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).eval()

response = requests.get("https://image.freepik.com/free-photo/two-beautiful-puppies-cat-dog_58409-6024.jpg")
img = Image.open(BytesIO(response.content))

center_crop = transforms.Compose([
 transforms.Resize(256),
 transforms.CenterCrop(224),
])

normalize = transforms.Compose([
    transforms.ToTensor(),               # converts the image to a tensor with values between 0 and 1
    transforms.Normalize(                # normalize to follow 0-centered imagenet pixel RGB distribution
     mean=[0.485, 0.456, 0.406],
     std=[0.229, 0.224, 0.225]
    )
])
input_img = normalize(center_crop(img)).unsqueeze(0)

计算归因

模型的前 3 个预测中,类别 208 和 283 分别对应狗和猫。

让我们使用 Captum 的 Occlusion 算法将这些预测分别归因于输入的相应部分。

from captum.attr import Occlusion

occlusion = Occlusion(model)

strides = (3, 9, 9)               # smaller = more fine-grained attribution but slower
target=208,                       # Labrador index in ImageNet
sliding_window_shapes=(3,45, 45)  # choose size enough to change object appearance
baselines = 0                     # values to occlude the image with. 0 corresponds to gray

attribution_dog = occlusion.attribute(input_img,
                                       strides = strides,
                                       target=target,
                                       sliding_window_shapes=sliding_window_shapes,
                                       baselines=baselines)


target=283,                       # Persian cat index in ImageNet
attribution_cat = occlusion.attribute(input_img,
                                       strides = strides,
                                       target=target,
                                       sliding_window_shapes=sliding_window_shapes,
                                       baselines=0)

除了 Occlusion,Captum 还提供了许多算法,例如 Integrated GradientsDeconvolutionGuidedBackpropGuided GradCamDeepLiftGradientShap。所有这些算法都是 Attribution 的子类,它们在初始化时需要您的模型作为一个可调用的 forward_func,并且都有一个 attribute(...) 方法,以统一的格式返回归因结果。

让我们可视化图像的计算归因结果。

可视化结果

Captum 的 visualization 工具提供了开箱即用的方法来可视化图像输入和文本输入的归因结果。

import numpy as np
from captum.attr import visualization as viz

# Convert the compute attribution tensor into an image-like numpy array
attribution_dog = np.transpose(attribution_dog.squeeze().cpu().detach().numpy(), (1,2,0))

vis_types = ["heat_map", "original_image"]
vis_signs = ["all", "all"] # "positive", "negative", or "all" to show both
# positive attribution indicates that the presence of the area increases the prediction score
# negative attribution indicates distractor areas whose absence increases the score

_ = viz.visualize_image_attr_multiple(attribution_dog,
                                      np.array(center_crop(img)),
                                      vis_types,
                                      vis_signs,
                                      ["attribution for dog", "image"],
                                      show_colorbar = True
                                     )


attribution_cat = np.transpose(attribution_cat.squeeze().cpu().detach().numpy(), (1,2,0))

_ = viz.visualize_image_attr_multiple(attribution_cat,
                                      np.array(center_crop(img)),
                                      ["heat_map", "original_image"],
                                      ["all", "all"], # positive/negative attribution or all
                                      ["attribution for cat", "image"],
                                      show_colorbar = True
                                     )

如果您的数据是文本,visualization.visualize_text() 提供了一个专门的视图来探索输入文本之上的归因。了解更多请访问:http://captum.ai/tutorials/IMDB_TorchText_Interpret

最终注意事项

Captum 可以处理 PyTorch 中的大多数模型类型,涵盖视觉、文本等模态。使用 Captum,您可以: * 将特定输出归因于模型输入,如上所示。 * 将特定输出归因于隐藏层神经元(参阅 Captum API 参考)。 * 将隐藏层神经元响应归因于模型输入(参阅 Captum API 参考)。

有关支持方法的完整 API 和教程列表,请查阅我们的网站 http://captum.ai

Gilbert Tanner 的另一篇有用文章:https://gilberttanner.com/blog/interpreting-pytorch-models-with-captum

脚本总运行时间: ( 0 分钟 0.000 秒)

画廊由 Sphinx-Gallery 生成

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源