当前位置:网站首页>Mechanism and principle of multihead attention and masked attention
Mechanism and principle of multihead attention and masked attention
2022-07-25 05:42:00 【iioSnail】
List of articles
One 、 This article suggests that
Before reading this article , You need to understand it thoroughly first Self-Attention. I recommend reading another blog post Layers of analysis , Let you understand completely Self-Attention、MultiHead-Attention and Masked-Attention The mechanism and principle of . The content of this article is also in the above article , You can watch it together .
Two . MultiHead Attention
2.1 MultiHead Attention Theoretical explanation
stay Transformer Is used in MultiHead Attention, Actually, it's not the same thing Self Attention It's not a big difference . First clarify the following points , Then start to explain :
- MultiHead Of head No matter how many , Parameter quantities are It's the same . Not at all head many , There are many parameters .
- When MultiHead Of head by 1 when , and No Equivalent to Self Attetnion,MultiHead Attention and Self Attention It's something different
- MultiHead Attention Also used Self Attention Formula
- MultiHead except W q , W k , W v W^q, W^k, W^v Wq,Wk,Wv Outside three matrices , We need to define one more W o W^o Wo.
Okay , Know the above points , We can start to explain MultiHeadAttention 了 .
MultiHead Attention Most of the logic and Self Attention It's consistent , It is from finding Q,K,V And then began to change , So let's start from here .
Now we have solved Q, K, V matrix , about Self-Attention, We can already bring in the formula , Represented by images, it is :

For the sake of simplicity , The figure ignores Softmax and d k d_k dk The calculation of
and MultiHead Attention I did one thing before entering the formula , Namely Demolition , It follows “ Word vector dimension ” In this direction , take Q,K,V Split into multiple heads , As shown in the figure :

Here mine head The number of 4. Since it has been disassembled into several head, Then the following calculation , It's also their own head Calculate , As shown in the figure :

But it can be calculated in this way Attention Use Concat The effect of merging is not very good , So finally, we need to use an additional W o W^o Wo matrix , Yes Attention Do another linear transformation , As shown in the figure :

You can also see it here ,head The number is not the more the better . And why use MultiHead Attention,Transformer The explanation given is :Multi-head attention Allow the model to focus on the information of different representation subspaces from different locations . Anyway, it's better to use it than not .
2.2. Pytorch Realization MultiHead Attention
This code refers to the project annotated-transformer.
First, define a general Attention function :
def attention(query, key, value):
""" Calculation Attention Result . What is actually introduced here is Q,K,V, and Q,K,V The calculation of is put in the model , Please refer to the following MultiHeadedAttention class . there Q,K,V There are two kinds of Shape, If it is Self-Attention,Shape by (batch, The number of words , d_model), for example (1, 7, 128), namely batch_size by 1, A word of 7 Word , Every word 128 dimension But if it is Multi-Head Attention, be Shape by (batch, head Count , The number of words ,d_model/head Count ), for example (1, 8, 7, 16), namely Batch_size by 1,8 individual head, A word of 7 Word ,128/8=16. In this way, you can actually see , So-called MultiHead It's really just the 128 It's taken apart . stay Transformer in , Because of the use of MultiHead Attention, therefore Q,K,V Of Shape It will only be the second kind . """
# obtain d_model Value . The reason why we can get , Because query And the input of shape identical ,
# if Self-Attention, Then the last dimension is the dimension of word vector , That is to say d_model Value .
# if MultiHead Attention, Then the last dimension is d_model / h,h by head Count
d_k = query.size(-1)
# perform QK^T / √d_k
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Execute Softmax
# there p_attn It's a square array
# if Self Attention, be shape by (batch, The number of words , frequency ), for example (1, 7, 7)
# if MultiHead Attention, be shape by (batch, head Count , The number of words , The number of words )
p_attn = scores.softmax(dim=-1)
# Finally, multiply by V.
# about Self Attention Come on , result Shape by (batch, The number of words , d_model), This is the final result .
# But for the MultiHead Attention Come on , result Shape by (batch, head Count , The number of words ,d_model/head Count )
# And this is not the end result , Later, we will head Merge , Turn into (batch, The number of words , d_model). But this is MultiHeadAttention
# What to do .
return torch.matmul(p_attn, value)
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model):
""" h: head The number of """
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
# Definition W^q, W^k, W^v and W^o matrix .
# If you don't know why you use nn.Linear Define the matrix , You can refer to this article :
# https://blog.csdn.net/zhaohongfei_358/article/details/122797190
self.linears = [
nn.Linear(d_model, d_model),
nn.Linear(d_model, d_model),
nn.Linear(d_model, d_model),
nn.Linear(d_model, d_model),
]
def forward(self, x):
# obtain Batch Size
nbatches = x.size(0)
""" 1. Find out Q, K, V, This is for MultiHead Of Q,K,V, therefore Shape by (batch, head Count , The number of words ,d_model/head Count ) 1.1 First , By definition W^q,W^k,W^v Find out SelfAttention Of Q,K,V, here Q,K,V Of Shape by (batch, The number of words , d_model) The corresponding code is `linear(x)` 1.2 Split into bulls , the Shape from (batch, The number of words , d_model) Turn into (batch, The number of words , head Count ,d_model/head Count ). The corresponding code is `view(nbatches, -1, self.h, self.d_k)` 1.3 The final exchange “ The number of words ” and “head Count ” These two dimensions , take head Put the number in front , Final shape Turn into (batch, head Count , The number of words ,d_model/head Count ). The corresponding code is `transpose(1, 2)` """
query, key, value = [
linear(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for linear, x in zip(self.linears, (x, x, x))
]
""" 2. Find out Q,K,V after , adopt attention Function calculated Attention result , here x Of shape by (batch, head Count , The number of words ,d_model/head Count ) self.attn Of shape by (batch, head Count , The number of words , The number of words ) """
x = attention(
query, key, value
)
""" 3. Will be multiple head Merge again , the x Of shape from (batch, head Count , The number of words ,d_model/head Count ) And then it becomes (batch, The number of words ,d_model) 3.1 First , In exchange for “head Count ” and “ The number of words ”, These two dimensions , The result is (batch, The number of words , head Count , d_model/head Count ) The corresponding code is :`x.transpose(1, 2).contiguous()` 3.2 And then “head Count ” and “d_model/head Count ” These two dimensions merge , The result is (batch, The number of words ,d_model) """
x = (
x.transpose(1, 2)
.contiguous()
.view(nbatches, -1, self.h * self.d_k)
)
# Finally through W^o The matrix performs another linear transformation , Get the final result .
return self.linears[-1](x)
Next try to use :
# Definition 8 individual head, The dimension of the word vector is 512
model = MultiHeadedAttention(8, 512)
# Pass in a batch_size by 2, 7 Word , Each word is 512 dimension
x = torch.rand(2, 7, 512)
# Output Attention After the results of the
print(model(x).size())
Output is :
torch.Size([2, 7, 512])
3、 ... and . Masked Attention
3.1 Why use Mask Mask
stay Transformer Medium Decoder There is one of them. Masked MultiHead Attention. This section will explain it in detail .
First of all, let's review Attention Formula :
O n × d v = Attention ( Q n × d k , K n × d k , V n × d v ) = softmax ( Q n × d k K d k × n T d k ) V n × d v = A n × n ′ V n × d v \begin{aligned} O_{n\times d_v} = \text { Attention }(Q_{n\times d_k}, K_{n\times d_k}, V_{n\times d_v})&=\operatorname{softmax}\left(\frac{Q_{n\times d_k} K^{T}_{d_k\times n}}{\sqrt{d_k}}\right) V_{n\times d_v} \\\\ & = A'_{n\times n} V_{n\times d_v} \end{aligned} On×dv= Attention (Qn×dk,Kn×dk,Vn×dv)=softmax(dkQn×dkKdk×nT)Vn×dv=An×n′Vn×dv
among :
O n × d v = [ o 1 o 2 ⋮ o n ] , A n × n ′ = [ α 1 , 1 ′ α 2 , 1 ′ ⋯ α n , 1 ′ α 1 , 2 ′ α 2 , 2 ′ ⋯ α n , 2 ′ ⋮ ⋮ ⋮ α 1 , n ′ α 2 , n ′ ⋯ α n , n ′ ] , V n × d v = [ v 1 v 2 ⋮ v n ] O_{n\times d_v}= \begin{bmatrix} o_1\\ o_2\\ \vdots \\ o_n\\ \end{bmatrix},~~~~A'_{n\times n} = \begin{bmatrix} \alpha'_{1,1} & \alpha'_{2,1} & \cdots &\alpha'_{n,1} \\ \alpha'_{1,2} & \alpha'_{2,2} & \cdots &\alpha'_{n,2} \\ \vdots & \vdots & &\vdots \\ \alpha'_{1,n} & \alpha'_{2,n} & \cdots &\alpha'_{n,n} \\ \end{bmatrix}, ~~~~V_{n\times d_v}= \begin{bmatrix} v_1\\ v_2\\ \vdots \\ v_n\\ \end{bmatrix} On×dv=⎣⎡o1o2⋮on⎦⎤, An×n′=⎣⎡α1,1′α1,2′⋮α1,n′α2,1′α2,2′⋮α2,n′⋯⋯⋯αn,1′αn,2′⋮αn,n′⎦⎤, Vn×dv=⎣⎡v1v2⋮vn⎦⎤
hypothesis ( v 1 , v 2 , . . . v n ) (v_1, v_2, ... v_n) (v1,v2,...vn) Corresponding ( machine , device , learn , xi , really , good , play ) ( machine , device , learn , xi , really , good , play ) ( machine , device , learn , xi , really , good , play ). that ( o 1 , o 2 , . . . , o n ) (o_1, o_2, ..., o_n) (o1,o2,...,on) It corresponds to ( machine ′ , device ′ , learn ′ , xi ′ , really ′ , good ′ , play ′ ) ( machine ', device ', learn ', xi ', really ', good ', play ') ( machine ′, device ′, learn ′, xi ′, really ′, good ′, play ′). among machine ′ machine ' machine ′ contains v 1 v_1 v1 To v n v_n vn All attention information . And calculation machine ′ machine ' machine ′ At the time of the ( machine , device , . . . ) ( machine , device , ...) ( machine , device ,...) The weight of these words is A ′ A' A′ The first line of ( α 1 , 1 ′ , α 2 , 1 ′ , . . . ) (\alpha'_{1,1}, \alpha'_{2,1}, ...) (α1,1′,α2,1′,...).
If you recall the above , So let's take a look at Transformer Usage of , Suppose we want to use Transformer translate “Machine learning is fun” this sentence .
First , We will “Machine learning is fun” Send Encoder, Output a name Memory Of Tensor, As shown in the figure :

After that, we will Memory As Decoder An input to , Use Decoder forecast .Decoder Not all at once “ Machine learning is fun ” Say it , But one word, one word ( Or word by word , It depends on your way of word segmentation ), As shown in the figure :
Then , We will call again Decoder, This time it's incoming “<bos> machine ”:
By analogy , Until the last output <eos> end :

When Transformer Output <eos> when , The prediction is over .
Here we will find , about Decoder It is predicted word by word , So suppose we Decoder The input is “ machine learning ” when ,“ xi ” Words can only be seen in front “ Machine science ” Three words , So now for “ xi ” There are only “ machine learning ” Four word attention information .
however , For example, the last step is “<bos> Machine learning is fun ”, Still can't let “ xi ” See the words behind “ It's fun ” Three words , So use mask Cover it , Why is that ? as a result of : If you allow “ xi ” See the words behind , that “ xi ” The encoding of words will change .
Let's analyze :
At first we only introduced “ machine ”( Ignore bos), At this time to use attention Mechanism , take “ machine ” The word code is [ 0.13 , 0.73 , . . . ] [0.13, 0.73, ...] [0.13,0.73,...]
The second time , We introduced “ machine ”, At this time to use attention Mechanism , If we don't “ device ” Words covered , that “ machine ” The encoding of words will change , It is no longer [ 0.13 , 0.73 , . . . ] [0.13, 0.73, ...] [0.13,0.73,...] 了 , Maybe it becomes [ 0.95 , 0.81 , . . . ] [0.95, 0.81, ...] [0.95,0.81,...].
This will lead to the first “ machine ” The code of the word is [ 0.13 , 0.73 , . . . ] [0.13, 0.73, ...] [0.13,0.73,...], The second time it became [ 0.95 , 0.81 , . . . ] [0.95, 0.81, ...] [0.95,0.81,...], This may cause network problems . So in order not to let “ machine ” The encoding of words changes , So we use mask, Cover up “ machine ” Words after words , That is, even if he can attention The words after , Don't let him attention.
Many articles explain Mask To prevent Transformer Disclose the following information that it should not see during training , I think this explanation is wrong :①Transformer Of Decoder There is no distinction between training and testing , So if it is to prevent the training from divulging the following information , Then why do we have to mask when reasoning ? ② Pass to Decoder It's all about Decoder I reasoned it out by myself , It reasoned out by itself. Don't let it see , It is said to prevent information leakage , This is not bullshit .
Of course , This is also my personal view , Maybe I misunderstood it
3.2 How to do mask Mask
To mask , Only need to scores Just do it , That is to say A n × n ′ A'_{n\times n} An×n′ . Direct example :
for the first time , We only have v 1 v_1 v1 Variable , So it is :
[ o 1 ] = [ α 1 , 1 ′ ] ⋅ [ v 1 ] \begin{bmatrix} o_1\\ \end{bmatrix}=\begin{bmatrix} \alpha'_{1,1} \end{bmatrix} \cdot \begin{bmatrix} v_1\\ \end{bmatrix} [o1]=[α1,1′]⋅[v1]
The second time , We have v 1 , v 2 v_1, v_2 v1,v2 Two variables :
[ o 1 o 2 ] = [ α 1 , 1 ′ α 2 , 1 ′ α 1 , 2 ′ α 2 , 2 ′ ] [ v 1 v 2 ] \begin{bmatrix} o_1\\ o_2 \end{bmatrix} = \begin{bmatrix} \alpha'_{1,1} & \alpha'_{2,1} \\ \alpha'_{1,2} & \alpha'_{2,2} \end{bmatrix} \begin{bmatrix} v_1\\ v_2\\ \end{bmatrix} [o1o2]=[α1,1′α1,2′α2,1′α2,2′][v1v2]
At this time, if we are wrong A 2 × 2 ′ A'_{2\times 2} A2×2′ Mask , o 1 o_1 o1 The value of will change ( The first is α 1 , 1 ′ v 1 \alpha'_{1,1}v_1 α1,1′v1, The second time it became α 1 , 1 ′ v 1 + α 2 , 1 ′ v 2 \alpha'_{1,1}v_1+\alpha'_{2,1}v_2 α1,1′v1+α2,1′v2). Look at it this way , We just need to put α 2 , 1 ′ \alpha'_{2,1} α2,1′ Cover it , This will ensure twice o 1 o_1 o1 It's the same .
So the second time is actually :
[ o 1 o 2 ] = [ α 1 , 1 ′ 0 α 1 , 2 ′ α 2 , 2 ′ ] [ v 1 v 2 ] \begin{bmatrix} o_1\\ o_2 \end{bmatrix} = \begin{bmatrix} \alpha'_{1,1} & 0 \\ \alpha'_{1,2} & \alpha'_{2,2} \end{bmatrix} \begin{bmatrix} v_1\\ v_2\\ \end{bmatrix} [o1o2]=[α1,1′α1,2′0α2,2′][v1v2]
By analogy , If we go to the n n n When the time , It should become :
[ o 1 o 2 ⋮ o n ] = [ α 1 , 1 ′ 0 ⋯ 0 α 1 , 2 ′ α 2 , 2 ′ ⋯ 0 ⋮ ⋮ ⋮ α 1 , n ′ α 2 , n ′ ⋯ α n , n ′ ] [ v 1 v 2 ⋮ v n ] \begin{bmatrix} o_1\\ o_2\\ \vdots \\ o_n\\ \end{bmatrix} = \begin{bmatrix} \alpha'_{1,1} & 0 & \cdots & 0 \\ \alpha'_{1,2} & \alpha'_{2,2} & \cdots & 0 \\ \vdots & \vdots & &\vdots \\ \alpha'_{1,n} & \alpha'_{2,n} & \cdots &\alpha'_{n,n} \\ \end{bmatrix} \begin{bmatrix} v_1\\ v_2\\ \vdots \\ v_n\\ \end{bmatrix} ⎣⎡o1o2⋮on⎦⎤=⎣⎡α1,1′α1,2′⋮α1,n′0α2,2′⋮α2,n′⋯⋯⋯00⋮αn,n′⎦⎤⎣⎡v1v2⋮vn⎦⎤
3.3 Why is it negative infinity instead of 0
According to the above ,mask The mask is 0, But why is the mask in the source code − 1 e 9 -1e9 −1e9 ( Negative infinity ).Attention Part of the source code is as follows :
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = scores.softmax(dim=-1)
Look at it carefully , What we said above A n × n ′ A'_{n\times n} An×n′ What is it? , yes softmax After that . And in the source code , The source code is softmax Mask before , So it's negative infinity , Because it will be negative infinite softmax And then it becomes 0 了 .
边栏推荐
- 新时代生产力工具——FlowUs 息流全方位评测
- HTB-Beep
- Analyzing the principle of DNS resolution in kubernetes cluster
- 2021年ICPC陕西省赛热身赛 B.CODE(位运算)
- 批量下载视频小技巧
- Microservice configuration center Nacos
- Leetcode 204. count prime numbers (wonderful)
- Linear algebra (3)
- Y76. Chapter IV Prometheus large factory monitoring system and practice -- Prometheus advanced (VII)
- Continuous maximum sum and judgement palindrome
猜你喜欢

微服务 - Hystrix 熔断器

Leetcode 237. delete nodes in the linked list

Talk about how redis handles requests

微服务 - 网关Gateway组件

传输线理论之相速、相位等的概念

剑指 Offer 05. 替换空格

sqlilabs less-29

Productivity tool in the new era -- flowus information flow comprehensive evaluation

Easyrecovery free data recovery tool is easy to operate and restore data with one click

LeetCode 15:三数之和
随机推荐
MATLAB作图实例:5:双轴图
Interface idempotency
Atof(), atoi(), atol() functions [detailed]
Big talk · book sharing | Haas Internet of things device cloud integrated development framework
Microservice configuration center Nacos
50:第五章:开发admin管理服务:3:开发【查询admin用户名是否已存在,接口】;(这个接口需要登录时才能调用;所以我们编写了拦截器,让其拦截请求,判断用户是否是登录状态;)
Working principle and precautions of bubble water level gauge
Array programming problem of CSDN programming challenge
LeetCode第302场周赛
Please stop using system The currenttimemillis() statistical code is time-consuming, which is really too low!
HTB-Optimum
Adaptation dynamics | in June, sequoiadb completed mutual certification with five products
The difference between function and task in SystemVerilog
Base64 (conversion between string and Base64 string)
Linear algebra (3)
Single sign on (one sign on, available everywhere)
R language Visual scatter diagram, geom using ggrep package_ text_ The repl function avoids overlapping labels between data points (set the hJust parameter to show that labels of all data points are a
Summary of common attributes of flex layout
暑期总结2
Unity accesses chartandgraph chart plug-in