当前位置:网站首页>7 自定义损失函数
7 自定义损失函数
2022-06-26 15:30:00 【X1996_】
自定义损失函数
这个实验需要用到mnist.npz数据集
自定义训练和用自带的fit()函数训练好像差不多
自定义训练
头文件
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
# 按需,OOM
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
载入数据集并处理
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
# 归一化
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train,depth=10)
y_test = tf.one_hot(y_test,depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
搭建网络
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
定义损失函数,一个是用类实现的,一个是用函数实现的,都能用
# #多分类的focal loss 损失函数,类的实现
# class FocalLoss(tf.keras.losses.Loss):
# def __init__(self,gamma=2.0,alpha=0.25):
# self.gamma = gamma
# self.alpha = alpha
# super(FocalLoss, self).__init__()
# def call(self,y_true,y_pred):
# y_pred = tf.nn.softmax(y_pred,axis=-1)
# epsilon = tf.keras.backend.epsilon()#1e-7
# y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
# y_true = tf.cast(y_true,tf.float32)
# loss = - y_true * tf.math.pow(1 - y_pred, self.gamma) * tf.math.log(y_pred)
# loss = tf.math.reduce_sum(loss,axis=1)
# return loss
# 函数的方式实现
def FocalLoss(gamma=2.0,alpha=0.25):
def focal_loss_fixed(y_true, y_pred):
y_pred = tf.nn.softmax(y_pred,axis=-1)
epsilon = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
y_true = tf.cast(y_true,tf.float32)
loss = - y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
loss = tf.math.reduce_sum(loss,axis=1)
return loss
return focal_loss_fixed
选择优化器损失函数。。。。。
model = MyModel()
# 自带的损失函数
# loss_object = tf.keras.losses.CategoricalCrossentropy()
# 自己定义的损失函数
loss_object = FocalLoss(gamma=2.0,alpha=0.25)
optimizer = tf.keras.optimizers.Adam()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
训练
epochs = 5
for epoch in range(epochs):
# 在下一个epoch开始时,重置评估指标
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
for images, labels in train_ds:
train_step(images, labels)
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
print(template.format(epoch + 1,
train_loss.result(),
train_accuracy.result() * 100,
test_loss.result(),
test_accuracy.result() * 100))
fit()训练
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = np.int32(y_train)
y_test = np.int32(y_test)
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train,depth=10)
y_test = tf.one_hot(y_test,depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(100).batch(32)
# 定义模型
def MyModel():
inputs = tf.keras.Input(shape=(28,28,1), name='digits')
x = tf.keras.layers.Conv2D(32, 3, activation='relu')(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
outputs = tf.keras.layers.Dense(10,activation='softmax', name='predictions')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
# #多分类的focal loss 损失函数
class FocalLoss(tf.keras.losses.Loss):
def __init__(self,gamma=2.0,alpha=0.25):
self.gamma = gamma
self.alpha = alpha
super(FocalLoss, self).__init__()
def call(self,y_true,y_pred):
y_pred = tf.nn.softmax(y_pred,axis=-1)
epsilon = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
y_true = tf.cast(y_true,tf.float32)
loss = - y_true * tf.math.pow(1 - y_pred, self.gamma) * tf.math.log(y_pred)
loss = tf.math.reduce_sum(loss,axis=1)
return loss
# def FocalLoss(gamma=2.0,alpha=0.25):
# def focal_loss_fixed(y_true, y_pred):
# y_pred = tf.nn.softmax(y_pred,axis=-1)
# epsilon = tf.keras.backend.epsilon()
# y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
# y_true = tf.cast(y_true,tf.float32)
# loss = - y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
# loss = tf.math.reduce_sum(loss,axis=1)
# return loss
# return focal_loss_fixed
# 优化器损失函数评估指标那些
# 损失函数可以用自己定义的
model = MyModel()
model.compile(optimizer = tf.keras.optimizers.Adam(0.001), #优化器
loss = FocalLoss(gamma=2.0,alpha=0.25), #损失函数
metrics = [tf.keras.metrics.CategoricalAccuracy()]
) #评估函数
# 训练
model.fit(train_ds, epochs=5,validation_data=test_ds)
边栏推荐
- # 粒子滤波 PF——三维匀速运动CV目标跟踪(粒子滤波VS扩展卡尔曼滤波)
- Summary of data interface API used in word search and translation applications
- [tcapulusdb knowledge base] tcapulusdb doc acceptance - transaction execution introduction
- 全面解析Discord安全问题
- Use of abortcontroller
- Golang 1.18 go work usage
- Unable to download Plug-in after idea local agent
- OpenSea上如何创建自己的NFT(Polygon)
- Super double efficiency! Pycharm ten tips
- TweenMax+SVG切换颜色动画场景
猜你喜欢
Restcloud ETL resolves shell script parameterization
9 Tensorboard的使用
[file] VFS four structs: file, dentry, inode and super_ What is a block? difference? Relationship-- Editing
查词翻译类应用使用数据接口api总结
Audio and video learning (III) -- SIP protocol
Keil4 opens the single-chip microcomputer project to a blank, and the problem of 100% program blocking of cpu4 is solved
Development, deployment and online process of NFT project (1)
HW安全响应
还存在过有键盘的kindle?
How to handle 2gcsv files that cannot be opened? Use byzer
随机推荐
When a project with cmake is cross compiled to a link, an error cannot be found So dynamic library file
Panoramic analysis of upstream, middle and downstream industrial chain of "dry goods" NFT
反射修改final
TweenMax+SVG切换颜色动画场景
Using restcloud ETL shell component to schedule dataX offline tasks
[tcapulusdb knowledge base] tcapulusdb doc acceptance - transaction execution introduction
2022 Beijing Shijingshan District specializes in the application process for special new small and medium-sized enterprises, with a subsidy of 100000-200000 yuan
Nanopi duo2 connection WiFi
sqlite加载csv文件,并做数据分析
Mr. Du said that the website was updated with illustrations
音视频学习(二)——帧率、码流和分辨率
canvas三个圆点闪烁动画
【leetcode】331. 验证二叉树的前序序列化
【问题解决】新版webots纹理等资源文件加载/下载时间过长
Evaluation - TOPSIS
CNN optimized trick
selenium chrome 禁用js 禁用图片
音视频学习(一)——PTZ控制原理
Evaluate:huggingface detailed introduction to the evaluation index module
Binding method of multiple sub control signal slots under QT