当前位置:网站首页>RESNET practice in tensorflow
RESNET practice in tensorflow
2022-06-26 05:04:00 【Rain and dew touch the real king】
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential
class BasicBlock(layers.Layer):
def __init__(self, filter_num, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same')
self.bn1 = layers.BatchNormalization()
self.relu = layers.Activation('relu')
self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')
self.bn2 = layers.BatchNormalization()
if stride != 1:
self.downsample = Sequential()
self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=stride))
else:
self.downsample = lambda x:x
def call(self, inputs, training=None):
# [b, h, w, c]
out = self.conv1(inputs)
out = self.bn1(out,training=training)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out,training=training)
identity = self.downsample(inputs)
output = layers.add([out, identity])
output = tf.nn.relu(output)
return output
class ResNet(keras.Model):
def __init__(self, layer_dims, num_classes=100): # [2, 2, 2, 2]
super(ResNet, self).__init__()
self.stem = Sequential([layers.Conv2D(64, (3, 3), strides=(1, 1)),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')
])
self.layer1 = self.build_resblock(64, layer_dims[0])
self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)
# output: [b, 512, h, w],
self.avgpool = layers.GlobalAveragePooling2D()
self.fc = layers.Dense(num_classes)
def call(self, inputs, training=None):
x = self.stem(inputs,training=training)
x = self.layer1(x,training=training)
x = self.layer2(x,training=training)
x = self.layer3(x,training=training)
x = self.layer4(x,training=training)
# [b, c]
x = self.avgpool(x)
# [b, 100]
x = self.fc(x)
return x
def build_resblock(self, filter_num, blocks, stride=1):
res_blocks = Sequential()
# may down sample
res_blocks.add(BasicBlock(filter_num, stride))
for _ in range(1, blocks):
res_blocks.add(BasicBlock(filter_num, stride=1))
return res_blocks
def resnet18():
return ResNet([2, 2, 2, 2])
def resnet34():
return ResNet([3, 4, 6, 3])
边栏推荐
猜你喜欢

Rsync common error messages (common errors on the window)

ROS 笔记(07)— 客户端 Client 和服务端 Server 的实现

2.22.2.14

How MySQL deletes all redundant duplicate data

Guanghetong and anti international bring 5g R16 powerful performance to the AI edge computing platform based on NVIDIA Jetson Xavier nx

ROS notes (07) - Implementation of client and server

【Unity3D】碰撞体组件Collider

Differences between TCP and UDP

Dbeaver installation and configuration of offline driver

Pycharm package import error without warning
随机推荐
Resample
5. <tag-栈和常规问题>补充: lt.946. 验证栈序列(同剑指 Offer 31. 栈的压入、弹出序列)
Image translation /gan:unsupervised image-to-image translation with self attention networks
Solution to back-off restarting failed container
【Unity3D】刚体组件Rigidbody
Numpy general function
LeetCode 19. Delete the penultimate node of the linked list
2022.2.11
UWB超高精度定位系统原理图
2.9 learning summary
A ZABBIX self discovery script (shell Basics)
Muke.com actual combat course
Some parameter settings and feature graph visualization of yolov5-6.0
[greedy college] recommended system engineer training plan
Multipass中文文档-使用实例命令别名
LISP programming language
2022.2.15
Sklearn Library -- linear regression model
Astype conversion data type
6.1 - 6.2 Introduction à la cryptographie à clé publique