当前位置:网站首页>实现mnist手写数字识别
实现mnist手写数字识别
2022-06-24 19:50:00 【ㄣ知冷煖*】
前言
实现mnist手写数字识别一、代码实现
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import matplotlib as plt
from tensorflow.keras import models
from tensorflow.keras import layers
(train_images,train_labels), (test_images, test_labels) = mnist.load_data()
# train_images.shape: (60000,28,28) 6万张图像,每一张图像都是28*28的像素图片。
# 构建神经网络
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
# 几分类就写几,这里是10分类。
network.add(layers.Dense(10, activation='softmax'))
# compile(编译):损失函数、优化器、在训练和测试过程中需要监控的指标
# metrics:指标列表,对于分类问题,我们一般将该列表设置为metrics=['accuracy'],均方误差回归损失用mse
# 多分类损失用'categorical_crossentropy',二分类损失用'binary_crossentropy'
network.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 数据处理:将其变换为网络要求的形状,并且进行归一化
train_images = train_images.reshape((60000, 28*28))
train_images = train_images.astype('float32')/255
test_images = test_images.reshape((10000, 28*28))
test_images = test_images.astype('float32')/255
from tensorflow.keras.utils import to_categorical
# to_categorical:将类别向量转换为二进制(只有0和1)的矩阵类型表示。即将原有的类别向量转换为独热编码的形式。
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
# 开始训练
network.fit(train_images, train_labels, epochs=20, batch_size=128)
# 评估
test_loss, test_acc = network.evaluate(test_images, test_labels)
print(test_loss, test_acc)
二、一些注意问题
2-1、网络构建方式
可以是(通过构建器创建):
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
network.add(layers.Dense(10, activation='softmax'))
也可以是(通过add方法构建):
network = models.Sequential([
layers.Dense(512, activation='relu', input_shape=(28*28,)),
layers.Dense(10, activation='softmax'),
])
2-2、确定模型输入数据的规格
第一层需要通过参数传递告知模型数据规格,后边的层不需要,因为可以自动的根据第一层的输出进行推导。
通过input_shape参数:
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
也可以通过input_dim参数设定,和上边的含义类似:
network.add(layers.Dense(512, activation='relu', input_dim=28*28))
注意:input_shape=(2828,)代表的是输入的数据是2828维的一阶向量。input_shape的格式是元组,所以必须写为(28*28,)这种形式。
2-3、全连接层内部的张量运算
例子:
keras.layers.Dense(512, activation='relu')
注解:输入一个2D张量,返回另一个2D张量。函数如下所示
公式表示:output = relu(dot(w, input) + b)
即:输入张量和张量w(给定形状的随机张量)之间的点积运算(dot),得到的2D张量与向量b之间的加法运算,最后经过relu激活函数(即max(x,0)),relu运算和加法运算都是逐元素运算。
2-4、有关于点积的一些理解
keras.layers.Dense(512, activation='relu')
注意:两个向量之间的点积是一个标量,而且只有元素个数相同的向量之间才能做点积, 逐元素相乘然后相加。
import numpy as np
np.dot([1, 2],[3,4])
# 输出
# 11
一般:两个矩阵之间的点积,对于两个矩阵x和y,当且仅当x.shape[1] == y.shape[0] 时,你才可以对它们做点积,得到的结果是一个形状为(x.shape[0], y.shape[1])的矩阵,即x的行与y的列相乘后的和相加。
np.dot([[1, 2],[1,2]], [[3, 4],[3,4]])
# 输出
# array([[ 9, 12],
# [ 9, 12]])
参考文章:
通过Sequential快速搭建tensorflow模型.
Input_shape参数.
Keras中文文档.
优化器optimizers.
目标函数objectives.
Sequential模型方法.
总结
有些事情努力了就好啦,虽然结果惨不忍睹。。。
边栏推荐
- Tongji and Ali won the CVPR best student thesis, lifeifei won the Huang xutao award, and nearly 6000 people attended the offline conference
- Hibernate learning 3 - custom SQL
- C WinForm maximizes occlusion of the taskbar and full screen display
- Current situation analysis and development trend prediction report of hesperidase industry in the world and China from 2022 to 2028
- How does VR panorama make money? Based on the objective analysis of the market from two aspects
- Power application of 5g DTU wireless communication module
- iNFTnews | 国内NFT发展仅限于数字藏品吗?
- 为什么生命科学企业都在陆续上云?
- Virtual machine - network configuration
- VIM use command
猜你喜欢

Binder mechanism and Aidl communication example

MySQL log management

有趣的checkbox计数器

Svg+js keyboard control path

What is the difference between one way and two way ANOVA analysis, and how to use SPSS or prism for statistical analysis

iNFTnews | 国内NFT发展仅限于数字藏品吗?
Outer screen and widescreen wasted? Harmonyos folding screen design specification teaches you to use it

Use and click of multitypeadapter in recycleview

微搭低代码中实现增删改查

C# Winform 最大化遮挡任务栏和全屏显示问题
随机推荐
不重要的token可以提前停止计算!英伟达提出自适应token的高效视觉Transformer网络A-ViT,提高模型的吞吐量!...
Ott marketing is booming. How should businesses invest?
JDBC —— 数据库连接
Current situation and development prospect forecast report of global and Chinese tetrahydrofurfuryl alcohol acetate industry from 2022 to 2028
VIM use command
Common redis commands in Linux system
Analysis report on development mode and investment direction of sodium lauriminodipropionate in the world and China 2022 ~ 2028
颜色渐变梯度颜色集合
Hibernate learning 3 - custom SQL
Usage of ViewModel and livedata in jetpack
Intensive reading of thinking about markdown
Analysis report on development trend and investment forecast of global and Chinese D-leucine industry from 2022 to 2028
[leaderboard] Carla leaderboard leaderboard leaderboard operation and participation in hands-on teaching
Overview of medium and low speed aerospace electronic bus
WordPress add photo album function [advanced custom fields Pro custom fields plug-in series tutorial]
C program design topic 18-19 final exam exercise solutions (Part 2)
ArcGIS loads free online historical images as the base map (no plug-ins are required)
为什么生命科学企业都在陆续上云?
在滴滴和字节跳动干了 5年软件测试,太真实…
Paint rounded rectangle