当前位置:网站首页>6 custom layer
6 custom layer
2022-06-26 15:58:00 【X1996_】
The custom layer name should not be the same as the self-contained layer name
from sklearn import datasets
import tensorflow as tf
import numpy as np
iris = datasets.load_iris()
data = iris.data
labels = iris.target
# Define a full connectivity layer
class MyDense(tf.keras.layers.Layer):
def __init__(self, units=32, **kwargs):
self.units = units
super(MyDense, self).__init__(**kwargs)
# build Methods are generally defined as Layer Parameters that need to be trained
# trainable=True Get involved in training False Don't take part in training
# name Need to name , Otherwise, an error will occur when saving the model
def build(self, input_shape):
self.w = self.add_weight(shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=True,
name='w')
self.b = self.add_weight(shape=(self.units,),
initializer='random_normal',
trainable=True,
name='b')
super(MyDense,self).build(input_shape) # It's equivalent to setting self.built = True
#call Methods generally define forward propagation operation logic ,__call__ Method calls it .
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
# If you want a custom Layer adopt Functional API When combined into a model, you can serialize , You need to customize it get_config Method .
# Model cannot be saved without definition
def get_config(self):
config = super(MyDense, self).get_config()
config.update({
'units': self.units})
return config
# Functional programming
inputs = tf.keras.Input(shape=(4,))
x = MyDense(units=16)(inputs) # The number of neurons is set to 16
x = tf.nn.tanh(x) # The full connection layer is followed by an activation function
x = tf.keras.layers.Dense(8)(x)
x = tf.nn.relu(x)
x = MyDense(units=3)(x) # Three categories
predictions = tf.nn.softmax(x)
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# Upset
data = np.concatenate((data,labels.reshape(150,1)),axis=-1)
np.random.shuffle(data)
labels = data[:,-1]
data = data[:,:4]
# Optimizer Adam
# Loss function Cross entropy loss function
# Evaluation function #acc
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
# Training
model.fit(data, labels, batch_size=32, epochs=100,shuffle=True)
Show network structure
model.summary()
Save the model
model.save('keras_model_tf_version.h5')
Load model predictions
# Add the custom layer name to the dictionary before loading the model
# Need to put MyDense The network can be defined only when it is written
_custom_objects = {
"MyDense" : MyDense
}
new_model = tf.keras.models.load_model("keras_model_tf_version.h5",custom_objects=_custom_objects)
y_pred = new_model.predict(data)
np.argmax(y_pred,axis=1)
边栏推荐
- Why are encoder and decoder structures often used in image segmentation tasks?
- Particle filter PF - 3D CV target tracking with uniform motion (particle filter vs extended Kalman filter)
- 反射修改final
- CNN优化trick
- CNN optimized trick
- [CEPH] Introduction to cephfs caps
- A blog to thoroughly master the theory and practice of particle filter (PF) (matlab version)
- Golang 1.18 go work usage
- Stepn novice introduction and advanced
- NFT交易原理分析(2)
猜你喜欢

svg环绕地球动画js特效

Binding method of multiple sub control signal slots under QT

NFT 项目的开发、部署、上线的流程(1)

Evaluate:huggingface detailed introduction to the evaluation index module

基于 MATLAB的自然过渡配音处理方案探究

svg野人动画代码

NFT Platform Security Guide (1)

Unable to download Plug-in after idea local agent

查词翻译类应用使用数据接口api总结

Super double efficiency! Pycharm ten tips
随机推荐
C language reading data
5 模型保存与加载
【思考】在买NFT的时候你在买什么?
【leetcode】48. Rotate image
Svg savage animation code
1 张量的简单使用
JS events
OpenSea上如何创建自己的NFT(Polygon)
svg环绕地球动画js特效
I want to know how to open an account through online stock? Is online account opening safe?
5000字解析:实战化场景下的容器安全攻防之道
[thinking] what were you buying when you bought NFT?
Particle filter PF -- Application in maneuvering target tracking (particle filter vs extended Kalman filter)
canvas三个圆点闪烁动画
[CEPH] Introduction to cephfs caps
还存在过有键盘的kindle?
查词翻译类应用使用数据接口api总结
「干货」NFT 上中下游产业链全景分析
JS handwritten bind, apply, call
Reflection modification final