当前位置:网站首页>Pyg tutorial (6): customizing the messaging network

Pyg tutorial (6): customizing the messaging network

2022-06-21 06:44:00 Si Xi is towering

One . Preface

In the last article, I mainly introduced GNN The messaging mechanism of , stay PyG A messaging base class is provided in torch_geometric.nn.MessagePassing, It realizes the automatic processing of message passing , By inheriting this class, you can easily build your own message propagation GNN.

The main content of this article includes :MessagePassing Class parsing 、 Inherit MessagePassing Realization GAT.

Two . How to customize the messaging network

To customize GNN Model , You need to inherit first MessagePassing class , Then rewrite the following method :

  • message(...): Build the message to be delivered ;
  • aggregate(...): Aggregate the messages delivered from the source node to the target node ;
  • update(...): Update the message of the node .

The above methods do not have to be customized , if MessagePassing Class default implementation to meet your needs , You can not rewrite .

2.1 Constructors

Inherit MessagePassing After the class , In the constructor, you can use super().__init__() Method to the base class MessagePassing Pass parameters , To specify some behavior of messaging .MessagePassing Class initialization functions are as follows ( The parameters of this function are the parameters that can be passed from the subclass to the parent class ):

def __init__(self, aggr: Optional[str] = "add",
             flow: str = "source_to_target", node_dim: int = -2,
             decomposed_layers: int = 1):

Description of common parameters

Parameters explain
aggr Message passing aggregation , Common ones include addmeanminmax wait .
flow The direction of message propagation , among source_to_target Indicates from the source node to the target node 、target_to_source Indicates from the target node to the source node
node_dim The dimension of communication

2.2 propagate function

Before I introduce the three related functions of messaging , Let's start with propagate function , The function is Start function of message propagation , After this function is called, execute in sequence messageaggregate and update Method to complete the Pass on polymerization and to update . The declaration of this function is as follows :

propagate(self, edge_index: Adj, size: Size = None, **kwargs)

Parameter description

Parameters explain
edge_index Edge index
size Adjacency matrix shape, if None Then it means square matrix , If the adjacency matrix is not a square matrix, it is necessary to transfer the shape
**kwargs structure 、 Additional data needed to aggregate and update messages , Can be passed into propagate function , These parameters can Received by three messaging functions .

This function usually passes in edge_index And characteristics x.

2.3 message function

message The function is used to Build node messages Of . Pass to propagate Of functions tensor It can be mapped to the central node and neighbor nodes , It only needs Add after the corresponding variable name _i or _j that will do , Commonly known as _i by Center node , call _j by Neighbor nodes .

Example :

self.propagate(edge_index, x=x):
    pass

def message(self, x_i, x_j, edge_index_i):
    pass

In this example propagate The function takes two arguments edge_index and x, be message The function can be based on propagate The two arguments in the function construct their own arguments , Above message The parameters constructed in the function are :

  • x_i: A matrix composed of eigenvectors of central nodes , Note that this matrix is different from the characteristic matrix of graph nodes ;
  • x_j: A matrix composed of eigenvectors of neighbor nodes ;
  • edge_index_i: Index of the central node .

Be careful , if flow='source_to_target', Then the message will be transmitted from the neighbor node to the central node , if flow='target_to_source' Then the message will be transmitted from the central node to the neighbor node , The default is the first case .

2.4 aggregate function

Message aggregation function aggregate Used to aggregate messages from neighbors , Common ones include addsummean and max etc. , Can pass super().__init__() Parameters in aggr To set . Of this function The first parameter is zero message The output of the function ( Return value ).

2.5 update function

update Function to update the message of the node ,aggregate The output of the function ( Return value ) As the first argument to the function .

3、 ... and .GAT actual combat

This section is inherited MessagePassing Class to construct a graph attention network GAT.

3.1 GAT The messaging mechanism of

GAT The messaging formula for is as follows :
h i ( l + 1 ) = ∑ j ∈ N ( i ) α i , j W ( l ) h j ( l ) α i j l = s o f t m a x i ( e i j l ) e i j l = L e a k y R e L U ( a ⃗ T [ W h i ( l ) ∥ W h j ( l ) ] ) \begin{aligned} h_i^{(l+1)} & = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} \\ \alpha_{ij}^{l} & = \mathrm{softmax_i} (e_{ij}^{l})\\ e_{ij}^{l} & = \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i}^{(l)} \| W h_{j}^{(l)}]\right)\end{aligned} hi(l+1)αijleijl=jN(i)αi,jW(l)hj(l)=softmaxi(eijl)=LeakyReLU(aT[Whi(l)Whj(l)])
among h i ( l ) , h j ( l ) h_i^{(l)},h_j^{(l)} hi(l),hj(l) They represent nodes respectively i i i And nodes j j j In the l l l The eigenvector of the layer . It can be seen from the above formula that , When aggregating neighbor messages , It is necessary to first calculate the attention weight from the neighbor node to the central node , And then we do the weighted sum .

3.2 Concrete realization

according to 3.1 The messaging mechanism in section ,GAT The implementation of convolution layer is as follows :

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax, add_remaining_self_loops


class GATConv(MessagePassing):
    def __init__(self, in_feats, out_feats, alpha, drop_prob=0.0):
        super().__init__(aggr="add")
        self.drop_prob = drop_prob
        self.lin = nn.Linear(in_feats, out_feats, bias=False)
        self.a = nn.Parameter(torch.zeros(size=(2*out_feats, 1)))
        self.leakrelu = nn.LeakyReLU(alpha)
        nn.init.xavier_uniform_(self.a)

    def forward(self, x, edge_index):
        edge_index, _ = add_remaining_self_loops(edge_index)
        #  Calculation  Wh
        h = self.lin(x)
        #  Start message propagation 
        h_prime = self.propagate(edge_index, x=h)
        return h_prime

    def message(self, x_i, x_j, edge_index_i):
        #  Calculation a(Wh_i || wh_j)
        e = torch.matmul((torch.cat([x_i, x_j], dim=-1)), self.a)
        e = self.leakrelu(e)
        alpha = softmax(e, edge_index_i)
        alpha = F.dropout(alpha, self.drop_prob, self.training)
        return x_j * alpha


if __name__ == "__main__":
    conv = GATConv(in_feats=3, out_feats=3, alpha=0.2)
    x = torch.rand(4, 3)
    edge_index = torch.tensor(
        [[0, 1, 1, 2, 0, 2, 0, 3], [1, 0, 2, 1, 2, 0, 3, 0]], dtype=torch.long)
    x = conv(x, edge_index)
    print(x.shape)

In the above implementation process , stay message Function ( Message building phase ) The attention weight has been calculated , Therefore, it is only necessary to sum the weights of neighbors in the subsequent aggregation process , This can be done by super().__init__(aggr="add") To achieve . There is no other special operation for message update , So there is no need to customize , Press the default .

Through the above GAT Convolution layer , You can construct a GAT Model .

Four . Conclusion

Reference material :

PyG It is very convenient to implement our own message passing graph neural network , Of course, the introduction of this article is not perfect , It may be supplemented by additional .

原网站

版权声明
本文为[Si Xi is towering]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/172/202206210624432710.html