Saliency map的实现
作者:互联网
import PIL, torch, torchvision import matplotlib.pyplot as plt import sys import pandas as pd # 标准化 def normalize(image): return (image - image.min()) / (image.max() - image.min()) def show_saliency_map(img_path, model, size=100, cmap=plt.cm.hot): # evaluate模式 model.eval() # 图像变换 aug1 = torchvision.transforms.Compose( [torchvision.transforms.Resize((size, size)), torchvision.transforms.ToTensor()]) aug2 = torchvision.transforms.Resize((size, size)) aug3 = torchvision.transforms.ToPILImage() # 读取一张图片 img = PIL.Image.open(img_path) img = img.convert("RGB") # 变换 timg = aug1(img).view(1, 3, size, size) # 梯度 timg.requires_grad = True # 正向传播得到output output = model(timg) # 获取预测概率最大的index timg_class = output.argmax(dim=1).item() # 1000类dict pd_data = pd.read_csv('./1000class_dict.csv') pd_data_en = pd_data.iloc[:, 3] class_index_en = pd_data_en.to_dict() pd_data_zh = pd_data.iloc[:, 2] class_index_zh = pd_data_zh.to_dict() print(class_index_zh[timg_class],class_index_en[timg_class]) # 找到output的对应fc输出单元 s = output[0, timg_class] # 反向传播求此单元梯度 s.backward() with torch.no_grad(): # 得到了梯度 grad = timg.grad.data[0] # 对梯度图处理,取绝对值,求像素通道最大值 graph = torch.max(torch.abs(grad), dim=0)[0] # [0]是max_value [1]是max_index lambd = 0.1 # paper中的方法 saliency_map_gray = (graph - lambd * (torch.norm(timg, 2) ** 2).item()).numpy() # 直接梯度求绝对值 saliency_map_rgb = timg.grad.abs().cpu() # 将每个通道归一化 saliency_map_rgb = torch.stack([normalize(item) for item in saliency_map_rgb]) fig, ax = plt.subplots(1, 3) raw_img = aug2(img) ax[0].imshow(raw_img) ax[0].set_title(class_index_en[timg_class]) rgb_saliency = aug3(saliency_map_rgb.view(3, size, size)) ax[1].imshow(rgb_saliency) ax[1].set_title('RGB map') ax[2].imshow(saliency_map_gray, cmap=cmap) ax[2].set_title('gray map') plt.show() img = './panda.png' model = torchvision.models.resnet18(pretrained=True) show_saliency_map(img, model, size = 224)
参考:Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps , https://arxiv.org/abs/1312.6034
标签:map,img,Saliency,class,saliency,timg,实现,size 来源: https://www.cnblogs.com/mydrizzle/p/13977924.html