当前位置:网站首页>Classic network learning RESNET code implementation
Classic network learning RESNET code implementation
2022-07-25 02:34:00 【csdn__ Dong】
Preface
Based on the previous theoretical analysis , Today, let's talk about learning ResNet Code implementation of , If you haven't seen << Classic online learning -ResNet>> It's suggested to take a look at . Before I write this , I also investigated other online implementations , Are all behind pytorch The official source code is well implemented , So the official version explains how to achieve resNet
ResNet framework
Here is still the architecture diagram in the paper :

Each layer in the figure is actually BasicBlock perhaps BotteNeck structure . Here is given ResNet-34 The structure diagram is shown in the figure , The dashed connecting line in the figure indicates that the number of channels is different , Channel adjustment required Use zero padding or 1x1 To achieve this goal .


The code is interpreted as :
conv->bn->relu->conv->bn->shortcut->relu
BasicBlock structure

# be used for resnet18 and resnet34 Basic residual structure block
#downsample The residual structure corresponding to the dotted line
# # Downsampling is performed by conv3_1, conv4_1, and conv5_1 with a stride of 2
class BasicBlock(nn.Module):
# Channel expansion factor , The base number is 64
expansion: int = 1
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn1 = norm_layer(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(out_channels, out_channels)
self.bn2 = norm_layer(out_channels)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# The dotted line of the model architecture in the paper , Down sampling is required
if self.downsample is not None:
identity = self.downsample(x)
#shortcut Connect
out += identity
out = self.relu(out)
return out
In code conv3x3 Definition
# Definition 3x3 belt padding Convolution of
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
- What needs to be noted is :
bias = False

therefore , After convolution , If you want to pick up BN operation , It is better not to set the offset , Because it doesn't work , And occupy the memory of the graphics card .
BottleNeck structure

class Bottleneck(nn.Module):
# pytorch Realization Bottleneck Is in 3x3 Convolution (self.conv2) Set up stride = 2
# The original paper (https://arxiv.org/abs/1512.03385) To realize Bottleneck Is in 1x1 Convolution (self.conv1) Set up stride = 2
# This improves the accuracy . This variant is also known as ResNet V1.5 Reference resources https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# Channel expansion factor
expansion: int = 4
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(out_channels * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(in_channels, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, out_channels * self.expansion)
self.bn3 = norm_layer(out_channels * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
ResNet Code implementation
class ResNet(nn.Module):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.in_channels = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
f"or a 3-element tuple, got {
replace_stride_with_dilation}"
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.in_channels)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch( The last of each residual block BN Initialize with zero ),
# so that the residual branch starts with zeros, and each residual block behaves like an identity.( So each residual block starts from zero , It's like identity)
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 ( Improved accuracy 0.2~0.3)
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
# establish conv2_x,conv3_x,conv4_x,conv5_x layer
# channel:conv2/3/4/5 The number of the first convolution kernel on the main branch of the residual structure corresponding to various depths / The channel number
# The number of residual structures of a convolution
def _make_layer(
self,
# Residual block type : It can be BasicBlock perhaps Bottleneck
block: Type[Union[BasicBlock, Bottleneck]],
# The residual is faster than the input channel of the first convolution
channels: int,
# Number of residual blocks
blocks: int,
stride: int = 1,
dilate: bool = False,
) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
# about resnet50/101/152 Layer structure , The first layer is the dotted line residual , Take the next sample
# about resnet18/34 Layer network will skip this judgment , Because input and output shape Agreement , No down sampling
# conv2_x The first layer of lower sampling only needs to be increased channel, Don't change the height and width (stride = 1) Because input and output shape All for 64×64
if stride != 1 or self.in_channels != channels * block.expansion:
downsample = nn.Sequential(
conv1x1(self.in_channels, channels * block.expansion, stride),
norm_layer(channels * block.expansion),
)
layers = []
layers.append(
block(
self.in_channels, channels, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
)
)
self.in_channels = channels * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.in_channels,
channels,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
)
)
#Sequential Class to implement a simple sequential connection model
return nn.Sequential(*layers)
def _forward_impl(self, x: Tensor) -> Tensor:
''' Forward propagation implementation function '''
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# Global average pooling
x = self.avgpool(x)
x = torch.flatten(x, 1)
# The last full connection layer
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
''' Positive communication '''
return self._forward_impl(x)
The above codes are from pytorch Source code , There are deletions for easy understanding , All the above have been put into github Warehouse
边栏推荐
- Routing policy interferes with routing
- How MySQL 8.0 based on TRX_ Id find the statement of the whole transaction
- Creating elements of DOM series
- Chinese son-in-law OTA Ono became the first Asian president of the University of Michigan, with an annual salary of more than 6.5 million!
- Flutter apple native Pinyin keyboard input exception on textfield | Pinyin input process callback problem
- StrError and PERROR
- Generator set work arrangement problem code
- Ten year structure and five-year Life-03 trouble as a technical team leader
- QT realizes calendar beautification
- SetTimeout parameters [easy to understand]
猜你喜欢

Apk packaging process

R language uses logistic regression, ANOVA, outlier analysis and visual classification iris iris data set

Generator set work arrangement problem code

Focus on improving women's and children's sense of gain, happiness and security! In the next ten years, Guangzhou Women's and children's undertakings will make such efforts

How to judge which star you look like?

Server performance monitoring

Mgre.hdlc.ppp.chap.nat comprehensive experiment

Details of C language compilation preprocessing and comparison of macros and functions

DLL load failed: the page file is too small to complete the operation

Componentization and modularization
随机推荐
I was forced to graduate by a big factory and recited the eight part essay in a two-month window. Fortunately, I went ashore, otherwise I wouldn't be able to repay the mortgage
@Retryable @backoff @recover retry the use of annotations
R language uses logistic regression, ANOVA, outlier analysis and visual classification iris iris data set
Unable to display spline in UE4 (unreal engine4) terrain editing tool
YuQue - a useful tool for document writing and knowledge precipitation
After working for two months in the summer vacation, I understood three routing schemes of keepalived high availability
Vs2019 configuring Qt5 development environment
Cookies and sessions
Solution to the occupation of project startup port
Vite dynamically loads static resource pictures, and fixes the 404 problem of pictures after packaging.
[leetcode] 2. Add two numbers - go language problem solving
DNA helped solve the outstanding case 30 years ago. The suspect strangled his girlfriend because he fell in love with his roommate. He was already the CEO of the technology company when he was arreste
How to judge which star you look like?
Get to know string thoroughly
"Introduction to interface testing" punch in day08: can you save all parameters to excel for test data?
Introduction to web security telent testing and defense
Let's customize the loader
Web vulnerability
How to communicate with aliens
July 8, 2022