当前位置:网站首页>Use pytorch to build mobilenetv2 and learn and train based on migration

Use pytorch to build mobilenetv2 and learn and train based on migration

2022-06-25 08:13:00 @BangBang

MobileNetV2 The network structure is as follows , For a detailed explanation of the network, please refer to the blog :MobileNet series (2):MobileNet-V2 Network details
 Insert picture description here

chart 1 MobileNet V2 Network architecture

From the network structure of the table, we can see , The model is basically a stacked inverse residual structure (bottleneck), And then through 1x1 Common convolution kernel operation of , The next step is to pool the core into 7x7 Average pooled sampling , Finally through 1x1 Convolution yields the final output . The key to building this network is Inverse residual structure , As long as it is built Inverse residual structure , It is very convenient to build the network .

pytorch The network structures,

stay model.py In file , First, define the basic components of the network .
stay mobilenet v2 Convolution in the network is basically through :Conv+BN+ReLU6 Composed of .

Convolution component

Conv+BN+ReLU6

class ConvBNReLU(nn.Sequential):
	def __init__(self,in_channel,out_channel,kernel_size,stride=1,groups=1):
		padding=(kernel_size-1) // 2
		super(ConvBNReLU,self).__init__(
			nn.Conv2d(in_channel,out_channel,kernel_size,stride,padding,groups=groups,bias=False),
			nn.BatchNorm2d(out_channel),
			nn.ReLU6(inplace=True)
		)

Be careful groups=1 It means that the construction is a normal convolution , If groups be equal to in_channel, So it's going to be DW Convolution . Because to use BN layer , therefore bias It's not used , Set to False

Inverse residual structure

Define a InvertedResidual class , It inherits from nn.Moudle The parent class . The network diagram of inverse residual structure is as follows :
 Insert picture description here
The structure of inverse residual network is similar to that of ordinary residual network , The ordinary residual structure is a structure with thick ends and thin middle , On the contrary, the structure of inverse residuals is thin at both ends and thick in the middle . See :MobileNet series (2):MobileNet-V2 Network details ,DW The number of convolutions is an input channel It's the same , Every DW The convolution layer is responsible for only one channel. So after DW No change after convolution channel Size .

class InvertedResidual(nn.Module):
	def __init__(self,in_channel,out_channel,stride,expand_ratio):
		super(InvertResidual,self).__init__()
		hidden_channel=in_channel*expand_ratio
		self.use_shotcut = stride ==1 and in_channel==out_channel
		layers= []
		if expand_ratio !=1:
			# 1x1 Conv
			layers.append(ConvBNReLU(in_channel,hidden_channel,kernel_size=1))
		layers.extend([
			# 3x3 depthwise conv
			ConvBNReLU(hidden_channel,hidden_channel,stride=stride,groups=hidden_channel)
			# 1x1 Conv (linear)
			nn.Conv2d(hidden_channel,out_channel,kernel_size=1,bias=False)
			nn.BatchNorm2d(out_channel)
		])
		self.conv=nn.Sequential(*layers)
	
	def forward(self,x):
		if self.use_shotcut:
			return x+ self.conv(x)
		else:
			return self.conv(x)

MobileNet V2 Network structure

Definition MobileNetV2 class , Inherit nn.Module, The complete network construction code is as follows :

class MobileNetV2(nn.Module):
	def __init__(self,num_classes=100,alpha=1.0,round_nearest=8):
		super(MobileNetV2,self).__init__()
		block=InvertedResidual
		input_channel=_make_divisible(32*alpha,round_nearest)
		last_channel=_make_divisible(1280*alpha,round_nearest)
		
		inverted_residual_setting = [
			# t,c,n,s
			[1,16,1,1],
			[6,24,2,2],
			[6,32,3,2],
			[6,64,4,2],
			[9,96,3,1],
			[6,160,3,2],
			[6,320,1,1]
		]
		
		features = []
		# conv1 layer
		features.append(ConvBNReLU(3,input_channel,stride=2))
		# build inverted residual blocks
		for t,c,n,s in inverted_residual_setting:
			#  adopt _make_divisible Adjust the number of convolution kernels to round_nearest Integer multiple 
			output_channels= _make_divisible(c*alpha,round_nearest)
			for i in range(n):
				
				stride= s if i==0 else 1 
				features.append(block(input_channel,output_channel,stride,expand_ratio=t))
				input_channel=output_channel
		# building last several layers
		features.append(ConvBNReLU(input_channel,last_channel,1))
		#combine feature layers
		self.features=nn.Sequential(*features)  
		
		#building classifier
		self.avgpool=nn.AdaptiveAvgPool2d((1,1))
		self.classifier = nn.Sequential(
			nn.Dropout(0.2),
			nn.Linear(last_channel,num_classes)
		)
		
		# weight initialization
		for m in self.modules():
			if isinstance(m,nn.Conv2d):
				nn.init.kaiming_normal_(m.weight,mode='fan_out')
				if m.bias is not None:
					m.init.zeros_(m.bias)
				elif isinstance(m,nn.BatchNorm2d):
					nn.init.ones_(m.weight,0,0.01)
					nn.init.zeros_(m.bias)
	#  Positive propagation process 
	def forward(self,x):
		x=self.features(x)
		x=self.avgpool(x)
		x=torch.flatten(x,1)
		x=self.classifier(x)
		return x

among _make_divisible function l originate tensorflow Official implementation code :

def _make_divisible(ch,divisor=8,min_ch=None):
	""" https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py """
	if min_ch is None:
		min_ch=divisor
	new_ch=max(min_ch,int(ch+divisor/2)//divisor*divisor)
	#Make sure that round down dose not go down by more than 10%
	if new_ch <0.9 * ch:
		new_ch +=divisor
	return new_ch

model training

First of all, say , How to download the official pre training model parameters . For example, download. mobilenet Pre training model of

import torchvision.models.mobilenet

Click on torchvision.models.mobilenet Enter the official function definition , Here's one model_urls, This url It is the link to download the pre training weight of the model :

model_urls= {
    
	'mobilenet_v2':'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'
}

Copy the model url Go to Xunlei to download , After downloading, it will be saved in the current project directory , And name :mobilenet_v2.pth

Training scripts

train.py

1. import python package

import torch
import torch.nn as nn
from torchvision import transforms,datasets
import json
import os
import torch.optim as optim
from model import MobileNetV2

2. Data preparation

data_transform= {
    
	"train": transforms.Compose([transforms.RandomResizeCrop(224),
								transforms.RandomHorizontalFlip(),
								transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
	"val":transforms.Compose([transforms.Resize(256),
							  transforms.CenterCrop(224),
							  transforms.ToTensor(),
							  transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}

data_root = os.path.abspath(os.path.join(os.getcwd(),'../..')) #get data root path
image_path=data_root +"/data_set/flower_data/" #flower data set path

train_dataset = datasets.ImageFolder(root=image_path + "train",transform=data_transform["train"])
train_num=len(train_dataset)

#{'daisy':'0','dandelion`:1,'roses':2,'sunflower':3,'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict =dict((val,key) for key,value in flower_list.items())
# write dict into json file
json_str=json.dumps(cla_dict,indent=4)
with open('class_indices.json','w') as json_file:
	json_file.write(json_str)

bath_size=16
train_loader=torch.utils.data.DataLoader(train_dataset,
										batch_size=batch_size,shuffle=True,
										num_workers=0)
validate_data=datasets.ImageFolder(root=image_path + "val",
								  transform=data_transform["val"])
val_num=len(validate_dataset)
validate_loader=torch.utils.data.DataLoader(validate_dataset,
											batch_size=batch_size,shuffle=False,
											num_works=0)

3. Load model

net=MobileNetV2(num_classes=5)
model_weight_path="./mobilenet_v2.pth"
# load pretrain weights
assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
pre_weights = torch.load(model_weight_path, map_location=device)
# delete classifier weights
pre_dict=={
    k:v for k,v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
# strict = False  Indicates that only the matching weights are read 
missing_keys,unexpected_keys=net.load_state_dict(pre_dict,strict=False)

# freeze features weights
for param in net.features.parameters():
	param.requires_grad=False
net.to(device)

4. Model training

# define loss function
loss_function=nn.CrossEntropyLoss()

# construct an optimizer
params=[p for p in net.parameters() if p.requires_grad]
optimizer=optim.Adam(params,lr=0.0001)

best_acc=0.0
save_path='./MobileNetV2.pth'
train_steps = len(train_loader)

for epoch in range(epochs):
	#train
	net.train()
	running_loss=0.0
	train_bar=tqdm(train_loader)
	for step,data in enumerate(train_bar):
		images,labels=data
		optimizer.zero_grad()
		logits=net(images.to(device))
		loss=loss_function(logits,labels.to(device))
		loss.backward()
		optimizer.step()
		
		# print statistics
		running_loss +=loss.item()
		
		train_bar.desc="train epoch [{} / {}] loss:{:.3f}".format(epoch+1,epochs,loss)
	
	#validate
	net.eval()
	acc=0.0 #accumulate accurate number / epoch
	with torch.no_grad():
		val_bar=tqdm(validate_loader)
		for val_data in val_bar:
			val_images,val_labels=val_data
			outputs = net(val_images.to(device))
			# loss = loss_function(outputs,test_labels)
			predict_y= torch.max(outputs,dim=1)[1]
			acc += torch.eq(predict_y,val_labels.to(device)).sum().item()
			
			val_bar.desc ="valid epoch [{}/{}]".format(epoch+1,epochs)
		val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
   print('Finished Training')
原网站

版权声明
本文为[@BangBang]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/176/202206250642353158.html