当前位置:网站首页>Pytorch learning notes (VII) ------------------ vision transformer

Pytorch learning notes (VII) ------------------ vision transformer

2022-06-25 02:35:00 Clear thinking

Catalog

One 、Patch and Linear map

Two 、Adding classification token

3、 ... and 、Positional encoding

Four 、LN, MSA  and  Residual Connection

5、 ... and 、LN、MLP and Residual Connection

6、 ... and 、Classification MLP


Preface :vision transformer(vit) since Dosovitskiy Since their introduction , It has been playing a leading role in the field of computer vision , In most cases, it exceeds the traditional convolutional neural network (cnn)Transformer It is actually in naturallanguageprocessing (NLP) field , and vit The whole idea of NLP A big difference is no different , It is to divide a complete picture into several token, And I'll put these token Input into the network , Be similar to NLP The input of the statement in , These are separated token It is equivalent to every little word

This is Vision Transformers for Remote Sensing Image Classification Pictures published in , Let me borrow it

Through this picture , You can see a Be separated from x1-x9 9 A picture , And they are of equal length . These sub images are linearly embedded , These sub images are now just a one-dimensional vector , You can also see these pictures from x1-x9 It is separated from the original picture in order , That's important , after , In these token That is, add position information to the vector , Through these subgraphs, the network can restore the original image

After embedding the location information , these tokens And a for classification token Pass together to transformer encoder in , This is why when data is passed in +1, This 1 It's classification token. In this transformer encoder Contains a layer of normalization (LN), Bulls pay attention to themselves (MSA) And a residual connection (resdiual connection), And then the second one LN, A multilayer perceptron (MLP), A residual . Generally speaking ,encoder The inner block can be repeated many times , Be similar to Resnet. Last , A for classification MLP Block to classify the special classification marks that were originally passed in , It's a sort of thing .

Now look back at the picture above , Do you think your mind is a little more open

One 、Patch and Linear map

The first question is how to change a picture into an English sentence , The author's method is to divide it into several subgraphs , And map to the vector according to the position sequence

for instance , Here's a picture 3*224*224 Pictures of the (3 Number of channels RGB) We can divide it into 14*14 Of patch, every last patch The size is 16*16

(N,C,H,W)→(N, 3, 224, 224)→ (N, pathes, patch_dim)  → (N, 14*14, 16*16)

Now enter 3*224*224 The picture of becomes (196, 256), Every patch The dimension of is 16*16, We have patch Each sub picture can be fed back through linear mapping , also , Linear mapping can be mapped to any vector , Call it the hidden dimension , Here again , We can 256 It maps to 8  256→8, Note that the mapped dimensions should be divisible

Two 、Adding classification token

I said before stay tokens Pass in transformer encoder A category should be added to the token, Its purpose is to capture information about other tags , This will in MSA Occur in the . When all images are transferred in , We can just use this one classification token To classify images

Or just 3*224*224 Example , The above said

(N, 196, 256)→(N, 196+1, 256)

This way 1 It's classification token

3、 ... and 、Positional encoding

When the network receives each of these patch Input , How does it know each patch The position in the original image

Vaswani The research of et al , You can do this simply by adding sine and cosine waves

meanwhile , The tag size is (N, 197, 256) Ahead N Will be (197, 256) This location code is repeated N Time

Four 、LN, MSA  and  Residual Connection

LN: Given an input , Subtract the mean and divide by the standard deviation

MSA: Put each one patch Mapping to 3 Different vectors :q,k and  v, After mapping , adopt q And k And then divide by dim The square root of ,softmax These results ( Attention point ), Finally, match each attention cue to v Multiply , Final addition ( It feels boring )

meanwhile , Create a different number for each self - attention header Q,K,V Mapping function

Let's use an example to illustrate

(N, 197, 256)→(N, 197, 16, 16)→ nn.Linear(16, 16) → (N, 197, 256)

The input is (N, 197,256), Through long attention ( This is used here. 16 Head ) Change the vector to (N, 197, 16, 16), At this point, we need another nn.Linear(16, 16) To map it to (N, 197, 256)

Residual Connection: residual

It was said before that in the incoming transformer encoder Will add a classification token, Those token How to get other token The information of , after LN,MSA And residual operation , This classification token There are other things token Information about .

5、 ... and 、LN、MLP and Residual Connection

Previously mentioned in transformer enconder The first step in the block is to add LN, MSA And the residuals , Here is the second step , Join in LN、    MLP and residual

6、 ... and 、Classification MLP

After a series of operations , Our network has many weight indexes and data , stay MLP in , We can N Only classification marks are extracted from three sequences (token), And use token To get the classification

for example , Every one we chose before token yes 16dim Vector , The categories to be classified are 5 class , We can use MLP Create a 16*5 Matrix , And use softmax Function activation

Whole vit The construction of the network has been completed

PY The code is as follows

class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches=14, hidden_d=8, n_heads=2, out_d=5, device=None): 
        
        super(MyViT, self).__init__()
        self.device = device

        
        self.input_shape = input_shape
        self.n_patches = n_patches
        self.n_heads = n_heads
        assert input_shape[1] % n_patches == 0, 
        assert input_shape[2] % n_patches == 0, 
        self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
        self.hidden_d = hidden_d

        # 1) Linear mapper
        self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        # (In forward method)

        # 4a) Layer normalization 1
        self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 4b) Multi-head Self Attention (MSA) and classification token
        self.msa = MyMSA(self.hidden_d, n_heads)

        # 5a) Layer normalization 2
        self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 5b) Encoder MLP
        self.enc_mlp = nn.Sequential(
            nn.Linear(self.hidden_d, self.hidden_d),
            nn.ReLU()
        )

        # 6) Classification MLP
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
       
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)

        
        tokens = self.linear_mapper(patches)

       
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

       
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1).to(self.device)

       
        out = tokens + self.msa(self.ln1(tokens))

        
        out = out + self.enc_mlp(self.ln2(out))
       
        
        out = out[:, 0]

        return self.mlp(out)

def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
       
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

This paper refers to https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0cicon-default.png?t=M5H6https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c

You are welcome to correct the shortcomings , Source code can be private letters or comments , See it and reply

原网站

版权声明
本文为[Clear thinking]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/176/202206242301421728.html