当前位置:网站首页>Pytorch learning notes (VII) ------------------ vision transformer
Pytorch learning notes (VII) ------------------ vision transformer
2022-06-25 02:35:00 【Clear thinking】
Catalog
Two 、Adding classification token
3、 ... and 、Positional encoding
Four 、LN, MSA and Residual Connection
5、 ... and 、LN、MLP and Residual Connection
6、 ... and 、Classification MLP
Preface :vision transformer(vit) since Dosovitskiy Since their introduction , It has been playing a leading role in the field of computer vision , In most cases, it exceeds the traditional convolutional neural network (cnn)Transformer It is actually in naturallanguageprocessing (NLP) field , and vit The whole idea of NLP A big difference is no different , It is to divide a complete picture into several token, And I'll put these token Input into the network , Be similar to NLP The input of the statement in , These are separated token It is equivalent to every little word
This is Vision Transformers for Remote Sensing Image Classification Pictures published in , Let me borrow it
Through this picture , You can see a Be separated from x1-x9 9 A picture , And they are of equal length . These sub images are linearly embedded , These sub images are now just a one-dimensional vector , You can also see these pictures from x1-x9 It is separated from the original picture in order , That's important , after , In these token That is, add position information to the vector , Through these subgraphs, the network can restore the original image
After embedding the location information , these tokens And a for classification token Pass together to transformer encoder in , This is why when data is passed in +1, This 1 It's classification token. In this transformer encoder Contains a layer of normalization (LN), Bulls pay attention to themselves (MSA) And a residual connection (resdiual connection), And then the second one LN, A multilayer perceptron (MLP), A residual . Generally speaking ,encoder The inner block can be repeated many times , Be similar to Resnet. Last , A for classification MLP Block to classify the special classification marks that were originally passed in , It's a sort of thing .
Now look back at the picture above , Do you think your mind is a little more open
One 、Patch and Linear map
The first question is how to change a picture into an English sentence , The author's method is to divide it into several subgraphs , And map to the vector according to the position sequence
for instance , Here's a picture 3*224*224 Pictures of the (3 Number of channels RGB) We can divide it into 14*14 Of patch, every last patch The size is 16*16
(N,C,H,W)→(N, 3, 224, 224)→ (N, pathes, patch_dim) → (N, 14*14, 16*16)
Now enter 3*224*224 The picture of becomes (196, 256), Every patch The dimension of is 16*16, We have patch Each sub picture can be fed back through linear mapping , also , Linear mapping can be mapped to any vector , Call it the hidden dimension , Here again , We can 256 It maps to 8 256→8, Note that the mapped dimensions should be divisible
Two 、Adding classification token
I said before stay tokens Pass in transformer encoder A category should be added to the token, Its purpose is to capture information about other tags , This will in MSA Occur in the . When all images are transferred in , We can just use this one classification token To classify images
Or just 3*224*224 Example , The above said
(N, 196, 256)→(N, 196+1, 256)
This way 1 It's classification token
3、 ... and 、Positional encoding
When the network receives each of these patch Input , How does it know each patch The position in the original image
Vaswani The research of et al , You can do this simply by adding sine and cosine waves
meanwhile , The tag size is (N, 197, 256) Ahead N Will be (197, 256) This location code is repeated N Time
Four 、LN, MSA and Residual Connection
LN: Given an input , Subtract the mean and divide by the standard deviation
MSA: Put each one patch Mapping to 3 Different vectors :q,k and v, After mapping , adopt q And k And then divide by dim The square root of ,softmax These results ( Attention point ), Finally, match each attention cue to v Multiply , Final addition ( It feels boring )
meanwhile , Create a different number for each self - attention header Q,K,V Mapping function
Let's use an example to illustrate
(N, 197, 256)→(N, 197, 16, 16)→ nn.Linear(16, 16) → (N, 197, 256)
The input is (N, 197,256), Through long attention ( This is used here. 16 Head ) Change the vector to (N, 197, 16, 16), At this point, we need another nn.Linear(16, 16) To map it to (N, 197, 256)
Residual Connection: residual
It was said before that in the incoming transformer encoder Will add a classification token, Those token How to get other token The information of , after LN,MSA And residual operation , This classification token There are other things token Information about .
5、 ... and 、LN、MLP and Residual Connection
Previously mentioned in transformer enconder The first step in the block is to add LN, MSA And the residuals , Here is the second step , Join in LN、 MLP and residual
6、 ... and 、Classification MLP
After a series of operations , Our network has many weight indexes and data , stay MLP in , We can N Only classification marks are extracted from three sequences (token), And use token To get the classification
for example , Every one we chose before token yes 16dim Vector , The categories to be classified are 5 class , We can use MLP Create a 16*5 Matrix , And use softmax Function activation
Whole vit The construction of the network has been completed
PY The code is as follows
class MyViT(nn.Module):
def __init__(self, input_shape, n_patches=14, hidden_d=8, n_heads=2, out_d=5, device=None):
super(MyViT, self).__init__()
self.device = device
self.input_shape = input_shape
self.n_patches = n_patches
self.n_heads = n_heads
assert input_shape[1] % n_patches == 0,
assert input_shape[2] % n_patches == 0,
self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
self.hidden_d = hidden_d
# 1) Linear mapper
self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)
# 2) Classification token
self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
# 3) Positional embedding
# (In forward method)
# 4a) Layer normalization 1
self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))
# 4b) Multi-head Self Attention (MSA) and classification token
self.msa = MyMSA(self.hidden_d, n_heads)
# 5a) Layer normalization 2
self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))
# 5b) Encoder MLP
self.enc_mlp = nn.Sequential(
nn.Linear(self.hidden_d, self.hidden_d),
nn.ReLU()
)
# 6) Classification MLP
self.mlp = nn.Sequential(
nn.Linear(self.hidden_d, out_d),
nn.Softmax(dim=-1)
)
def forward(self, images):
n, c, w, h = images.shape
patches = images.reshape(n, self.n_patches ** 2, self.input_d)
tokens = self.linear_mapper(patches)
tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1).to(self.device)
out = tokens + self.msa(self.ln1(tokens))
out = out + self.enc_mlp(self.ln2(out))
out = out[:, 0]
return self.mlp(out)
def get_positional_embeddings(sequence_length, d):
result = torch.ones(sequence_length, d)
for i in range(sequence_length):
for j in range(d):
result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
return result
class MyMSA(nn.Module):
def __init__(self, d, n_heads=2):
super(MyMSA, self).__init__()
self.d = d
self.n_heads = n_heads
assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"
d_head = int(d / n_heads)
self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
self.d_head = d_head
self.softmax = nn.Softmax(dim=-1)
def forward(self, sequences):
result = []
for sequence in sequences:
seq_result = []
for head in range(self.n_heads):
q_mapping = self.q_mappings[head]
k_mapping = self.k_mappings[head]
v_mapping = self.v_mappings[head]
seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
seq_result.append(attention @ v)
result.append(torch.hstack(seq_result))
return torch.cat([torch.unsqueeze(r, dim=0) for r in result])
You are welcome to correct the shortcomings , Source code can be private letters or comments , See it and reply
边栏推荐
- Distributed transaction solutions and code implementation
- Please run IDA with elevated permissons for local debugging.
- E - average and median
- 华为、阿里等大厂程序员真的好找对象吗?
- DDD concept is complex and difficult to understand. How to design code implementation model in practice?
- It's 2022, and you still don't know what performance testing is?
- F - Spices(线性基)
- 当一个接口出现异常时候,你是如何分析异常的?
- Practice and Thinking on process memory
- 入坑机器学习:一,绪论
猜你喜欢
Of the seven levels of software testers, it is said that only 1% can achieve level 7
Redis
计网 | 【四 网络层】知识点及例题
1-6搭建Win7虚拟机环境
When they are in private, they have a sense of propriety
Is it out of reach to enter Ali as a tester? Here may be the answer you want
软件测试人员的7个等级,据说只有1%的人能做到级别7
Pit entry machine learning: I. Introduction
How to monitor the log through the easycvr interface to observe the platform streaming?
会自动化—10K,能做自动化—20K,你搞懂自动化测试没有?
随机推荐
计网 | 【四 网络层】知识点及例题
E - Average and Median(二分)
Exploring the mystery of C language program -- C language program compilation and preprocessing
QT package the EXE file to solve the problem that "the program input point \u zdapvj cannot be located in the dynamic link library qt5cored.dll"
jwt
Post competition summary of kaggle patent matching competition
LINQ 查询(3)
中信证券手机开户是靠谱的吗?安全吗
ARM汇编中的栈桢小结
Is the compass reliable? Is it safe to open a securities account?
Migrate Oracle database from windows system to Linux Oracle RAC cluster environment (4) -- modify the scanip of Oracle11g RAC cluster
转行软件测试2年了,给还在犹豫的女生一点建议
What is the reason for the disconnection of video playback due to the EHOME protocol access of easycvr platform?
Folding screen will become an important weapon for domestic mobile phones to share the apple market
把 Oracle 数据库从 Windows 系统迁移到 Linux Oracle Rac 集群环境(4)—— 修改 oracle11g rac 集群的 scanIP
[day 26] given the ascending array nums of n elements, find a function to find the subscript of target in nums | learn binary search
Can automate - 10k, can automate - 20K, do you understand automated testing?
AI服装生成,帮你完成服装设计的最后一步
ida中交叉引用的解析
How do the TMUX color palette work?