注意
点击 这里 下载完整的示例代码
模型可解释性示例¶
这是一个使用 captum 分析模型可解释性目的的输入的 TorchX 应用程序示例。它使用训练器应用程序示例中的训练模型和数据预处理应用程序示例中的预处理示例。输出是一系列带有集成梯度属性叠加在其上的图像。
有关使用 captum 的更多信息,请参见 https://captum.ai/tutorials/CIFAR_TorchVision_Interpret。
使用¶
在本地将此主模块作为 Python 进程运行。下面的运行假设模型已使用 torchx/examples/apps/lightning/train.py
中的使用说明进行了训练。
$ torchx run -s local_cwd utils.python
--script ./lightning/interpret.py
--
--load_path /tmp/torchx/train/last.ckpt
--output_path /tmp/torchx/interpret
使用图像查看器可视化 *.png
文件,这些文件生成在 output_path
下。
注意
对于使用 TorchX 的 utils.python
内置功能的本地运行,实际上等效于直接运行主模块(例如 python ./interpret.py
)。使用 TorchX 启动简单的单进程 Python 程序的好处是,可以通过将 -s local_cwd
替换为远程调度器(如 Kubernetes)来在远程调度器上启动,方法是指定 -s kubernetes
。
import argparse
import itertools
import os.path
import sys
import tempfile
from typing import List
import fsspec
import torch
from torchx.examples.apps.lightning.data import (
create_random_data,
download_data,
TinyImageNetDataModule,
)
from torchx.examples.apps.lightning.model import TinyImageNetModel
# ensure data and module are on the path
sys.path.append(".")
# FIXME: captum must be imported after torch otherwise it causes python to crash
if True:
import numpy as np
from captum.attr import IntegratedGradients, visualization as viz
def parse_args(argv: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="example TorchX captum app")
parser.add_argument(
"--load_path",
type=str,
help="checkpoint path to load model weights from",
required=True,
)
parser.add_argument(
"--data_path",
type=str,
help="path to load the training data from, if not provided, random dataset will be created",
)
parser.add_argument(
"--output_path",
type=str,
help="path to place analysis results",
required=True,
)
return parser.parse_args(argv)
def convert_to_rgb(arr: torch.Tensor) -> np.ndarray: # pyre-ignore[24]
"""
This converts the image from a torch tensor with size (1, 1, 64, 64) to
numpy array with size (64, 64, 3).
"""
out = arr.squeeze().swapaxes(0, 2)
assert out.shape == (64, 64, 3), "invalid shape produced"
return out.numpy()
def main(argv: List[str]) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
args = parse_args(argv)
# Init our model
model = TinyImageNetModel()
print(f"loading checkpoint: {args.load_path}...")
model.load_from_checkpoint(checkpoint_path=args.load_path)
# Download and setup the data module
if not args.data_path:
data_path = os.path.join(tmpdir, "data")
os.makedirs(data_path)
create_random_data(data_path)
else:
data_path = download_data(args.data_path, tmpdir)
data = TinyImageNetDataModule(
data_dir=data_path,
batch_size=1,
)
ig = IntegratedGradients(model)
data.setup("test")
dataloader = data.test_dataloader()
# process first 5 images
for i, (input, label) in enumerate(itertools.islice(dataloader, 5)):
print(f"analyzing example {i}")
# input = input.unsqueeze(dim=0)
model.zero_grad()
attr_ig, delta = ig.attribute(
input,
target=label,
baselines=input * 0,
return_convergence_delta=True,
)
if attr_ig.count_nonzero() == 0:
# Our toy model sometimes has no IG results.
print("skipping due to zero gradients")
continue
fig, axis = viz.visualize_image_attr(
convert_to_rgb(attr_ig),
convert_to_rgb(input),
method="blended_heat_map",
sign="all",
show_colorbar=True,
title="Overlayed Integrated Gradients",
)
out_path = os.path.join(args.output_path, f"ig_{i}.png")
print(f"saving heatmap to {out_path}")
with fsspec.open(out_path, "wb") as f:
fig.savefig(f)
if __name__ == "__main__" and "NOTEBOOK" not in globals():
main(sys.argv[1:])
# sphinx_gallery_thumbnail_path = '_static/img/gallery-app.png'
脚本的总运行时间: ( 0 分钟 0.000 秒)