当前位置:网站首页>Interpretation of swin transformer source code
Interpretation of swin transformer source code
2022-06-24 16:15:00 【languageX】
2020 year 5 month ,Facebook AI Launched DERT( Detection Transformer), It is used for target detection and panoramic segmentation .
2020 year 10 month , Google put forward Vit(Vision Transformer), utilize Transformer Classify images , Instead of convolutional networks .
2021 year 1 month ,OpenAI Two models are proposed :DALL·E Based on this article, the image is generated directly ,CLIP Map the image to the category of the text description . Both models use Transformer .
2021 year 3 month , Microsoft put forward Swin Transformer, hold CV All the major tasks were given to Tu bang ....
I can let it go ? I can't ... To sum up, I read the paper and code some time ago swin_transformer Framework and implementation .
The paper : https://arxiv.org/abs/2103.14030
Code : https://github.com/microsoft/Swin-Transformer
swin_transformer Introduce
1. swin_transformer Optimization point
swin_transformer Before comparison Vit There are two improvements :
1. Introduced CNN The multi-level transformers structure
Vit The scale of is constant , It is not easy to access to downstream tasks , For example, split encoder Stage can be easily accessed resnet etc. backbone The Internet , and Vit The size of the feature map is unchanged (b).swin_transfomer By combining image_patchesd The way to introduce multi-level structure , Here's the picture (a).
2、 Reduce computational complexity and memory consumption
The grey block in the figure above is defined as patch, The red block is defined as window.swin_transfomer Through the shard window , Calculation self_attention It is aimed at these local non overlapping window. The original MSA And in the paper W-MSA The calculation complexity of is shown in the formula below , among M Is the window containing patch The number of , That is to say window_size, Its size is much smaller than h,w Of . The calculation complexity and hw It's a linear relationship . Here is the complexity calculation method , After analyzing the source code, we can have a clearer understanding of .
2. swin_transformer How to optimize
For the first optimization point , The network architecture used in this paper is as follows :
The structure is divided into 4 individual stage,stages The size of the feature map in is reduced to 1/4,1/8,1/16,1/32.
For the second optimization point , The paper points out that only FM segmentation windows, Then for each window Conduct self_attention There is a drawback , There is no communication between windows . So series connection is proposed W-MSA and SW-MSA The way .
W-MSA Is a window without overlap self_attention Calculation , and cyclic shift As shown below , Make a... To the window shift. Originally 2*2 Number of windows , It is divided into 3*3 A window . But the amount of calculation will increase 1.5*1.5 times . An alternative method proposed by the author is to carry out a roll operation , take 2*2 The window of moves left and up , The moved window contains the information of other windows in the upper layer . however ABC The area should not be an adjacent area , So we need to do a mask operation .
Finally, remember to reverse shift Move the whole window back ~
3. swin_transformer What is the result?
The result is to put CV Several big tasks have been killed ..
swin_transformer Source code analysis
Here's an in-depth look at... From a code perspective swin_transformer
First understand the main classes :BasicLayer Realization stage The process of ,SwinTransformerBlock yes BasicLayer The main logic module is also the core module of the thesis ,WindowAttention yes SwinTransformerBlock To realize attention Module .
depths:(2,2,6,4) Decide on each layer Of SwinTransformerBlock Number of executions .
The paper proposes that 4 A set of parametric models , Let's take Swin-T As an example to introduce .
Code module logic :
patch_embed + pos_embed
stage1
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage2
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage3
-BasicLayer
--SwinTransformerBlock(*6)
---WindowAttention
stage4
-BasicLayer
--SwinTransformerBlock(*4)
---WindowAttention
Code logic of main modules :
1.patch_embed:PatchEmbed
Let's start with patch_embed,patch_embed Just press the input button patch Do a vector mapping . I think it's convolution ( title swin_transfomer, The first step is convolution ~ Convolution yyds)
Set input :(3,256,256),patch_size=4,embeding_dim=96
(1) The resolution is not enough 4 Divisible on pad To 4 Multiple
(2) General convolution kernel=4,stride=4, take image Map to non overlapping 4*4 Of patchs:(96,64,64)
(3) if necessary norm, Do it again layerNorm
(4)(3,256,256) adopt patch_embed, Characterized by (96,64,64)
2.absolute_pos_embed
If there is position_embeding step , Need to learn a 96,64,64 Of pos_emded Parameters . and patch_embed Conduct concat.
take emded Matrix process flatten+transpose-->64*64, 96
3.stages
Reduce the resolution *4 The characteristic diagram of 4 individual stage Of -BasicLayer
BasicLayer
1.attn_mask
Set up window_size=7, With stage1 As an example, the size of the input characteristic graph is (64,64).img_mask For the initial (70,70), Then through the window_partition The feature map is divided into 100 individual 7*7 The window of .
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask:, h, w, : = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
The purpose of the above code is to get 100 individual 49*49 Of attn_mask.
there attn_mask For the follow-up cyclic shift, That is to say SW-MSA Use .
First , Yes img_mask70*70 The graph of 9 Block assignment
63*63=0 4*63=1 3*64=2
63*4=3 4*4=4 3*4=5
64*3=6 4*3=7 3*3=8
And then by putting window_partition Cut the window into 100 individual 7*7 window , Tile data , obtain 100*49, Each window is subtracted from other windows , obtain 100*49*49, No more 0 Value assignment of -100. These are not 0 The meaning of location can be understood as that it is not the same area as the relative location in the above figure . combination cyclic shift, Express cyclic shift In one window Inside , Features that are not adjacent sub_window The location of , So we need to mask fall .
2.SwinTransformerBlock(*n)
(1)reshape+pad
For input 64*64, 96 Conduct layer_norm+reshape+pad operation .pad The function is to FM Of H,W yes window_size Multiple . Yes stage1:64*64, 96-->70,70,96
(2)window_mask_self_attention(W-MSA/SW-MSA)
Let's look at the first stage W-MSA blcok, That is not to join cyclic shift.
(a) Conduct window_partition, The feature map is divided into window_size*window_size Of patch,1,70*70,96 It's divided into 100,7,7,96, Again reshape100,49,96
(b) WindowAttention
Calculation self_attention
step1: obtain QKV matrix .X:100,49,64-->Q,K,V:100,3,49,32
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv0, qkv1, qkv2
Specific operation : Enter full connection C The channel extends to 3C, According to multi_head take FM It's divided into head_num Share , Last slipe obtain qkv matrix .100,3,49,32 Indicates the number of windows ,attention head , The length of the window ,C/head
step2: Calculation attention.
attn = (q @ k.transpose(-2, -1))
100,3,49,32*100,3,32,49-->:100,3,49,49 .self_attention You can view the principle of transformers The paper , I won't go into details here .
step3: Calculation relative_position_bias
The paper proposes , It is better to add relative position coding . That is to say step2 Calculated attn add relative_position_bias. and attn equally , The size should be (3,49*49) Matrix .
Let's see how to calculate relative_position_bias.
#define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size0 - 1) * (2 * window_size1 - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size0) coords_w = torch.arange(self.window_size1) coords = torch.stack(torch.meshgrid(coords_h, coords_w)) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten:, :, None - coords_flatten:, None, : # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords:, :, 0 += self.window_size0 - 1 # shift to start from 0 relative_coords:, :, 1 += self.window_size1 - 1 relative_coords:, :, 0 *= 2 * self.window_size1 - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_bias = self.relative_position_bias_tableself.relative_position_index.view(-1).view( self.window_size0 * self.window_size1, self.window_size0 * self.window_size1, -1)
Let's assume that the window size is 2, It is convenient to understand and calculate the relative position coding logic .
First, establish the coordinate system :
And then in X and Y Direction calculation relative_coords. Calculation relative_coords The first step is to add (window_size-1) To make all values positive , stay X Direction re *(2*window_size-1) So that the subsequent summation can distinguish (0,1) and (1,0) Such coordinates .
The final will be X and Y Sum of direction coordinate values , obtain relative_position_index .
According to the above calculation process , You can also know , our relative_position_bias_table( Parameters to learn ) The maximum value should be (window_size+(window_size-1))*(2*window_size-1).
With relative_position_index and relative_position_bias_table after ,relative_position_bias Can be obtained by looking up the table .
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
step4: Calculation attn_out
attn = attn + relative_position_bias.unsqueeze(0) x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
according to self_attntion Formula :
softmax(q*KT)*V-->:100,3,49,49*100,3,49,32-->100,3,49,32
step5: Make a full connection
reshape+proj -->100,49,96
Calculation self_attention and transformer in attention The mechanism is the same . stay NLP field , Input is BLC, Calculated attn yes L*L Represent each pos Of token To the other pos Of attention value . ad locum CV field , Before, the feature map was divided into different windows , The size of each window windowsize*windowsize, therefore L Corresponding windowsize*windowsize The length of , That is, each point in a window is relative to other points attention value , Is calculated for each window self_attention.
(3)window_reverse
The above process is through window_partition post-processing , There needs to be window_reverse, hold 100,49,96 Restore to 1,70,70,96
(4)short_cut
reverse After FM and SwinTransformerBlock The initial input is done once shortcut.SwinTransformerBlock The module process ends ~ Why? ? No, . We avoided before cyclic shift.
In execution block in , Yes shift_size yes
shift_size=0 if (i % 2 == 0) else window_size // 2,
So the second iteration block, We need to cyclic shift Of .
The execution logic is the same as above (1)-(4), The main difference is in the steps (2), Here is the main explanation ,shift_size Not for 0 when , step (2) The process of .
Look at the second stage SW-MSA blcok, That is to join cyclic shift.
(a) Do the same window_partition, obtain b,100,49,96 Characteristic graph . then
cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None
The meaning of this line of code is , take x Move to the left shift_size, Move up shift_size. That is to say, in the figure below cyclic shift. The purpose of this operation is , adopt window_partition After the W-MSA, There is no overlap between windows , Use SW-MSA You can associate windows , But one of the problems here is the following figure ABC Areas and adjacent windows are actually not adjacent , It's through roll After the operation, the assignment is in this area .
(b)windowAttention
Calculation attention Consistent with the appeal process , It's just in steps a We mentioned in ,ABC The area is calculating attention The need when mask fall , there mask It's us BasicLayer Of Obtained in the first step attn_mask(100,49,49)~
if mask is not None: nW = mask.shape0 attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn)
mask The main logic ,attn So let's say that right now 200,3,49,49, We calculated attn_mask yes (100,49,49), Because it is for the window position mask and bs and head_num irrelevant , So will attn and mask , respectively, reshape To (2, 100, 3, 49, 49) and (1,100,1,49,49) Just fine .
Finally remember window_rever after , Remember to shift_x to sereverse Go back .
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) The above will be the most complicated SwinTransformerBlock The module introduction is over ~
3.down_sample
downsamp( the last one stage Unwanted ) It uses PatchMerging. Yes FM Interval sampling is used to achieve the purpose of downsampling , Again concat low resolution FM after , Through full connection C Channel clipping . It's like pixelShuffle Reverse operation of .
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) x = x.view(B, H, W, C) padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x:, 0::2, 0::2, : # B H/2 W/2 C x1 = x:, 1::2, 0::2, : # B H/2 W/2 C x2 = x:, 0::2, 1::2, : # B H/2 W/2 C x3 = x:, 1::2, 1::2, : # B H/2 W/2 C x = torch.cat(x0, x1, x2, x3, -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x)
So that's one basicLayer The logic of , Through four stage Get the characteristic map of different scales (Swin-T)
stage1-->96, 64, 64
stage2-->192, 32, 32
stage3-->384, 16, 16
stage4--> 768, 8, 8
With these four characteristic graphs, we can compare resnet And so on , Connected to the downstream task ~
边栏推荐
- The equipment is connected to the easycvr platform through the national standard gb28181. How to solve the problem of disconnection?
- CAP:多重注意力机制,有趣的细粒度分类方案 | AAAI 2021
- The catch-up of domestic chips has scared Qualcomm, the leader of mobile phone chips in the United States, and made moves to cope with the competition
- 存在安全隐患 路虎召回部分混动揽运
- mysql时间戳格式转换日期格式字符串
- My network relationship with "apifox"
- How to obtain ECS metadata
- Solution to the problem that FreeRTOS does not execute new tasks
- [download attached] installation and simple use of Chinese version of awvs
- Summary of common tools and usage
猜你喜欢

一文详解JackSon配置信息
![clang: warning: argument unused during compilation: ‘-no-pie‘ [-Wunused-command-line-argument]](/img/f0/42f394dbc989d381387c7b953d2a39.jpg)
clang: warning: argument unused during compilation: ‘-no-pie‘ [-Wunused-command-line-argument]

Three solutions for Jenkins image failing to update plug-in Center
![[interview high frequency questions] sequential DP questions with difficulty of 3/5 and direct construction](/img/32/720ffa63a90cd5d37460face3fde38.png)
[interview high frequency questions] sequential DP questions with difficulty of 3/5 and direct construction

使用阿里云RDS for SQL Server性能洞察优化数据库负载-初识性能洞察

Here comes Wi Fi 7. How strong is it?

【云原生 | Kubernetes篇】Kubernetes基础入门(三)

Cap: multiple attention mechanism, interesting fine-grained classification scheme | AAAI 2021

CAP:多重注意力机制,有趣的细粒度分类方案 | AAAI 2021

nifi从入门到实战(保姆级教程)——环境篇
随机推荐
Installer la Bibliothèque imagemagick 7.1 et l'extension imagick de PHP
How to open a futures account safely? Which futures companies are more reliable?
一文理解OpenStack网络
60 divine vs Code plug-ins!!
Goby+awvs realize attack surface detection
使用阿里云RDS for SQL Server性能洞察优化数据库负载-初识性能洞察
【面试高频题】难度 3/5,可直接构造的序列 DP 题
How to expand disk space on AWS host
Global and Chinese market of insect proof clothing 2022-2028: Research Report on technology, participants, trends, market size and share
Golang+redis reentrant lock
Still worried about missing measurements? Let's use Jacobo to calculate the code coverage
MySQL timestamp format conversion date format string
【云原生 | Kubernetes篇】Kubernetes基础入门(三)
CAP:多重注意力机制,有趣的细粒度分类方案 | AAAI 2021
Introduction to new features of ECMAScript 2019 (ES10)
Install the imagemagick7.1 library and the imageick extension for PHP
Pytorch 转置卷积
Nature publishes significant progress in quantum computing: the first quantum integrated circuit implementation in history
Here comes Wi Fi 7. How strong is it?
2021-05-04: given a non negative integer C, you need to judge whether there are two integers a and B, so that a*a+b*b=c.