当前位置:网站首页>MNIST model training (with code)
MNIST model training (with code)
2022-06-21 14:51:00 【Clan return】
mnist model training ( The attached code )
Get the github Address 1:https://github.com/Att100/CIFAR10_Pytorch.
Get the github Address 2:https://github.com/chenyaofo/pytorch-cifar-models.
Training demonstration
Environmental Science :
python=3.6 ~ 3.8
Package Version
----------------------- ---------
torch 1.10.1
torchvision 0.11.2
tqdm 4.62.3
File directory shows :
root directory :
This data The file will be automatically created and downloaded when running the training code .model The files in are AlexNet Of py file
The model code :AlexNet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class AlexNet(nn.Module):
def __init__(self):
super(AlexNet, self).__init__()
self.Conv2d_1 = nn.Conv2d(kernel_size=3, in_channels=1, out_channels=96, padding=1)
self.bn_1 = nn.BatchNorm2d(96)
self.maxpool_1 = nn.MaxPool2d((3, 3), stride=2, padding=1)
self.Conv2d_2 = nn.Conv2d(kernel_size=5, in_channels=96, out_channels=256, padding=2)
self.bn_2 = nn.BatchNorm2d(256)
self.maxpool_2 = nn.MaxPool2d((3, 3), stride=2, padding=1)
self.Conv2d_3 = nn.Conv2d(kernel_size=3, in_channels=256, out_channels=384, padding=1)
self.Conv2d_4 = nn.Conv2d(kernel_size=3, in_channels=384, out_channels=384, padding=1)
self.Conv2d_5 = nn.Conv2d(kernel_size=3, in_channels=384, out_channels=256, padding=1)
self.bn_3 = nn.BatchNorm2d(256)
self.maxpool_3 = nn.MaxPool2d((3, 3), stride=2, padding=1)
self.fc_1 = nn.Linear(4*4*256, 2048)
self.dp_1 = nn.Dropout()
self.fc_2 = nn.Linear(2048, 1024)
self.dp_2 = nn.Dropout()
self.fc_3 = nn.Linear(1024, 10)
def forward(self, x):
x = self.Conv2d_1(x)
x = self.bn_1(x)
x = F.relu(x)
x = self.maxpool_1(x)
x = self.Conv2d_2(x)
x = self.bn_2(x)
x = F.relu(x)
x = self.maxpool_2(x)
x = F.relu(self.Conv2d_3(x))
x = F.relu(self.Conv2d_4(x))
x = F.relu(self.Conv2d_5(x))
x = self.bn_3(x)
x = F.relu(x)
x = self.maxpool_3(x)
x = x.view(-1, 4*4*256)
x = F.relu(self.fc_1(x))
x = self.dp_1(x)
x = F.relu(self.fc_2(x))
x = self.dp_2(x)
x = self.fc_3(x)
return x
Training code :train_alexnet.py
import torch
import codecs
import itertools
import torch.nn.init as init
import torch.nn as nn
from tqdm import tqdm
from model import AlexNet
import torchvision.utils as utils
import torch.optim as optim
from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_mnist = DataLoader(datasets.MNIST('data', train=True, download=True, transform=transform), batch_size=128, shuffle=True)
alexNet = AlexNet.AlexNet()
learning_rate = 0.001
momentum = 0.9
epoches = 200
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
alexNet = alexNet.to(device)
# optimizer and loss function
optimizer = optim.SGD(alexNet.parameters(), lr=learning_rate, momentum=momentum, weight_decay=1e-5)
loss_func = torch.nn.CrossEntropyLoss()
for epoch in range(epoches):
train_mnist = tqdm(train_mnist)
running_loss = 0.0
for inputs, labels in train_mnist:
inputs,labels = inputs.to(device),labels.to(device)
# ============= forward =============
outputs = alexNet(inputs)
# ============= backward ============
optimizer.zero_grad()
loss = loss_func(outputs, labels)
loss.backward()
optimizer.step()
# ============= log information =====
running_loss += loss.item()
description = 'epoch: %d , current_loss: %.4f, running_loss: %.4f' % (epoch, loss.item(), running_loss)
train_mnist.set_description(description)
train_mnist.update()
torch.save(alexNet.state_dict(), './checkpoint/checkpoint_' + str(epoch) + '.pt')
Test model generalization performance code : test_model_accuracy.py
import torch
import codecs
import itertools
import torch.nn.init as init
import torch.nn as nn
from tqdm import tqdm
from model import AlexNet
import torchvision.utils as utils
import torch.optim as optim
from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_mnist = DataLoader(datasets.MNIST('data', train=False, download=True, transform=transform), batch_size=128, shuffle=True)
alexNet = AlexNet.AlexNet()
alexNet.load_state_dict(torch.load("./checkpoint/checkpoint_10.pt", map_location=device))
alexNet = alexNet.to(device)
alexNet = alexNet.eval()
with torch.no_grad():
train_mnist = tqdm(train_mnist)
total = 0
correct = 0
for inputs, labels in train_mnist:
inputs,labels = inputs.to(device),labels.to(device)
# ============= forward =============
outputs = alexNet(inputs)
# ============= precision ===========
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
description = 'correct: %.4f, total: %.4f , accuracy: %.4f' % (correct, total, correct/total)
train_mnist.set_description(description)
train_mnist.update()
Super massive visual model :https://github.com/52CV/CVPR-2021-Papers.
边栏推荐
- Use Matplotlib to draw the first figure
- New project template of punctual atom F103 based on firmware library
- Chapter 2 - physical layer (I)
- Cmake upgrade
- Fundamentals of C language 13: file input / output
- Niuke - real exercise-01
- What is SQL injection
- Selection (041) - what is the output of the following code?
- Use ant for running program with command line arguments
- Subshell
猜你喜欢

Summary of web development technology knowledge

Win10 install tensorflow
![[how to install MySQL 8.0 to a non system disk] [how to create a new connection with Navicat and save it to a non system disk] and [uninstall MySQL 8.0]](/img/e4/895cc829e34692a069923e226deb11.jpg)
[how to install MySQL 8.0 to a non system disk] [how to create a new connection with Navicat and save it to a non system disk] and [uninstall MySQL 8.0]

Select everything between matching brackets in vs Code - select everything between matching brackets in vs Code
![The whole process of Netease cloud music API installation and deployment [details of local running projects and remote deployment]](/img/3b/678fdf93cf6cc39caaec8e753af169.jpg)
The whole process of Netease cloud music API installation and deployment [details of local running projects and remote deployment]

Learn upward management and four questioning skills to get twice the result with half the effort

USB message capture tcpdump

Analysis of ROC and AUC

T32 add toolbar button

Vscade, open a folder or workspace... (file - > open folder) solution
随机推荐
T32 add toolbar button
For the first time in China, Tsinghua and other teams won the wsdm2022 only best paper award, and Hong Kong Chinese won the "time test Award"
What fun things can a desk service do
SSH based command operation
网上开户安全吗?新手可以开账户吗
kernel GDB
階乘求和
Somme factorielle
Color segmentation based on RGB difference method
How is the redemption time of financial products stipulated?
阶乘求和
QT - basic knowledge
DP question brushing record
Use ant for running program with command line arguments
UBI error: ubi_ read_ volume_ table: the layout volume was not found
ES6 test questions
Reverse generate the corresponding DTD constraint according to the XML file
Summary of the most basic methods of numpy
Clickhouse cluster installation has too many dry goods
Indexes, constraints and views in Oracle Database