当前位置:网站首页>Interpretation of swin transformer source code
Interpretation of swin transformer source code
2022-06-24 16:15:00 【languageX】
2020 year 5 month ,Facebook AI Launched DERT( Detection Transformer), It is used for target detection and panoramic segmentation .
2020 year 10 month , Google put forward Vit(Vision Transformer), utilize Transformer Classify images , Instead of convolutional networks .
2021 year 1 month ,OpenAI Two models are proposed :DALL·E Based on this article, the image is generated directly ,CLIP Map the image to the category of the text description . Both models use Transformer .
2021 year 3 month , Microsoft put forward Swin Transformer, hold CV All the major tasks were given to Tu bang ....
I can let it go ? I can't ... To sum up, I read the paper and code some time ago swin_transformer Framework and implementation .
The paper : https://arxiv.org/abs/2103.14030
Code : https://github.com/microsoft/Swin-Transformer
swin_transformer Introduce
1. swin_transformer Optimization point
swin_transformer Before comparison Vit There are two improvements :
1. Introduced CNN The multi-level transformers structure
Vit The scale of is constant , It is not easy to access to downstream tasks , For example, split encoder Stage can be easily accessed resnet etc. backbone The Internet , and Vit The size of the feature map is unchanged (b).swin_transfomer By combining image_patchesd The way to introduce multi-level structure , Here's the picture (a).
2、 Reduce computational complexity and memory consumption
The grey block in the figure above is defined as patch, The red block is defined as window.swin_transfomer Through the shard window , Calculation self_attention It is aimed at these local non overlapping window. The original MSA And in the paper W-MSA The calculation complexity of is shown in the formula below , among M Is the window containing patch The number of , That is to say window_size, Its size is much smaller than h,w Of . The calculation complexity and hw It's a linear relationship . Here is the complexity calculation method , After analyzing the source code, we can have a clearer understanding of .
2. swin_transformer How to optimize
For the first optimization point , The network architecture used in this paper is as follows :
The structure is divided into 4 individual stage,stages The size of the feature map in is reduced to 1/4,1/8,1/16,1/32.
For the second optimization point , The paper points out that only FM segmentation windows, Then for each window Conduct self_attention There is a drawback , There is no communication between windows . So series connection is proposed W-MSA and SW-MSA The way .
W-MSA Is a window without overlap self_attention Calculation , and cyclic shift As shown below , Make a... To the window shift. Originally 2*2 Number of windows , It is divided into 3*3 A window . But the amount of calculation will increase 1.5*1.5 times . An alternative method proposed by the author is to carry out a roll operation , take 2*2 The window of moves left and up , The moved window contains the information of other windows in the upper layer . however ABC The area should not be an adjacent area , So we need to do a mask operation .
Finally, remember to reverse shift Move the whole window back ~
3. swin_transformer What is the result?
The result is to put CV Several big tasks have been killed ..
swin_transformer Source code analysis
Here's an in-depth look at... From a code perspective swin_transformer
First understand the main classes :BasicLayer Realization stage The process of ,SwinTransformerBlock yes BasicLayer The main logic module is also the core module of the thesis ,WindowAttention yes SwinTransformerBlock To realize attention Module .
depths:(2,2,6,4) Decide on each layer Of SwinTransformerBlock Number of executions .
The paper proposes that 4 A set of parametric models , Let's take Swin-T As an example to introduce .
Code module logic :
patch_embed + pos_embed
stage1
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage2
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage3
-BasicLayer
--SwinTransformerBlock(*6)
---WindowAttention
stage4
-BasicLayer
--SwinTransformerBlock(*4)
---WindowAttention
Code logic of main modules :
1.patch_embed:PatchEmbed
Let's start with patch_embed,patch_embed Just press the input button patch Do a vector mapping . I think it's convolution ( title swin_transfomer, The first step is convolution ~ Convolution yyds)
Set input :(3,256,256),patch_size=4,embeding_dim=96
(1) The resolution is not enough 4 Divisible on pad To 4 Multiple
(2) General convolution kernel=4,stride=4, take image Map to non overlapping 4*4 Of patchs:(96,64,64)
(3) if necessary norm, Do it again layerNorm
(4)(3,256,256) adopt patch_embed, Characterized by (96,64,64)
2.absolute_pos_embed
If there is position_embeding step , Need to learn a 96,64,64 Of pos_emded Parameters . and patch_embed Conduct concat.
take emded Matrix process flatten+transpose-->64*64, 96
3.stages
Reduce the resolution *4 The characteristic diagram of 4 individual stage Of -BasicLayer
BasicLayer
1.attn_mask
Set up window_size=7, With stage1 As an example, the size of the input characteristic graph is (64,64).img_mask For the initial (70,70), Then through the window_partition The feature map is divided into 100 individual 7*7 The window of .
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask:, h, w, : = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
The purpose of the above code is to get 100 individual 49*49 Of attn_mask.
there attn_mask For the follow-up cyclic shift, That is to say SW-MSA Use .
First , Yes img_mask70*70 The graph of 9 Block assignment
63*63=0 4*63=1 3*64=2
63*4=3 4*4=4 3*4=5
64*3=6 4*3=7 3*3=8
And then by putting window_partition Cut the window into 100 individual 7*7 window , Tile data , obtain 100*49, Each window is subtracted from other windows , obtain 100*49*49, No more 0 Value assignment of -100. These are not 0 The meaning of location can be understood as that it is not the same area as the relative location in the above figure . combination cyclic shift, Express cyclic shift In one window Inside , Features that are not adjacent sub_window The location of , So we need to mask fall .
2.SwinTransformerBlock(*n)
(1)reshape+pad
For input 64*64, 96 Conduct layer_norm+reshape+pad operation .pad The function is to FM Of H,W yes window_size Multiple . Yes stage1:64*64, 96-->70,70,96
(2)window_mask_self_attention(W-MSA/SW-MSA)
Let's look at the first stage W-MSA blcok, That is not to join cyclic shift.
(a) Conduct window_partition, The feature map is divided into window_size*window_size Of patch,1,70*70,96 It's divided into 100,7,7,96, Again reshape100,49,96
(b) WindowAttention
Calculation self_attention
step1: obtain QKV matrix .X:100,49,64-->Q,K,V:100,3,49,32
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv0, qkv1, qkv2
Specific operation : Enter full connection C The channel extends to 3C, According to multi_head take FM It's divided into head_num Share , Last slipe obtain qkv matrix .100,3,49,32 Indicates the number of windows ,attention head , The length of the window ,C/head
step2: Calculation attention.
attn = (q @ k.transpose(-2, -1))
100,3,49,32*100,3,32,49-->:100,3,49,49 .self_attention You can view the principle of transformers The paper , I won't go into details here .
step3: Calculation relative_position_bias
The paper proposes , It is better to add relative position coding . That is to say step2 Calculated attn add relative_position_bias. and attn equally , The size should be (3,49*49) Matrix .
Let's see how to calculate relative_position_bias.
#define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size0 - 1) * (2 * window_size1 - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size0) coords_w = torch.arange(self.window_size1) coords = torch.stack(torch.meshgrid(coords_h, coords_w)) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten:, :, None - coords_flatten:, None, : # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords:, :, 0 += self.window_size0 - 1 # shift to start from 0 relative_coords:, :, 1 += self.window_size1 - 1 relative_coords:, :, 0 *= 2 * self.window_size1 - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_bias = self.relative_position_bias_tableself.relative_position_index.view(-1).view( self.window_size0 * self.window_size1, self.window_size0 * self.window_size1, -1)
Let's assume that the window size is 2, It is convenient to understand and calculate the relative position coding logic .
First, establish the coordinate system :
And then in X and Y Direction calculation relative_coords. Calculation relative_coords The first step is to add (window_size-1) To make all values positive , stay X Direction re *(2*window_size-1) So that the subsequent summation can distinguish (0,1) and (1,0) Such coordinates .
The final will be X and Y Sum of direction coordinate values , obtain relative_position_index .
According to the above calculation process , You can also know , our relative_position_bias_table( Parameters to learn ) The maximum value should be (window_size+(window_size-1))*(2*window_size-1).
With relative_position_index and relative_position_bias_table after ,relative_position_bias Can be obtained by looking up the table .
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
step4: Calculation attn_out
attn = attn + relative_position_bias.unsqueeze(0) x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
according to self_attntion Formula :
softmax(q*KT)*V-->:100,3,49,49*100,3,49,32-->100,3,49,32
step5: Make a full connection
reshape+proj -->100,49,96
Calculation self_attention and transformer in attention The mechanism is the same . stay NLP field , Input is BLC, Calculated attn yes L*L Represent each pos Of token To the other pos Of attention value . ad locum CV field , Before, the feature map was divided into different windows , The size of each window windowsize*windowsize, therefore L Corresponding windowsize*windowsize The length of , That is, each point in a window is relative to other points attention value , Is calculated for each window self_attention.
(3)window_reverse
The above process is through window_partition post-processing , There needs to be window_reverse, hold 100,49,96 Restore to 1,70,70,96
(4)short_cut
reverse After FM and SwinTransformerBlock The initial input is done once shortcut.SwinTransformerBlock The module process ends ~ Why? ? No, . We avoided before cyclic shift.
In execution block in , Yes shift_size yes
shift_size=0 if (i % 2 == 0) else window_size // 2,
So the second iteration block, We need to cyclic shift Of .
The execution logic is the same as above (1)-(4), The main difference is in the steps (2), Here is the main explanation ,shift_size Not for 0 when , step (2) The process of .
Look at the second stage SW-MSA blcok, That is to join cyclic shift.
(a) Do the same window_partition, obtain b,100,49,96 Characteristic graph . then
cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None
The meaning of this line of code is , take x Move to the left shift_size, Move up shift_size. That is to say, in the figure below cyclic shift. The purpose of this operation is , adopt window_partition After the W-MSA, There is no overlap between windows , Use SW-MSA You can associate windows , But one of the problems here is the following figure ABC Areas and adjacent windows are actually not adjacent , It's through roll After the operation, the assignment is in this area .
(b)windowAttention
Calculation attention Consistent with the appeal process , It's just in steps a We mentioned in ,ABC The area is calculating attention The need when mask fall , there mask It's us BasicLayer Of Obtained in the first step attn_mask(100,49,49)~
if mask is not None: nW = mask.shape0 attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn)
mask The main logic ,attn So let's say that right now 200,3,49,49, We calculated attn_mask yes (100,49,49), Because it is for the window position mask and bs and head_num irrelevant , So will attn and mask , respectively, reshape To (2, 100, 3, 49, 49) and (1,100,1,49,49) Just fine .
Finally remember window_rever after , Remember to shift_x to sereverse Go back .
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) The above will be the most complicated SwinTransformerBlock The module introduction is over ~
3.down_sample
downsamp( the last one stage Unwanted ) It uses PatchMerging. Yes FM Interval sampling is used to achieve the purpose of downsampling , Again concat low resolution FM after , Through full connection C Channel clipping . It's like pixelShuffle Reverse operation of .
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) x = x.view(B, H, W, C) padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x:, 0::2, 0::2, : # B H/2 W/2 C x1 = x:, 1::2, 0::2, : # B H/2 W/2 C x2 = x:, 0::2, 1::2, : # B H/2 W/2 C x3 = x:, 1::2, 1::2, : # B H/2 W/2 C x = torch.cat(x0, x1, x2, x3, -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x)
So that's one basicLayer The logic of , Through four stage Get the characteristic map of different scales (Swin-T)
stage1-->96, 64, 64
stage2-->192, 32, 32
stage3-->384, 16, 16
stage4--> 768, 8, 8
With these four characteristic graphs, we can compare resnet And so on , Connected to the downstream task ~
边栏推荐
- 2021-04-29: given an array arr, it represents a row of balloons with scores. One for each blow
- MySQL development specification
- Using alicloud RDS for SQL Server Performance insight to optimize database load - first understanding of performance insight
- 安装ImageMagick7.1库以及php的Imagick扩展
- 不忘初心
- ThinkPHP 漏洞利用工具
- 2021-05-03: given a non negative integer num, how to avoid circular statements,
- Recommend several super practical data analysis tools
- Golang+redis reentrant lock
- [download attached] installation and simple use of Chinese version of awvs
猜你喜欢
【面试高频题】难度 3/5,可直接构造的序列 DP 题
Several common DoS attacks
60 divine vs Code plug-ins!!
我与“Apifox”的网络情缘
[C language questions -- leetcode 12 questions] take you off and fly into the garbage
微信公众号调试与Natapp环境搭建
ZOJ - 4104 sequence in the pocket
Nifi from introduction to practice (nanny level tutorial) - environment
Build go command line program tool chain
MySQL進階系列:鎖-InnoDB中鎖的情况
随机推荐
Here comes Wi Fi 7. How strong is it?
60 个神级 VS Code 插件!!
My network relationship with "apifox"
2021-04-18: given a two-dimensional array matrix, the value in it is either 1 or 0,
存在安全隐患 部分冒险家混动版将召回
Rush for IPO, Hello, I'm in a hurry
Is Shanjin futures safe? What are the procedures for opening futures accounts? How to reduce the futures commission?
Most common usage of vim editor
C. Three displays(动态规划)Codeforces Round #485 (Div. 2)
Goby+AWVS 实现攻击面检测
嵌入式开发基础之线程间通信
Nature publishes significant progress in quantum computing: the first quantum integrated circuit implementation in history
[download attached] installation and simple use of Chinese version of awvs
Istio FAQ: region awareness does not take effect
Remain true to our original aspiration
Apple is no match for the longest selling mobile phone made in China, and has finally brought back the face of the domestic mobile phone
Two problems of qtreewidget returning as DLL in singleton mode
Flink kubernetes application deployment
Pytorch transpose convolution
Learning these 10 kinds of timed tasks, I'm a little floating