当前位置:网站首页>【MGT】代码解读之model-MGT
【MGT】代码解读之model-MGT
2022-06-21 08:21:00 【panbaoran913】
论文解读:《Meta Graph Transformer: A Novel Framework for Spatial–Temporal Traffic Prediction》
代码链接:https://github.com/lonicera-yx/MGT
壹、测试主框架
一、文件目录

文件夹MGT-main下包含4个子文件夹,一个数据压缩文件,一个main.py等等。MGT.py位于子文件夹models中,包含了3个def和14个class
二、if 主函数测试MGT
if __name__ == '__main__':
print(os.getcwd())
#cfgs = yaml.safe_load(open('cfgs/HZMetro_MGT.yaml'))['model']
cfgs = yaml.safe_load(open('../cfgs/HZMetro_MGT.yaml'))['model']
model = MGT(cfgs)
# dummy data 虚拟数据
B, P, Q, N, C = 10, 4, 4, 80, 2
# B:batch_size P:history Q:feture N:Nodes C:Features
M = 73, 2 # M is tuple
eigenmaps_k = 8 #拉普拉斯特征映射降维方法的参数
n = 3
inputs = torch.randn(B, P, N, C, dtype=torch.float32)#(10,4,80,2)
targets = torch.randn(B, Q, N, C, dtype=torch.float32)#(10,4,80,2)
inputs_time0 = torch.randint(M[0], (B, P), dtype=torch.int64)#(10,4) max_int is 73
targets_time0 = torch.randint(M[0], (B, Q), dtype=torch.int64)#(10,4)
inputs_time1 = torch.randint(M[1], (B, P), dtype=torch.int64)#(10,4) max_int is 2
targets_time1 = torch.randint(M[1], (B, Q), dtype=torch.int64)#(10,4)
eigenmaps = torch.randn(N, eigenmaps_k, dtype=torch.float32)#(80,8)
transition_matrices = torch.rand(n, N, N, dtype=torch.float32)#(3,80,80)
extras = [inputs_time0, targets_time0, inputs_time1, targets_time1]
statics = {
'eigenmaps': eigenmaps, 'transition_matrices': transition_matrices}
# forward
outputs1 = model(inputs, targets, *extras, **statics) #*和**见注释1
outputs2 = model(inputs, None, *extras, **statics)
注释1:
见博文《def 参数 及参数解构 》
贰、MGT
def __init__ 结构搭建

在原文中有MTG的各种变形(如下),我们不考虑这些,只考经典的MTG
所以,在MGT下,self.noTE=self.noSE=False.共包含5个层结构:时间嵌入层、空间嵌入层、时空嵌入层、编码器结构、解码器结构。
def forward流程图
流程图中的input是一个,为了美观,所以拆分为两个分别作为输入。
叁、 三个嵌入层TE\SE\STE
1. TE
def _init__
在init中主要定义了一个不可优化的参数矩阵self.pe和两个层结构。第一个层结构含两个嵌入层,第二个层结构是一个全连接层。如下:
注释:
def forward
数据形状和层结构的搭建如下图所示,从而完成数据的时间嵌入.橙色是对于input来说,蓝色是对于target来说的。nn.Embedding,nn.linear等都是固定的层结构。
代码为:
注释:
2 SE
空间嵌入是将具有空间特征的矩阵进行线性变换即可。
3 STE
将SE后的(z_inputs,z_targets)和TE后的u,进行扩维,最后经过一个线性变换合并信息。
注释:
- torch.stack,沿一个新维度对输入张量序列进行连接,序列中所有张量应为相同形状;stack 函数返回的结果会新增一个维度,而stack()函数指定的dim参数,就是新增维度的(下标)位置。
肆、Encoder
一、Encoder
def __init__

def forward

二、EncoderLayer层
def __init__
类从cfgs中获得的变量,设置的class的属性
其中包括3个层结构:TSA,SSA和FFN
伍、时间\空间\时编码自注意力层
1.TSA层
当使用元学习的时候,包含三个层结构:MetaLearner(元学习),LayerNorm(层归一化),Linear(现象变换)
注释:

c=torch.randint(10,(num_weight_matrices, B, P, N, num_heads, d_k, d_model))
注释:
2.SSA
当使用元学习的时候,包含4个结构:Meta_learner(元学习列表),Linear(线性变换),dropout, LayerNorm(层归一化)
输入数据为:inputs,c_inputs,transition_matrices
其 数据运行如下:

3.TEDA
当使用元学习的时候,包含3个层结构:MetaLearner, LayerNorm, Linear

陆、Decoder
一、Decoder
def __init__


def forward





二、DecoderLayer
def __init__
DecoderLayer在cfgs中获得变量和类属性
查看DecoderLayer中的层结构,我们知道MTG有如下的变体,但在这里我们只考虑MGT.
MGT中有4个层结构:
具体的形状如下
TSA层
SSA层
TEDA层
FFN层
def forward

捌、其他层
一、MetaLearner层

MetaLearner包含2个全连接层,形状如下:
二、FeedForward层
包含两个全连接层(Linear)和层归一化层(LayerNorm)
三、Projection

玖、三个多头函数
1. multihead_linear_transform

2. multihead_temporal_attention

3. multihead_spatial_attention

边栏推荐
- Software engineering - Chapter 3 software requirements analysis
- 面试经验---字节
- showCTF Web入门题系列
- 5分钟搞懂MySQL - 行转列
- Interview experience - bytes
- Mono of unity 5 can also support C # 6
- 1005 spell it right (20 points) (test point 3)
- Can you implement these requirements with MySQL
- Kotlin middle tail recursive function
- Kotlin---- control statement
猜你喜欢

Mono of unity 5 can also support C # 6

Three ways to solve cross domain problems

showCTF Web入门题系列

Zhongyi Antu submitted for registration: proposed to raise 600million yuan, with annual revenue of nearly 1.2 billion yuan

【活动早知道】LiveVideoStack近期活动一览

Journal (résumé en langue c)

Vision_ Transformer code exercise

MMS for risc-v

2022-2028 global cooling on-off valve industry research and trend analysis report

Two image enhancement methods: image point operation and image graying
随机推荐
Use lua+redis+openresty to realize concurrent optimization of e-commerce Homepage
Linux安装达梦数据库/DM8(附带客户端工具安装完整版)
Three declaration methods of structure type
5分钟搞懂MySQL - 行转列
[kotlin] first day
PS prompts "script error -50 general Photoshop error, how to solve it?
Global and Chinese market for crankshaft position sensors 2022-2028: Research Report on technology, participants, trends, market size and share
【元宇宙3d大赛】
sql查看数据库/表磁盘占用情况,杀死进程终止tidb中的连接
写文章的markdown规则
[DB written interview 220] how to back up control files in oracle? What are the ways to back up control files?
[DB written interview 274] in Oracle, what is deferred segment creation?
Ads Filter Design Wizard tool I
Represent each record in the dataframe as a dictionary
Kotlin---- control statement
使用Lua+Redis+OpenResty实现电商首页并发优化
2022-2028 global postoperative pressure suit industry research and trend analysis report
Generic functions in kotlin
Redis master-slave vulnerability and remote connection vulnerability
Gql+nodejs+mysql database