当前位置:网站首页>8 user defined evaluation function
8 user defined evaluation function
2022-06-26 15:50:00 【X1996_】
The user-defined evaluation function is similar to the user-defined loss function , This article defines an evaluation function , Return the correct number
Custom training
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
# On demand ,OOM
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
# Data processing
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# Build a model
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
# Custom evaluation functions
# The returned number is a correct number
class CatgoricalTruePositives(tf.keras.metrics.Metric):
def __init__(self, name='categorical_true_positives', **kwargs):
super(CatgoricalTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name='tp', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.argmax(y_pred,axis=-1)
values = tf.equal(tf.cast(y_true, 'int32'), tf.cast(y_pred, 'int32'))
values = tf.cast(values, 'float32')
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, 'float32')
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
def result(self):
return self.true_positives
def reset_states(self):
self.true_positives.assign(0.)
model = MyModel()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy() # Loss function
optimizer = tf.keras.optimizers.Adam() # Optimizer
# Evaluation function
train_loss = tf.keras.metrics.Mean(name='train_loss') #loss
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') # Accuracy rate
train_tp = CatgoricalTruePositives(name="train_tp") # Return the correct number
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
test_tp = CatgoricalTruePositives(name='test_tp')
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Evaluate the result of the function
train_loss(loss)
train_accuracy(labels, predictions)
train_tp(labels, predictions)
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
test_tp(labels, predictions)
EPOCHS = 5
for epoch in range(EPOCHS):
# The next epoch At the beginning of the , Reset evaluation indicator
train_loss.reset_states()
train_accuracy.reset_states()
train_tp.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
test_tp.reset_states()
for images, labels in train_ds:
train_step(images, labels)
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
template = 'Epoch {}, Loss: {}, Accuracy: {}, TP: {},Test Loss: {}, Test Accuracy: {}, Test TP:{}'
print(template.format(epoch + 1,
train_loss.result(),
train_accuracy.result() * 100,
train_tp.result(),
test_loss.result(),
test_accuracy.result() * 100,
test_tp.result()))
fit() Training
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
# On demand ,OOM
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train,depth=10)
y_test = tf.one_hot(y_test,depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(100).batch(32)
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
# Customize
# The returned number is a correct number
#y_true
#y_pred
class CatgoricalTruePositives(tf.keras.metrics.Metric):
def __init__(self, name='categorical_true_positives', **kwargs):
super(CatgoricalTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name='tp', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.argmax(y_pred,axis=-1)
y_true = tf.argmax(y_true,axis=-1)
values = tf.equal(tf.cast(y_true, 'int32'), tf.cast(y_pred, 'int32'))
values = tf.cast(values, 'float32')
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, 'float32')
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
def result(self):
return self.true_positives
def reset_states(self):
self.true_positives.assign(0.)
model = MyModel()
model.compile(optimizer = tf.keras.optimizers.Adam(0.001), # Optimizer
loss = tf.keras.losses.CategoricalCrossentropy(), # Loss function
metrics = [tf.keras.metrics.CategoricalAccuracy(),
CatgoricalTruePositives(),
]
) # Evaluation function
model.fit(train_ds, epochs=5,validation_data=test_ds)
边栏推荐
- 在哪个平台买股票开户安全?求指导
- [C language practice - printing hollow upper triangle and its deformation]
- 5000字解析:实战化场景下的容器安全攻防之道
- Evaluate:huggingface评价指标模块入门详细介绍
- Secure JSON protocol
- Solana扩容机制分析(1):牺牲可用性换取高效率的极端尝试 | CatcherVC Research
- Notes on brushing questions (19) -- binary tree: modification and construction of binary search tree
- How to handle 2gcsv files that cannot be opened? Use byzer
- JVM笔记
- 安全Json协议
猜你喜欢
[C language practice - printing hollow upper triangle and its deformation]
el-dialog拖拽,边界问题完全修正,网上版本的bug修复
svg上升的彩色气泡动画
Use of abortcontroller
AbortController的使用
5000 word analysis: the way of container security attack and defense in actual combat scenarios
[tcapulusdb knowledge base] Introduction to tcapulusdb data structure
Evaluate:huggingface评价指标模块入门详细介绍
JVM notes
IntelliJ idea -- Method for formatting SQL files
随机推荐
2Gcsv文件打不开怎么处理,使用byzer工具
NFT 平台安全指南(1)
Audio and video learning (II) -- frame rate, code stream and resolution
如何配置使用新的单线激光雷达
PCIe Capabilities List
反射修改final
【毕业季·进击的技术er】 什么是微信小程序,带你推开小程序的大门
Mr. Du said that the website was updated with illustrations
一篇博客彻底掌握:粒子滤波 particle filter (PF) 的理论及实践(matlab版)
Svg capital letter a animation JS effect
Seurat to h5ad summary
面试高频 | 你追我赶的Flink双流join
「幹貨」NFT 上中下遊產業鏈全景分析
HW safety response
selenium chrome 禁用js 禁用图片
【leetcode】112. 路径总和 - 113. 路径总和 II
【leetcode】701. Insert operation in binary search tree
/etc/profile、/etc/bashrc、~/. Bashrc differences
sqlite加载csv文件,并做数据分析
svg上升的彩色气泡动画