当前位置:网站首页>(一)keras手写数字体识别并识别自己写的数字
(一)keras手写数字体识别并识别自己写的数字
2022-06-26 15:30:00 【X1996_】
训练数据用的是mnist,这是一个官网实例,我把它跑了一遍后把模型参数保存了下来,然后用参数来识别自己的图片数字。
程序主要三部分:
- 跑官网实例保存模型参数
- 制作自己的图片,转换为需要的格式
- 加载模型参数,测试自己的照片
一、模型训练及参数保存
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
batch_size = 128
num_classes = 10
epochs = 12
# input image dimensions
# 输入图像维度
img_rows, img_cols = 28, 28
# the data, shuffled and split between train and test sets
# 用于训练和测试的数据集,经过了筛选(清洗、数据样本顺序打乱)和分割(分割为训练和测试集)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 转换为输入需要的格式
if K.image_data_format() == 'channels_first': # Theano框架,图像通道在前
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else: # TensorFlow框架,图像通道在后
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
# 浮点数
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# 训练数量和测试数量
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
# 类别向量转为2分类矩阵
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
# 搭建网络
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))
# 损失函数,优化器
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
# 开始训练
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test))
# 测试
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
# 保存模型参数
model.save('model.h5')
训练得到的参数保存在程序所在目录里的 model.h5 文件里
二、制作自己的图片
尺寸是2828大小,背景黑色,字体白色,可以在电脑里的画图里制作,先用黑色笔写数字,然后反色,把每个数字截屏下来,然后处理成2828,黑白照片。
我用Opencv做的图片处理:
我做了0-9命名的10张图片,和程序在同一个目录下,当然也可以自己用其他方式处理
import cv2 as cv
for i in range(10):
img = cv.imread("%d"%(i)+".png")
img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
res=cv.resize(img,(28,28),interpolation=cv.INTER_CUBIC)
cv.imwrite("%d"%(i)+".png", res)
cv.waitKey(0)

三、识别自己的照片
现载入模型参数,然后预测自己的图片并输出预测值
from __future__ import print_function
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
# 导入模型
model = tf.keras.models.load_model('./model.h5')
for i in range(10):
test_my_img = plt.imread('%d'%(i)+'.png')
img = test_my_img
test_my_img = test_my_img.reshape(1,28,28,1)
p = model.predict(test_my_img)
print(np.argmax(p[0]))
plt.imshow(img)
plt.show()
预测值和输入图片如下:









把9预测成7了,人眼很容易就认出来了,但程序分辨有点困难
识别一张大图片上有很多数字:https://blog.csdn.net/X1996_/article/details/108889366
边栏推荐
- 2022 Beijing Shijingshan District specializes in the application process for special new small and medium-sized enterprises, with a subsidy of 100000-200000 yuan
- Have you ever had a Kindle with a keyboard?
- Development, deployment and online process of NFT project (1)
- 【C语言练习——打印空心上三角及其变形】
- 我想知道如何通过线上股票开户?在线开户安全么?
- 评价——TOPSIS
- Evaluation - TOPSIS
- Why are encoder and decoder structures often used in image segmentation tasks?
- Svg animation around the earth JS special effects
- Panoramic analysis of upstream, middle and downstream industrial chain of "dry goods" NFT
猜你喜欢

Solana capacity expansion mechanism analysis (2): an extreme attempt to sacrifice availability for efficiency | catchervc research

【问题解决】新版webots纹理等资源文件加载/下载时间过长

IntelliJ idea -- Method for formatting SQL files

JVM笔记

sqlite加载csv文件,并做数据分析

Application of ansible automation

还存在过有键盘的kindle?

Binding method of multiple sub control signal slots under QT

How to handle 2gcsv files that cannot be opened? Use byzer

「干货」NFT 上中下游产业链全景分析
随机推荐
Svg canvas canvas drag
js创意图标导航菜单切换背景色
SVG大写字母A动画js特效
STEPN 新手入門及進階
A blog to thoroughly master the theory and practice of particle filter (PF) (matlab version)
人人都当科学家之免Gas体验mint爱死机
OpenSea上如何创建自己的NFT(Polygon)
[tcapulusdb knowledge base] Introduction to tcapulusdb data structure
Use of abortcontroller
Have you ever had a Kindle with a keyboard?
golang 临时对象池优化
/etc/profile、/etc/bashrc、~/.bashrc的区别
[applet practice series] Introduction to the registration life cycle of the applet framework page
Unable to download Plug-in after idea local agent
Svg savage animation code
JVM notes
Evaluate:huggingface detailed introduction to the evaluation index module
[tcapulusdb knowledge base] tcapulusdb doc acceptance - transaction execution introduction
NFT Platform Security Guide (1)
「幹貨」NFT 上中下遊產業鏈全景分析