当前位置:网站首页>实现mnist手写数字识别
实现mnist手写数字识别
2022-06-24 19:50:00 【ㄣ知冷煖*】
前言
实现mnist手写数字识别一、代码实现
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import matplotlib as plt
from tensorflow.keras import models
from tensorflow.keras import layers
(train_images,train_labels), (test_images, test_labels) = mnist.load_data()
# train_images.shape: (60000,28,28) 6万张图像,每一张图像都是28*28的像素图片。
# 构建神经网络
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
# 几分类就写几,这里是10分类。
network.add(layers.Dense(10, activation='softmax'))
# compile(编译):损失函数、优化器、在训练和测试过程中需要监控的指标
# metrics:指标列表,对于分类问题,我们一般将该列表设置为metrics=['accuracy'],均方误差回归损失用mse
# 多分类损失用'categorical_crossentropy',二分类损失用'binary_crossentropy'
network.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 数据处理:将其变换为网络要求的形状,并且进行归一化
train_images = train_images.reshape((60000, 28*28))
train_images = train_images.astype('float32')/255
test_images = test_images.reshape((10000, 28*28))
test_images = test_images.astype('float32')/255
from tensorflow.keras.utils import to_categorical
# to_categorical:将类别向量转换为二进制(只有0和1)的矩阵类型表示。即将原有的类别向量转换为独热编码的形式。
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
# 开始训练
network.fit(train_images, train_labels, epochs=20, batch_size=128)
# 评估
test_loss, test_acc = network.evaluate(test_images, test_labels)
print(test_loss, test_acc)
二、一些注意问题
2-1、网络构建方式
可以是(通过构建器创建):
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
network.add(layers.Dense(10, activation='softmax'))
也可以是(通过add方法构建):
network = models.Sequential([
layers.Dense(512, activation='relu', input_shape=(28*28,)),
layers.Dense(10, activation='softmax'),
])
2-2、确定模型输入数据的规格
第一层需要通过参数传递告知模型数据规格,后边的层不需要,因为可以自动的根据第一层的输出进行推导。
通过input_shape参数:
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
也可以通过input_dim参数设定,和上边的含义类似:
network.add(layers.Dense(512, activation='relu', input_dim=28*28))
注意:input_shape=(2828,)代表的是输入的数据是2828维的一阶向量。input_shape的格式是元组,所以必须写为(28*28,)这种形式。
2-3、全连接层内部的张量运算
例子:
keras.layers.Dense(512, activation='relu')
注解:输入一个2D张量,返回另一个2D张量。函数如下所示
公式表示:output = relu(dot(w, input) + b)
即:输入张量和张量w(给定形状的随机张量)之间的点积运算(dot),得到的2D张量与向量b之间的加法运算,最后经过relu激活函数(即max(x,0)),relu运算和加法运算都是逐元素运算。
2-4、有关于点积的一些理解
keras.layers.Dense(512, activation='relu')
注意:两个向量之间的点积是一个标量,而且只有元素个数相同的向量之间才能做点积, 逐元素相乘然后相加。
import numpy as np
np.dot([1, 2],[3,4])
# 输出
# 11
一般:两个矩阵之间的点积,对于两个矩阵x和y,当且仅当x.shape[1] == y.shape[0] 时,你才可以对它们做点积,得到的结果是一个形状为(x.shape[0], y.shape[1])的矩阵,即x的行与y的列相乘后的和相加。
np.dot([[1, 2],[1,2]], [[3, 4],[3,4]])
# 输出
# array([[ 9, 12],
# [ 9, 12]])
参考文章:
通过Sequential快速搭建tensorflow模型.
Input_shape参数.
Keras中文文档.
优化器optimizers.
目标函数objectives.
Sequential模型方法.
总结
有些事情努力了就好啦,虽然结果惨不忍睹。。。
边栏推荐
- 信号完整性(SI)电源完整性(PI)学习笔记(二十五)差分对与差分阻抗(五)
- UE4 WebBrowser图表不能显示问题
- MySQL日志管理
- [interview question] what is a transaction? What are dirty reads, unrepeatable reads, phantom reads, and how to deal with several transaction isolation levels of MySQL
- 离散数学及其应用 2018-2019学年春夏学期期末考试 习题详解
- C程序设计专题 18-19年期末考试习题解答(下)
- In the past 5 years, from "Diandian" to the current test development, my success is worth learning from.
- Im instant messaging development application keeping alive process anti kill
- Scrollview height cannot fill full screen
- Difficult and miscellaneous problems: A Study on the phenomenon of text fuzziness caused by transform
猜你喜欢
Difficult and miscellaneous problems: A Study on the phenomenon of text fuzziness caused by transform
How to use promise Race() and promise any() ?

Virtual machine - network configuration
@mysql
5-minute NLP: summary of 3 pre training libraries for rapid realization of NER

Go crawler framework -colly actual combat (III) -- panoramic cartoon picture capture and download

The third generation of power electronics semiconductors: SiC MOSFET learning notes (V) research on driving power supply

VNC viewer remote connection raspberry pie without display

【面试题】什么是事务,什么是脏读、不可重复读、幻读,以及MySQL的几种事务隔离级别的应对方法

部门新来的00后真是卷王,工作没两年,跳槽到我们公司起薪18K都快接近我了
随机推荐
Adding, deleting, modifying and checking in low build code
信号完整性(SI)电源完整性(PI)学习笔记(二十五)差分对与差分阻抗(五)
Why do more and more physical stores use VR panorama? What are the advantages?
Canvas spiral style animation JS special effect
Decoupling pages and components using lifecycle
D does not require opapply() as a domain
Use of JMeter
Paper review: U2 net, u-net composed of u-net
Microsoft won the title of "leader" in the magic quadrant of Gartner industrial Internet of things platform again!
Hyperledger Fabric 2. X dynamic update smart contract
The new employee of the Department after 00 is really a champion. He has worked for less than two years. The starting salary of 18K is close to me when he changes to our company
Wx applet jump page
Android SQLite database
Is it so difficult to calculate the REM size of the web page according to the design draft?
Collective example
ServerSocket and socket connection
从数字化过渡到智能制造
im即时通讯开发应用保活之进程防杀
Signal integrity (SI) power integrity (PI) learning notes (I) introduction to signal integrity analysis
Use coordinatorlayout+appbarlayout+collapsingtoolbarlayout to create a collapsed status bar