当前位置:网站首页>Pytoch learning notes advanced_ CNN (using perception_module) implements MNIST dataset classification - (comments and results)
Pytoch learning notes advanced_ CNN (using perception_module) implements MNIST dataset classification - (comments and results)
2022-07-25 15:41:00 【whut_ L】
Catalog
Program code
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(), # take shape by (H, W, C) Of img To shape by (C, H, W) Of tensor, Normalize each value to [0,1]
transforms.Normalize((0.1307, ), (0.3081, )) # Data standardization by channel
])
train_dataset = datasets.MNIST(root = '../Pycharm/dataset/mnist/', train = True, download = True, transform = transform)
train_loader = DataLoader(train_dataset, shuffle = True, batch_size = batch_size)
test_dataset = datasets.MNIST(root = '../Pycharm/dataset/mnist/', train = False, download = True, transform = transform)
test_loader = DataLoader(test_dataset, shuffle = False, batch_size = batch_size)
class InceptionA(torch.nn.Module):
def __init__(self, in_channels):
super(InceptionA, self).__init__()
self.branch1x1 = torch.nn.Conv2d(in_channels, 16, kernel_size = 1) # 1x1 Convolution
self.branch5x5_1 = torch.nn.Conv2d(in_channels, 16, kernel_size = 1) # First 1x1 Convolution
self.branch5x5_2 = torch.nn.Conv2d(16, 24, kernel_size = 5, padding = 2) # Again 5x5 Convolution padding = 2 To ensure that the image size remains unchanged
self.branch3x3_1 = torch.nn.Conv2d(in_channels, 16, kernel_size = 1) # First 1x1 Convolution
self.branch3x3_2 = torch.nn.Conv2d(16, 24, kernel_size = 3, padding = 1) # Again 3x3 Convolution
self.branch3x3_3 = torch.nn.Conv2d(24, 24, kernel_size = 3, padding = 1) # # Again 3x3 Convolution Note the size of the input and output dimensions
self.branch_pool = torch.nn.Conv2d(in_channels, 24, kernel_size = 1) # After pooling 1x1 Convolution
def forward(self, x):
branch1x1 = self.branch1x1(x) # Module1
branch5x5 = self.branch5x5_1(x) # Module2
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3 = self.branch3x3_1(x) # Module3
branch3x3 = self.branch3x3_2(branch3x3)
branch3x3 = self.branch3x3_3(branch3x3)
branch_pool = F.avg_pool2d(x, kernel_size = 3, stride = 1, padding = 1) # The average pooling
branch_pool = self.branch_pool(branch_pool) # Module4
outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
return torch.cat(outputs, dim = 1)
class Net(torch.nn.Module):
def __init__(self): # Constructors
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size = 5) # Convolution layer 1
self.conv2 = torch.nn.Conv2d(88, 20, kernel_size = 5) # Convolution layer 2
self.incep1 = InceptionA(in_channels = 10)
self.incep2 = InceptionA(in_channels = 20)
self.mp = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(1408, 10)
def forward(self, x):
in_size = x.size(0)
x = F.relu(self.mp(self.conv1(x))) # Convolution 、 Pooling 、 Activation function
x = self.incep1(x)
x = F.relu(self.mp(self.conv2(x))) # Convolution 、 Pooling 、 Activation function
x = self.incep2(x)
x = x.view(in_size, -1) # reshape
x = self.fc(x) # Fully connected layer
return x
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # call GPU or CPU
model.to(device)
criterion = torch.nn.CrossEntropyLoss() # Calculate the cross entropy loss
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.5) # Build optimizer ,lr For learning rate ,momentum Is the impulse factor
def train(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader, 0): # Traversal function ,0 Says from the first 0 Elements start , Returns the data subscript and value
inputs, target = data # features , label
inputs, target = inputs.to(device), target.to(device)
optimizer.zero_grad() # The gradient goes to zero
# forward + backward + updata
outputs = model(inputs)
loss = criterion(outputs, target) # Calculate the loss
loss.backward() # Back propagation gradient value
optimizer.step() # Update parameters
running_loss += loss.item() # Get an element value of the element tensor , Convert tensors into floating point numbers
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
running_loss = 0.0
def test():
correct = 0
total = 0
with torch.no_grad(): # The data does not calculate the gradient
for data in test_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, dim = 1) #predicted by tensor Index of the maximum value per row
total += labels.size(0) # Total sample
correct += (predicted == labels).sum().item() # Predict the exact number of samples
print('Accuracy on test set: %d %%' % (100 * correct / total)) # Accuracy rate
def main():
for epoch in range(10):
train(epoch)
test()
main()
Inception block diagram

Execution results

边栏推荐
- MySQL - Summary of common SQL statements
- Beyond Compare 4 实现class文件对比【最新】
- MySQL - user and permission control
- No tracked branch configured for branch xxx or the branch doesn‘t exist. To make your branch trac
- Leetcode - 303 area and retrieval - array immutable (design prefix and array)
- Flink-1.13.6版本的 Flink sql以yarn session 模式运行,怎么禁用托管
- CF888G-巧妙字典树+暴力分治(异或最小生成树)
- Games101 review: linear algebra
- Week303 of leetcode
- Idea eye care settings
猜你喜欢

CVPR 2022 | in depth study of batch normalized estimation offset in network

JVM knowledge brain map sharing

GAMES101复习:线性代数

Beyond Compare 4 实现class文件对比【最新】

《图书馆管理系统——“借书还书”模块》项目研发阶段性总结

不愧是阿里内部“千亿级并发系统架构设计笔记”面面俱到,太全了
SQL cultivation manual from scratch - practical part

Node learning

CF888G-巧妙字典树+暴力分治(异或最小生成树)

LeetCode - 362 敲击计数器(设计)
随机推荐
Leetcode - 232 realize queue with stack (design double stack to realize queue)
二进制补码
Understanding the difference between wait() and sleep()
带你创建你的第一个C#程序(建议收藏)
MySQL优化总结二
Matlab randInt, matlab randInt function usage "recommended collection"
C # fine sorting knowledge points 10 generic (recommended Collection)
MySQL - user and permission control
Leetcode - 622 design cycle queue (Design)
Use cpolar to build a business website (how to buy a domain name)
How to solve cross domain problems
CVPR 2022 | 网络中批处理归一化估计偏移的深入研究
CF365-E - Mishka and Divisors,数论+dp
Pytorch学习笔记--Pytorch常用函数总结1
LeetCode - 641 设计循环双端队列(设计)*
var、let、const之间的区别
Notes on inputview and inputaccessoryview of uitextfield
2021 Shanghai sai-d-cartland number variant, DP
盒子躲避鼠标
Redis分布式锁,没它真不行