当前位置:网站首页>Tensorflow one layer neural network training handwritten digit recognition
Tensorflow one layer neural network training handwritten digit recognition
2022-07-23 23:13:00 【TJMtaotao】
# coding:utf-8 # Import tensorflow Deep learning library and handwriting data set import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # Define the number of training iterations of the network 、 Learning rate 、droup max_steps = 50000 learning_rate = 0.001 dropout = 0.9 # Where the downloaded data and model are saved data_dir = "MNIST_data/" log_dir = "/tmp/tensorflow/mnist/logs/mnist_with_summaries" # download mnist Data sets mnist = input_data.read_data_sets(data_dir, one_hot=True) sess = tf.InteractiveSession() # Define the input data for the network with tf.name_scope('input'): x = tf.placeholder(tf.float32, [None, 784], name='x-input') y_ = tf.placeholder(tf.float32, [None, 10], name='y-input') with tf.name_scope('input_reshape'): image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) tf.summary.image('input', image_shaped_input, 10) # Define the initialization method of model weight and offset term def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) def bias_variable(shape): initial =tf.constant(0.1, shape=shape) return tf.Variable(initial) # Definition Variable Data summary of variables def variable_summaries(var): with tf.name_scope('summaries'): mean = tf.reduce_mean(var) tf.summary.scalar('mean',mean) with tf.name_scope('stddev'): stddev =tf.sqrt(tf.reduce_mean(tf.square(var - mean))) tf.summary.scalar('stddev',stddev) tf.summary.scalar('max', tf.reduce_max(var)) tf.summary.scalar('min', tf.reduce_min(var)) tf.summary.histogram('histogram', var) # Design MLP Multilayer neural network training data def nn_layer(input_tensor, input_dim, outpu_dim, layer_name, act=tf.nn.relu): with tf.name_scope(layer_name): with tf.name_scope('weights'): weights =weight_variable([input_dim, outpu_dim]) variable_summaries(weights) with tf.name_scope('biases'): biases =bias_variable([outpu_dim]) variable_summaries(biases) with tf.name_scope('Wx_plux_b'): preactivate =tf.matmul(input_tensor, weights)+biases tf.summary.histogram('pre_activations', preactivate) activations =act(preactivate,name='activation') tf.summary.histogram('activations',activations) return activations # Create a layer of neural network hidden1 = nn_layer(x, 784, 500, 'layer1') with tf.name_scope('dropout'): keep_prob = tf.placeholder(tf.float32) tf.summary.scalar('dropout_keep_probality', keep_prob) dropped =tf.nn.dropout(hidden1, keep_prob) # Output y = nn_layer(dropped, 500, 10, 'layer2', act=tf.identity) # Solving cross loss entropy with tf.name_scope('cross_entropy'): diff = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_) with tf.name_scope('total'): cross_entropy =tf.reduce_mean(diff) tf.summary.scalar('cross_entropy',cross_entropy) # Adam The optimizer optimizes the loss , Predict the correct number of samples and calculate the correct rate accuray with tf.name_scope('train'): train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy) with tf.name_scope('accuracy'): with tf.name_scope('correct_prediction'): correct_prediction = tf.equal(tf.arg_max(y, 1),tf.arg_max(y_, 1)) with tf.name_scope('accuracy'): accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', accuracy) # tf.summary Summary ,tf.summary.FileWrite Used for training and test data storage , Also initialize all variables merged = tf.summary.merge_all() train_writer =tf.summary.FileWriter(log_dir + '/train',sess.graph) test_writer =tf.summary.FileWriter(log_dir + '/test') tf.global_variables_initializer().run() # Collect data for training , Divided into training and test data def feed_dict(train): if train: xs, ys =mnist.train.next_batch(100) k = dropout else: xs, ys = mnist.test.images, mnist.test.labels k= 1.0 return {x: xs, y_:ys, keep_prob: k} # Create a model Saver saver = tf.train.Saver() for i in range(max_steps): if i % 10 == 0: summary,acc = sess.run([merged, accuracy],feed_dict=feed_dict(False)) test_writer.add_summary(summary, i) print('Accuracy at step %s: %s'% (i, acc)) else: if i % 100 ==99: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary, _ =sess.run([merged,train_step], feed_dict=feed_dict(True),options=run_options, run_metadata=run_metadata) train_writer.add_run_metadata(run_metadata, 'step%03d' % i) train_writer.add_summary(summary, i) saver.save(sess, log_dir+"/model.ckpt", i) print('Adding run metadata for',i) else: summary, _ = sess.run([merged,train_step],feed_dict=feed_dict(True)) train_writer.add_summary(summary, i) train_writer.close() test_writer.close() The recognition rate can reach 98% Accuracy at step 12150: 0.9809 Accuracy at step 12160: 0.9809 Accuracy at step 12170: 0.9817 Accuracy at step 12180: 0.9815 Accuracy at step 12190: 0.9815 Adding run metadata for 12199 Accuracy at step 12200: 0.9827 Accuracy at step 12210: 0.9832 Accuracy at step 12220: 0.9829 Accuracy at step 12230: 0.9829 Accuracy at step 12240: 0.9826 Accuracy at step 12250: 0.9828 Accuracy at step 12260: 0.9822
边栏推荐
- [Matplotlib drawing]
- Array -- 209. Subarray with the smallest length
- Internet协议栈 TCP/IP模型 物理层、链路层、网络层、传输层、应用层的作用
- Sword finger offer II 115. reconstruction sequence
- [in depth study of 4g/5g/6g topic -40]: urllc-11 - in depth interpretation of 3GPP urllc related protocols, specifications and technical principles -5-5g QoS principle and Architecture: slicing, PDU s
- QT set cache and compile output path
- Getting started database days3
- 汇编语言伪指令详解(附实例)
- After reading this article, thoroughly understand grpc!
- [laser principle and Application-8]: EMC design of laser circuit
猜你喜欢
![[audio and video technology] video quality evaluation MSU vqmt & Netflix vmaf](/img/1c/bc71ba1eb3723cdd80501f2b0ad5ce.png)
[audio and video technology] video quality evaluation MSU vqmt & Netflix vmaf

Preparation for raspberry pie 3B serial port login

Analysis of video capability and future development trend based on NVR Technology

Array - 11. Containers with the most water

AutoCAD advanced operation

一,数字逻辑的化简

1000个Okaleido Tiger首发上线Binance NFT,引发抢购热潮

Basic operations of AutoCAD

Remember an experience of being cheated by the Internet

Tap series article 4 | backstage based tap developer portal
随机推荐
[jailhouse article] a novel software architecture for mixed criticality systems (2020)
1000个Okaleido Tiger首发上线Binance NFT,引发抢购热潮
Mongodb - Introduction to the use of $exists and the combination of $ne, $nin, $nor, $not in query statements
Video Number strengthens the fight against vulgar content: the content that violates public order and good customs must be eliminated
Rails搭配OSS最佳实践
Smart IOT source code with configuration IOT source code industrial IOT source code: support sensor analysis services, real-time data collection and remote control
Exch:pop3 and IMAP4 operation guide
Grey correlation analysis (matlab)
1000个Okaleido Tiger首发上线Binance NFT,引发抢购热潮
Crazy God redis notes 10
None and Nan, Nan, Nan
Programming in the novel [serial 19] the moon bends in the yuan universe
Array - 59. Spiral matrix II
Learning MySQL is enough
20. Valid parentheses valid parentheses
Ways to improve the utilization of openeuler resources 01: Introduction
礪夏行動|源啟數字化:既有模式,還是開源創新?
mysqlbinlog命令介绍(远程拉取binlog日志)
Tap series article 6 | application model of tap
The Minesweeper game