当前位置:网站首页>实现mnist手写数字识别
实现mnist手写数字识别
2022-06-24 19:50:00 【ㄣ知冷煖*】
前言
实现mnist手写数字识别一、代码实现
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import matplotlib as plt
from tensorflow.keras import models
from tensorflow.keras import layers
(train_images,train_labels), (test_images, test_labels) = mnist.load_data()
# train_images.shape: (60000,28,28) 6万张图像,每一张图像都是28*28的像素图片。
# 构建神经网络
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
# 几分类就写几,这里是10分类。
network.add(layers.Dense(10, activation='softmax'))
# compile(编译):损失函数、优化器、在训练和测试过程中需要监控的指标
# metrics:指标列表,对于分类问题,我们一般将该列表设置为metrics=['accuracy'],均方误差回归损失用mse
# 多分类损失用'categorical_crossentropy',二分类损失用'binary_crossentropy'
network.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 数据处理:将其变换为网络要求的形状,并且进行归一化
train_images = train_images.reshape((60000, 28*28))
train_images = train_images.astype('float32')/255
test_images = test_images.reshape((10000, 28*28))
test_images = test_images.astype('float32')/255
from tensorflow.keras.utils import to_categorical
# to_categorical:将类别向量转换为二进制(只有0和1)的矩阵类型表示。即将原有的类别向量转换为独热编码的形式。
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
# 开始训练
network.fit(train_images, train_labels, epochs=20, batch_size=128)
# 评估
test_loss, test_acc = network.evaluate(test_images, test_labels)
print(test_loss, test_acc)
二、一些注意问题
2-1、网络构建方式
可以是(通过构建器创建):
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
network.add(layers.Dense(10, activation='softmax'))
也可以是(通过add方法构建):
network = models.Sequential([
layers.Dense(512, activation='relu', input_shape=(28*28,)),
layers.Dense(10, activation='softmax'),
])
2-2、确定模型输入数据的规格
第一层需要通过参数传递告知模型数据规格,后边的层不需要,因为可以自动的根据第一层的输出进行推导。
通过input_shape参数:
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
也可以通过input_dim参数设定,和上边的含义类似:
network.add(layers.Dense(512, activation='relu', input_dim=28*28))
注意:input_shape=(2828,)代表的是输入的数据是2828维的一阶向量。input_shape的格式是元组,所以必须写为(28*28,)这种形式。
2-3、全连接层内部的张量运算
例子:
keras.layers.Dense(512, activation='relu')
注解:输入一个2D张量,返回另一个2D张量。函数如下所示
公式表示:output = relu(dot(w, input) + b)
即:输入张量和张量w(给定形状的随机张量)之间的点积运算(dot),得到的2D张量与向量b之间的加法运算,最后经过relu激活函数(即max(x,0)),relu运算和加法运算都是逐元素运算。
2-4、有关于点积的一些理解
keras.layers.Dense(512, activation='relu')
注意:两个向量之间的点积是一个标量,而且只有元素个数相同的向量之间才能做点积, 逐元素相乘然后相加。
import numpy as np
np.dot([1, 2],[3,4])
# 输出
# 11
一般:两个矩阵之间的点积,对于两个矩阵x和y,当且仅当x.shape[1] == y.shape[0] 时,你才可以对它们做点积,得到的结果是一个形状为(x.shape[0], y.shape[1])的矩阵,即x的行与y的列相乘后的和相加。
np.dot([[1, 2],[1,2]], [[3, 4],[3,4]])
# 输出
# array([[ 9, 12],
# [ 9, 12]])
参考文章:
通过Sequential快速搭建tensorflow模型.
Input_shape参数.
Keras中文文档.
优化器optimizers.
目标函数objectives.
Sequential模型方法.
总结
有些事情努力了就好啦,虽然结果惨不忍睹。。。
边栏推荐
- [leaderboard] Carla leaderboard leaderboard leaderboard operation and participation in hands-on teaching
- Eye gaze estimation using webcam
- Wx applet jump page
- Eliminate duplicate dependencies
- JDBC - database connection
- ServerSocket and socket connection
- 软件测试与游戏测试文章合集录
- 离散数学及其应用 2018-2019学年春夏学期期末考试 习题详解
- Modstart: embrace new technologies and take the lead in supporting laravel 9.0
- [distributed system design profile (2)] kV raft
猜你喜欢
Go crawler framework -colly actual combat (II) -- Douban top250 crawling
Ott marketing is booming. How should businesses invest?
【排行榜】Carla leaderboard 排行榜 运行与参与手把手教学
MySQL log management
Related operations of ansible and Playbook
为什么生命科学企业都在陆续上云?
人体改造 VS 数字化身
In the past 5 years, from "Diandian" to the current test development, my success is worth learning from.
Signal integrity (SI) power integrity (PI) learning notes (XXV) differential pair and differential impedance (V)
Meta&伯克利基于池化自注意力机制提出通用多尺度视觉Transformer,在ImageNet分类准确率达88.8%!开源...
随机推荐
Global and Chinese tetrahydrofurfuryl butyrate industry operation pattern and future prospect report 2022 ~ 2028
Eye gaze estimation using webcam
中低速航空航天电子总线概述
Use of JMeter
The third generation of power electronics semiconductors: SiC MOSFET learning notes (V) research on driving power supply
JDBC - database connection
Alternative to log4j
Paper review: U2 net, u-net composed of u-net
What is test development? Can you find a job at this stage?
从数字化过渡到智能制造
Only positive integers can be entered in the text box
JMeter socket connection sends data
Design and practice of vivo server monitoring architecture
Meta & Berkeley proposed a universal multi-scale visual transformer based on pooled self attention mechanism. The classification accuracy in Imagenet reached 88.8%! Open source
What is the difference between one way and two way ANOVA analysis, and how to use SPSS or prism for statistical analysis
Analysis report on operation pattern and supply and demand situation of global and Chinese cyano ketoprofen industry from 2022 to 2028
Canvas spiral style animation JS special effect
UE4 WebBrowser图表不能显示问题
[interview question] the difference between instancof and getclass()
Paint rounded rectangle