当前位置:网站首页>Continuous training on tensorflow breakpoint (principle + code explanation)
Continuous training on tensorflow breakpoint (principle + code explanation)
2022-06-22 08:58:00 【Ghost crap】
One 、 What is breakpoint continuation training
Breakpoint continuation training means that the training is interrupted before the completion of the training for some reasons , The next training can continue on the basis of the previous training . This approach is very friendly for models that require long-term training
Two 、 Model file parsing

checkpoint The file will record the saving information , It allows you to locate the latest saved model ;
.meta The file saves the current NN Network structure :tf.train.import_meta_graph(‘MODEL_NAME.ckpt-1174.meta’)
.data The file saves the current parameter name and value , Network weight 、 bias 、 Operation etc.
.index The file holds secondary index information , Is an immutable string table
As for the number after the file name 1174 It represents different batches of model training , We usually only need the latest one ;
3、 ... and 、 How to realize breakpoint continuation training
Two prerequisites are met :
(1): The snapshot in the model trainer is saved locally ( Breakpoint data saving )
(2): You can restore the field environment trained by the model by reading the snapshot ( Breakpoint data recovery )
among , Both of these operations require tensorflow Medium train.Saver class . Address of official documents
1. establish tensorflow.train.Saver class
saver = tf.train.Saver(max_to_keep=1) # It is allowed to save the latest number of models for training , The default value is 5, If only the latest model is saved, the assignment is 1
2. Use Saver Object's save Method to save the model
saver.save(sess,os.path.join(MODEL_SVAE_PATH , MODEL_NAME ),global_step=epoch)
#sess For the session that needs to be saved ,MODEL_SVAE_PATH Save path for model ,MODEL_NAME Name the model , global_step Model training times
3. Recovery of breakpoint data
3.1 Only load the model, not the graph
saver.restore(sess, ckpt.model_checkpoint_path) # Resume the current session , take ckpt The value in is assigned to w and b
This method is generally used in the breakpoint continuation training , When the network model structure is not very complex , Rebuilding the session graph is also only at the millisecond level . Besides , The graph structure is loaded only once , Because during the whole training process , The network structure will not change .
3.2 The graph structure and parameters are loaded
saver = tf.train.import_meta_graph(ckpt+".meta") # Load graph structure , That is, the structure of neural network
saver.restore(sess, ckpt.model_checkpoint_path) # Resume the current session , take ckpt The value in is assigned to w and b
Four 、 Code details
In this part , I will attach a complete code for the first contact breakpoint continuation , People who want to use it in the program . The code is divided into four parts : Load what you need package, Load data set , Build a network model , Training models . The first three parts of the code have little relevance to the breakpoint continuation training , Students who already have some basic knowledge can directly look at the last part . Complete code training requires tensorflow>=2.0 and cifar10 Data sets .
reminder :
If you only save and load model parameters , Each retraining of the model will indeed load the last model parameters , The training effect will continue to be optimized on the basis of the previous one , But the training times of the model start again , under these circumstances , The actual model training times cannot be determined . But in some cases , We need an exact model training parameter to compare with other benchmark models . So we need to keep the last training epoch, Subtract... When starting a new round of training , Only in this way can the training times of the model be consistent .
4.1 Import required package
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior() #tf2.0 in placehold The function is obsolete , In order to retain tf1.0 Chinese grammar
import os # It is convenient to create models and save them checkpoint Folder
import pickle
import numpy as np
import os
import re # To get the number in the string , Introducing regular expressions
MODEL_SAVE_PATH = './ckpt/' #saver.save Will be automatically saved in this folder checkpoint file
MODEL_NAME = 'vgg model'
batch_size = 20
train_steps = 10000
test_steps = 100
CIFAR_DIR = "D:\Dataset\cifar-10-python\cifar-10-batches-py" # Change to yourself cifar10 Directory of datasets
print(os.listdir(CIFAR_DIR))
4.2 Load data ( This part has nothing to do with breakpoint continuation training )
def load_data(filename):
"""read data from data file."""
with open(filename, 'rb') as f:
data = pickle.load(f, encoding='bytes')
return data[b'data'], data[b'labels']
# tensorflow.Dataset.
class CifarData:
def __init__(self, filenames, need_shuffle):
all_data = []
all_labels = []
for filename in filenames:
data, labels = load_data(filename)
all_data.append(data)
all_labels.append(labels)
self._data = np.vstack(all_data)
self._data = self._data / 127.5 - 1
self._labels = np.hstack(all_labels)
print(self._data.shape)
print(self._labels.shape)
self._num_examples = self._data.shape[0]
self._need_shuffle = need_shuffle
self._indicator = 0
if self._need_shuffle:
self._shuffle_data()
def _shuffle_data(self):
# [0,1,2,3,4,5] -> [5,3,2,4,0,1]
p = np.random.permutation(self._num_examples)
self._data = self._data[p]
self._labels = self._labels[p]
def next_batch(self, batch_size):
"""return batch_size examples as a batch."""
end_indicator = self._indicator + batch_size
if end_indicator > self._num_examples:
if self._need_shuffle:
self._shuffle_data()
self._indicator = 0
end_indicator = batch_size
else:
raise Exception("have no more examples")
if end_indicator > self._num_examples:
raise Exception("batch size is larger than all examples")
batch_data = self._data[self._indicator: end_indicator]
batch_labels = self._labels[self._indicator: end_indicator]
self._indicator = end_indicator
return batch_data, batch_labels
train_filenames = [os.path.join(CIFAR_DIR, 'data_batch_%d' % i) for i in range(1, 6)]
test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')]
train_data = CifarData(train_filenames, True)
test_data = CifarData(test_filenames, False)
4.3 Build a network model ( It has nothing to do with continuing training )
x = tf.placeholder(tf.float32, [None, 3072])
y = tf.placeholder(tf.int64, [None])
# [None], eg: [0,5,6,3]
x_image = tf.reshape(x, [-1, 3, 32, 32])
# 32*32
x_image = tf.transpose(x_image, perm=[0, 2, 3, 1])
# conv1: Neuron diagram , feature_map, Output image
conv1_1 = tf.layers.conv2d(x_image,
32, # output channel number
(3,3), # kernel size
padding = 'same',
activation = tf.nn.relu,
name = 'conv1_1')
conv1_2 = tf.layers.conv2d(conv1_1,
32, # output channel number
(3,3), # kernel size
padding = 'same',
activation = tf.nn.relu,
name = 'conv1_2')
# 16 * 16
pooling1 = tf.layers.max_pooling2d(conv1_2,
(2, 2), # kernel size
(2, 2), # stride
name = 'pool1')
conv2_1 = tf.layers.conv2d(pooling1,
32, # output channel number
(3,3), # kernel size
padding = 'same',
activation = tf.nn.relu,
name = 'conv2_1')
conv2_2 = tf.layers.conv2d(conv2_1,
32, # output channel number
(3,3), # kernel size
padding = 'same',
activation = tf.nn.relu,
name = 'conv2_2')
# 8 * 8
pooling2 = tf.layers.max_pooling2d(conv2_2,
(2, 2), # kernel size
(2, 2), # stride
name = 'pool2')
conv3_1 = tf.layers.conv2d(pooling2,
32, # output channel number
(3,3), # kernel size
padding = 'same',
activation = tf.nn.relu,
name = 'conv3_1')
conv3_2 = tf.layers.conv2d(conv3_1,
32, # output channel number
(3,3), # kernel size
padding = 'same',
activation = tf.nn.relu,
name = 'conv3_2')
# 4 * 4 * 32
pooling3 = tf.layers.max_pooling2d(conv3_2,
(2, 2), # kernel size
(2, 2), # stride
name = 'pool3')
# [None, 4 * 4 * 32]
flatten = tf.layers.flatten(pooling3)
y_ = tf.layers.dense(flatten, 10)
loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=y_)
# y_ -> sofmax
# y -> one_hot
# loss = ylogy_
# indices
predict = tf.argmax(y_, 1)
# [1,0,1,1,1,0,0,0]
correct_prediction = tf.equal(predict, y)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))
with tf.name_scope('train_op'):
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)
4.4 Model training and breakpoint continuation training ( The key )
init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=1) #Saver Class provides methods for saving and restoring models
# train 10k: 73.4%
with tf.Session() as sess:
sess.run(init)
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
# Equivalent to the previous sentence :ckpt = tf.train.latest_checkpoint("./ckpt/")
print(ckpt) # The latest saved model name
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)# Resume the current session , take ckpt The value in is assigned to w and b
steped = re.findall(r"\d+\.?\d*", str(ckpt)) # Extract numbers from strings
print('the step finished last time is ' + steped[0]) # Output the last training times
steped = int(steped[0]) # Ensure that the total number of training is certain
print('Model restored...')
else:
steped = 0
print('No model')
for step in range(train_steps-steped):
batch_data, batch_labels = train_data.next_batch(batch_size)
loss_val, acc_val, _ = sess.run(
[loss, accuracy, train_op],
feed_dict={
x: batch_data,
y: batch_labels})
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=step+steped+1)
if (step+steped+1) % 100 == 0: # Calculate the output from the last training times
print('[Train] Step: %d, loss: %4.5f, acc: %4.5f'
% (step+steped+1, loss_val, acc_val))
if (step+steped+1) % 1000 == 0:
test_data = CifarData(test_filenames, False)
all_test_acc_val = []
for j in range(test_steps):
test_batch_data, test_batch_labels \
= test_data.next_batch(batch_size)
test_acc_val = sess.run(
[accuracy],
feed_dict = {
x: test_batch_data,
y: test_batch_labels
})
all_test_acc_val.append(test_acc_val)
test_acc = np.mean(all_test_acc_val)
print('[Test ] Step: %d, acc: %4.5f' % (step+steped+1, test_acc))
4.5 The training effect of breakpoint continuation training

It's not easy to code words , If there is any help, please give me a like , thank you ~
边栏推荐
猜你喜欢

11 外观模式

Why can MySQL indexes improve query efficiency so much?

My first go program

yolov5 export Gpu推理模型导出

Win11 mongodb installation tutorial

13 proxy mode

14 职责链模式

Interview shock 59: can there be multiple auto increment columns in a table?

Thread. Source code analysis of start() method

Mysql+orcle (SQL implements recursive query of all data of child nodes)
随机推荐
Do not use primitive types in new code during the use of generic types
[conda]conda切换为中科大源
Golang 开发 常用的第三方库 没有最全只有更全
Spark Yarn内存资源计算分析(参考)--Executor Cores、Nums、Memory优化配置
10 装饰模式
The third-party libraries commonly used in golang development are not the most complete, but more complete
Solidity from introduction to practice (III)
Thoroughly understand my SQL index knowledge points
Synchronized
关于EasyPoi导入Excel文件二级表头数据时@ExcelEntity实体类里的第一列数据为null的这档事
np.arange与np.linspace细微区别(数据溢出问题)
Off line identification of least square method
Flask blog practice - integrated rich text editor quill
Yolov5 export GPU inference model export
开发报错记录
12 yuan sharing mode
10.file/io stream -bite
16 解释器模式
Luogu p4292 [wc2010] reconstruction plan
Flask博客实战 - 创建后台管理应用