当前位置:网站首页>torchvision.models._utils.IntermediateLayerGetter使用教程
torchvision.models._utils.IntermediateLayerGetter使用教程
2022-06-27 09:35:00 【jjw_zyfx】
话不多少看例子
import torch
import torchvision
m = torchvision.models.resnet18(pretrained=True)
print(m)
new_m = torchvision.models._utils.IntermediateLayerGetter(m,
{
'layer1': 'feat1', 'layer3': 'feat2'})
out = new_m(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])
# 输出结果为:
# [('feat1', torch.Size([1, 64, 56, 56])), ('feat2', torch.Size([1, 256, 14, 14]))]
# 由于结果太长就在这说明下:IntermediateLayerGetter的第一个参数就是指模型,即这里是resnet18,
# 第二个参数是个字典,其中'layer1': 'feat1',中的layer1代表resnet18中的layer1即只取到
# resnet18中的layer1层即可,feat1就是你想要自己构建的网络的小模块名。'layer3': 'feat2'
# 同理 layer3表示你要在resnet18中从头开始到layer3层(包含layer3层)结束,并用feat2作为小模块名
完整的输出结果:
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
[('feat1', torch.Size([1, 64, 56, 56])), ('feat2', torch.Size([1, 256, 14, 14]))]
边栏推荐
- 你睡觉时大脑真在自动学习!首个人体实验证据来了:加速1-4倍重放,深度睡眠阶段效果最好...
- 提高效率 Or 增加成本,开发人员应如何理解结对编程?
- Summary of three basic interview questions
- Design of multiple classes
- 多個類的設計
- Decompile the jar package and recompile it into a jar package after modification
- 【SO官方采访】为何使用Rust的开发者如此深爱它
- [system design] proximity service
- Curiosity mechanism in reinforcement learning
- Introduction to websocket protocol
猜你喜欢

ucore lab5

There is no doubt that this is an absolutely elaborate project

提高效率 Or 增加成本,开发人员应如何理解结对编程?

Semi-supervised Learning入门学习——Π-Model、Temporal Ensembling、Mean Teacher简介

ucore lab4

1098 insertion or heap sort (PAT class a)

Source insight 工具安装及使用方法
Shortcut key bug, reproducible (it seems that bug is the required function [funny.Gif])

Improving efficiency or increasing costs, how should developers understand pair programming?

Privacy computing fat offline prediction
随机推荐
Collection framework generic LinkedList TreeSet
Some considerations on operation / method overloading for thread to release lock resources
Imx8qxp DMA resources and usage (unfinished)
巴基斯坦安全部队开展反恐行动 打死7名恐怖分子
This, constructor, static, and inter call must be understood!
【系统设计】邻近服务
How do I get the STW (pause) time of a GC (garbage collector)?
1098 Insertion or Heap Sort(堆排序解释)(PAT甲级)
Fake constructor???
Installation and use of SVN version controller
更改pip镜像源
一次线上移动端报表网络连接失败问题定位与解决
有关二叉树的一些练习题
Prometheus alarm process and related time parameter description
Improving efficiency or increasing costs, how should developers understand pair programming?
1098 insertion or heap sort (PAT class a)
MySQL proficient-01 addition, deletion and modification
Flow chart of Alipay wechat payment business
Quelques exercices sur les arbres binaires
Es update values based on Index Names and index fields