当前位置:网站首页>MXNet对DenseNet(稠密连接网络)的实现
MXNet对DenseNet(稠密连接网络)的实现
2022-07-25 13:36:00 【寅恪光潜】
论文地址:Densely Connected Convolutional Networks
DenseNet其实跟前面的ResNet是很相似的,我们知道ResNet的梯度可以直接通过身份函数(激活函数前的输出与前面跨层的一个相加)从后面的层流到前面的层。但是通过求和相结合,可能会阻碍网络中的信息流。所以DenseNet做了改善,也就是说,每层的输入都会跨层到后面的每一层,或者说后面的每一层都会有来自前面的每一层的直接输入,然后它们不是进行相加,而是在通道维进行一个连结。我们直观的来看图就容易明白了,从图中我们可以知道对于任何层的模型,它们的连接数是可以表示为 L(L+1)/2,L是层数,比如3层,其连接数是6,4层的连接数是10;传统的就是多少层就是多少条连接数。

因为每层都跟其他层都有着非常紧密的连接,所以对于这样的一种模型,我们就称之为“稠密连结网络”或叫“密集卷积网络”。
对于论文中也特别提到了“瓶颈”设计(这个在ResNet中一样)对于DenseNet模型也是很有效的,就是在3x3卷积之间引入一个1x1的卷积,这样的模型称之为DenseNet-B。
为了进一步提高模型的紧凑性,我们可以减少过渡层的特征图数量。比如说密集块包含m个特征图,则让下面的过渡层生成 θm个输出特征图,其中0<θ≤1,θ为压缩因子。当θ=1,跨过过渡层的特征图数量保持不变,叫做DenseNet-C。
如果同时使用了瓶颈层以及一个θ<1的过渡层,叫做DenseNet-BC。对于DenseNet-BC的模型,参数量非常的少,而且性能非常好,0.8M的参数就达到了跟10.2M参数的1001层(预激活)ResNet相当的精度。
比较了多种数据集,尤其是最接近的ResNet的比较,图中可以看出DenseNets使用了更少的参数,而且实现了更低的错误率。如下图:

对于整个稠密网络的架构图,如下: 
构建稠密块
import d2lzh as d2l
from mxnet import gluon,init,nd
from mxnet.gluon import nn
#ResNet改良版的卷积块
#BN--ReLU--3x3卷积
def conv_block(num_channels):
blk=nn.Sequential()
blk.add(nn.BatchNorm(),nn.Activation('relu'),nn.Conv2D(num_channels,kernel_size=3,padding=1))
return blk
#稠密块
#多个conv_block组成,每块使用相同的通道数
#在前向计算时,将每块的输入和输出在通道维上连结(也就是当前块都会跟前面所有的块连结)
class DenseBlock(nn.Block):
def __init__(self,num_convs,num_channels,**kwargs):
super(DenseBlock,self).__init__(**kwargs)
self.net=nn.Sequential()
for _ in range(num_convs):
self.net.add(conv_block(num_channels))
def forward(self,X):
for blk in self.net:
Y=blk(X)
X=nd.concat(X,Y,dim=1)
return X
#观察下形状变化,尤其是通道数
blk=DenseBlock(4,10)#通道数为10的4个卷积块
blk.initialize()
X=nd.random.uniform(shape=(4,5,22,22))
XX=blk(X)
print(XX.shape)#4*10+5=45
#(4, 45, 22, 22)我们可以看出这个通道数是增加了,如果过多的话会让模型变得复杂,这里我们使用过渡层来处理,使用一个1x1的卷积层来减小通道数,并使用步幅为2的平均池化层让宽高减半,从而进一步降低模型复杂度。
过渡层
def transition_block(num_channels):
blk=nn.Sequential()
blk.add(nn.BatchNorm(),nn.Activation('relu'),nn.Conv2D(num_channels,kernel_size=1),
nn.AvgPool2D(pool_size=2,strides=2))
return blk
blk=transition_block(10)
blk.initialize()
print(blk(XX).shape)
#(4, 10, 11, 11)DenseNet模型构建与训练
#DenseNet模型
net=nn.Sequential()
net.add(nn.Conv2D(64,kernel_size=7,strides=2,padding=3),
nn.BatchNorm(),nn.Activation('relu'),
nn.MaxPool2D(pool_size=3,strides=2,padding=1))
#num_channels为当前通道数,后面将通过过渡层减半,growth_rate增长率为稠密块里的卷积块的通道数
num_channels,growth_rate=64,32
num_convs_in_dense_blocks=[4,4,4,4]#4个稠密块,每个稠密块里4个卷积层
for i,num_convs in enumerate(num_convs_in_dense_blocks):
net.add(DenseBlock(num_convs,growth_rate))
#上一个稠密块的输出通道数
num_channels+=num_convs*growth_rate
#在稠密块之间加入通道数减半的过渡层
if i!=len(num_convs_in_dense_blocks)-1:
num_channels //= 2
net.add(transition_block(num_channels))
#最后接全局池化层和全连接层
net.add(nn.BatchNorm(),nn.Activation('relu'),nn.GlobalAvgPool2D(),nn.Dense(10))
#训练模型,由于模型比较深,宽高224降为48来简化计算,不然报内存溢出错误
lr,num_epochs,batch_size,ctx=0.1,5,256,d2l.try_gpu()
net.initialize(force_reinit=True,ctx=ctx,init=init.Xavier())
trainer=gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':lr})
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size,resize=48)
d2l.train_ch5(net,train_iter,test_iter,batch_size,trainer,ctx,num_epochs)
'''
epoch 1, loss 0.4878, train acc 0.821, test acc 0.871, time 35.3 sec
epoch 2, loss 0.3063, train acc 0.885, test acc 0.862, time 32.2 sec
epoch 3, loss 0.2618, train acc 0.902, test acc 0.865, time 32.2 sec
epoch 4, loss 0.2367, train acc 0.911, test acc 0.909, time 31.9 sec
epoch 5, loss 0.2146, train acc 0.919, test acc 0.905, time 31.8 sec
'''边栏推荐
猜你喜欢

Sports luxury or safety luxury? Which type of Asian Dragon and Volvo S60 should we start with?

【CTR】《Towards Universal Sequence Representation Learning for Recommender Systems》 (KDD‘22)
TCP的拥塞控制

刷题-洛谷-P1059 明明的随机数

从输入网址到网页显示

G027-op-ins-rhel-04 RedHat openstack creates a customized qcow2 format image

刷题-洛谷-P1089 津津的储蓄计划

The migration of arm architecture to alsa lib and alsa utils is smooth

MLIR原理与应用技术杂谈

Design and principle of thread pool
随机推荐
Mutex lock, spin lock, read-write lock... Clarify their differences and applications
Jupyter Notebook介绍
stable_baselines快速入门
Brpc source code analysis (III) -- the mechanism of requesting other servers and writing data to sockets
刷题-洛谷-P1059 明明的随机数
0717RHCSA
The simplest solution of the whole network 1045 access denied for user [email protected] (using password:YES)
MLIR原理与应用技术杂谈
我的创作纪念日
Uncaught SyntaxError: Octal literals are not allowed in strict mode.
ThreadLocal&Fork/Join
Peripheral system calls SAP's webapi interface
0713RHCSA
刷题-洛谷-P1085 不高兴的津津
AQS of concurrent programming
2022年下半年软考信息安全工程师如何备考?
mujoco+spinningup进行强化学习训练快速入门
Congestion control of TCP
stable_ Baselines quick start
电脑里一辈子都不想删的神仙软件