当前位置:网站首页>Interpretation of swin transformer source code
Interpretation of swin transformer source code
2022-06-24 16:15:00 【languageX】
2020 year 5 month ,Facebook AI Launched DERT( Detection Transformer), It is used for target detection and panoramic segmentation .
2020 year 10 month , Google put forward Vit(Vision Transformer), utilize Transformer Classify images , Instead of convolutional networks .
2021 year 1 month ,OpenAI Two models are proposed :DALL·E Based on this article, the image is generated directly ,CLIP Map the image to the category of the text description . Both models use Transformer .
2021 year 3 month , Microsoft put forward Swin Transformer, hold CV All the major tasks were given to Tu bang ....
I can let it go ? I can't ... To sum up, I read the paper and code some time ago swin_transformer Framework and implementation .
The paper : https://arxiv.org/abs/2103.14030
Code : https://github.com/microsoft/Swin-Transformer
swin_transformer Introduce
1. swin_transformer Optimization point
swin_transformer Before comparison Vit There are two improvements :
1. Introduced CNN The multi-level transformers structure
Vit The scale of is constant , It is not easy to access to downstream tasks , For example, split encoder Stage can be easily accessed resnet etc. backbone The Internet , and Vit The size of the feature map is unchanged (b).swin_transfomer By combining image_patchesd The way to introduce multi-level structure , Here's the picture (a).
2、 Reduce computational complexity and memory consumption
The grey block in the figure above is defined as patch, The red block is defined as window.swin_transfomer Through the shard window , Calculation self_attention It is aimed at these local non overlapping window. The original MSA And in the paper W-MSA The calculation complexity of is shown in the formula below , among M Is the window containing patch The number of , That is to say window_size, Its size is much smaller than h,w Of . The calculation complexity and hw It's a linear relationship . Here is the complexity calculation method , After analyzing the source code, we can have a clearer understanding of .
2. swin_transformer How to optimize
For the first optimization point , The network architecture used in this paper is as follows :
The structure is divided into 4 individual stage,stages The size of the feature map in is reduced to 1/4,1/8,1/16,1/32.
For the second optimization point , The paper points out that only FM segmentation windows, Then for each window Conduct self_attention There is a drawback , There is no communication between windows . So series connection is proposed W-MSA and SW-MSA The way .
W-MSA Is a window without overlap self_attention Calculation , and cyclic shift As shown below , Make a... To the window shift. Originally 2*2 Number of windows , It is divided into 3*3 A window . But the amount of calculation will increase 1.5*1.5 times . An alternative method proposed by the author is to carry out a roll operation , take 2*2 The window of moves left and up , The moved window contains the information of other windows in the upper layer . however ABC The area should not be an adjacent area , So we need to do a mask operation .
Finally, remember to reverse shift Move the whole window back ~
3. swin_transformer What is the result?
The result is to put CV Several big tasks have been killed ..
swin_transformer Source code analysis
Here's an in-depth look at... From a code perspective swin_transformer
First understand the main classes :BasicLayer Realization stage The process of ,SwinTransformerBlock yes BasicLayer The main logic module is also the core module of the thesis ,WindowAttention yes SwinTransformerBlock To realize attention Module .
depths:(2,2,6,4) Decide on each layer Of SwinTransformerBlock Number of executions .
The paper proposes that 4 A set of parametric models , Let's take Swin-T As an example to introduce .
Code module logic :
patch_embed + pos_embed
stage1
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage2
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage3
-BasicLayer
--SwinTransformerBlock(*6)
---WindowAttention
stage4
-BasicLayer
--SwinTransformerBlock(*4)
---WindowAttention
Code logic of main modules :
1.patch_embed:PatchEmbed
Let's start with patch_embed,patch_embed Just press the input button patch Do a vector mapping . I think it's convolution ( title swin_transfomer, The first step is convolution ~ Convolution yyds)
Set input :(3,256,256),patch_size=4,embeding_dim=96
(1) The resolution is not enough 4 Divisible on pad To 4 Multiple
(2) General convolution kernel=4,stride=4, take image Map to non overlapping 4*4 Of patchs:(96,64,64)
(3) if necessary norm, Do it again layerNorm
(4)(3,256,256) adopt patch_embed, Characterized by (96,64,64)
2.absolute_pos_embed
If there is position_embeding step , Need to learn a 96,64,64 Of pos_emded Parameters . and patch_embed Conduct concat.
take emded Matrix process flatten+transpose-->64*64, 96
3.stages
Reduce the resolution *4 The characteristic diagram of 4 individual stage Of -BasicLayer
BasicLayer
1.attn_mask
Set up window_size=7, With stage1 As an example, the size of the input characteristic graph is (64,64).img_mask For the initial (70,70), Then through the window_partition The feature map is divided into 100 individual 7*7 The window of .
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask:, h, w, : = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
The purpose of the above code is to get 100 individual 49*49 Of attn_mask.
there attn_mask For the follow-up cyclic shift, That is to say SW-MSA Use .
First , Yes img_mask70*70 The graph of 9 Block assignment
63*63=0 4*63=1 3*64=2
63*4=3 4*4=4 3*4=5
64*3=6 4*3=7 3*3=8
And then by putting window_partition Cut the window into 100 individual 7*7 window , Tile data , obtain 100*49, Each window is subtracted from other windows , obtain 100*49*49, No more 0 Value assignment of -100. These are not 0 The meaning of location can be understood as that it is not the same area as the relative location in the above figure . combination cyclic shift, Express cyclic shift In one window Inside , Features that are not adjacent sub_window The location of , So we need to mask fall .
2.SwinTransformerBlock(*n)
(1)reshape+pad
For input 64*64, 96 Conduct layer_norm+reshape+pad operation .pad The function is to FM Of H,W yes window_size Multiple . Yes stage1:64*64, 96-->70,70,96
(2)window_mask_self_attention(W-MSA/SW-MSA)
Let's look at the first stage W-MSA blcok, That is not to join cyclic shift.
(a) Conduct window_partition, The feature map is divided into window_size*window_size Of patch,1,70*70,96 It's divided into 100,7,7,96, Again reshape100,49,96
(b) WindowAttention
Calculation self_attention
step1: obtain QKV matrix .X:100,49,64-->Q,K,V:100,3,49,32
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv0, qkv1, qkv2
Specific operation : Enter full connection C The channel extends to 3C, According to multi_head take FM It's divided into head_num Share , Last slipe obtain qkv matrix .100,3,49,32 Indicates the number of windows ,attention head , The length of the window ,C/head
step2: Calculation attention.
attn = (q @ k.transpose(-2, -1))
100,3,49,32*100,3,32,49-->:100,3,49,49 .self_attention You can view the principle of transformers The paper , I won't go into details here .
step3: Calculation relative_position_bias
The paper proposes , It is better to add relative position coding . That is to say step2 Calculated attn add relative_position_bias. and attn equally , The size should be (3,49*49) Matrix .
Let's see how to calculate relative_position_bias.
#define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size0 - 1) * (2 * window_size1 - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size0) coords_w = torch.arange(self.window_size1) coords = torch.stack(torch.meshgrid(coords_h, coords_w)) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten:, :, None - coords_flatten:, None, : # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords:, :, 0 += self.window_size0 - 1 # shift to start from 0 relative_coords:, :, 1 += self.window_size1 - 1 relative_coords:, :, 0 *= 2 * self.window_size1 - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_bias = self.relative_position_bias_tableself.relative_position_index.view(-1).view( self.window_size0 * self.window_size1, self.window_size0 * self.window_size1, -1)
Let's assume that the window size is 2, It is convenient to understand and calculate the relative position coding logic .
First, establish the coordinate system :
And then in X and Y Direction calculation relative_coords. Calculation relative_coords The first step is to add (window_size-1) To make all values positive , stay X Direction re *(2*window_size-1) So that the subsequent summation can distinguish (0,1) and (1,0) Such coordinates .
The final will be X and Y Sum of direction coordinate values , obtain relative_position_index .
According to the above calculation process , You can also know , our relative_position_bias_table( Parameters to learn ) The maximum value should be (window_size+(window_size-1))*(2*window_size-1).
With relative_position_index and relative_position_bias_table after ,relative_position_bias Can be obtained by looking up the table .
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
step4: Calculation attn_out
attn = attn + relative_position_bias.unsqueeze(0) x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
according to self_attntion Formula :
softmax(q*KT)*V-->:100,3,49,49*100,3,49,32-->100,3,49,32
step5: Make a full connection
reshape+proj -->100,49,96
Calculation self_attention and transformer in attention The mechanism is the same . stay NLP field , Input is BLC, Calculated attn yes L*L Represent each pos Of token To the other pos Of attention value . ad locum CV field , Before, the feature map was divided into different windows , The size of each window windowsize*windowsize, therefore L Corresponding windowsize*windowsize The length of , That is, each point in a window is relative to other points attention value , Is calculated for each window self_attention.
(3)window_reverse
The above process is through window_partition post-processing , There needs to be window_reverse, hold 100,49,96 Restore to 1,70,70,96
(4)short_cut
reverse After FM and SwinTransformerBlock The initial input is done once shortcut.SwinTransformerBlock The module process ends ~ Why? ? No, . We avoided before cyclic shift.
In execution block in , Yes shift_size yes
shift_size=0 if (i % 2 == 0) else window_size // 2,
So the second iteration block, We need to cyclic shift Of .
The execution logic is the same as above (1)-(4), The main difference is in the steps (2), Here is the main explanation ,shift_size Not for 0 when , step (2) The process of .
Look at the second stage SW-MSA blcok, That is to join cyclic shift.
(a) Do the same window_partition, obtain b,100,49,96 Characteristic graph . then
cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None
The meaning of this line of code is , take x Move to the left shift_size, Move up shift_size. That is to say, in the figure below cyclic shift. The purpose of this operation is , adopt window_partition After the W-MSA, There is no overlap between windows , Use SW-MSA You can associate windows , But one of the problems here is the following figure ABC Areas and adjacent windows are actually not adjacent , It's through roll After the operation, the assignment is in this area .
(b)windowAttention
Calculation attention Consistent with the appeal process , It's just in steps a We mentioned in ,ABC The area is calculating attention The need when mask fall , there mask It's us BasicLayer Of Obtained in the first step attn_mask(100,49,49)~
if mask is not None: nW = mask.shape0 attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn)
mask The main logic ,attn So let's say that right now 200,3,49,49, We calculated attn_mask yes (100,49,49), Because it is for the window position mask and bs and head_num irrelevant , So will attn and mask , respectively, reshape To (2, 100, 3, 49, 49) and (1,100,1,49,49) Just fine .
Finally remember window_rever after , Remember to shift_x to sereverse Go back .
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) The above will be the most complicated SwinTransformerBlock The module introduction is over ~
3.down_sample
downsamp( the last one stage Unwanted ) It uses PatchMerging. Yes FM Interval sampling is used to achieve the purpose of downsampling , Again concat low resolution FM after , Through full connection C Channel clipping . It's like pixelShuffle Reverse operation of .
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) x = x.view(B, H, W, C) padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x:, 0::2, 0::2, : # B H/2 W/2 C x1 = x:, 1::2, 0::2, : # B H/2 W/2 C x2 = x:, 0::2, 1::2, : # B H/2 W/2 C x3 = x:, 1::2, 1::2, : # B H/2 W/2 C x = torch.cat(x0, x1, x2, x3, -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x)
So that's one basicLayer The logic of , Through four stage Get the characteristic map of different scales (Swin-T)
stage1-->96, 64, 64
stage2-->192, 32, 32
stage3-->384, 16, 16
stage4--> 768, 8, 8
With these four characteristic graphs, we can compare resnet And so on , Connected to the downstream task ~
边栏推荐
- 2021-05-03: given a non negative integer num, how to avoid circular statements,
- nifi从入门到实战(保姆级教程)——环境篇
- 打破内存墙的新利器成行业“热搜”!持久内存让打工人也能玩转海量数据+高维模型
- 2021-04-25: given an array arr and a positive number m, the
- 【面试高频题】难度 3/5,可直接构造的序列 DP 题
- MySQL date timestamp conversion
- Pytorch transpose convolution
- Installer la Bibliothèque imagemagick 7.1 et l'extension imagick de PHP
- April 30, 2021: there are residential areas on a straight line, and the post office can only be built on residential areas. Given an ordered positive array arr
- Database tools in intelij can connect but cannot display schema, tables
猜你喜欢

My network relationship with "apifox"

How to easily realize online karaoke room and sing "mountain sea" with Wang Xinling
MySQL進階系列:鎖-InnoDB中鎖的情况

实现领域驱动设计 - 使用ABP框架 - 领域逻辑 & 应用逻辑

B. Ternary Sequence(思维+贪心)Codeforces Round #665 (Div. 2)

【云原生 | Kubernetes篇】Kubernetes基础入门(三)

How to expand disk space on AWS host
![[cloud native | kubernetes chapter] Introduction to kubernetes Foundation (III)](/img/21/503ed54a2fa14fbfd67f75a55ec286.png)
[cloud native | kubernetes chapter] Introduction to kubernetes Foundation (III)
![Software test [high frequency] interview questions sorted out by staying up late (latest in 2022)](/img/33/2c2256fd98b908ddaf5573f644ad7f.png)
Software test [high frequency] interview questions sorted out by staying up late (latest in 2022)

一文理解OpenStack网络
随机推荐
Introduction to new features of ECMAScript 2019 (ES10)
微信公众号调试与Natapp环境搭建
[my advanced OpenGL learning journey] learning notes of OpenGL coordinate system
Convert text to hexadecimal, and reverse
Logging is not as simple as you think
Install the imagemagick7.1 library and the imageick extension for PHP
How to open a futures account safely? Which futures companies are more reliable?
转置卷积详解
打破内存墙的新利器成行业“热搜”!持久内存让打工人也能玩转海量数据+高维模型
A new weapon to break the memory wall has become a "hot search" in the industry! Persistent memory enables workers to play with massive data + high-dimensional models
C. Three displays(动态规划)Codeforces Round #485 (Div. 2)
B. Ternary Sequence(思维+贪心)Codeforces Round #665 (Div. 2)
My network relationship with "apifox"
一文理解OpenStack网络
Some experiences of K project: global template highlights
D. Solve the maze (thinking +bfs) codeforces round 648 (Div. 2)
How to expand disk space on AWS host
PyTorch中的转置卷积详解
MySQL Innodb和Myisam
One article explains Jackson configuration information in detail