当前位置:网站首页>3 keras版本模型训练
3 keras版本模型训练
2022-06-26 15:30:00 【X1996_】
顺序模型
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
# 搭建模型
model = tf.keras.Sequential()
model.add(layers.Dense(64, activation='relu'))#第一层
model.add(layers.Dense(64, activation='relu'))#第二层
model.add(layers.Dense(10))#第三层
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
顺序模型 2
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
# 搭建模型
model = tf.keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)),#第一层
layers.Dense(64, activation='relu'),#第二层
layers.Dense(10)#第三层
])
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
函数式
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
inputs = tf.keras.Input(shape=(32,))
# inputs = tf.keras.Input(shape=(32,))
x = layers.Dense(64, activation='relu')(inputs) #第一层
x = layers.Dense(64, activation='relu')(x) #第二层
predictions = layers.Dense(10)(x) #第三层
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
子类化模型
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# 定义自己需要的层
self.dense_1 = layers.Dense(32, activation='relu') #
self.dense_2 = layers.Dense(num_classes)
def call(self, inputs):
#定义前向传播
# 使用在 (in `__init__`)定义的层
x = self.dense_1(inputs)
x = self.dense_2(x)
return x
model = MyModel(num_classes=10)
# 指定损失函数优化器那些
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 回调函数
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
# 当‘val_loss’不再下降时候停止训练
monitor='val_loss',
# “不再下降”被定义为“减少不超过1e-2”
min_delta=1e-2,
# “不再改善”进一步定义为“至少2个epoch”
patience=2,
verbose=1),
# 保存权重
tf.keras.callbacks.ModelCheckpoint(
filepath='mymodel_{epoch}',
# 模型保存路径
#
# 下面的两个参数意味着当且仅当`val_loss`分数提高时,我们才会覆盖当前检查点。
save_best_only=True,
monitor='val_loss',
#加入这个仅仅保存模型权重
save_weights_only=True,
verbose=1),
# 动态调整学习率
tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
verbose=1,
mode='max',
factor=0.5,
patience=3)
]
# 训练
model.fit(data, labels,
epochs=30,
batch_size=64,
callbacks=callbacks,
validation_split=0.2
)
画图
tf.keras.utils.plot_model(model, 'multi_input_and_output_model.png', show_shapes=True,dpi=500)

模型训练:model.fit()
模型验证:model.evaluate()
模型预测: model.predict()
# Evaluate the model on the test data using `evaluate`
print('\n# Evaluate on test data')
results = model.evaluate(x_test, y_test, batch_size=128)
print('test loss, test acc:', results)
# Generate predictions (probabilities -- the output of the last layer)
# on new data using `predict`
print('\n# Generate predictions for 3 samples')
predictions = model.predict(x_test[:3])
print('predictions shape:', predictions.shape)
边栏推荐
- [thinking] what were you buying when you bought NFT?
- [tcapulusdb knowledge base] tcapulusdb system user group introduction
- NFT transaction principle analysis (1)
- 评价——TOPSIS
- 【leetcode】112. Path sum - 113 Path sum II
- Evaluation - TOPSIS
- 【leetcode】701. Insert operation in binary search tree
- El dialog drag and drop, the boundary problem is completely corrected, and the bug of the online version is fixed
- Auto Sharding Policy will apply Data Sharding policy as it failed to apply file Sharding Policy
- 8 自定义评估函数
猜你喜欢

IntelliJ idea -- Method for formatting SQL files

PCIe Capabilities List

NFT 项目的开发、部署、上线的流程(1)

Audio and video learning (III) -- SIP protocol

Svg savage animation code

【C语言练习——打印空心上三角及其变形】

Ansible自动化的运用

Evaluate:huggingface detailed introduction to the evaluation index module

Vsomeip3 dual computer communication file configuration

SQLite loads CSV files and performs data analysis
随机推荐
评价——TOPSIS
[tcapulusdb knowledge base] Introduction to tcapulusdb data structure
Ansible自动化的运用
NFT Platform Security Guide (2)
JS creative icon navigation menu switch background color
HW safety response
粒子滤波 PF——在机动目标跟踪中的应用(粒子滤波VS扩展卡尔曼滤波)
一篇博客彻底掌握:粒子滤波 particle filter (PF) 的理论及实践(matlab版)
为什么图像分割任务中经常用到编码器和解码器结构?
el-dialog拖拽,边界问题完全修正,网上版本的bug修复
NFT transaction principle analysis (1)
AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy
简单科普Ethereum的Transaction Input Data
Solana扩容机制分析(1):牺牲可用性换取高效率的极端尝试 | CatcherVC Research
JS simple deepcopy (Introduction recursion)
如何配置使用新的单线激光雷达
SQLite loads CSV files and performs data analysis
[graduation season · advanced technology Er] what is a wechat applet, which will help you open the door of the applet
OpenSea上如何创建自己的NFT(Polygon)
I want to know how to open an account through online stock? Is online account opening safe?