当前位置:网站首页>Pytorch learning 02: handwritten digit recognition
Pytorch learning 02: handwritten digit recognition
2022-06-22 00:58:00 【HMTT】
describe
Handwritten numeral recognition is a classic case in deep learning ,pytorch It also provides a simple MNIST Data sets ( Handwritten digital datasets ) Loading method , Very suitable for beginners to try .
With the help of various tutorials , Realized handwritten numeral recognition .
Introduce
Utils
It realizes some methods that are convenient to process data and display data .
Load data
Use pytorch The built-in method loads the dataset , And do some data processing , Such as changing data format and regularization .
Creating networks
The structure of the network used this time is a three-layer full connection layer , The activation functions of the first two layers are R e l u Relu Relu, The activation function of the last layer is S o f t m a x Softmax Softmax.
Training
The optimizer used for training is A d a m Adam Adam , The learning rate is set to 0.01 0.01 0.01, The loss function is m s e ( all Fang Bad ) mse( Mean square error ) mse( all Fang Bad ). For the loss function , Actually, I wanted to use cross entropy , But I am not very good at using it because I have just learned it pytorch.
Implementation code
import torch
from matplotlib import pyplot as plt
# If it is not added, the following error will occur , And the picture will not be displayed
# Error #15: Initializing libiomp5md.dll...
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
#
# Utils
def plot_curve(data):
''' draw a curve :param data: data :return: nothing '''
fig = plt.figure()
plt.plot(range(len(data)), data, color="blue")
plt.legend(["value"], loc="upper right")
plt.xlabel("step")
plt.ylabel("value")
plt.show()
def plot_image(img, label, name):
''' Draw pictures :param img: :param label: :param name: :return: '''
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0] * 0.3081 + 0.1307,
cmap="gray",
interpolation="none")
plt.title("{}:{}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def one_hot(label, depth=10):
''' Only hot code conversion , Convert the tag into a unique hot code :param label: label :param depth: The dimension of the single hot code :return: The unique hot code opposite to the label '''
out = torch.zeros(label.size(0), depth)
idx = torch.LongTensor(label).view(-1, 1)
out.scatter_(dim=1, index=idx, value=1)
return out
# end Utils
# code
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
# 1 Load data
def load_data():
''' load MNIST data :return: Training set loader , Test set loader '''
batch_size = 512
## 1.1 load MNIST Training set
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
"mnist.data", # Data storage path
train=True, # Whether it is a training set
download=True, # If there is no , Whether to download data
transform=torchvision.transforms.Compose([ # Modify the imported data
torchvision.transforms.ToTensor(), # The original numpy Format into torch Medium Tensor Format
torchvision.transforms.Normalize( # Regularize the data , Make data in 0 Evenly distributed nearby , The original data is 0~1
(0.1307,), (0.3081,)
)
])
),
batch_size=batch_size, # Number of pictures loaded each time
shuffle=True # Whether to break up the data
)
## 1.2 Load test set
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
"mnist.data", # Data storage path
train=False, # Whether it is a training set
download=True, # If there is no , Whether to download data
transform=torchvision.transforms.Compose([ # Modify the imported data
torchvision.transforms.ToTensor(), # The original numpy Format into torch Medium Tensor Format
torchvision.transforms.Normalize( # Regularize the data , Make data in 0 Evenly distributed nearby , The original data is 0~1
(0.1307,), (0.3081,)
)
])
),
batch_size=batch_size, # Number of pictures loaded each time
shuffle=False # Whether to break up the data
)
return train_loader, test_loader
# 2 Creating networks
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# first floor
# Enter the image size as 28*28, The output dimension is 256
self.fc1 = nn.Linear(28 * 28, 256)
# The second floor
# Enter the dimension as the upper level 256, The output dimension is 64
self.fc2 = nn.Linear(256, 64)
# The third level
# Input is the output of the previous layer 64, The output dimension is 10, Because it needs to be done 10 classification
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
# x:[batch_size, 1, 28, 28]
# h1 = relu(xw1+b1)
x = F.relu(self.fc1(x))
# h2 = relu(h1w2+b2)
x = F.relu(self.fc2(x))
# h3 = h2w3+b3
x = F.softmax(self.fc3(x))
return x
# 3 Start training
if __name__ == "__main__":
# Load data
train_loader, test_loader = load_data()
# Instantiation model
net = Net()
# Set up the optimizer
# net.parameters Returns the [w1, b1, w2, b2, w3, b3]
optimizer = optim.Adam(net.parameters(),
lr=0.01
)
# Record the training process loss
train_loss = []
# Start training
for epoch in range(3):
for batch_idx, (x, y) in enumerate(train_loader):
# x: [batch_size, 1, 28, 28]
# y: [512]
# Leveling data
# [b, 1, 28, 28] => [b, 28*28]
x = x.view(x.size(0), 28 * 28)
# out: [batch_size, 10]
out = net(x)
# take y It becomes an independent heat vector
y_onehot = one_hot(y)
# loss = mse(out, y_onehot)
loss = F.mse_loss(out, y_onehot)
# Clear the last calculated gradient
optimizer.zero_grad()
# Calculate the gradient
loss.backward()
# Update gradient ,w' = w - lr * grad
optimizer.step()
# Record loss
train_loss.append(loss.item())
if batch_idx % 10 == 0:
print(epoch, batch_idx, loss.item())
# We got a better [w1, b1. w2. b2. w3. b3]
# Output loss Change curve
plot_curve(train_loss)
# View the effect of the model on the test set and calculate the accuracy
total_correct = 0
for x, y in test_loader:
x = x.view(x.size(0), 28 * 28)
out = net(x)
# Get the correct forecast quantity
pred = out.argmax(dim=1)
correct = pred.eq(y).sum().item()
total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct / total_num
# Print the accuracy of the model on the test set
print("test acc:", acc)
# View forecast results
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28 * 28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")
Result display :
loss Change chart :

Test example diagram :

Part of the training process and test set accuracy :

边栏推荐
猜你喜欢
随机推荐
Lecture 3 of Data Engineering Series: characteristic engineering of data centric AI
【DailyFresh】课程记录2
旋转框目标检测————关于旋转框定义和解决方案
NS32F103VBT6软硬件替代STM32F103VBT6
Document. How to use and listen for readyState
【Redis】ubuntu中安装redis以及redis的基本使用和配置
【应试技巧】格林公式记忆方法及简单推导
0x00007ffff3d3ecd0 in _IO_vfprintf_internal (s=0x7ffff40b5620 <_IO_2_1_stdout_>
Div set scrolling and monitor scrolling distance
pytorch学习01:梯度下降实现简单线性回归
力扣每日一题-第24天-485.最大连续1的个数
pytorch学习05:索引和切片
pytorch学习08:拼接与拆分
关于相机位姿的可视化
How to judge pure IP? Where can I find it? Is it Expensive?
ERROR 2002 (HY000): Can't connect to local MySQL server through socket '/tmp/mysql.sock' (2)
Eslint: error
记录一次小jsp的bug
积分体系运营汇中,用户的哪些行为可以获得积分
mysql整理






![四数之和[数组排序+双指针]](/img/9e/e1932120a9b69847898e89c45e2eb4.png)

