当前位置:网站首页>Unet代码实现
Unet代码实现
2022-06-23 06:19:00 【休斯顿凤梨】
原理:
编码器+解码器+网络
实现:
#代码参考官方
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph import Conv2D
from paddle.fluid.dygraph import BatchNorm
from paddle.fluid.dygraph import Pool2D
from paddle.fluid.dygraph import Conv2DTranspose
class Encoder(Layer):
def __init__(self,num_channels,num_filters):
super(Encoder,self).__init__()
self.conv1 = Conv2D(num_channels,
num_filters,
filter_size=3,
stride=1,
padding=1)
self.bn1 = BatchNorm(num_filters,act = 'relu')
self.conv2 = Conv2D(num_filters,
num_filters,
filter_size=3,
stride=1,
padding=1)
self.bn2 = BatchNorm(num_filters,act = 'relu')
self.pool = Pool2D(pool_size=2,pool_stride=2,pool_type='max',ceil_mode=True)
def forward(self,inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
x_pooled = self.pool(x)
return x,x_pooled
class Decoder(Layer):
def __init__(self,num_channels,num_filters):
super(Decoder,self).__init__()
self.up = Conv2DTranspose(num_channels=num_channels,
num_filters=num_filters,
filter_size=2,
stride = 2)
self.conv1 = Conv2D(num_channels,
num_filters,
filter_size=3,
stride=1,
padding=1)
self.bn1 = BatchNorm(num_filters,act = 'relu')
self.conv2 = Conv2D(num_filters,
num_filters,
filter_size=3,
stride=1,
padding=1
)
self.bn2 = BatchNorm(num_filters,act='relu')
def forward(self,inputs_prev,inputs):
x = self.up(inputs)
h_diff = (inputs_prev.shape[2]-x.shape[2])
w_diff = (inputs_prev.shape[3]-x.shape[3])
x = fluid.layers.pad2d(x,paddings=[h_diff//2,h_diff-h_diff//2,w_diff//2,w_diff-w_diff//2])
x = fluid.layers.concat([inputs_prev,x],axis=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
return x
class UNet(Layer):
def __init__(self,num_classes=59):
super(UNet,self).__init__()
self.down1 = Encoder(num_channels=3,num_filters=64)
self.down2 = Encoder(num_channels=64,num_filters=128)
self.down3 = Encoder(num_channels=128,num_filters=256)
self.down4 = Encoder(num_channels=256,num_filters=512)
self.mid_conv1 = Conv2D(512,1024,filter_size=1,padding=0,stride=1)
self.mid_bn1 = BatchNorm(1024,act = 'relu')
self.mid_conv2 = Conv2D(1024,1024,filter_size=1,padding=0,stride=1)
self.mid_bn2 = BatchNorm(1024,act='relu')
self.up4 = Decoder(1024,512)
self.up3 = Decoder(512,256)
self.up2 = Decoder(256,128)
self.up1 = Decoder(128,64)
self.last_conv = Conv2D(num_channels=64,num_filters=num_classes,filter_size=1)
def forward(self,inputs):
x1,x = self.down1(inputs)
print(x1.shape,x.shape)
x2,x = self.down2(x)
print(x2.shape,x.shape)
x3,x = self.down3(x)
print(x3.shape,x.shape)
x4,x = self.down4(x)
print(x4.shape,x.shape)
#middle layers
x = self.mid_conv1(x)
x = self.mid_bn1(x)
x = self.mid_conv2(x)
x = self.mid_bn2(x)
print(x4.shape,x.shape)
x = self.up4(x4,x)
print(x3.shape,x.shape)
x = self.up3(x3,x)
print(x2.shape,x.shape)
x = self.up2(x2,x)
print(x1.shape,x.shape)
x = self.up1(x1,x)
x = self.last_conv(x)
return x
def main():
with fluid.dygraph.guard(fluid.CPUPlace()):
model = UNet(num_classes=59)
x_data = np.random.rand(1,3,123,123).astype(np.float32)
inputs = to_variable(x_data)
pred = model(inputs)
print(pred.shape)
if __name__ == "__main__":
main()
效果

边栏推荐
- [STL] summary of deque usage of sequential containers
- What are the pension financial products in 2022? Low risk
- 小白投资理财必看:图解基金买入与卖出规则
- 数据统计与分析基础 实验一 基本语法及运算
- Configuration and compilation of mingw-w64, msys and ffmpeg
- 20220620 uniformly completely observable (UCO)
- 994. rotten oranges - non recursive method
- 407 stack and queue (232. implementing queue with stack, 225. implementing stack with queue)
- What you need to know about five insurances and one fund
- 深度学习系列47:styleGAN总结
猜你喜欢

Analyzing the creation principle in maker Education

【STL】pair用法总结

【项目实训】线形箭头的变化

【BULL中文文档】用于在 NodeJS 中处理分布式作业和消息的队列包

Xxl-sso enables SSO single sign on

MySQL MVCC多版本并发控制

MySQL optimization

994. 腐烂的橘子-非递归法

别找了诸位 【十二款超级好用的谷歌插件都在这】(确定不来看看?)

Intentional shared lock, intentional exclusive lock and deadlock of MySQL
随机推荐
A small method of debugging equipment serial port information with ADB
How to migrate virtual machines from VirtualBox to hype-v
[shell] tree command
Verilog syntax explanation
深度学习系列46:人脸图像超分GFP-GAN
Concepts and differences of DQL, DML, DDL and DCL
Centos7 MySQL records
Lombok的使用
Idea automatically generates serialVersionUID
excel高级绘图技巧100讲(八)-Excel绘制WIFI图
407 stack and queue (232. implementing queue with stack, 225. implementing stack with queue)
MySQL basic query
产品-Axure9(英文版),原型设计 制作下拉二级菜单
Idea installing the cloudtoolkit plug-in
Badly placed()'s problem
Add IPAD control function into shairplay
994. 腐烂的橘子-非递归法
QT method of compiling projects using multithreading
Influence of steam education on domestic college students
Solve the mining virus sshd2 (redis does not set a password and clear the crontab scheduled task)