快捷方式

模型可解释性示例

这是一个使用 captum 分析模型可解释性目的的输入的 TorchX 应用程序示例。它使用训练器应用程序示例中的训练模型和数据预处理应用程序示例中的预处理示例。输出是一系列带有集成梯度属性叠加在其上的图像。

有关使用 captum 的更多信息,请参见 https://captum.ai/tutorials/CIFAR_TorchVision_Interpret

使用

在本地将此主模块作为 Python 进程运行。下面的运行假设模型已使用 torchx/examples/apps/lightning/train.py 中的使用说明进行了训练。

$ torchx run -s local_cwd utils.python
    --script ./lightning/interpret.py
    --
    --load_path /tmp/torchx/train/last.ckpt
    --output_path /tmp/torchx/interpret

使用图像查看器可视化 *.png 文件,这些文件生成在 output_path 下。

注意

对于使用 TorchX 的 utils.python 内置功能的本地运行,实际上等效于直接运行主模块(例如 python ./interpret.py)。使用 TorchX 启动简单的单进程 Python 程序的好处是,可以通过将 -s local_cwd 替换为远程调度器(如 Kubernetes)来在远程调度器上启动,方法是指定 -s kubernetes

import argparse
import itertools
import os.path
import sys
import tempfile
from typing import List

import fsspec
import torch
from torchx.examples.apps.lightning.data import (
    create_random_data,
    download_data,
    TinyImageNetDataModule,
)
from torchx.examples.apps.lightning.model import TinyImageNetModel


# ensure data and module are on the path
sys.path.append(".")


# FIXME: captum must be imported after torch otherwise it causes python to crash
if True:
    import numpy as np
    from captum.attr import IntegratedGradients, visualization as viz


def parse_args(argv: List[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="example TorchX captum app")
    parser.add_argument(
        "--load_path",
        type=str,
        help="checkpoint path to load model weights from",
        required=True,
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="path to load the training data from, if not provided, random dataset will be created",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        help="path to place analysis results",
        required=True,
    )

    return parser.parse_args(argv)


def convert_to_rgb(arr: torch.Tensor) -> np.ndarray:  # pyre-ignore[24]
    """
    This converts the image from a torch tensor with size (1, 1, 64, 64) to
    numpy array with size (64, 64, 3).
    """
    out = arr.squeeze().swapaxes(0, 2)
    assert out.shape == (64, 64, 3), "invalid shape produced"
    return out.numpy()


def main(argv: List[str]) -> None:
    with tempfile.TemporaryDirectory() as tmpdir:
        args = parse_args(argv)

        # Init our model
        model = TinyImageNetModel()

        print(f"loading checkpoint: {args.load_path}...")
        model.load_from_checkpoint(checkpoint_path=args.load_path)

        # Download and setup the data module
        if not args.data_path:
            data_path = os.path.join(tmpdir, "data")
            os.makedirs(data_path)
            create_random_data(data_path)
        else:
            data_path = download_data(args.data_path, tmpdir)
        data = TinyImageNetDataModule(
            data_dir=data_path,
            batch_size=1,
        )

        ig = IntegratedGradients(model)

        data.setup("test")
        dataloader = data.test_dataloader()

        # process first 5 images
        for i, (input, label) in enumerate(itertools.islice(dataloader, 5)):
            print(f"analyzing example {i}")
            # input = input.unsqueeze(dim=0)
            model.zero_grad()
            attr_ig, delta = ig.attribute(
                input,
                target=label,
                baselines=input * 0,
                return_convergence_delta=True,
            )

            if attr_ig.count_nonzero() == 0:
                # Our toy model sometimes has no IG results.
                print("skipping due to zero gradients")
                continue

            fig, axis = viz.visualize_image_attr(
                convert_to_rgb(attr_ig),
                convert_to_rgb(input),
                method="blended_heat_map",
                sign="all",
                show_colorbar=True,
                title="Overlayed Integrated Gradients",
            )
            out_path = os.path.join(args.output_path, f"ig_{i}.png")
            print(f"saving heatmap to {out_path}")
            with fsspec.open(out_path, "wb") as f:
                fig.savefig(f)


if __name__ == "__main__" and "NOTEBOOK" not in globals():
    main(sys.argv[1:])


# sphinx_gallery_thumbnail_path = '_static/img/gallery-app.png'

脚本的总运行时间: ( 0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源