当前位置:网站首页>6 自定义层
6 自定义层
2022-06-26 15:30:00 【X1996_】
自定义的层名不要与自带的层重名
from sklearn import datasets
import tensorflow as tf
import numpy as np
iris = datasets.load_iris()
data = iris.data
labels = iris.target
# 定义一个全连接层
class MyDense(tf.keras.layers.Layer):
def __init__(self, units=32, **kwargs):
self.units = units
super(MyDense, self).__init__(**kwargs)
# build方法一般定义Layer需要被训练的参数
# trainable=True 参与训练 False 不参与训练
# name需要命名,不然模型保存会出现错误
def build(self, input_shape):
self.w = self.add_weight(shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=True,
name='w')
self.b = self.add_weight(shape=(self.units,),
initializer='random_normal',
trainable=True,
name='b')
super(MyDense,self).build(input_shape) # 相当于设置self.built = True
#call方法一般定义正向传播运算逻辑,__call__方法调用了它。
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
#如果要让自定义的Layer通过Functional API 组合成模型时可以序列化,需要自定义get_config方法。
# 不定义不能保存模型
def get_config(self):
config = super(MyDense, self).get_config()
config.update({
'units': self.units})
return config
# 函数式编程
inputs = tf.keras.Input(shape=(4,))
x = MyDense(units=16)(inputs) # 神经元个数设置为16
x = tf.nn.tanh(x) # 全连接层后接一个激活函数
x = tf.keras.layers.Dense(8)(x)
x = tf.nn.relu(x)
x = MyDense(units=3)(x) #三分类
predictions = tf.nn.softmax(x)
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# 打乱
data = np.concatenate((data,labels.reshape(150,1)),axis=-1)
np.random.shuffle(data)
labels = data[:,-1]
data = data[:,:4]
#优化器 Adam
#损失函数 交叉熵损失函数
#评估函数 #acc
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
# 训练
model.fit(data, labels, batch_size=32, epochs=100,shuffle=True)
显示网络结构
model.summary()
保存模型
model.save('keras_model_tf_version.h5')
加载模型预测
# 加载模型之前要把自定义层名称添加到字典里
# 需要把MyDense的网络写出来才能定义
_custom_objects = {
"MyDense" : MyDense
}
new_model = tf.keras.models.load_model("keras_model_tf_version.h5",custom_objects=_custom_objects)
y_pred = new_model.predict(data)
np.argmax(y_pred,axis=1)
边栏推荐
- AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy
- Selenium chrome disable JS disable pictures
- nanoPi Duo2连接wifi
- 人人都当科学家之免Gas体验mint爱死机
- STEPN 新手入门及进阶
- How to handle 2gcsv files that cannot be opened? Use byzer
- 10 tf.data
- Svg savage animation code
- AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy
- 「幹貨」NFT 上中下遊產業鏈全景分析
猜你喜欢

Svg capital letter a animation JS effect
![[CEPH] MKDIR | mksnap process source code analysis | lock state switching example](/img/4a/0aeb69ae6527c65a67be535828b48a.jpg)
[CEPH] MKDIR | mksnap process source code analysis | lock state switching example

PCIe Capabilities List

5000 word analysis: the way of container security attack and defense in actual combat scenarios

Evaluation - TOPSIS

【C语言练习——打印空心上三角及其变形】

NFT合约基础知识讲解

10 tf.data

IntelliJ idea -- Method for formatting SQL files

NFT交易原理分析(2)
随机推荐
安全Json协议
【leetcode】331. Verifying the preorder serialization of a binary tree
NFT 平台安全指南(2)
Why are encoder and decoder structures often used in image segmentation tasks?
面试踩坑总结一
全面解析Discord安全问题
Golang 1.18 go work usage
SVG大写字母A动画js特效
Mr. Du said that the website was updated with illustrations
js文本滚动分散动画js特效
Development, deployment and online process of NFT project (1)
[tcapulusdb knowledge base] tcapulusdb system user group introduction
JVM笔记
How to handle 2gcsv files that cannot be opened? Use byzer
Evaluation - TOPSIS
[thinking] what were you buying when you bought NFT?
Keil4 opens the single-chip microcomputer project to a blank, and the problem of 100% program blocking of cpu4 is solved
Summary of data interface API used in word search and translation applications
手机上怎么开户?在线开户安全么?
NFT 项目的开发、部署、上线的流程(1)