当前位置:网站首页>pytorch_grad_cam——pytorch下的模型特征(Class Activation Mapping, CAM)可视化库
pytorch_grad_cam——pytorch下的模型特征(Class Activation Mapping, CAM)可视化库
2022-06-27 01:58:00 【万里鹏程转瞬至】
深度学习是一个"黑盒"系统。它通过“end-to-end”的方式来工作,中间过程是不可知的,通过中间特征可视化可以对模型的数据进行一定的解释。最早的特征可视化是通过在模型最后一个conv层的Global average pooling实现,并将分类层设置为单个全连接层。通过Global average pooling的值来确定各个feature map的权重,然后累加到一起实现可视化。后来有衍生出了一系列,基于特定class label反向传播获取梯度的可视化方法,Grad-CAM。更为详细的发展路线可以参考万字长文:特征可视化技术(CAM) - 知乎
在这里,博主只是想简单的分享一个pytorch下的CAM可视化库的使用,通过使用该库,只需要简单的几行代码就可以实现CAM可视化。此外,博主也基于hook自行实现了GradCAM,代码在本文的最后面,两段代码复制到一起即可。基于对分类或定位错误样本的CAM可视化,我们可以快速的定位到模型的症状,有选择的调整数据,从而增强模型的预测精度。GitHub - jacobgil/pytorch-grad-cam: Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM博主没有看过该库的源码,但是预计是使用hook技术实现的库(在pytorch模型的前向传播和反向传播过程中,可以对任意一个layer设置hook,拉取数据的状态)。
该库支持以下CAM方法,同时支持图像在线增强使CAM结果更加平滑。
| Method | What it does |
|---|---|
| GradCAM | Weight the 2D activations by the average gradient |
| GradCAM++ | Like GradCAM but uses second order gradients |
| XGradCAM | Like GradCAM but scale the gradients by the normalized activations |
| AblationCAM | Zero out activations and measure how the output drops (this repository includes a fast batched implementation) |
| ScoreCAM | Perbutate the image by the scaled activations and measure how the output drops |
| EigenCAM | Takes the first principle component of the 2D Activations (no class discrimination, but seems to give great results) |
| EigenGradCAM | Like EigenCAM but with class discrimination: First principle component of Activations*Grad. Looks like GradCAM, but cleaner |
| LayerCAM | Spatially weight the activations by positive gradients. Works better especially in lower layers |
| FullGrad | Computes the gradients of the biases from all over the network, and then sums them |
1、安装
pip install pytorch_grad_cam
2、使用
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50
import torchvision
import torch
from matplotlib import pyplot as plt
import numpy as np
def myimshows(imgs, titles=False, fname="test.jpg", size=6):
lens = len(imgs)
fig = plt.figure(figsize=(size * lens,size))
if titles == False:
titles="0123456789"
for i in range(1, lens + 1):
cols = 100 + lens * 10 + i
plt.xticks(())
plt.yticks(())
plt.subplot(cols)
if len(imgs[i - 1].shape) == 2:
plt.imshow(imgs[i - 1], cmap='Reds')
else:
plt.imshow(imgs[i - 1])
plt.title(titles[i - 1])
plt.xticks(())
plt.yticks(())
plt.savefig(fname, bbox_inches='tight')
plt.show()
def tensor2img(tensor,heatmap=False,shape=(224,224)):
np_arr=tensor.detach().numpy()#[0]
#对数据进行归一化
if np_arr.max()>1 or np_arr.min()<0:
np_arr=np_arr-np_arr.min()
np_arr=np_arr/np_arr.max()
#np_arr=(np_arr*255).astype(np.uint8)
if np_arr.shape[0]==1:
np_arr=np.concatenate([np_arr,np_arr,np_arr],axis=0)
np_arr=np_arr.transpose((1,2,0))
return np_arr
path=r"D:\\daxiang.jpg"
bin_data=torchvision.io.read_file(path)#加载二进制数据
img=torchvision.io.decode_image(bin_data)/255#解码成CHW的图片
img=img.unsqueeze(0)#变成BCHW的数据,B==1; squeeze
input_tensor=torchvision.transforms.functional.resize(img,[224, 224])
#对图像进行水平翻转,得到两个数据
input_tensors=torch.cat([input_tensor, input_tensor.flip(dims=(3,))],axis=0)
model = resnet50(pretrained=True)
target_layers = [model.layer4[-1]]#如果传入多个layer,cam输出结果将会取均值
#cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
with GradCAM(model=model, target_layers=target_layers, use_cuda=False) as cam:
targets = [ClassifierOutputTarget(386),ClassifierOutputTarget(386)] #指定查看class_num为386的热力图
# aug_smooth=True, eigen_smooth=True 使用图像增强是热力图变得更加平滑
grayscale_cams = cam(input_tensor=input_tensors, targets=targets)#targets=None 自动调用概率最大的类别显示
for grayscale_cam,tensor in zip(grayscale_cams,input_tensors):
#将热力图结果与原图进行融合
rgb_img=tensor2img(tensor)
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
myimshows([rgb_img, grayscale_cam, visualization],["image","cam","image + cam"])代码执行后的输出如图1所示

3、自己实现GradCAM
3.1 基本库导入和函数定义
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.transforms import transforms
from torchvision import models
from torchsummary import summary
from matplotlib import pyplot as plt
import numpy as np
import cv2
def myimshows(imgs, titles=False, fname="test.jpg", size=6):
lens = len(imgs)
fig = plt.figure(figsize=(size * lens,size))
if titles == False:
titles="0123456789"
for i in range(1, lens + 1):
cols = 100 + lens * 10 + i
plt.xticks(())
plt.yticks(())
plt.subplot(cols)
if len(imgs[i - 1].shape) == 2:
plt.imshow(imgs[i - 1], cmap='Reds')
else:
plt.imshow(imgs[i - 1])
plt.title(titles[i - 1])
plt.xticks(())
plt.yticks(())
plt.savefig(fname, bbox_inches='tight')
plt.show()
def tensor2img(tensor,heatmap=False,shape=(224,224)):
np_arr=tensor.detach().numpy()[0]
#对数据进行归一化
if np_arr.max()>1 or np_arr.min()<0:
np_arr=np_arr-np_arr.min()
np_arr=np_arr/np_arr.max()
np_arr=(np_arr*255).astype(np.uint8)
if np_arr.shape[0]==1:
np_arr=np.concatenate([np_arr,np_arr,np_arr],axis=0)
np_arr=np_arr.transpose((1,2,0))
if heatmap:
np_arr = cv2.resize(np_arr, shape)
np_arr = cv2.applyColorMap(np_arr, cv2.COLORMAP_JET) # 将热力图应用于原始图像
return np_arr/255
def backward_hook(module, grad_in, grad_out):
grad_block.append(grad_out[0].detach())
print("backward_hook:",grad_in[0].shape,grad_out[0].shape)
def farward_hook(module, input, output):
fmap_block.append(output)
print("farward_hook:",input[0].shape,output.shape)3.2 实现GradCAM
#加载模型
model = models.resnet18(pretrained=True)
model.eval() # 评估模式
#summary(model,input_size=(3,512,512))
# 注册hook
fh=model.layer4.register_forward_hook(farward_hook)
bh=model.layer4.register_backward_hook(backward_hook)
#定义存储特征和梯度的数组
fmap_block = list()
grad_block = list()
#加载变量并进行预测
path=r"D:\\daxiang.jpg"
bin_data=torchvision.io.read_file(path)#加载二进制数据
img=torchvision.io.decode_image(bin_data)/255#解码成CHW的图片
img=img.unsqueeze(0)#变成BCHW的数据,B==1; squeeze
img=torchvision.transforms.functional.resize(img,[224, 224])
preds=model(img)
print("pred type:",preds.argmax(1))
#构造label,并进行反向传播
clas=386#
trues=torch.ones((1,),dtype=torch.int64)*clas
ce_loss=nn.CrossEntropyLoss()
loss=ce_loss(preds,trues)
loss.backward()
# 卸载hook
fh.remove()
bh.remove()
#取出相应的特征和梯度
layer1_grad=grad_block[-1] #layer1_grad.shape [1, 64, 128, 128]
layer1_fmap=fmap_block[-1]
#将梯度与fmap相乘
cam=layer1_grad[0,0].mul(layer1_fmap[0,0])
for i in range(1,layer1_grad.shape[1]):
cam+=layer1_grad[0,i].mul(layer1_fmap[0,i])
layer1_grad=layer1_grad.sum(1,keepdim=True) #layer1_grad.shape [1, 1, 128, 128]
layer1_fmap=layer1_fmap.sum(1,keepdim=True) #为了统一在tensor2img函数中调用
cam=cam.reshape((1,1,*cam.shape))
#进行可视化
img_np=tensor2img(img)
#layer1_fmap=torchvision.transforms.functional.resize(layer1_fmap,[224, 224])
layer1_grad_np=tensor2img(layer1_grad,heatmap=True,shape=(224,224))
layer1_fmap_np=tensor2img(layer1_fmap,heatmap=True,shape=(224,224))
cam_np=tensor2img(cam,heatmap=True,shape=(224,224))
print("颜色越深(红),表示该区域的值越大")
myimshows([img_np,cam_np,cam_np*0.4+img_np*0.6],['image','cam','cam + image']) 代码的执行输出如图2所示

边栏推荐
- XSS attack (note)
- Oracle/PLSQL: Translate Function
- C language -- Design of employee information management system
- Oracle/PLSQL: Rpad Function
- H5 liquid animation JS special effect code
- dat. gui. JS star circle track animation JS special effect
- snakemake 使用的注意事项
- 宁愿去996也不要待业在家啦!24岁,失业7个月,比上班更惨的,是没班可上
- Oracle/PLSQL: Lower Function
- Leetcode 785: judgment bipartite graph
猜你喜欢

“所有专业都在劝退”,对大学生最友好的竟然是它?

Continuous delivery blue ocean application

Flink学习2:应用场景

执念斩长河暑期规划

Why pass SPIF_ Sendchange flag systemparametersinfo will hang?

Did your case really pass?

ThreadLocal详解

Don't be brainwashed. This is the truth about the wages of 90% of Chinese people

图论知识及其应用初步调研

1.44 inch TFT-LCD display screen mold taking tutorial
随机推荐
h5液体动画js特效代码
Oracle/PLSQL: NumToYMInterval Function
Summary of config mechanism and methods in UVM (1)
Memcached Foundation 12
ThreadLocal详解
Oracle/PLSQL: NumToDSInterval Function
为什么先划分训练集和测试集后归一化?
Canvas particles: mouse following JS effect
jwt的认证流程和使用案例
Oracle/PLSQL: NumToYMInterval Function
Hibernate generates SQL based on Dialect
Oracle/PLSQL: To_Clob Function
Oracle/PLSQL: VSize Function
bluecms代码审计入门
Memcached basics 11
Google began to roll itself, AI architecture pathways was blessed, and 20billion generation models were launched
Addition, deletion, modification and query of ymal file
UVM in UVM_ config_ Setting and obtaining DB non-linear
Oracle/PLSQL: Translate Function
Shell脚本系列篇(1) 入门