当前位置:网站首页>Pyg tutorial (6): customizing the messaging network
Pyg tutorial (6): customizing the messaging network
2022-06-21 06:44:00 【Si Xi is towering】
One . Preface
In the last article, I mainly introduced GNN The messaging mechanism of , stay PyG A messaging base class is provided in torch_geometric.nn.MessagePassing, It realizes the automatic processing of message passing , By inheriting this class, you can easily build your own message propagation GNN.
The main content of this article includes :MessagePassing Class parsing 、 Inherit MessagePassing Realization GAT.
Two . How to customize the messaging network
To customize GNN Model , You need to inherit first MessagePassing class , Then rewrite the following method :
message(...): Build the message to be delivered ;aggregate(...): Aggregate the messages delivered from the source node to the target node ;update(...): Update the message of the node .
The above methods do not have to be customized , if
MessagePassingClass default implementation to meet your needs , You can not rewrite .
2.1 Constructors
Inherit MessagePassing After the class , In the constructor, you can use super().__init__() Method to the base class MessagePassing Pass parameters , To specify some behavior of messaging .MessagePassing Class initialization functions are as follows ( The parameters of this function are the parameters that can be passed from the subclass to the parent class ):
def __init__(self, aggr: Optional[str] = "add",
flow: str = "source_to_target", node_dim: int = -2,
decomposed_layers: int = 1):
Description of common parameters :
| Parameters | explain |
|---|---|
aggr | Message passing aggregation , Common ones include add、mean、min、max wait . |
flow | The direction of message propagation , among source_to_target Indicates from the source node to the target node 、target_to_source Indicates from the target node to the source node |
node_dim | The dimension of communication |
2.2 propagate function
Before I introduce the three related functions of messaging , Let's start with propagate function , The function is Start function of message propagation , After this function is called, execute in sequence message、aggregate and update Method to complete the Pass on 、 polymerization and to update . The declaration of this function is as follows :
propagate(self, edge_index: Adj, size: Size = None, **kwargs)
Parameter description :
| Parameters | explain |
|---|---|
edge_index | Edge index |
size | Adjacency matrix shape, if None Then it means square matrix , If the adjacency matrix is not a square matrix, it is necessary to transfer the shape |
**kwargs | structure 、 Additional data needed to aggregate and update messages , Can be passed into propagate function , These parameters can Received by three messaging functions . |
This function usually passes in
edge_indexAnd characteristicsx.
2.3 message function
message The function is used to Build node messages Of . Pass to propagate Of functions tensor It can be mapped to the central node and neighbor nodes , It only needs Add after the corresponding variable name _i or _j that will do , Commonly known as _i by Center node , call _j by Neighbor nodes .
Example :
self.propagate(edge_index, x=x):
pass
def message(self, x_i, x_j, edge_index_i):
pass
In this example propagate The function takes two arguments edge_index and x, be message The function can be based on propagate The two arguments in the function construct their own arguments , Above message The parameters constructed in the function are :
x_i: A matrix composed of eigenvectors of central nodes , Note that this matrix is different from the characteristic matrix of graph nodes ;x_j: A matrix composed of eigenvectors of neighbor nodes ;edge_index_i: Index of the central node .
Be careful , if
flow='source_to_target', Then the message will be transmitted from the neighbor node to the central node , ifflow='target_to_source'Then the message will be transmitted from the central node to the neighbor node , The default is the first case .
2.4 aggregate function
Message aggregation function aggregate Used to aggregate messages from neighbors , Common ones include add、sum、mean and max etc. , Can pass super().__init__() Parameters in aggr To set . Of this function The first parameter is zero message The output of the function ( Return value ).
2.5 update function
update Function to update the message of the node ,aggregate The output of the function ( Return value ) As the first argument to the function .
3、 ... and .GAT actual combat
This section is inherited MessagePassing Class to construct a graph attention network GAT.
3.1 GAT The messaging mechanism of
GAT The messaging formula for is as follows :
h i ( l + 1 ) = ∑ j ∈ N ( i ) α i , j W ( l ) h j ( l ) α i j l = s o f t m a x i ( e i j l ) e i j l = L e a k y R e L U ( a ⃗ T [ W h i ( l ) ∥ W h j ( l ) ] ) \begin{aligned} h_i^{(l+1)} & = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} \\ \alpha_{ij}^{l} & = \mathrm{softmax_i} (e_{ij}^{l})\\ e_{ij}^{l} & = \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i}^{(l)} \| W h_{j}^{(l)}]\right)\end{aligned} hi(l+1)αijleijl=j∈N(i)∑αi,jW(l)hj(l)=softmaxi(eijl)=LeakyReLU(aT[Whi(l)∥Whj(l)])
among h i ( l ) , h j ( l ) h_i^{(l)},h_j^{(l)} hi(l),hj(l) They represent nodes respectively i i i And nodes j j j In the l l l The eigenvector of the layer . It can be seen from the above formula that , When aggregating neighbor messages , It is necessary to first calculate the attention weight from the neighbor node to the central node , And then we do the weighted sum .
3.2 Concrete realization
according to 3.1 The messaging mechanism in section ,GAT The implementation of convolution layer is as follows :
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax, add_remaining_self_loops
class GATConv(MessagePassing):
def __init__(self, in_feats, out_feats, alpha, drop_prob=0.0):
super().__init__(aggr="add")
self.drop_prob = drop_prob
self.lin = nn.Linear(in_feats, out_feats, bias=False)
self.a = nn.Parameter(torch.zeros(size=(2*out_feats, 1)))
self.leakrelu = nn.LeakyReLU(alpha)
nn.init.xavier_uniform_(self.a)
def forward(self, x, edge_index):
edge_index, _ = add_remaining_self_loops(edge_index)
# Calculation Wh
h = self.lin(x)
# Start message propagation
h_prime = self.propagate(edge_index, x=h)
return h_prime
def message(self, x_i, x_j, edge_index_i):
# Calculation a(Wh_i || wh_j)
e = torch.matmul((torch.cat([x_i, x_j], dim=-1)), self.a)
e = self.leakrelu(e)
alpha = softmax(e, edge_index_i)
alpha = F.dropout(alpha, self.drop_prob, self.training)
return x_j * alpha
if __name__ == "__main__":
conv = GATConv(in_feats=3, out_feats=3, alpha=0.2)
x = torch.rand(4, 3)
edge_index = torch.tensor(
[[0, 1, 1, 2, 0, 2, 0, 3], [1, 0, 2, 1, 2, 0, 3, 0]], dtype=torch.long)
x = conv(x, edge_index)
print(x.shape)
In the above implementation process , stay message Function ( Message building phase ) The attention weight has been calculated , Therefore, it is only necessary to sum the weights of neighbors in the subsequent aggregation process , This can be done by super().__init__(aggr="add") To achieve . There is no other special operation for message update , So there is no need to customize , Press the default .
Through the above GAT Convolution layer , You can construct a GAT Model .
Four . Conclusion
Reference material :
PyG It is very convenient to implement our own message passing graph neural network , Of course, the introduction of this article is not perfect , It may be supplemented by additional .
边栏推荐
- Contos7 installing SVN server
- 【查询数据表中第三行数据】
- Leetcode database mysql topic (difficulty: simple)
- Markdown mathematical grammar [detailed summary]
- Old users come back and have a look
- 浅了解泛型机制
- EasyUI monitors mouse press events & keyboard events [simple and detailed]
- Niuke-top101-bm25
- 端午节-简单侧边导航栏
- 递归建立链式二叉树,完成前中后序遍历以及其他功能(附源码)
猜你喜欢

Why should I use the source code of nametuple replace(‘,‘, ‘ ‘). Split() instead of split(‘,‘)

我的高考经历与总结

5254. dynamic planning of selling wood blocks

TweenMax示波器3d动画

TypeError: iter() returned non-iterator of type ‘xxx‘

How powerful are spectral graph neural networks

【MySQL】数据库函数通关教程上篇(聚合、数学、字符串、日期、控制流函数)

第6期:大学生应该选择哪种主流编程语言

如何通过JDBC访问MySQL数据库?手把手实现登录界面(图解+完整代码)

156-Rust和Solana环境配置
随机推荐
leetcode数据库mysql题目(难度:简单)
easyUI的combox下拉列表的远程数据的绑定方法
[JDBC from starting to Real combat] JDBC Basic clearance tutoriel (Summary of the first part)
我的高考经历与总结
827. maximum man-made island and collection search
520 bubble source code
GEO2R:对GEO数据库中的数据进行差异分析
不给糖就捣蛋svg万圣节js特效
TypeError: iter() returned non-iterator of type ‘xxx‘
MySQL数据库基础:子查询
How to limit intranet speed
使用cell ranger进行单细胞转录组定量分析
递归建立链式二叉树,完成前中后序遍历以及其他功能(附源码)
How powerful are spectral graph neural networks
Hamming code verification [simple and detailed]
Issue 7: roll inside and lie flat. How do you choose
麦克风loading动画效果
小程序【第一期】
创新项目实训:数据爬取
Only your actions are the answers to these questions