当前位置:网站首页>Detailed explanation of transpose convolution in pytorch

Detailed explanation of transpose convolution in pytorch

2022-06-24 16:06:00 Full stack programmer webmaster

Hello everyone , I meet you again , I'm your friend, Quan Jun .

Preface

Transposition convolution , A scientific name transpose convolution, stay tf and torch It's called this in the . Sometimes you can see others call it in the paper deconvolution( deconvolution ), But the term is not appropriate . Because transpose convolution is not direct convolution The inverse operation of (reverse), The original tensor cannot be restored , So it is wrong to call it deconvolution . Just from shape Look up , The shape of the result is the same as that of the original tensor .

I wrote this article because of the blog about transpose convolution on the Internet , Can't explain clearly , I watched it for a long time, but it was still in the clouds . I can only do it myself .

One 、 Basic operation —— Dislocation scanning

Definition In this paper , We will use the usual convolution , be called direct convolution.

transpose conv And direct conv The biggest difference is :

Transpose convolution supports stagger scanning .

obviously , The nature of dislocation scanning makes the number of scans more . therefore transpose conv Output result of ,shape It will be larger than the input . This is the transposed convolution energy in shape Restore on input The basic principle of .( Of course, the numerical value cannot be restored )

Two 、 Shape formula

This paper on convolution is well written , Ten ! branch ! detailed ! all ! But it is as difficult to understand as the heavenly script .

A guide to convolution arithmetic for deep learning Vincent Dumoulin, Francesco Visin https://arxiv.org/abs/1603.07285

I try to write another version of myself .

Definition 2.1 We will kernel_size same , Can restore input shape The transposition convolution of , It's called and direct conv phase Corresponding transpose conv.

For example, we enter a (7×7), Direct convolution yields (3×3). In theory, there are many transpose conv Can from (3×3) Revert to (7×7), This is not conducive to our study . So we stipulate that only kernel_size The same one transpose conv Is the corresponding , corresponding transpose conv.

Theorem 2.1 every last direct conv Corresponding transpose conv, There is another one shape Equivalent in transformation , Corresponding direct conv.

obviously , Want to achieve (3×3) Revert to (7×7) This matter , We can also use belts padding Of direct conv Achieve , isn't it? ? The same is true in other cases , Just restore the shape ,transpose conv You can always use some direct conv Instead of , We also call it kernel_size same , only padding The different one is corresponding direct conv.

Now we have 3 A different concept ,original direct conv, corresponding transpose conv, corresponding direct conv.

Let's agree on a few symbols to indicate , hereinafter referred to as

Yes direct conv

input_size = i ( Be careful , We implicitly assumed that 2D The input shape is square, so only one letter is required i, Not necessarily 2 Letters wh) output_size = o kernel_size = k padding = p stride = s

Yes transpose conv

input_size=i’ output_size = o’ kernel_size = k’, Because we only study the corresponding one, So here k’=k padding=p’ stride=s’

We want to transpose the output of the convolution , Can restore the original input shape , Hope o’=i ,

Theorem 2.2 direct conv The shape formula of o = [ i + 2 ∗ p − ( k − 1 ) ] / s o=[i+2*p-(k-1)]/s o=[i+2∗p−(k−1)]/s This formula is familiar to anyone who has learned convolution .

Theorem 2.3 transpose conv The shape formula of o ′ = [ i ′ + 2 ∗ p ′ + ( k ′ − 1 ) ] / s ′ o’= [i’+2*p’+(k’-1)]/s’ o′=[i′+2∗p′+(k′−1)]/s′

Most beginners will faint at this step , Because transpose convolution stealthily changes 2 A concept ,s’ And p’. If you look down, you will see why it is said that it was stolen . I'll make a conclusion first ,

In the current corresponding transpose convolution ,s’ Always equal to 1,p’<=0.

prove

Yes Theorem 2.2 To deform , obtain i = o ∗ s + ( k − 1 ) − 2 ∗ p = o ∗ s + ( k − 1 ) + 2 ∗ ( − p ) i= o*s+(k-1) -2*p \\ =o*s +(k-1) +2*(-p) i=o∗s+(k−1)−2∗p=o∗s+(k−1)+2∗(−p) contrast Theorem 2.3 o ′ = [ i ′ + ( k − 1 ) + 2 ∗ p ′ ] / s ′ o’= [i’+(k-1)+2*p’]/s’ o′=[i′+(k−1)+2∗p′]/s′

If we hope to achieve o ′ = i o’=i o′=i,corresponding transpose conv Should be satisfied : i ′ = o ∗ s i’=o*s i′=o∗s p ′ = − p p’=-p p′=−p s ′ = 1 s’=1 s′=1

Despite the above 3 This condition is not the only solution , But it is the simplest set of solutions in practical application , So it is taken as the default solution . You see torch and tf Source code , It's all set up like this .

The above formula indicates , To make the output of transpose convolution o ′ o’ o′ Perfectly restore the input shape of direct convolution i i i, It needs to be right first o o o do s t r i d e stride stride Handle , Then proceed in steps of 1 The dislocation scanning of (k-1) Shape gain , Finally subtract p a d d i n g padding padding.

“ The step size is 1 The dislocation scanning of the can get (k-1) Shape gain ” It is a conclusion that is self-evident .

3、 ... and 、 Advanced operation ,stride Handle , step 1 Dislocation scanning , And padding Ablation

In this section In the second quarter Finally derived from 3 Step by step decomposition description .

3.1 stride Handle

I call it internal zero-padding, Referred to as internal padding. To put it simply , Is to put the input of transpose convolution o Zoom in first stride times , The filled part uses zero.( Instead of general picture interpolation filling )

Someone here must be wondering , According to the formula given in Section 2 , i ′ = o ∗ s = 3 ∗ 2 = 6 i’=o*s=3*2=6 i′=o∗s=3∗2=6 That's right , How could it be 5.

in fact , We are In practice The processing formula actually used is i ′ = ( o − 1 ) ∗ s + 1 i’=(o-1)*s+1 i′=(o−1)∗s+1

I will be in Section 5 Supplementary discussion on this issue .

3.2 step 1 Dislocation scanning

This has been introduced in the first section .

3.3 padding Ablation

Last section said p ′ = − p p’=-p p′=−p, This means that we are in the transpose convolution , What you do is not add sides , But to eliminate the edge . With p=1 For example , that p’=-1, We need to get around 1 side . Final o’=7+2*(-1)=5

Four 、 Code validation 3 Correctness of steps

Be careful !torch Medium weight Will be reverse.

This section contains general header files

import torch
import torch.nn.functional as F

4.1 A basic 1d Transposition convolution

count 2d Very tired , have a look 1d Just figure it out . Code reference

inputs = torch.Tensor([[1,2,3],[4,5,6]]).unsqueeze(0) #(1,2,3)
weights = torch.Tensor([1.1,2.2,3.3]).view(1,1,-1).repeat(2,1,1) #(2,5,k=3)
print(inputs.shape)
print(weights.shape)
o =F.conv_transpose1d(inputs, weights,padding=0,stride=1)
print(o.shape)
print(o)

Print the results

torch.Size([1, 2, 3])
torch.Size([2, 1, 3])
torch.Size([1, 1, 5])
tensor([[[ 5.5000, 18.7000, 41.8000, 42.9000, 29.7000]]])

Basic description of the above operations . inputs Shape is [batch_size,C_in,L_in]=[1,2,3] weights The shape of is [C_in,C_out,kernel_size] = [2,1,3] The shape of the output is [batch_size,C_out,L_out] = [1,1,5] In this case , We set up

batch_size=1, C_in=2, L_in=3, That is, the output of direct convolution o=3 C_out=1, k=3 p=0 s=1

among C_out It is essentially the number of output characteristic graphs , We order 1, So the result only needs to output a feature graph .

Analyzing the results is not difficult to find

o = tensor([[[ 5.5000, 18.7000, 41.8000, 42.9000, 29.7000]]]) 5.5 = (1+4)*1.1 18.7= (2+5)*1.1+(1+4)*2.2 41.8= (3+6)*1.1+(2+5)*2.2+(1+4)*3.3 42.9= (3+6)*2.2+(2+5)*3.3 29.7= (3+6)*3.3

The corresponding scanning mode is

This is a very strange point , weights Our definition is [ 1.1 , 2.2 , 3.3 ] [1.1,2.2,3.3] [1.1,2.2,3.3], When you go in and calculate , Inside it reverse Become [ 3.3 , 2.2 , 1.1 ] [3.3,2.2,1.1] [3.3,2.2,1.1]

Of course, we usually use transpose convolution , Most of them are learned by random initialization parameters , This reverse Also does not affect the . But if it's fixed weights, When you manually control transpose convolution , This reverse It's very noteworthy . In the use of torch Be careful when you .

4.2 stride Verification of the correctness of processing

inputs = torch.Tensor([[1,2,3],[4,5,6]]).unsqueeze(0) #(1,2,3)
weights = torch.Tensor([1.1,2.2,3.3]).view(1,1,-1).repeat(2,1,1) #(2,5,k=3)
print(inputs.shape)
print(weights.shape)
o =F.conv_transpose1d(inputs, weights,padding=0,stride=2)
print(o.shape)
print(o)

Output

torch.Size([1, 2, 3])
torch.Size([2, 1, 3])
torch.Size([1, 1, 7])
tensor([[[ 5.5000, 11.0000, 24.2000, 15.4000, 33.0000, 19.8000, 29.7000]]])

In this example, we set p=0,s=2. It's like we're here 3.2 As you guessed in , Input o=3, Be first stride=2 Processing becomes i’=5, Then proceed k=3、 In steps of 1 Dislocation scanning , The last output shape o’=7.

It can also be verified numerically that it is correct . “ Due to space limitation , I can't write here .” Please verify by yourself .

4.3 padding Validation of ablation accuracy

stay 4.2 Based on the code , hold padding Change to 1 that will do .

oinputs = torch.Tensor([[1,2,3],[4,5,6]]).unsqueeze(0) #(1,2,3)
weights = torch.Tensor([1.1,2.2,3.3]).view(1,1,-1).repeat(2,1,1) 
print(inputs.shape)
print(weights.shape)
o =F.conv_transpose1d(inputs, weights,padding=1,stride=2)
print(o.shape)
print(o)

Output

torch.Size([1, 2, 3])
torch.Size([2, 1, 3])
torch.Size([1, 1, 5])
tensor([[[11.0000, 24.2000, 15.4000, 33.0000, 19.8000]]])

be aware , 4.2 The output in is [ 5.5 , 11 , 24.2 , 15.4 , 33 , 19.8 , 29.7 ] [5.5, 11, 24.2, 15.4, 33, 19.8, 29.7] [5.5,11,24.2,15.4,33,19.8,29.7], common 7 individual . 4.3 The output in is [ 11 , 24.2 , 15.4 , 33 , 19.8 ] [11, 24.2, 15.4, 33, 19.8] [11,24.2,15.4,33,19.8], common 5 individual . We set up p=1 after , The output is in 4.2 On the basis of , The two ends are eliminated 1 It's worth , obtain o’=5.

thus , We proved it perfectly In the third quarter All my conjectures .

The grass ( A plant ) I suddenly realized that , This convolution kernel is called transpose conv filter. The punctuation of this thing may not be (transpose conv) filter, namely filter of transpose conv, It is transpose (conv filter), namely conv filter in transpose form? Is it because of this , therefore weights The entrance must be transpose??? therefore 1.1,2.2,3.3 Reversed . metaphysics ...

5、 ... and 、 Yes stride Supplement to processing

Let's come back and talk about In the third quarter The problems left over from the project , Why? o=3,s=2,k=3,p=1 when ,i’=5, instead of 6.

Review the background significance of transpose convolution again , We hope that shape Direct convolution of the upper reduction input. Think about it ,i= How many hours , after k=3,p=1,s=2 Direct convolution of , Can get o=3?

You can already see in the code in Section 4 , What we finally restored o’=5.

It is clear that , If i=5,k=3,p=1,s=2, We can get o=3 Of .

The key is that the direct convolution has an implicit operation , if [ i + 2 ∗ p − ( k − 1 ) ] % s ! = 0 [ i+2*p-(k-1) ]\%s !=0 [i+2∗p−(k−1)]%s!=0 Direct convolution , Will do extra padding operation .

In this case ,i=5,k=3,p=1,s=2, To calculate the i + 2 ∗ p − ( k − 1 ) = 5 + 2 ∗ 1 − ( 3 − 1 ) = 5 i+2*p-(k-1)=5+2*1-(3-1)=5 i+2∗p−(k−1)=5+2∗1−(3−1)=5, Can not be s = 2 s=2 s=2 to be divisible by . therefore 5 Additional pad once , obtain 6, then o = 6 / 2 = 3 o=6/2=3 o=6/2=3.

Empathy , If i=6,k=3,p=1,s=2, You can get o=3.

in other words , In satisfying our given background meaning “ Input that wants transpose convolution to restore direct convolution shape” On the basis of , {o=3,k=3,p=1,s=2} This set of conditions , Yes 2 A solution ,i=5 or i=6.

But in the computer , We can't make transpose_conv_layer Output 2 A solution , We can only find a definite solution , Only in this way can the calculation be feasible . So this kind of “ Inside pad” Of stride Processing mode , It's essentially a Man made rules . Although a given condition may correspond to multiple solutions , But we always take the smallest solution .

and , such stride Processing mode , It looks beautiful, doesn't it .

6、 ... and 、corresponding direct conv

We review the formula obtained in Section 2 i = o ∗ s + ( k − 1 ) − 2 ∗ p = o ∗ s + ( k − 1 ) + 2 ∗ ( − p ) i= o*s+(k-1) -2*p \\ =o*s +(k-1) +2*(-p) i=o∗s+(k−1)−2∗p=o∗s+(k−1)+2∗(−p) Continue to deform to get i = o ∗ s + 2 ∗ ( k − 1 ) − ( k − 1 ) + 2 ∗ ( − p ) = o ∗ s + 2 ∗ ( k − 1 − p ) − ( k − 1 ) i =o*s+2*(k-1)-(k-1)+2*(-p)\\ =o*s+2*(k-1-p) – (k-1) i=o∗s+2∗(k−1)−(k−1)+2∗(−p)=o∗s+2∗(k−1−p)−(k−1)

contrast Theorem 2.2 The formula of o ′ ′ = [ i ′ ′ + 2 ∗ p ′ ′ − ( k ′ ′ − 1 ) ] / s ′ ′ o”=[i”+2*p”-(k”-1)]/s” o′′=[i′′+2∗p′′−(k′′−1)]/s′′

It's not hard to get out ,corresponding direct conv If you want the output to restore the input perfectly , Need to make i ′ ′ = o ∗ s i”=o*s i′′=o∗s k ′ ′ = k k”=k k′′=k s ′ ′ = 1 s”=1 s′′=1 p ′ ′ = k − 1 − p p”=k-1-p p′′=k−1−p To get it o ′ ′ = = i o”==i o′′==i Of course , We will also meet The fourth quarter, Multi solution problems discussed in . So the same thing , We are also right i ′ ′ = o ∗ s i”=o*s i′′=o∗s Fine tuning , Change to i ′ ′ = ( o − 1 ) ∗ s + 1 i”=(o-1)*s+1 i′′=(o−1)∗s+1, In this way, we can always get the minimum value of multiple solutions .

What is the meaning of this section ? I don't know either . But the author of the paper in the second section is very keen to discuss this .( Stand hand )

He has an interesting conclusion in the literal sense . Let's agree s=1. If p=0, You can get p ′ ′ = k − 1 p”=k-1 p′′=k−1, The author called this p ′ ′ p” p′′ by fully-padding. So we can say .

non-padding Of direct conv, Corresponding corresponding direct conv yes fully-padding Of .

By contraries , If you want to let p ′ ′ = 0 p”=0 p′′=0, Need p = k − 1 p=k-1 p=k−1 So you can say

fully-padding Of direct conv, Corresponding corresponding direct conv yes non-padding Of .

It's probably just such a use .

Publisher : Full stack programmer stack length , Reprint please indicate the source :https://javaforall.cn/151942.html Link to the original text :https://javaforall.cn

原网站

版权声明
本文为[Full stack programmer webmaster]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/175/202206241545217815.html