mg是什么代码_有没有解读代码的软件

mg是什么代码_有没有解读代码的软件本文讲解了MGT的代码结构和模型搭建_metagraphtransformer(neurocomputing

论文解读:《Meta Graph Transformer: A Novel Framework for Spatial–Temporal Traffic Prediction》

代码链接:https://github.com/lonicera-yx/MGT


壹、测试主框架

一、文件目录

在这里插入图片描述
文件夹MGT-main下包含4个子文件夹,一个数据压缩文件,一个main.py等等。
MGT.py位于子文件夹models中,包含了3个def和14个class

二、if 主函数测试MGT

if __name__ == '__main__':
    print(os.getcwd())
    #cfgs = yaml.safe_load(open('cfgs/HZMetro_MGT.yaml'))['model']
    cfgs = yaml.safe_load(open('../cfgs/HZMetro_MGT.yaml'))['model']
    model = MGT(cfgs)

    # dummy data 虚拟数据
    B, P, Q, N, C = 10, 4, 4, 80, 2
    # B:batch_size P:history Q:feture N:Nodes C:Features
    M = 73, 2 # M is tuple
    eigenmaps_k = 8 #拉普拉斯特征映射降维方法的参数
    n = 3

    inputs = torch.randn(B, P, N, C, dtype=torch.float32)#(10,4,80,2)
    targets = torch.randn(B, Q, N, C, dtype=torch.float32)#(10,4,80,2)

    inputs_time0 = torch.randint(M[0], (B, P), dtype=torch.int64)#(10,4) max_int is 73
    targets_time0 = torch.randint(M[0], (B, Q), dtype=torch.int64)#(10,4)
    inputs_time1 = torch.randint(M[1], (B, P), dtype=torch.int64)#(10,4) max_int is 2
    targets_time1 = torch.randint(M[1], (B, Q), dtype=torch.int64)#(10,4)

    eigenmaps = torch.randn(N, eigenmaps_k, dtype=torch.float32)#(80,8)

    transition_matrices = torch.rand(n, N, N, dtype=torch.float32)#(3,80,80)

    extras = [inputs_time0, targets_time0, inputs_time1, targets_time1]
    statics = { 
   'eigenmaps': eigenmaps, 'transition_matrices': transition_matrices}

    # forward
    outputs1 = model(inputs, targets, *extras, **statics) #*和**见注释1
    outputs2 = model(inputs, None, *extras, **statics)

注释1:
见博文《def 参数 及参数解构 》

贰、MGT

def __init__ 结构搭建

在这里插入图片描述

在原文中有MTG的各种变形(如下),我们不考虑这些,只考经典的MTG
在这里插入图片描述
所以,在MGT下,self.noTE=self.noSE=False.共包含5个层结构:时间嵌入层、空间嵌入层、时空嵌入层、编码器结构、解码器结构。

def forward流程图

流程图中的input是一个,为了美观,所以拆分为两个分别作为输入。

dict

list

B,P

B,Q

B,P,d_m

B,Q,d_m

N,k

N,d_m

B,P,N,d_m

B,Q,N,d_m

B,P,N,C

B,P,N,C

B,Q,N,C

特征映射矩阵

转移矩阵

input0

input1

target0

target1

inputs

inputs

targets

extra

时间嵌入层

z_input

z_target

statics

空间嵌入层

U

时空嵌入层

c_inputs

c_targets

encoder

en-out

encoder

叁、 三个嵌入层TE\SE\STE

1. TE

def _init__
在这里插入图片描述
在init中主要定义了一个不可优化的参数矩阵self.pe和两个层结构。第一个层结构含两个嵌入层,第二个层结构是一个全连接层。如下:
在这里插入图片描述
注释:

  1. self.register_buffer
  2. nn.Embedding

def forward
数据形状和层结构的搭建如下图所示,从而完成数据的时间嵌入.橙色是对于input来说,蓝色是对于target来说的。nn.Embedding,nn.linear等都是固定的层结构。
在这里插入图片描述
代码为:
在这里插入图片描述
注释:

  1. torch.Tensor.expand

2 SE

空间嵌入是将具有空间特征的矩阵进行线性变换即可。
在这里插入图片描述

3 STE

将SE后的(z_inputs,z_targets)和TE后的u,进行扩维,最后经过一个线性变换合并信息。
在这里插入图片描述

注释:

  1. torch.stack,沿一个新维度对输入张量序列进行连接,序列中所有张量应为相同形状;stack 函数返回的结果会新增一个维度,而stack()函数指定的dim参数,就是新增维度的(下标)位置。

肆、Encoder

一、Encoder

def __init__

在这里插入图片描述

def forward

在这里插入图片描述

二、EncoderLayer层

def __init__

类从cfgs中获得的变量,设置的class的属性
在这里插入图片描述
其中包括3个层结构:TSA,SSA和FFN
在这里插入图片描述

伍、时间\空间\时编码自注意力层

1.TSA层

当使用元学习的时候,包含三个层结构:MetaLearner(元学习),LayerNorm(层归一化),Linear(现象变换)
在这里插入图片描述
注释:

  1. nn.LayerNorm,《用法》,《实现》

在这里插入图片描述

c=torch.randint(10,(num_weight_matrices, B, P, N, num_heads, d_k, d_model))

注释:

  1. torch.new_full,《案例讲解》,《精彩讲解》
  2. torch.triu, 《详例讲解》

2.SSA

当使用元学习的时候,包含4个结构:Meta_learner(元学习列表),Linear(线性变换),dropout, LayerNorm(层归一化)
在这里插入图片描述
输入数据为:inputs,c_inputs,transition_matrices
其 数据运行如下:
在这里插入图片描述
在这里插入图片描述

3.TEDA

当使用元学习的时候,包含3个层结构:MetaLearner, LayerNorm, Linear
在这里插入图片描述
在这里插入图片描述

陆、Decoder

一、Decoder

def __init__

在这里插入图片描述
在这里插入图片描述

def forward

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

二、DecoderLayer

def __init__

DecoderLayer在cfgs中获得变量和类属性
在这里插入图片描述
查看DecoderLayer中的层结构,我们知道MTG有如下的变体,但在这里我们只考虑MGT.
在这里插入图片描述
MGT中有4个层结构:
在这里插入图片描述
具体的形状如下
TSA层
在这里插入图片描述
SSA层
在这里插入图片描述
TEDA层
在这里插入图片描述
FFN层
在这里插入图片描述

def forward

在这里插入图片描述

捌、其他层

一、MetaLearner层

在这里插入图片描述
MetaLearner包含2个全连接层,形状如下:
在这里插入图片描述

二、FeedForward层

包含两个全连接层(Linear)和层归一化层(LayerNorm)
在这里插入图片描述

三、Projection

在这里插入图片描述

玖、三个多头函数

1. multihead_linear_transform

在这里插入图片描述

2. multihead_temporal_attention

在这里插入图片描述

3. multihead_spatial_attention

在这里插入图片描述

今天的文章mg是什么代码_有没有解读代码的软件分享到此就结束了,感谢您的阅读。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/88907.html

(0)
编程小号编程小号

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注