当前位置:网站首页>3. Keras version model training
3. Keras version model training
2022-06-26 15:58:00 【X1996_】
Sequence model
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
# Build a model
model = tf.keras.Sequential()
model.add(layers.Dense(64, activation='relu'))# first floor
model.add(layers.Dense(64, activation='relu'))# The second floor
model.add(layers.Dense(10))# The third level
# Specify the loss function optimizer
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Callback function
callbacks = [
# Stop early
tf.keras.callbacks.EarlyStopping(
# When ‘val_loss’ Stop training when you no longer descend
monitor='val_loss',
# “ No more decline ” Is defined as “ Reduce by no more than 1e-2”
min_delta=1e-2,
# “ No more improvement ” Further defined as “ At least 2 individual epoch”
patience=2,
verbose=1),
# Save weights
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# Model save path
#
# The following two parameters mean if and only if `val_loss` When the score increases , We will overwrite the current checkpoint .
save_best_only=True,
monitor='val_loss',
# Add this to just save the model weights
save_weights_only=True,
verbose=1),
# Adjust the learning rate dynamically
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# Training
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
Sequence model 2
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
# Build a model
model = tf.keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)),# first floor
layers.Dense(64, activation='relu'),# The second floor
layers.Dense(10)# The third level
])
# Specify the loss function optimizer
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Callback function
callbacks = [
# Stop early
tf.keras.callbacks.EarlyStopping(
# When ‘val_loss’ Stop training when you no longer descend
monitor='val_loss',
# “ No more decline ” Is defined as “ Reduce by no more than 1e-2”
min_delta=1e-2,
# “ No more improvement ” Further defined as “ At least 2 individual epoch”
patience=2,
verbose=1),
# Save weights
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# Model save path
#
# The following two parameters mean if and only if `val_loss` When the score increases , We will overwrite the current checkpoint .
save_best_only=True,
monitor='val_loss',
# Add this to just save the model weights
save_weights_only=True,
verbose=1),
# Adjust the learning rate dynamically
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# Training
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
Functional expression
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
inputs = tf.keras.Input(shape=(32,))
# inputs = tf.keras.Input(shape=(32,))
x = layers.Dense(64, activation='relu')(inputs) # first floor
x = layers.Dense(64, activation='relu')(x) # The second floor
predictions = layers.Dense(10)(x) # The third level
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# Specify the loss function optimizer
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Callback function
callbacks = [
# Stop early
tf.keras.callbacks.EarlyStopping(
# When ‘val_loss’ Stop training when you no longer descend
monitor='val_loss',
# “ No more decline ” Is defined as “ Reduce by no more than 1e-2”
min_delta=1e-2,
# “ No more improvement ” Further defined as “ At least 2 individual epoch”
patience=2,
verbose=1),
# Save weights
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# Model save path
#
# The following two parameters mean if and only if `val_loss` When the score increases , We will overwrite the current checkpoint .
save_best_only=True,
monitor='val_loss',
# Add this to just save the model weights
save_weights_only=True,
verbose=1),
# Adjust the learning rate dynamically
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# Training
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
Subclassing model
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# Define the layers you need
self.dense_1 = layers.Dense(32, activation='relu') #
self.dense_2 = layers.Dense(num_classes)
def call(self, inputs):
# Define forward propagation
# Use in (in `__init__`) Defined layer
x = self.dense_1(inputs)
x = self.dense_2(x)
return x
model = MyModel(num_classes=10)
# Specify the loss function optimizer
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Callback function
callbacks = [
# Stop early
tf.keras.callbacks.EarlyStopping(
# When ‘val_loss’ Stop training when you no longer descend
monitor='val_loss',
# “ No more decline ” Is defined as “ Reduce by no more than 1e-2”
min_delta=1e-2,
# “ No more improvement ” Further defined as “ At least 2 individual epoch”
patience=2,
verbose=1),
# Save weights
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# Model save path
#
# The following two parameters mean if and only if `val_loss` When the score increases , We will overwrite the current checkpoint .
save_best_only=True,
monitor='val_loss',
# Add this to just save the model weights
save_weights_only=True,
verbose=1),
# Adjust the learning rate dynamically
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# Training
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
drawing
tf.keras.utils.plot_model(model, 'multi_input_and_output_model.png', show_shapes=True,dpi=500)

model training :model.fit()
Model validation :model.evaluate()
Model to predict : model.predict()
# Evaluate the model on the test data using `evaluate`
print('\n# Evaluate on test data')
results = model.evaluate(x_test, y_test, batch_size=128)
print('test loss, test acc:', results)
# Generate predictions (probabilities -- the output of the last layer)
# on new data using `predict`
print('\n# Generate predictions for 3 samples')
predictions = model.predict(x_test[:3])
print('predictions shape:', predictions.shape)
边栏推荐
- AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy
- Keil4 opens the single-chip microcomputer project to a blank, and the problem of 100% program blocking of cpu4 is solved
- Use of abortcontroller
- 反射修改final
- js文本滚动分散动画js特效
- CNN优化trick
- Why are encoder and decoder structures often used in image segmentation tasks?
- el-dialog拖拽,边界问题完全修正,网上版本的bug修复
- NFT 平台安全指南(1)
- SVG大写字母A动画js特效
猜你喜欢
随机推荐
5000 word analysis: the way of container security attack and defense in actual combat scenarios
Nanopi duo2 connection WiFi
Interview pit summary I
HW safety response
[wechat applet] event binding, do you understand?
【leetcode】112. 路径总和 - 113. 路径总和 II
Audio and video learning (II) -- frame rate, code stream and resolution
Is it safe to buy stocks and open accounts through the QR code of the securities manager? Want to open an account for stock trading
svg上升的彩色气泡动画
11 cnn简介
el-dialog拖拽,边界问题完全修正,网上版本的bug修复
How do I open an account on my mobile phone? Is online account opening safe?
Transformation of zero knowledge QAP problem
【问题解决】新版webots纹理等资源文件加载/下载时间过长
AbortController的使用
Summary of students' learning career (2022)
H5 close the current page, including wechat browser (with source code)
Super double efficiency! Pycharm ten tips
Auto Sharding Policy will apply Data Sharding policy as it failed to apply file Sharding Policy
JS events






