当前位置:网站首页>Handwritten numeral recognition based on tensorflow
Handwritten numeral recognition based on tensorflow
2022-06-26 18:17:00 【Little fox dreams of going to fairy tale town】
import numpy as np
#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior() # solve tf.placeholder Report the wrong question
import matplotlib.pyplot as plt
import input_data # The database used is tensorflow Built in database , Can be downloaded to local
mnist = input_data.read_data_sets('data/',one_hot=True)
#network topologies Network topology
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10
#inputs and outputs Input Output
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])
#network parameters Network parameters
stddev = 0.1
weights = {
'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev=stddev)),
'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))
}
biases = {
'b1':tf.Variable(tf.random_normal([n_hidden_1])),
'b2':tf.Variable(tf.random_normal([n_hidden_2])),
'out':tf.Variable(tf.random_normal([n_classes]))
}
print("NETWORK READY")
def multilayer_perceptron(_X,_weights,_biases):
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1,_weights['w2']),_biases['b2']))
return (tf.matmul(layer_2,_weights['out'])+_biases['out'])
#prediction
pred = multilayer_perceptron(x,weights,biases)
#loss and optimizer Loss function and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
#initializer
init = tf.global_variables_initializer()
print("FUNCTIONS READY")
# iteration
training_epochs = 20
batch_size = 100
display_step = 4
#launch the graph
sess = tf.Session()
sess.run(init)
#optimize
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)
#iteration
for i in range(total_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
feeds = {
x:batch_xs,y:batch_ys}
sess.run(optm,feed_dict=feeds)
avg_cost +=sess.run(cost,feed_dict=feeds)
avg_cost = avg_cost/total_batch
#display
if (epoch+1)%display_step==0:
print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
feeds = {
x:batch_xs,y:batch_ys}
training_acc = sess.run(accr,feed_dict=feeds)
print("Train Accuracy:%.3f"%(training_acc))
feeds = {
x:mnist.test.images,y:mnist.test.labels}
test_acc = sess.run(accr,feed_dict=feeds)
print("Test Accuracy:%.3f"%(test_acc))
print("Optimization Finished")
边栏推荐
猜你喜欢
随机推荐
sql中的几种删除操作
Applet setting button sharing function
交叉编译环境出现.so链接文件找不到问题
ROS query topic specific content common instructions
pycharm如何修改多行注释快捷键
Static registration and dynamic registration of JNI
MYSQL的下载与配置 mysql远程操控
Temporarily turn off MySQL cache
RSA encryption and decryption details
ISO文件
Let torch cuda. is_ Experience of available() changing from false to true
transforms.RandomCrop()的输入只能是PIL image 不能是tensor
数字签名标准(DSS)
[unity] use C in unity to execute external files, such as Exe or bat
Runtimeerror: CUDA error: out of memory own solution (it is estimated that it is not applicable to most people in special circumstances)
行锁与隔离级别案例分析
Properties file garbled
Dos et détails de la méthode d'attaque
Comparing the size relationship between two objects turns out to be so fancy
Leetcode interview question 29 clockwise print matrix