当前位置:网站首页>Tensorflow one layer neural network training handwritten digit recognition
Tensorflow one layer neural network training handwritten digit recognition
2022-07-23 23:13:00 【TJMtaotao】
# coding:utf-8 # Import tensorflow Deep learning library and handwriting data set import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # Define the number of training iterations of the network 、 Learning rate 、droup max_steps = 50000 learning_rate = 0.001 dropout = 0.9 # Where the downloaded data and model are saved data_dir = "MNIST_data/" log_dir = "/tmp/tensorflow/mnist/logs/mnist_with_summaries" # download mnist Data sets mnist = input_data.read_data_sets(data_dir, one_hot=True) sess = tf.InteractiveSession() # Define the input data for the network with tf.name_scope('input'): x = tf.placeholder(tf.float32, [None, 784], name='x-input') y_ = tf.placeholder(tf.float32, [None, 10], name='y-input') with tf.name_scope('input_reshape'): image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) tf.summary.image('input', image_shaped_input, 10) # Define the initialization method of model weight and offset term def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) def bias_variable(shape): initial =tf.constant(0.1, shape=shape) return tf.Variable(initial) # Definition Variable Data summary of variables def variable_summaries(var): with tf.name_scope('summaries'): mean = tf.reduce_mean(var) tf.summary.scalar('mean',mean) with tf.name_scope('stddev'): stddev =tf.sqrt(tf.reduce_mean(tf.square(var - mean))) tf.summary.scalar('stddev',stddev) tf.summary.scalar('max', tf.reduce_max(var)) tf.summary.scalar('min', tf.reduce_min(var)) tf.summary.histogram('histogram', var) # Design MLP Multilayer neural network training data def nn_layer(input_tensor, input_dim, outpu_dim, layer_name, act=tf.nn.relu): with tf.name_scope(layer_name): with tf.name_scope('weights'): weights =weight_variable([input_dim, outpu_dim]) variable_summaries(weights) with tf.name_scope('biases'): biases =bias_variable([outpu_dim]) variable_summaries(biases) with tf.name_scope('Wx_plux_b'): preactivate =tf.matmul(input_tensor, weights)+biases tf.summary.histogram('pre_activations', preactivate) activations =act(preactivate,name='activation') tf.summary.histogram('activations',activations) return activations # Create a layer of neural network hidden1 = nn_layer(x, 784, 500, 'layer1') with tf.name_scope('dropout'): keep_prob = tf.placeholder(tf.float32) tf.summary.scalar('dropout_keep_probality', keep_prob) dropped =tf.nn.dropout(hidden1, keep_prob) # Output y = nn_layer(dropped, 500, 10, 'layer2', act=tf.identity) # Solving cross loss entropy with tf.name_scope('cross_entropy'): diff = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_) with tf.name_scope('total'): cross_entropy =tf.reduce_mean(diff) tf.summary.scalar('cross_entropy',cross_entropy) # Adam The optimizer optimizes the loss , Predict the correct number of samples and calculate the correct rate accuray with tf.name_scope('train'): train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy) with tf.name_scope('accuracy'): with tf.name_scope('correct_prediction'): correct_prediction = tf.equal(tf.arg_max(y, 1),tf.arg_max(y_, 1)) with tf.name_scope('accuracy'): accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', accuracy) # tf.summary Summary ,tf.summary.FileWrite Used for training and test data storage , Also initialize all variables merged = tf.summary.merge_all() train_writer =tf.summary.FileWriter(log_dir + '/train',sess.graph) test_writer =tf.summary.FileWriter(log_dir + '/test') tf.global_variables_initializer().run() # Collect data for training , Divided into training and test data def feed_dict(train): if train: xs, ys =mnist.train.next_batch(100) k = dropout else: xs, ys = mnist.test.images, mnist.test.labels k= 1.0 return {x: xs, y_:ys, keep_prob: k} # Create a model Saver saver = tf.train.Saver() for i in range(max_steps): if i % 10 == 0: summary,acc = sess.run([merged, accuracy],feed_dict=feed_dict(False)) test_writer.add_summary(summary, i) print('Accuracy at step %s: %s'% (i, acc)) else: if i % 100 ==99: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary, _ =sess.run([merged,train_step], feed_dict=feed_dict(True),options=run_options, run_metadata=run_metadata) train_writer.add_run_metadata(run_metadata, 'step%03d' % i) train_writer.add_summary(summary, i) saver.save(sess, log_dir+"/model.ckpt", i) print('Adding run metadata for',i) else: summary, _ = sess.run([merged,train_step],feed_dict=feed_dict(True)) train_writer.add_summary(summary, i) train_writer.close() test_writer.close() The recognition rate can reach 98% Accuracy at step 12150: 0.9809 Accuracy at step 12160: 0.9809 Accuracy at step 12170: 0.9817 Accuracy at step 12180: 0.9815 Accuracy at step 12190: 0.9815 Adding run metadata for 12199 Accuracy at step 12200: 0.9827 Accuracy at step 12210: 0.9832 Accuracy at step 12220: 0.9829 Accuracy at step 12230: 0.9829 Accuracy at step 12240: 0.9826 Accuracy at step 12250: 0.9828 Accuracy at step 12260: 0.9822
边栏推荐
- 121. The best time to buy and sell stocks
- 1000 okaleido tiger launched binance NFT, triggering a rush to buy
- How does the easynvr platform turn off anonymous login?
- Analysis of mobile semantics and perfect forwarding
- [audio and video technology] video quality evaluation MSU vqmt & Netflix vmaf
- 【音视频技术】视频质量评价 MSU VQMT & Netflix vmaf
- Wechat applet implements a global event bus by itself
- D1-H 开发板——哪吒 开发入门
- [unity3d daily bug] unity3d solves "the type or namespace name" XXX "cannot be found (are you missing the using directive or assembly reference?)" Etc
- The Minesweeper game
猜你喜欢

Principal component analysis (matlab)

Getting started database days2

D1-H 开发板——哪吒 开发入门

Grey prediction (matlab)

Finding all paths between two points in a directed graph

SOLIDWORK learning notes: Sketch geometric relationships and editing

Redis pipeline technology / partition

一,数字逻辑的化简

Extract any page number in PDF file with itextpdf

Tap series article 9 | application development accelerator
随机推荐
USB to can device in nucleic acid extractor high performance USB interface can card
D2admin framework is basically used
关于电脑端同步到手机端数据
fl studio 20.9更新中文版宿主DAW数字音频工作站
TAP 系列文章4 | 基于 Backstage 的 TAP 开发者门户
Exch:pop3 and IMAP4 operation guide
礪夏行動|源啟數字化:既有模式,還是開源創新?
Rails搭配OSS最佳实践
砺夏行动|源启数字化:既有模式,还是开源创新?
unity visual studio2019升级到2022版本(扔掉盗版红渣)
Lixia action | Yuanqi Digitalization: existing mode or open source innovation?
Tap series article 8 | tap Learning Center - learn through hands-on tutorials
Matlab Foundation
Tap series article 6 | application model of tap
Getting started database days3
视频号加强打击低俗内容:对违背公序良俗的内容必须赶尽杀绝
Diabetes genetic risk testing challenge baseline
Sword finger offer II 115. reconstruction sequence
What if the content of software testing is too simple?
糖尿病遗传风险检测挑战赛进阶