当前位置:网站首页>(pytorch advanced path 2) word embedding and position embedding
(pytorch advanced path 2) word embedding and position embedding
2022-06-22 15:44:00 【likeGhee】
List of articles
word embedding
embedding The function is to discretize the high dimension token Mapping to low dimensional dense token
Suppose a task background : English translation German , First we need to construct an English sentence Source sequence source sentence And a target sequence target sentence German , Source sequence src_seq and Target sequence tgt_seq
How to build a sequence ? Come into contact with NLP It should be no stranger , The characters of the sequence are represented by a word list dict In the form of an index of
Specified sequence length , Assuming that src and tgt len
# %%
import numpy
import torch as T
import torch.nn as nn
import torch.nn.functional as F
# %%
# Suppose there are two sentences
batch_size = 2
# Each sentence is... Long 2~5
src_len = T.randint(2, 5, (batch_size, ))
tgt_len = T.randint(2, 5, (batch_size, ))
# Convenient research , We write dead
src_len = T.Tensor([2, 4]).to(T.int32)
tgt_len = T.Tensor([4, 3]).to(T.int32)
print(src_len)
print(tgt_len)
Output results tensor([2, 4]),tensor([4, 2]), explain src The length of the sentence is 2 and 4,tgt The length of the sentence is 4 and 2, There are two sentences 
Then we build seq, hypothesis src and tgt dict The maximum sequence number is 8, The maximum number of words is 8, Random generation seq Put in list, In order to keep the sentence length consistent , I still need padding operation , Use functional Inside pad function , After that, the sequence uses unsqueeze、cat convert to [batch_size, max_len] Formal tensor As batch Input
# %%
# Word list size
max_source_word_num = 8
max_target_word_num = 8
# Maximum sequence length
max_source_seq_len = 5
max_target_seq_len = 5
# Generate seq
src_seq = [T.randint(1, max_source_word_num, (L,)) for L in src_len]
# padding
src_seq = list(map(lambda x: F.pad(x, (0, max_source_seq_len - len(x))), src_seq))
# L one dimension is convenient for us to splice
src_seq = list(map(lambda x: T.unsqueeze(x, 0), src_seq))
# Splicing
src_seq = T.cat(src_seq, 0)
print(src_seq)
tgt_seq = [F.pad(T.randint(1, max_target_word_num, (L,)), (0, max_target_seq_len-L)) for L in tgt_len]
tgt_seq = list(map(lambda x: T.unsqueeze(x, 0), tgt_seq))
tgt_seq = T.cat(tgt_seq, 0)
print(tgt_seq)
Output results :
Input complete , The middle part embedding, Use pytorch Of API,nn.Embedding
The first parameter num_embeddings, Number of words , We usually take the maximum vocabulary size + 1,padding Of 0 You count
The second parameter embedding_dim, Word vector dimension , It's usually 512, It is convenient for us to take 8
# %%
model_dim = 8
# structure embedding table
src_embedding_table = nn.Embedding(max_source_word_num + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_target_word_num + 1, model_dim)
print(src_embedding_table.weight.size())
# Test it forward
src_embedding = src_embedding_table(src_seq)
print(src_embedding.size())

position embedding
Attention is all you need There is PE(position embedding) The expression of , The general idea is to convert the position information of the word in the sentence into a vector , And again WE(word embedding) Add up 
First PE It's a two-dimensional matrix :[max_len, dim], The maximum length can be the same as max_source_seq_len Agreement , It is stipulated here max_position_len=5
PE A matrix can be thought of as a multiplication of two matrices ( Not matrix multiplication but element by element multiplication ), A matrix is pos(/ On the left ), Another matrix is i(/ On the right ), Odd sequence and even sequence are multiplied respectively sin and cos
# %%
max_position_len = 5
pos_matrix = T.arange(max_position_len).reshape((-1, 1))
print(pos_matrix)
# Because there are odd and even numbers , So the interval is 2
i_matrix = T.pow(10000, T.arange(0, model_dim, 2).reshape([1, -1]) / model_dim)
print(i_matrix)
# structure embedding matrix
pe_embedding_table = T.zeros([max_position_len, model_dim])
# Even columns , The line doesn't change ,0::2 Even columns , It means subscript from 0 Start , Until the last , Take the step of 2 All elements of
pe_embedding_table[:, 0::2] = T.sin(pos_matrix / i_matrix)
# Odd columns
pe_embedding_table[:, 1::2] = T.cos(pos_matrix / i_matrix)
print(pe_embedding_table)
structure nn.Module, Replace weight
# %%
# rewrite nn Module weight Way to create pe embedding
pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)
print(pe_embedding.weight.size())
Construct input , We need to pass in the location index , Nature is to use range Operation , At last, calculate PE
# %%
# Construct location index
src_pos = T.cat([T.unsqueeze(T.arange(max_position_len), 0) for _ in src_len] , 0)
print(src_pos)
tgt_pos = T.cat([T.unsqueeze(T.arange(max_position_len), 0) for _ in tgt_len] , 0)
# forword Calculation src-pe
src_pe_embedding = pe_embedding(src_pos)
print(src_pe_embedding.size())
Complete code
# %%
from pyexpat import model
from turtle import pos
import numpy
import torch as T
import torch.nn as nn
import torch.nn.functional as F
# %%
# Suppose there are two sentences
batch_size = 2
# Each sentence is... Long 2~5
src_len = T.randint(2, 5, (batch_size, ))
tgt_len = T.randint(2, 5, (batch_size, ))
print(src_len)
print(tgt_len)
# Convenient research , We write dead
src_len = T.Tensor([2, 4]).to(T.int32)
tgt_len = T.Tensor([4, 3]).to(T.int32)
print(src_len)
print(tgt_len)
# %%
# Word list size
max_source_word_num = 8
max_target_word_num = 8
# Maximum sequence length
max_source_seq_len = 5
max_target_seq_len = 5
# Generate seq
src_seq = [T.randint(1, max_source_word_num, (L,)) for L in src_len]
# padding
src_seq = list(map(lambda x: F.pad(x, (0, max_source_seq_len - len(x))), src_seq))
# L one dimension is convenient for us to splice
src_seq = list(map(lambda x: T.unsqueeze(x, 0), src_seq))
# Splicing
src_seq = T.cat(src_seq, 0)
print(src_seq)
tgt_seq = [F.pad(T.randint(1, max_target_word_num, (L,)), (0, max_target_seq_len-L)) for L in tgt_len]
tgt_seq = list(map(lambda x: T.unsqueeze(x, 0), tgt_seq))
tgt_seq = T.cat(tgt_seq, 0)
print(tgt_seq)
# %%
model_dim = 8
src_embedding_table = nn.Embedding(max_source_word_num + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_target_word_num + 1, model_dim)
print(src_embedding_table.weight.size())
# Test it forward
src_embedding = src_embedding_table(src_seq)
print(src_embedding.size())
# %%
# %%
max_position_len = 5
pos_matrix = T.arange(max_position_len).reshape((-1, 1))
print(pos_matrix)
# Because there are odd and even numbers , So the interval is 2
i_matrix = T.pow(10000, T.arange(0, model_dim, 2).reshape([1, -1]) / model_dim)
print(i_matrix)
# structure embedding matrix
pe_embedding_table = T.zeros([max_position_len, model_dim])
# Even columns , The line doesn't change ,0::2 Even columns , It means subscript from 0 Start , Until the last , Take the step of 2 All elements of
pe_embedding_table[:, 0::2] = T.sin(pos_matrix / i_matrix)
# Odd columns
pe_embedding_table[:, 1::2] = T.cos(pos_matrix / i_matrix)
print(pe_embedding_table)
# %%
# rewrite nn Module weight Way to create pe embedding
pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)
print(pe_embedding.weight.size())
# %%
# Construct location index
src_pos = T.cat([T.unsqueeze(T.arange(max_position_len), 0) for _ in src_len] , 0)
print(src_pos)
tgt_pos = T.cat([T.unsqueeze(T.arange(max_position_len), 0) for _ in tgt_len] , 0)
# forword Calculation src-pe
src_pe_embedding = pe_embedding(src_pos)
print(src_pe_embedding.size())
边栏推荐
- 网络安全的五大特点有哪些?五大属性是什么?
- 曾经,我同时兼职5份工作,只为给女友买个新款耳环......
- 大佬们 2.2.1cdc 监控sqlsever 只能拿到全量的数据 后期增量的数据拿不到 咋回事啊
- 再次认识 WebAssembly
- FreeRTOS task priority and interrupt priority
- TDengine 连接器上线 Google Data Studio 应用商店
- "Software defines the world, open source builds the future" 2022 open atom global open source summit will open at the end of July
- 标准化、最值归一化、均值归一化应用场景的进阶思考
- Ultimate efficiency is the foundation for the cloud native database tdsql-c to settle down
- 2020年蓝桥杯省赛真题-走方格(DP/DFS)
猜你喜欢

I took a private job and earned 15250. Is it still necessary to do my main business?

鸿世电器冲刺创业板:年营收6亿 刘金贤股权曾被广德小贷冻结

Yilian technology rushes to Shenzhen Stock Exchange: annual revenue of RMB 1.4 billion, 65% of which comes from Ningde times

The IPO of Tian'an technology was terminated: Fosun and Jiuding were shareholders who planned to raise 350million yuan

SDVO:LDSO+语义,直接法语义SLAM(RAL 2022)

Ros2 pre basic tutorial | using cmakelists Txt compile ros2 node

keil MDK 中使用虚拟串口调试串口

Countdown to the conference - Amazon cloud technology innovation conference invites you to build a new AI engine!

Tdengine connector goes online Google Data Studio store

Once, I had 5 part-time jobs just to buy a new earring for my girlfriend
随机推荐
After 100 days, Xiaoyu built a robot communication community!! Now invite moderators!
Bochs software usage record
【newman】postman生成漂亮的测试报告
向量1(类和对象)
UK considers listing arm in London based on national security
All famous network platforms in the world
好风凭借力 – 使用Babelfish 加速迁移 SQL Server 的代码转换实践
Is SQL analysis query unavailable in the basic version?
What are strong and weak symbols in embedded systems?
对领域驱动设计DDD理解
ROS2前置基础教程 | 小鱼教你用CMake依赖查找流程
js中const定义变量及for-of和for-in
三菱机械臂demo程序
类似attention nlp
又可以这样搞nlp(分类)
加密市场进入寒冬,是“天灾”还是“人祸”?
Scala语言学习-06-传名参数、传值参数、传函数参数的区别
山东泰安“6·21”燃气爆炸事故后续:全面排查整治餐饮场所燃气安全隐患
曾经,我同时兼职5份工作,只为给女友买个新款耳环......
Database connection pool: implementation of connection pool function point