当前位置:网站首页>Interpretation of swin transformer source code

Interpretation of swin transformer source code

2022-06-24 16:15:00 languageX

2020 year 5 month ,Facebook AI Launched DERT( Detection Transformer), It is used for target detection and panoramic segmentation .

2020 year 10 month , Google put forward Vit(Vision Transformer), utilize Transformer Classify images , Instead of convolutional networks .

2021 year 1 month ,OpenAI Two models are proposed :DALL·E Based on this article, the image is generated directly ,CLIP Map the image to the category of the text description . Both models use Transformer .

2021 year 3 month , Microsoft put forward Swin Transformer, hold CV All the major tasks were given to Tu bang ....

I can let it go ? I can't ... To sum up, I read the paper and code some time ago swin_transformer Framework and implementation .

The paper : https://arxiv.org/abs/2103.14030

Code : https://github.com/microsoft/Swin-Transformer

swin_transformer Introduce

1. swin_transformer Optimization point

swin_transformer Before comparison Vit There are two improvements :

1. Introduced CNN The multi-level transformers structure

Vit The scale of is constant , It is not easy to access to downstream tasks , For example, split encoder Stage can be easily accessed resnet etc. backbone The Internet , and Vit The size of the feature map is unchanged (b).swin_transfomer By combining image_patchesd The way to introduce multi-level structure , Here's the picture (a).

Figure 1 Swin Transformer and Vit contrast

2、 Reduce computational complexity and memory consumption

The grey block in the figure above is defined as patch, The red block is defined as window.swin_transfomer Through the shard window , Calculation self_attention It is aimed at these local non overlapping window. The original MSA And in the paper W-MSA The calculation complexity of is shown in the formula below , among M Is the window containing patch The number of , That is to say window_size, Its size is much smaller than h,w Of . The calculation complexity and hw It's a linear relationship . Here is the complexity calculation method , After analyzing the source code, we can have a clearer understanding of .

2. swin_transformer How to optimize

For the first optimization point , The network architecture used in this paper is as follows :

Swin transformer frame

The structure is divided into 4 individual stage,stages The size of the feature map in is reduced to 1/4,1/8,1/16,1/32.

For the second optimization point , The paper points out that only FM segmentation windows, Then for each window Conduct self_attention There is a drawback , There is no communication between windows . So series connection is proposed W-MSA and SW-MSA The way .

W-MSA Is a window without overlap self_attention Calculation , and cyclic shift As shown below , Make a... To the window shift. Originally 2*2 Number of windows , It is divided into 3*3 A window . But the amount of calculation will increase 1.5*1.5 times . An alternative method proposed by the author is to carry out a roll operation , take 2*2 The window of moves left and up , The moved window contains the information of other windows in the upper layer . however ABC The area should not be an adjacent area , So we need to do a mask operation .

Finally, remember to reverse shift Move the whole window back ~

cyclic shift mask self_attention The process

3. swin_transformer What is the result?

The result is to put CV Several big tasks have been killed ..

Classification task
Test task
Split task

swin_transformer Source code analysis

Here's an in-depth look at... From a code perspective swin_transformer

First understand the main classes :BasicLayer Realization stage The process of ,SwinTransformerBlock yes BasicLayer The main logic module is also the core module of the thesis ,WindowAttention yes SwinTransformerBlock To realize attention Module .

depths:(2,2,6,4) Decide on each layer Of SwinTransformerBlock Number of executions .

The paper proposes that 4 A set of parametric models , Let's take Swin-T As an example to introduce .

Code module logic :

patch_embed + pos_embed

stage1

-BasicLayer

--SwinTransformerBlock(*2)

---WindowAttention

stage2

-BasicLayer

--SwinTransformerBlock(*2)

---WindowAttention

stage3

-BasicLayer

--SwinTransformerBlock(*6)

---WindowAttention

stage4

-BasicLayer

--SwinTransformerBlock(*4)

---WindowAttention

Code logic of main modules :

1.patch_embed:PatchEmbed

Let's start with patch_embed,patch_embed Just press the input button patch Do a vector mapping . I think it's convolution ( title swin_transfomer, The first step is convolution ~ Convolution yyds)

Set input :(3,256,256),patch_size=4,embeding_dim=96

(1) The resolution is not enough 4 Divisible on pad To 4 Multiple

(2) General convolution kernel=4,stride=4, take image Map to non overlapping 4*4 Of patchs:(96,64,64)

(3) if necessary norm, Do it again layerNorm

(4)(3,256,256) adopt patch_embed, Characterized by (96,64,64)

2.absolute_pos_embed

If there is position_embeding step , Need to learn a 96,64,64 Of pos_emded Parameters . and patch_embed Conduct concat.

take emded Matrix process flatten+transpose-->64*64, 96

3.stages

Reduce the resolution *4 The characteristic diagram of 4 individual stage Of -BasicLayer

BasicLayer

1.attn_mask

Set up window_size=7, With stage1 As an example, the size of the input characteristic graph is (64,64).img_mask For the initial (70,70), Then through the window_partition The feature map is divided into 100 individual 7*7 The window of .

img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)
h_slices = (slice(0, -self.window_size),
 slice(-self.window_size, -self.shift_size),
 slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
 slice(-self.window_size, -self.shift_size),
 slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
 for w in w_slices:
 img_mask:, h, w, : = cnt
 cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

The purpose of the above code is to get 100 individual 49*49 Of attn_mask.

there attn_mask For the follow-up cyclic shift, That is to say SW-MSA Use .

First , Yes img_mask70*70 The graph of 9 Block assignment

63*63=0 4*63=1 3*64=2

63*4=3 4*4=4 3*4=5

64*3=6 4*3=7 3*3=8

img_mask Block

And then by putting window_partition Cut the window into 100 individual 7*7 window , Tile data , obtain 100*49, Each window is subtracted from other windows , obtain 100*49*49, No more 0 Value assignment of -100. These are not 0 The meaning of location can be understood as that it is not the same area as the relative location in the above figure . combination cyclic shift, Express cyclic shift In one window Inside , Features that are not adjacent sub_window The location of , So we need to mask fall .

2.SwinTransformerBlock(*n)

(1)reshape+pad

For input 64*64, 96 Conduct layer_norm+reshape+pad operation .pad The function is to FM Of H,W yes window_size Multiple . Yes stage1:64*64, 96-->70,70,96

(2)window_mask_self_attention(W-MSA/SW-MSA)

Let's look at the first stage W-MSA blcok, That is not to join cyclic shift.

(a) Conduct window_partition, The feature map is divided into window_size*window_size Of patch,1,70*70,96 It's divided into 100,7,7,96, Again reshape100,49,96

(b) WindowAttention

Calculation self_attention

attention Calculation formula

step1: obtain QKV matrix .X:100,49,64-->Q,K,V:100,3,49,32

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv0, qkv1, qkv2

Specific operation : Enter full connection C The channel extends to 3C, According to multi_head take FM It's divided into head_num Share , Last slipe obtain qkv matrix .100,3,49,32 Indicates the number of windows ,attention head , The length of the window ,C/head

step2: Calculation attention.

attn = (q @ k.transpose(-2, -1))

100,3,49,32*100,3,32,49-->:100,3,49,49 .self_attention You can view the principle of transformers The paper , I won't go into details here .

step3: Calculation relative_position_bias

The paper proposes , It is better to add relative position coding . That is to say step2 Calculated attn add relative_position_bias. and attn equally , The size should be (3,49*49) Matrix .

Let's see how to calculate relative_position_bias.

#define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
 torch.zeros((2 * window_size0 - 1) * (2 * window_size1 - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size0)
coords_w = torch.arange(self.window_size1)
coords = torch.stack(torch.meshgrid(coords_h, coords_w)) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten:, :, None - coords_flatten:, None, : # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords:, :, 0 += self.window_size0 - 1 # shift to start from 0
relative_coords:, :, 1 += self.window_size1 - 1
relative_coords:, :, 0 *= 2 * self.window_size1 - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_bias = self.relative_position_bias_tableself.relative_position_index.view(-1).view(
self.window_size0 * self.window_size1, self.window_size0 * self.window_size1, -1)

Let's assume that the window size is 2, It is convenient to understand and calculate the relative position coding logic .

First, establish the coordinate system :

And then in X and Y Direction calculation relative_coords. Calculation relative_coords The first step is to add (window_size-1) To make all values positive , stay X Direction re *(2*window_size-1) So that the subsequent summation can distinguish (0,1) and (1,0) Such coordinates .

relative_coords The calculation process

The final will be X and Y Sum of direction coordinate values , obtain relative_position_index .

relative_position_index The calculation process

According to the above calculation process , You can also know , our relative_position_bias_table( Parameters to learn ) The maximum value should be (window_size+(window_size-1))*(2*window_size-1).

With relative_position_index and relative_position_bias_table after ,relative_position_bias Can be obtained by looking up the table .

relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)

step4: Calculation attn_out

attn = attn + relative_position_bias.unsqueeze(0)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

according to self_attntion Formula :

softmax(q*KT)*V-->:100,3,49,49*100,3,49,32-->100,3,49,32

step5: Make a full connection

reshape+proj -->100,49,96

Calculation self_attention and transformer in attention The mechanism is the same . stay NLP field , Input is BLC, Calculated attn yes L*L Represent each pos Of token To the other pos Of attention value . ad locum CV field , Before, the feature map was divided into different windows , The size of each window windowsize*windowsize, therefore L Corresponding windowsize*windowsize The length of , That is, each point in a window is relative to other points attention value , Is calculated for each window self_attention.

(3)window_reverse

The above process is through window_partition post-processing , There needs to be window_reverse, hold 100,49,96 Restore to 1,70,70,96

(4)short_cut

reverse After FM and SwinTransformerBlock The initial input is done once shortcut.SwinTransformerBlock The module process ends ~ Why? ? No, . We avoided before cyclic shift.

In execution block in , Yes shift_size yes

shift_size=0 if (i % 2 == 0) else window_size // 2,

So the second iteration block, We need to cyclic shift Of .

The execution logic is the same as above (1)-(4), The main difference is in the steps (2), Here is the main explanation ,shift_size Not for 0 when , step (2) The process of .

Look at the second stage SW-MSA blcok, That is to join cyclic shift.

(a) Do the same window_partition, obtain b,100,49,96 Characteristic graph . then

cyclic shift
if self.shift_size > 0:
 shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
 attn_mask = mask_matrix
else:
 shifted_x = x
 attn_mask = None

The meaning of this line of code is , take x Move to the left shift_size, Move up shift_size. That is to say, in the figure below cyclic shift. The purpose of this operation is , adopt window_partition After the W-MSA, There is no overlap between windows , Use SW-MSA You can associate windows , But one of the problems here is the following figure ABC Areas and adjacent windows are actually not adjacent , It's through roll After the operation, the assignment is in this area .

(b)windowAttention

Calculation attention Consistent with the appeal process , It's just in steps a We mentioned in ,ABC The area is calculating attention The need when mask fall , there mask It's us BasicLayer Of Obtained in the first step attn_mask(100,49,49)~

if mask is not None:
 nW = mask.shape0
 attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
 attn = attn.view(-1, self.num_heads, N, N)
 attn = self.softmax(attn)
else:
 attn = self.softmax(attn)

mask The main logic ,attn So let's say that right now 200,3,49,49, We calculated attn_mask yes (100,49,49), Because it is for the window position mask and bs and head_num irrelevant , So will attn and mask , respectively, reshape To (2, 100, 3, 49, 49) and (1,100,1,49,49) Just fine .

Finally remember window_rever after , Remember to shift_x to sereverse Go back .

x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
 The above will be the most complicated SwinTransformerBlock The module introduction is over ~

3.down_sample

downsamp( the last one stage Unwanted ) It uses PatchMerging. Yes FM Interval sampling is used to achieve the purpose of downsampling , Again concat low resolution FM after , Through full connection C Channel clipping . It's like pixelShuffle Reverse operation of .

self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
x = x.view(B, H, W, C)
padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
 x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x:, 0::2, 0::2, : # B H/2 W/2 C
x1 = x:, 1::2, 0::2, : # B H/2 W/2 C
x2 = x:, 0::2, 1::2, : # B H/2 W/2 C
x3 = x:, 1::2, 1::2, : # B H/2 W/2 C
x = torch.cat(x0, x1, x2, x3, -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)

So that's one basicLayer The logic of , Through four stage Get the characteristic map of different scales (Swin-T)

stage1-->96, 64, 64

stage2-->192, 32, 32

stage3-->384, 16, 16

stage4--> 768, 8, 8

With these four characteristic graphs, we can compare resnet And so on , Connected to the downstream task ~

原网站

版权声明
本文为[languageX]所创,转载请带上原文链接,感谢
https://yzsam.com/2021/05/20210501234015013J.html