当前位置:网站首页>Use pytorch to build mobilenetv2 and learn and train based on migration
Use pytorch to build mobilenetv2 and learn and train based on migration
2022-06-25 08:13:00 【@BangBang】
MobileNetV2 The network structure is as follows , For a detailed explanation of the network, please refer to the blog :MobileNet series (2):MobileNet-V2 Network details 
From the network structure of the table, we can see , The model is basically a stacked inverse residual structure (bottleneck), And then through 1x1 Common convolution kernel operation of , The next step is to pool the core into 7x7 Average pooled sampling , Finally through 1x1 Convolution yields the final output . The key to building this network is Inverse residual structure , As long as it is built Inverse residual structure , It is very convenient to build the network .
pytorch The network structures,
stay model.py In file , First, define the basic components of the network .
stay mobilenet v2 Convolution in the network is basically through :Conv+BN+ReLU6 Composed of .
Convolution component
Conv+BN+ReLU6
class ConvBNReLU(nn.Sequential):
def __init__(self,in_channel,out_channel,kernel_size,stride=1,groups=1):
padding=(kernel_size-1) // 2
super(ConvBNReLU,self).__init__(
nn.Conv2d(in_channel,out_channel,kernel_size,stride,padding,groups=groups,bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU6(inplace=True)
)
Be careful groups=1 It means that the construction is a normal convolution , If groups be equal to in_channel, So it's going to be DW Convolution . Because to use BN layer , therefore bias It's not used , Set to False
Inverse residual structure
Define a InvertedResidual class , It inherits from nn.Moudle The parent class . The network diagram of inverse residual structure is as follows :
The structure of inverse residual network is similar to that of ordinary residual network , The ordinary residual structure is a structure with thick ends and thin middle , On the contrary, the structure of inverse residuals is thin at both ends and thick in the middle . See :MobileNet series (2):MobileNet-V2 Network details ,DW The number of convolutions is an input channel It's the same , Every DW The convolution layer is responsible for only one channel. So after DW No change after convolution channel Size .
class InvertedResidual(nn.Module):
def __init__(self,in_channel,out_channel,stride,expand_ratio):
super(InvertResidual,self).__init__()
hidden_channel=in_channel*expand_ratio
self.use_shotcut = stride ==1 and in_channel==out_channel
layers= []
if expand_ratio !=1:
# 1x1 Conv
layers.append(ConvBNReLU(in_channel,hidden_channel,kernel_size=1))
layers.extend([
# 3x3 depthwise conv
ConvBNReLU(hidden_channel,hidden_channel,stride=stride,groups=hidden_channel)
# 1x1 Conv (linear)
nn.Conv2d(hidden_channel,out_channel,kernel_size=1,bias=False)
nn.BatchNorm2d(out_channel)
])
self.conv=nn.Sequential(*layers)
def forward(self,x):
if self.use_shotcut:
return x+ self.conv(x)
else:
return self.conv(x)
MobileNet V2 Network structure
Definition MobileNetV2 class , Inherit nn.Module, The complete network construction code is as follows :
class MobileNetV2(nn.Module):
def __init__(self,num_classes=100,alpha=1.0,round_nearest=8):
super(MobileNetV2,self).__init__()
block=InvertedResidual
input_channel=_make_divisible(32*alpha,round_nearest)
last_channel=_make_divisible(1280*alpha,round_nearest)
inverted_residual_setting = [
# t,c,n,s
[1,16,1,1],
[6,24,2,2],
[6,32,3,2],
[6,64,4,2],
[9,96,3,1],
[6,160,3,2],
[6,320,1,1]
]
features = []
# conv1 layer
features.append(ConvBNReLU(3,input_channel,stride=2))
# build inverted residual blocks
for t,c,n,s in inverted_residual_setting:
# adopt _make_divisible Adjust the number of convolution kernels to round_nearest Integer multiple
output_channels= _make_divisible(c*alpha,round_nearest)
for i in range(n):
stride= s if i==0 else 1
features.append(block(input_channel,output_channel,stride,expand_ratio=t))
input_channel=output_channel
# building last several layers
features.append(ConvBNReLU(input_channel,last_channel,1))
#combine feature layers
self.features=nn.Sequential(*features)
#building classifier
self.avgpool=nn.AdaptiveAvgPool2d((1,1))
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(last_channel,num_classes)
)
# weight initialization
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode='fan_out')
if m.bias is not None:
m.init.zeros_(m.bias)
elif isinstance(m,nn.BatchNorm2d):
nn.init.ones_(m.weight,0,0.01)
nn.init.zeros_(m.bias)
# Positive propagation process
def forward(self,x):
x=self.features(x)
x=self.avgpool(x)
x=torch.flatten(x,1)
x=self.classifier(x)
return x
among _make_divisible function l originate tensorflow Official implementation code :
def _make_divisible(ch,divisor=8,min_ch=None):
""" https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py """
if min_ch is None:
min_ch=divisor
new_ch=max(min_ch,int(ch+divisor/2)//divisor*divisor)
#Make sure that round down dose not go down by more than 10%
if new_ch <0.9 * ch:
new_ch +=divisor
return new_ch
model training
First of all, say , How to download the official pre training model parameters . For example, download. mobilenet Pre training model of
import torchvision.models.mobilenet
Click on torchvision.models.mobilenet Enter the official function definition , Here's one model_urls, This url It is the link to download the pre training weight of the model :
model_urls= {
'mobilenet_v2':'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'
}
Copy the model url Go to Xunlei to download , After downloading, it will be saved in the current project directory , And name :mobilenet_v2.pth
Training scripts
train.py
1. import python package
import torch
import torch.nn as nn
from torchvision import transforms,datasets
import json
import os
import torch.optim as optim
from model import MobileNetV2
2. Data preparation
data_transform= {
"train": transforms.Compose([transforms.RandomResizeCrop(224),
transforms.RandomHorizontalFlip(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
"val":transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}
data_root = os.path.abspath(os.path.join(os.getcwd(),'../..')) #get data root path
image_path=data_root +"/data_set/flower_data/" #flower data set path
train_dataset = datasets.ImageFolder(root=image_path + "train",transform=data_transform["train"])
train_num=len(train_dataset)
#{'daisy':'0','dandelion`:1,'roses':2,'sunflower':3,'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict =dict((val,key) for key,value in flower_list.items())
# write dict into json file
json_str=json.dumps(cla_dict,indent=4)
with open('class_indices.json','w') as json_file:
json_file.write(json_str)
bath_size=16
train_loader=torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,shuffle=True,
num_workers=0)
validate_data=datasets.ImageFolder(root=image_path + "val",
transform=data_transform["val"])
val_num=len(validate_dataset)
validate_loader=torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size,shuffle=False,
num_works=0)
3. Load model
net=MobileNetV2(num_classes=5)
model_weight_path="./mobilenet_v2.pth"
# load pretrain weights
assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
pre_weights = torch.load(model_weight_path, map_location=device)
# delete classifier weights
pre_dict=={
k:v for k,v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
# strict = False Indicates that only the matching weights are read
missing_keys,unexpected_keys=net.load_state_dict(pre_dict,strict=False)
# freeze features weights
for param in net.features.parameters():
param.requires_grad=False
net.to(device)
4. Model training
# define loss function
loss_function=nn.CrossEntropyLoss()
# construct an optimizer
params=[p for p in net.parameters() if p.requires_grad]
optimizer=optim.Adam(params,lr=0.0001)
best_acc=0.0
save_path='./MobileNetV2.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
#train
net.train()
running_loss=0.0
train_bar=tqdm(train_loader)
for step,data in enumerate(train_bar):
images,labels=data
optimizer.zero_grad()
logits=net(images.to(device))
loss=loss_function(logits,labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss +=loss.item()
train_bar.desc="train epoch [{} / {}] loss:{:.3f}".format(epoch+1,epochs,loss)
#validate
net.eval()
acc=0.0 #accumulate accurate number / epoch
with torch.no_grad():
val_bar=tqdm(validate_loader)
for val_data in val_bar:
val_images,val_labels=val_data
outputs = net(val_images.to(device))
# loss = loss_function(outputs,test_labels)
predict_y= torch.max(outputs,dim=1)[1]
acc += torch.eq(predict_y,val_labels.to(device)).sum().item()
val_bar.desc ="valid epoch [{}/{}]".format(epoch+1,epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('Finished Training')
边栏推荐
- Websocket understanding and application scenarios
- Functions should not specify operation types through variables
- Neural network and deep learning-3-simple example of machine learning pytorch
- First experience Amazon Neptune, a fully managed map database
- 共话云原生数据库的未来
- FM signal, modulated signal and carrier
- 使用pytorch搭建MobileNetV2并基于迁移学习训练
- c#中设置lable控件的TextAlign属性控制文字居中的方法
- 时钟刻度盘的绘制
- WebSocket的理解以及应用场景
猜你喜欢

First experience Amazon Neptune, a fully managed map database

电子学:第014课——实验 15:防入侵报警器(第一部分)

Matlab代码格式一键美化神器

Electronics: Lesson 010 - Experiment 9: time and capacitors

剑指offer刷题(中等等级)

时钟刻度盘的绘制

Talk about the future of cloud native database

电子学:第013课——实验 14:可穿戴的脉冲发光体

Electronics: Lesson 011 - experiment 10: transistor switches

网络模型——OSI模型与TCP/IP模型
随机推荐
电子学:第010课——实验 9:时间与电容器
线程+线程问题记录
网络模型——OSI模型与TCP/IP模型
ffmpeg+SDL2实现音频播放
电子学:第012课——实验 13:烧烤 LED
Not afraid of losing a hundred battles, but afraid of losing heart
TCP acceleration notes
企业全面云化的时代——云数据库的未来
Authority design of SaaS system based on RBAC
c#磁盘驱动器及文件夹还有文件类的操作
Electronics: Lesson 012 - Experiment 11: light and sound
Bat start NET Core
使用apt-get命令如何安装软件?
Dietary intervention reduces cancer treatment-related symptoms and toxicity
Cloud computing exam version 1 0
Matlab代码格式一键美化神器
不怕百战失利,就怕灰心丧气
C disk drives, folders and file operations
[supplementary question] 2021 Niuke summer multi school training camp 9-N
Luogu p5994 [pa2014]kuglarz (XOR thinking +mst)