当前位置:网站首页>Pytoch learning 2 (CNN)
Pytoch learning 2 (CNN)
2022-06-27 13:37:00 【Rat human decline】
CNN neural network
CNN (Convolution Neural Network) Convolution neural network mainly consists of convolution layer , Activation layer , Pooling layer ,Dropout, batch Normal The equal layers are composed in a certain order .
Convolution and convolution layer
Convolution operation is widely used in image processing , Different convolution kernels can extract different features , For example, edge 、 linear 、 Angle and other characteristics . In deep convolutional neural networks , Through convolution operation, low-level to complex features of image can be extracted .
Single input channel , Grayscale image


Multichannel input
Three channel color image

Extend to n dimension

The essence of convolution layer
Input (n,iw,ih), Convolution layer (m,n,kw,kh), Output (m,ow,oh)

1 x 1 Convolution



Pooling layer -padding
padding = 1 The situation of

step -stride
In steps of 2

Pooling layer --pooling
give an example , Maximum pool layer

The overall structure of convolutional neural network

Define a neural network instance
LeNet-5
in the light of 0-9 Of 10 A digital , Carry out classified tasks



import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import os
from torch import optim
'''
Data set preparation
'''
batch_size = 64 # Batch training data 、 Amount of data per batch
DOWNLOAD_MNIST = False # Whether to download data online
# Data preparation
# Mnist digits dataset
if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):# Judge mnist Whether the dataset has been downloaded
# not mnist dir or mnist is empyt dir
DOWNLOAD_MNIST = True
train_dataset = datasets.MNIST(
root = './mnist',
train= True, #download train data
transform = transforms.ToTensor(),
download=DOWNLOAD_MNIST
)
test_dataset = datasets.MNIST(
root='./mnist',
train=False, #download test data False It means downloading the data of the test set
transform=transforms.ToTensor(),
download=DOWNLOAD_MNIST
)
# The interface is mainly used to read or output the data of the user-defined interface PyTorch Input of existing data reading interface
# according to batch size Encapsulated into Tensor, Later, it only needs to be repackaged into Variable It can be used as the input of the model
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) #shuffle Whether to disrupt the loading data
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
'''
Build a neural network model
'''
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# Build a convolution layer C1 and Pooling layer S2
self.conv1 = nn.Sequential(
nn.Conv2d(1,6,kernel_size=5,stride=1,padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2,padding=0)
)
# Build a convolution layer C3 and Pooling layer S4
self.conv2 = nn.Sequential(
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2,padding=0)
)
# Build a full connection layer C5 Fully connected layer F6 Output layer
self.fc = nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
# Set up network forward propagation , According to the order
def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # Used in all connection layers nn.Linear() Linear structure , The input and output dimensions are all one-dimensional , Therefore, it is necessary to pull the data into one dimension
x = self.fc(x)
return x
net = LeNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # If... Is detected GPU The environment uses GPU, Otherwise use CPU
net = LeNet().to(device) # Instantiate the network , Yes GPU Put the network into GPU Speed up
'''
Error and optimization
'''
loss_fuc = nn.CrossEntropyLoss() # Multiple classification problem , Choose the cross entropy loss function
optimizer = optim.SGD(net.parameters(),lr = 0.001,momentum = 0.9) # choice SGD, The learning rate is taken as 0.001
'''
Training process
'''
# Start training
EPOCH = 8 # The number of iterations
for epoch in range(EPOCH):
sum_loss = 0
# data fetch
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device) # Yes GPU Then put the data into GPU Speed up
# Gradient clear
optimizer.zero_grad()
# Transmission loss + Update parameters
output = net(inputs)
loss = loss_fuc(output, labels)
loss.backward()
optimizer.step()
# Every training 100 individual batch Print once average loss
sum_loss += loss.item()
if i % 100 == 99:
print('[Epoch:%d, batch:%d] train loss: %.03f' % (epoch + 1, i + 1, sum_loss / 100))
sum_loss = 0.0
correct = 0
total = 0
for data in test_loader:
test_inputs, labels = data
test_inputs, labels = test_inputs.to(device), labels.to(device)
outputs_test = net(test_inputs)
_, predicted = torch.max(outputs_test.data, 1) # Output the class with the highest score
total += labels.size(0) # Statistics 50 individual batch Total number of pictures
correct += (predicted == labels).sum() # Statistics 50 individual batch Number of correct classifications
print(' The first {} individual epoch The recognition accuracy of is :{}%'.format(epoch + 1, 100 * correct.item() / total))
# Model preservation
torch.save(net.state_dict(), 'E:\\ Graduate student \\ digital image processing \\python\\pytorch2\\ckpt.mdl')
# Model loading
# net.load_state_dict(torch.load('ckpt.mdl'))
边栏推荐
猜你喜欢
随机推荐
Cool in summer
[acwing] explanation of the 57th weekly competition
Privacy computing fat offline prediction
深信服X计划-系统基础总结
Summary of basic usage of command line editor sed
创建Deployment后,无法创建Pod问题处理
Summary of redis master-slave replication principle
实现WordPress上传图片自动重命名的方法
After 2 years of outsourcing, I finally landed! Record my ByteDance 3 rounds of interviews, hope to help you!
新华三的千亿企业梦,还得靠吃ICT老本来实现?
每日刷題記錄 (六)
Read a poem
同花顺能开户炒股吗?安全吗?
基于SSM实现招聘网站
云原生(三十) | Kubernetes篇之应用商店-Helm
Pre training weekly issue 51: reconstruction pre training, zero sample automatic fine tuning, one click call opt
面试官:Redis的共享对象池了解吗?
今天运气不错
MySQL locking mechanism and four isolation levels
CMOS level circuit analysis








