当前位置:网站首页>Detailed explanation of transpose convolution in pytorch
Detailed explanation of transpose convolution in pytorch
2022-06-24 16:06:00 【Full stack programmer webmaster】
Hello everyone , I meet you again , I'm your friend, Quan Jun .
Preface
Transposition convolution , A scientific name transpose convolution, stay tf and torch It's called this in the . Sometimes you can see others call it in the paper deconvolution( deconvolution ), But the term is not appropriate . Because transpose convolution is not direct convolution The inverse operation of (reverse), The original tensor cannot be restored , So it is wrong to call it deconvolution . Just from shape Look up , The shape of the result is the same as that of the original tensor .
I wrote this article because of the blog about transpose convolution on the Internet , Can't explain clearly , I watched it for a long time, but it was still in the clouds . I can only do it myself .
One 、 Basic operation —— Dislocation scanning
Definition In this paper , We will use the usual convolution , be called direct convolution.
transpose conv And direct conv The biggest difference is :
Transpose convolution supports stagger scanning .
obviously , The nature of dislocation scanning makes the number of scans more . therefore transpose conv Output result of ,shape It will be larger than the input . This is the transposed convolution energy in shape Restore on input The basic principle of .( Of course, the numerical value cannot be restored )
Two 、 Shape formula
This paper on convolution is well written , Ten ! branch ! detailed ! all ! But it is as difficult to understand as the heavenly script .
A guide to convolution arithmetic for deep learning Vincent Dumoulin, Francesco Visin https://arxiv.org/abs/1603.07285
I try to write another version of myself .
Definition 2.1 We will kernel_size same , Can restore input shape The transposition convolution of , It's called and direct conv phase Corresponding transpose conv.
For example, we enter a (7×7), Direct convolution yields (3×3). In theory, there are many transpose conv Can from (3×3) Revert to (7×7), This is not conducive to our study . So we stipulate that only kernel_size The same one transpose conv Is the corresponding , corresponding transpose conv.
Theorem 2.1 every last direct conv Corresponding transpose conv, There is another one shape Equivalent in transformation , Corresponding direct conv.
obviously , Want to achieve (3×3) Revert to (7×7) This matter , We can also use belts padding Of direct conv Achieve , isn't it? ? The same is true in other cases , Just restore the shape ,transpose conv You can always use some direct conv Instead of , We also call it kernel_size same , only padding The different one is corresponding direct conv.
Now we have 3 A different concept ,original direct conv, corresponding transpose conv, corresponding direct conv.
Let's agree on a few symbols to indicate , hereinafter referred to as
Yes direct conv
input_size = i ( Be careful , We implicitly assumed that 2D The input shape is square, so only one letter is required i, Not necessarily 2 Letters wh) output_size = o kernel_size = k padding = p stride = s
Yes transpose conv
input_size=i’ output_size = o’ kernel_size = k’, Because we only study the corresponding one, So here k’=k padding=p’ stride=s’
We want to transpose the output of the convolution , Can restore the original input shape , Hope o’=i ,
Theorem 2.2 direct conv The shape formula of o = [ i + 2 ∗ p − ( k − 1 ) ] / s o=[i+2*p-(k-1)]/s o=[i+2∗p−(k−1)]/s This formula is familiar to anyone who has learned convolution .
Theorem 2.3 transpose conv The shape formula of o ′ = [ i ′ + 2 ∗ p ′ + ( k ′ − 1 ) ] / s ′ o’= [i’+2*p’+(k’-1)]/s’ o′=[i′+2∗p′+(k′−1)]/s′
Most beginners will faint at this step , Because transpose convolution stealthily changes 2 A concept ,s’ And p’. If you look down, you will see why it is said that it was stolen . I'll make a conclusion first ,
In the current corresponding transpose convolution ,s’ Always equal to 1,p’<=0.
prove
Yes Theorem 2.2 To deform , obtain i = o ∗ s + ( k − 1 ) − 2 ∗ p = o ∗ s + ( k − 1 ) + 2 ∗ ( − p ) i= o*s+(k-1) -2*p \\ =o*s +(k-1) +2*(-p) i=o∗s+(k−1)−2∗p=o∗s+(k−1)+2∗(−p) contrast Theorem 2.3 o ′ = [ i ′ + ( k − 1 ) + 2 ∗ p ′ ] / s ′ o’= [i’+(k-1)+2*p’]/s’ o′=[i′+(k−1)+2∗p′]/s′
If we hope to achieve o ′ = i o’=i o′=i,corresponding transpose conv Should be satisfied : i ′ = o ∗ s i’=o*s i′=o∗s p ′ = − p p’=-p p′=−p s ′ = 1 s’=1 s′=1
Despite the above 3 This condition is not the only solution , But it is the simplest set of solutions in practical application , So it is taken as the default solution . You see torch and tf Source code , It's all set up like this .
The above formula indicates , To make the output of transpose convolution o ′ o’ o′ Perfectly restore the input shape of direct convolution i i i, It needs to be right first o o o do s t r i d e stride stride Handle , Then proceed in steps of 1 The dislocation scanning of (k-1) Shape gain , Finally subtract p a d d i n g padding padding.
“ The step size is 1 The dislocation scanning of the can get (k-1) Shape gain ” It is a conclusion that is self-evident .
3、 ... and 、 Advanced operation ,stride Handle , step 1 Dislocation scanning , And padding Ablation
In this section In the second quarter Finally derived from 3 Step by step decomposition description .
3.1 stride Handle
I call it internal zero-padding, Referred to as internal padding. To put it simply , Is to put the input of transpose convolution o Zoom in first stride times , The filled part uses zero.( Instead of general picture interpolation filling )
Someone here must be wondering , According to the formula given in Section 2 , i ′ = o ∗ s = 3 ∗ 2 = 6 i’=o*s=3*2=6 i′=o∗s=3∗2=6 That's right , How could it be 5.
in fact , We are In practice The processing formula actually used is i ′ = ( o − 1 ) ∗ s + 1 i’=(o-1)*s+1 i′=(o−1)∗s+1
I will be in Section 5 Supplementary discussion on this issue .
3.2 step 1 Dislocation scanning
This has been introduced in the first section .
3.3 padding Ablation
Last section said p ′ = − p p’=-p p′=−p, This means that we are in the transpose convolution , What you do is not add sides , But to eliminate the edge . With p=1 For example , that p’=-1, We need to get around 1 side . Final o’=7+2*(-1)=5
Four 、 Code validation 3 Correctness of steps
Be careful !torch Medium weight Will be reverse.
This section contains general header files
import torch
import torch.nn.functional as F4.1 A basic 1d Transposition convolution
count 2d Very tired , have a look 1d Just figure it out . Code reference
inputs = torch.Tensor([[1,2,3],[4,5,6]]).unsqueeze(0) #(1,2,3)
weights = torch.Tensor([1.1,2.2,3.3]).view(1,1,-1).repeat(2,1,1) #(2,5,k=3)
print(inputs.shape)
print(weights.shape)
o =F.conv_transpose1d(inputs, weights,padding=0,stride=1)
print(o.shape)
print(o)Print the results
torch.Size([1, 2, 3])
torch.Size([2, 1, 3])
torch.Size([1, 1, 5])
tensor([[[ 5.5000, 18.7000, 41.8000, 42.9000, 29.7000]]])Basic description of the above operations . inputs Shape is [batch_size,C_in,L_in]=[1,2,3] weights The shape of is [C_in,C_out,kernel_size] = [2,1,3] The shape of the output is [batch_size,C_out,L_out] = [1,1,5] In this case , We set up
batch_size=1, C_in=2, L_in=3, That is, the output of direct convolution o=3 C_out=1, k=3 p=0 s=1
among C_out It is essentially the number of output characteristic graphs , We order 1, So the result only needs to output a feature graph .
Analyzing the results is not difficult to find
o = tensor([[[ 5.5000, 18.7000, 41.8000, 42.9000, 29.7000]]]) 5.5 = (1+4)*1.1 18.7= (2+5)*1.1+(1+4)*2.2 41.8= (3+6)*1.1+(2+5)*2.2+(1+4)*3.3 42.9= (3+6)*2.2+(2+5)*3.3 29.7= (3+6)*3.3
The corresponding scanning mode is
This is a very strange point , weights Our definition is [ 1.1 , 2.2 , 3.3 ] [1.1,2.2,3.3] [1.1,2.2,3.3], When you go in and calculate , Inside it reverse Become [ 3.3 , 2.2 , 1.1 ] [3.3,2.2,1.1] [3.3,2.2,1.1]
Of course, we usually use transpose convolution , Most of them are learned by random initialization parameters , This reverse Also does not affect the . But if it's fixed weights, When you manually control transpose convolution , This reverse It's very noteworthy . In the use of torch Be careful when you .
4.2 stride Verification of the correctness of processing
inputs = torch.Tensor([[1,2,3],[4,5,6]]).unsqueeze(0) #(1,2,3)
weights = torch.Tensor([1.1,2.2,3.3]).view(1,1,-1).repeat(2,1,1) #(2,5,k=3)
print(inputs.shape)
print(weights.shape)
o =F.conv_transpose1d(inputs, weights,padding=0,stride=2)
print(o.shape)
print(o)Output
torch.Size([1, 2, 3])
torch.Size([2, 1, 3])
torch.Size([1, 1, 7])
tensor([[[ 5.5000, 11.0000, 24.2000, 15.4000, 33.0000, 19.8000, 29.7000]]])In this example, we set p=0,s=2. It's like we're here 3.2 As you guessed in , Input o=3, Be first stride=2 Processing becomes i’=5, Then proceed k=3、 In steps of 1 Dislocation scanning , The last output shape o’=7.
It can also be verified numerically that it is correct . “ Due to space limitation , I can't write here .” Please verify by yourself .
、
4.3 padding Validation of ablation accuracy
stay 4.2 Based on the code , hold padding Change to 1 that will do .
oinputs = torch.Tensor([[1,2,3],[4,5,6]]).unsqueeze(0) #(1,2,3)
weights = torch.Tensor([1.1,2.2,3.3]).view(1,1,-1).repeat(2,1,1)
print(inputs.shape)
print(weights.shape)
o =F.conv_transpose1d(inputs, weights,padding=1,stride=2)
print(o.shape)
print(o)Output
torch.Size([1, 2, 3])
torch.Size([2, 1, 3])
torch.Size([1, 1, 5])
tensor([[[11.0000, 24.2000, 15.4000, 33.0000, 19.8000]]])be aware , 4.2 The output in is [ 5.5 , 11 , 24.2 , 15.4 , 33 , 19.8 , 29.7 ] [5.5, 11, 24.2, 15.4, 33, 19.8, 29.7] [5.5,11,24.2,15.4,33,19.8,29.7], common 7 individual . 4.3 The output in is [ 11 , 24.2 , 15.4 , 33 , 19.8 ] [11, 24.2, 15.4, 33, 19.8] [11,24.2,15.4,33,19.8], common 5 individual . We set up p=1 after , The output is in 4.2 On the basis of , The two ends are eliminated 1 It's worth , obtain o’=5.
thus , We proved it perfectly In the third quarter All my conjectures .
The grass ( A plant ) I suddenly realized that , This convolution kernel is called transpose conv filter. The punctuation of this thing may not be (transpose conv) filter, namely filter of transpose conv, It is transpose (conv filter), namely conv filter in transpose form? Is it because of this , therefore weights The entrance must be transpose??? therefore 1.1,2.2,3.3 Reversed . metaphysics ...
5、 ... and 、 Yes stride Supplement to processing
Let's come back and talk about In the third quarter The problems left over from the project , Why? o=3,s=2,k=3,p=1 when ,i’=5, instead of 6.
Review the background significance of transpose convolution again , We hope that shape Direct convolution of the upper reduction input. Think about it ,i= How many hours , after k=3,p=1,s=2 Direct convolution of , Can get o=3?
You can already see in the code in Section 4 , What we finally restored o’=5.
It is clear that , If i=5,k=3,p=1,s=2, We can get o=3 Of .
The key is that the direct convolution has an implicit operation , if [ i + 2 ∗ p − ( k − 1 ) ] % s ! = 0 [ i+2*p-(k-1) ]\%s !=0 [i+2∗p−(k−1)]%s!=0 Direct convolution , Will do extra padding operation .
In this case ,i=5,k=3,p=1,s=2, To calculate the i + 2 ∗ p − ( k − 1 ) = 5 + 2 ∗ 1 − ( 3 − 1 ) = 5 i+2*p-(k-1)=5+2*1-(3-1)=5 i+2∗p−(k−1)=5+2∗1−(3−1)=5, Can not be s = 2 s=2 s=2 to be divisible by . therefore 5 Additional pad once , obtain 6, then o = 6 / 2 = 3 o=6/2=3 o=6/2=3.
Empathy , If i=6,k=3,p=1,s=2, You can get o=3.
in other words , In satisfying our given background meaning “ Input that wants transpose convolution to restore direct convolution shape” On the basis of , {o=3,k=3,p=1,s=2} This set of conditions , Yes 2 A solution ,i=5 or i=6.
But in the computer , We can't make transpose_conv_layer Output 2 A solution , We can only find a definite solution , Only in this way can the calculation be feasible . So this kind of “ Inside pad” Of stride Processing mode , It's essentially a Man made rules . Although a given condition may correspond to multiple solutions , But we always take the smallest solution .
and , such stride Processing mode , It looks beautiful, doesn't it .
6、 ... and 、corresponding direct conv
We review the formula obtained in Section 2 i = o ∗ s + ( k − 1 ) − 2 ∗ p = o ∗ s + ( k − 1 ) + 2 ∗ ( − p ) i= o*s+(k-1) -2*p \\ =o*s +(k-1) +2*(-p) i=o∗s+(k−1)−2∗p=o∗s+(k−1)+2∗(−p) Continue to deform to get i = o ∗ s + 2 ∗ ( k − 1 ) − ( k − 1 ) + 2 ∗ ( − p ) = o ∗ s + 2 ∗ ( k − 1 − p ) − ( k − 1 ) i =o*s+2*(k-1)-(k-1)+2*(-p)\\ =o*s+2*(k-1-p) – (k-1) i=o∗s+2∗(k−1)−(k−1)+2∗(−p)=o∗s+2∗(k−1−p)−(k−1)
contrast Theorem 2.2 The formula of o ′ ′ = [ i ′ ′ + 2 ∗ p ′ ′ − ( k ′ ′ − 1 ) ] / s ′ ′ o”=[i”+2*p”-(k”-1)]/s” o′′=[i′′+2∗p′′−(k′′−1)]/s′′
It's not hard to get out ,corresponding direct conv If you want the output to restore the input perfectly , Need to make i ′ ′ = o ∗ s i”=o*s i′′=o∗s k ′ ′ = k k”=k k′′=k s ′ ′ = 1 s”=1 s′′=1 p ′ ′ = k − 1 − p p”=k-1-p p′′=k−1−p To get it o ′ ′ = = i o”==i o′′==i Of course , We will also meet The fourth quarter, Multi solution problems discussed in . So the same thing , We are also right i ′ ′ = o ∗ s i”=o*s i′′=o∗s Fine tuning , Change to i ′ ′ = ( o − 1 ) ∗ s + 1 i”=(o-1)*s+1 i′′=(o−1)∗s+1, In this way, we can always get the minimum value of multiple solutions .
What is the meaning of this section ? I don't know either . But the author of the paper in the second section is very keen to discuss this .( Stand hand )
He has an interesting conclusion in the literal sense . Let's agree s=1. If p=0, You can get p ′ ′ = k − 1 p”=k-1 p′′=k−1, The author called this p ′ ′ p” p′′ by fully-padding. So we can say .
non-padding Of direct conv, Corresponding corresponding direct conv yes fully-padding Of .
By contraries , If you want to let p ′ ′ = 0 p”=0 p′′=0, Need p = k − 1 p=k-1 p=k−1 So you can say
fully-padding Of direct conv, Corresponding corresponding direct conv yes non-padding Of .
It's probably just such a use .
Publisher : Full stack programmer stack length , Reprint please indicate the source :https://javaforall.cn/151942.html Link to the original text :https://javaforall.cn
边栏推荐
- Step by step import RHEL image to Tencent cloud
- How does the effective date of SAP PP ECM affect the work order?
- Convert text to hexadecimal, and reverse
- 使用阿里云RDS for SQL Server性能洞察优化数据库负载-初识性能洞察
- 基于STM32的MD5校验
- D. Solve The Maze(思维+bfs)Codeforces Round #648 (Div. 2)
- Intelij 中的 Database Tools可以连接但是无法显示SCHEMA, TABLES
- 【应用推荐】最近大火的Apifox & Apipost 上手体验与选型建议
- [application recommendation] the hands-on experience and model selection suggestions of apifox & apipost in the recent fire
- Remain true to our original aspiration
猜你喜欢

几种常见的DoS攻击

Nifi from introduction to practice (nanny level tutorial) - environment

一文理解OpenStack网络

【面试高频题】难度 3/5,可直接构造的序列 DP 题

Wechat official account debugging and natapp environment building

存在安全隐患 部分冒险家混动版将召回

MySQL binlog

构建Go命令行程序工具链

Logging is not as simple as you think

Intelij 中的 Database Tools可以连接但是无法显示SCHEMA, TABLES
随机推荐
How to obtain ECS metadata
Most common usage of vim editor
great! The novel website project is completely open source
April 23, 2021: there are n cities in the TSP problem, and there is a distance between any two cities
Reference to junit5 test framework in gradle
实现领域驱动设计 - 使用ABP框架 - 领域逻辑 & 应用逻辑
一文详解JackSon配置信息
Several characteristics of pharmaceutical industry
一文理解OpenStack网络
Flink Kubernetes Application部署
The first in China! Tencent cloud key management system passes password application verification test
Intelij 中的 Database Tools可以连接但是无法显示SCHEMA, TABLES
不忘初心
2021-04-22: given many line segments, each line segment has two numbers [start, end],
Istio practical tips: Customize Max_ body_ size
Nature刊登量子计算重大进展:有史以来第一个量子集成电路实现
How to expand disk space on AWS host
2021-05-04: given a non negative integer C, you need to judge whether there are two integers a and B, so that a*a+b*b=c.
【Prometheus】4. Monitoring cases
Leetcode 139. Mot break word Split (medium)