当前位置:网站首页>Deeplab V3 code structure diagram
Deeplab V3 code structure diagram
2022-06-23 07:11:00 【Houston pineapple】
The paper :
Rethinking Atrous Convolution for Semantic Image Segmentation
Address of thesis :https://arxiv.org/abs/1706.05587
deeplab v3 framework

deeplab v3 Code implementation
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph import Conv2D
from paddle.fluid.dygraph import BatchNorm
from paddle.fluid.dygraph import Dropout
from resnet_multi_grid import ResNet50, ResNet101, ResNet152
class ASPPPooling(Layer):
# TODO: ASPPPooling :adaptive_pool + Conv1×1 + BN + ReLU + interpolate
def __init__(self, num_channels, num_filters):
super(ASPPPooling,self).__init__()
self.features = fluid.dygraph.Sequential(
Conv2D(num_channels=num_channels, num_filters=num_filters, filter_size=1),
BatchNorm(num_channels=num_filters, act='relu')
)
def forward(self, inputs):
n,c,h,w = inputs.shape
x = paddle.nn.functional.adaptive_avg_pool2d(inputs, (1,1))
x = self.features(x)
x = paddle.nn.functional.interpolate(x, (h,w), align_corners=False)
return x
class ASPPConv(fluid.dygraph.Sequential):
# TODO: ASPPConv ×3 :Conv3×3 dilation + BN + ReLU
def __init__(self, num_channels, num_filters, dilation):
super(ASPPConv,self).__init__(
Conv2D(num_channels=num_channels, num_filters=num_filters, filter_size=3, padding=dilation, dilation=dilation),
BatchNorm(num_filters, act='relu')
)
class ASPPModule(Layer):
# TODO: Conv1×1、ASPPConv ×3、ASPPPooling、concat、Project
def __init__(self, num_channels, num_filters, dilation_rates):
super(ASPPModule, self).__init__()
self.features = []
# Conv1×1 + BN + ReLU
self.features.append(
fluid.dygraph.Sequential(
Conv2D(num_channels=num_channels, num_filters=num_filters, filter_size=1),
BatchNorm(num_channels=num_filters, act='relu')
)
)
# ASPPConv ×3 :Conv3×3 dilation + BN + ReLU
for r in dilation_rates:
self.features.append(
ASPPConv(num_channels, num_filters, r)
)
# ASPPPooling :adaptive_pool + Conv1×1 + BN + ReLU + interpolate
self.features.append(ASPPPooling(num_channels, num_filters))
# concat(forward Write )
# Project = Conv + BN + ReLU
self.project = fluid.dygraph.Sequential( # ASPP Upgrade module rightmost 1×1Conv C'' The place of , Input is 5C'
Conv2D(num_channels = num_filters*(2 + len(dilation_rates)), num_filters=num_filters, filter_size=1),
BatchNorm(num_filters, act='relu')
)
def forward(self, inputs):
res = []
for op in self.features:
res.append(op(inputs))
# concat
x = paddle.concat(res, axis=1)
x = self.project(x)
return x
class DeepLabHead(fluid.dygraph.Sequential):
# TODO: ASPPModule、3x3Conv、bn、1×1Conv
def __init__(self, num_channels, num_classes):
super(DeepLabHead, self).__init__(
ASPPModule(num_channels, 256, [12, 24, 36]),
Conv2D(num_channels=256, num_filters=256, filter_size=3, padding=1),
BatchNorm(256, act='relu'),
Conv2D(256, num_classes, 1)
)
class DeepLab(Layer):
# TODO:
def __init__(self, num_classes=59):
super(DeepLab, self).__init__()
# stay resnet_multi_grid.py I've done it dilation
resnet = ResNet50(pretrained=False)
self.layer0 = fluid.dygraph.Sequential(
resnet.conv,
resnet.pool2d_max
)
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3 # dilation = 2
self.layer4 = resnet.layer4 # dilation = 4
# multigrid
self.layer5 = resnet.layer5 # Res Layers4_copy1 dilation = 4,8,16
self.layer6 = resnet.layer6 # Res Layers4_copy2 dilation = 8,16,32
self.layer7 = resnet.layer7 # Res Layers4_copy3 dilation = 16,32,64
feature_dim = 2048
self.classifier = DeepLabHead(feature_dim, num_classes)
def forward(self, inputs):
n, c, h, w = inputs.shape
x = self.layer0(inputs)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
x = self.layer7(x)
x = self.classifier(x)
x = paddle.nn.functional.interpolate(x, (h, w), align_corners=False)
return x
def main():
with fluid.dygraph.guard():
x_data = np.random.rand(2, 3, 512, 512).astype(np.float32)
x = to_variable(x_data)
model = DeepLab(num_classes=59)
model.eval()
pred = model(x)
print(pred.shape)
if __name__ == '__main__':
main()
#code reference https://blog.csdn.net/qq_39804263/article/details/120954082deeplab v3 effect

边栏推荐
- 307. 区域和检索 - 数组可修改
- 318. 最大单词长度乘积
- 如何在 PHP 中进行日期格式验证检查(正则)
- 898. subarray bitwise OR operation
- paddle版本问题
- [STL] summary of stack and queue usage of container adapter
- 产品-Axure9(英文版),原型设计 制作下拉二级菜单
- Open source oauth2 framework for SSO single sign on
- Analyzing the creation principle in maker Education
- Xiaobai must see in investment and wealth management: illustrated fund buying and selling rules
猜你喜欢

【日常训练】513. 找树左下角的值

How to migrate virtual machines from VirtualBox to hype-v

Learning and using quartz scheduling framework

EndNote20使用教程分享(未完

Unet代码实现

Idea installing the cloudtoolkit plug-in

Advanced drawing skills of Excel lecture 100 (VIII) -excel drawing WiFi diagram

Badly placed()'s problem

Quartz调度框架的学习使用

Analyzing the creation principle in maker Education
随机推荐
What you need to know about five insurances and one fund
[STL] summary of deque usage of sequential containers
[project training] multi segment line expanded to parallel line
正则表达式图文超详细总结不用死记硬背(上篇)
TP6+Redis+think-queue+Supervisor实现进程常驻消息队列/job任务
【项目实训】线形箭头的变化
MySQL index
897. 递增顺序搜索树
技术文章写作指南
How to verify date format in PHP (regular)
315. calculate the number of elements on the right that are smaller than the current element
深度学习系列46:人脸图像超分GFP-GAN
898. 子数组按位或操作
Eureka
MySQL mvcc multi version concurrency control
Run typescript code directly using TS node
JS dynamically creates a href circular download file. Only 10 or a fixed number of files can be downloaded
312. poke the balloon
306. Addenda
The illustration shows three handshakes and four waves. Xiaobai can understand them