论文解读:《Meta Graph Transformer: A Novel Framework for Spatial–Temporal Traffic Prediction》
代码链接:https://github.com/lonicera-yx/MGT
壹、测试主框架
一、文件目录
文件夹MGT-main
下包含4个子文件夹,一个数据压缩文件,一个main.py
等等。
MGT.py
位于子文件夹models
中,包含了3个def
和14个class
二、if 主函数测试MGT
if __name__ == '__main__':
print(os.getcwd())
#cfgs = yaml.safe_load(open('cfgs/HZMetro_MGT.yaml'))['model']
cfgs = yaml.safe_load(open('../cfgs/HZMetro_MGT.yaml'))['model']
model = MGT(cfgs)
# dummy data 虚拟数据
B, P, Q, N, C = 10, 4, 4, 80, 2
# B:batch_size P:history Q:feture N:Nodes C:Features
M = 73, 2 # M is tuple
eigenmaps_k = 8 #拉普拉斯特征映射降维方法的参数
n = 3
inputs = torch.randn(B, P, N, C, dtype=torch.float32)#(10,4,80,2)
targets = torch.randn(B, Q, N, C, dtype=torch.float32)#(10,4,80,2)
inputs_time0 = torch.randint(M[0], (B, P), dtype=torch.int64)#(10,4) max_int is 73
targets_time0 = torch.randint(M[0], (B, Q), dtype=torch.int64)#(10,4)
inputs_time1 = torch.randint(M[1], (B, P), dtype=torch.int64)#(10,4) max_int is 2
targets_time1 = torch.randint(M[1], (B, Q), dtype=torch.int64)#(10,4)
eigenmaps = torch.randn(N, eigenmaps_k, dtype=torch.float32)#(80,8)
transition_matrices = torch.rand(n, N, N, dtype=torch.float32)#(3,80,80)
extras = [inputs_time0, targets_time0, inputs_time1, targets_time1]
statics = {
'eigenmaps': eigenmaps, 'transition_matrices': transition_matrices}
# forward
outputs1 = model(inputs, targets, *extras, **statics) #*和**见注释1
outputs2 = model(inputs, None, *extras, **statics)
注释1:
见博文《def 参数 及参数解构 》
贰、MGT
def __init__
结构搭建
在原文中有MTG的各种变形(如下),我们不考虑这些,只考经典的MTG
所以,在MGT下,self.noTE=self.noSE=False.共包含5个层结构:时间嵌入层、空间嵌入层、时空嵌入层、编码器结构、解码器结构。
def forward
流程图
流程图中的input是一个,为了美观,所以拆分为两个分别作为输入。
叁、 三个嵌入层TE\SE\STE
1. TE
def _init__
在init中主要定义了一个不可优化的参数矩阵self.pe和两个层结构。第一个层结构含两个嵌入层,第二个层结构是一个全连接层。如下:
注释:
- self.register_buffer
- nn.Embedding
def forward
数据形状和层结构的搭建如下图所示,从而完成数据的时间嵌入.橙色是对于input来说,蓝色是对于target来说的。nn.Embedding,nn.linear等都是固定的层结构。
代码为:
注释:
- torch.Tensor.expand
2 SE
空间嵌入是将具有空间特征的矩阵进行线性变换即可。
3 STE
将SE后的(z_inputs,z_targets)和TE后的u,进行扩维,最后经过一个线性变换合并信息。
注释:
- torch.stack,沿一个新维度对输入张量序列进行连接,序列中所有张量应为相同形状;stack 函数返回的结果会新增一个维度,而stack()函数指定的dim参数,就是新增维度的(下标)位置。
肆、Encoder
一、Encoder
def __init__
def forward
二、EncoderLayer层
def __init__
类从cfgs中获得的变量,设置的class的属性
其中包括3个层结构:TSA,SSA和FFN
伍、时间\空间\时编码自注意力层
1.TSA层
当使用元学习的时候,包含三个层结构:MetaLearner(元学习),LayerNorm(层归一化),Linear(现象变换)
注释:
- nn.LayerNorm,《用法》,《实现》
c=torch.randint(10,(num_weight_matrices, B, P, N, num_heads, d_k, d_model))
注释:
- torch.new_full,《案例讲解》,《精彩讲解》
- torch.triu, 《详例讲解》
2.SSA
当使用元学习的时候,包含4个结构:Meta_learner(元学习列表),Linear(线性变换),dropout, LayerNorm(层归一化)
输入数据为:inputs,c_inputs,transition_matrices
其 数据运行如下:
3.TEDA
当使用元学习的时候,包含3个层结构:MetaLearner, LayerNorm, Linear
陆、Decoder
一、Decoder
def __init__
def forward
二、DecoderLayer
def __init__
DecoderLayer在cfgs中获得变量和类属性
查看DecoderLayer中的层结构,我们知道MTG有如下的变体,但在这里我们只考虑MGT.
MGT中有4个层结构:
具体的形状如下
TSA层
SSA层
TEDA层
FFN层
def forward
捌、其他层
一、MetaLearner层
MetaLearner包含2个全连接层,形状如下:
二、FeedForward层
包含两个全连接层(Linear)和层归一化层(LayerNorm)
三、Projection
玖、三个多头函数
1. multihead_linear_transform
2. multihead_temporal_attention
3. multihead_spatial_attention
今天的文章mg是什么代码_有没有解读代码的软件分享到此就结束了,感谢您的阅读。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/88907.html