当前位置:网站首页>Pyg tutorial (3): neighbor sampling
Pyg tutorial (3): neighbor sampling
2022-06-21 06:43:00 【Si Xi is towering】
One . Why do I need neighbor sampling ?
stay GNN field , Big pictures are very common , But because of GPU Limitations of video memory , The big picture can't be put into GPU Training on . So , Neighbor sampling can be used , In this way, you can take GNN Expand to the big picture . stay PyG in , There are many ways of neighbor sampling , Specific details torch_geometric.loader. This article takes GraphSage Taking neighbor sampling in as an example , Its presence PyG Achieve in NeighborLoader.
NeighborSamplerIt's also PyG About China GraphSage Implementation of neighbor sampling in , But it has been abandoned , It will be deleted in future versions .
Two .NeighborLoader Detailed explanation
2.1 GraphSage Neighbor sampling principle
Suppose the number of layers sampled is K K K, The number of neighbors sampled in each layer is S k S_k Sk,GraphSage The middle neighbor sampling is carried out in this way :
- Step one : First, give a small batch of node sets to sample neighbors B \mathcal{B} B;
- Step two : Yes B \mathcal{B} B Of 1 1 1 jump (hop) Neighbor sampling , Then get B 1 \mathcal{B}_1 B1, Then on B 1 \mathcal{B}_1 B1 Of 1 1 1 Jump neighbors to sample ( That is, of the initial node set 2 2 2 Jump neighbor ) obtain B 2 \mathcal{B}_2 B2, So back and forth K K K Time , Get a subgraph related to the initial small batch node set .
The picture on the left shows GraphSage One given in 2 Example of layer neighbor sampling , The number of neighbors sampled in each layer S k S_k Sk They are equal. ( The picture shows 3).

2.2 API Introduce
PyG in ,GraphSage The neighbor sampling implementation of is torch_geometric.loader.NeighborLoader, The initialization function parameter is :
def __init__(
self,
data: Union[Data, HeteroData],
num_neighbors: NumNeighbors,
input_nodes: InputNodes = None,
replace: bool = False,
directed: bool = True,
transform: Callable = None,
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
)
Common parameters are described as follows :
data: Graph object to sample , It can be a heterogeneous graphHeteroData, It can also be isomorphic graphData;num_neighbors: Every node, every iteration ( Each layer ) Maximum number of neighbors sampled ,List[int]type , for example[2,2]Represent sampling 2 layer , Each node in each layer samples at most 2 A neighbor ;input_nodes: The node index of the original graph to be included in the sub graph sampled from the original graph , namely 2.1 The first in the section B \mathcal{B} B,torch.Tensor()type ;directed: If set toFalse, All edges between all sampling nodes will be included ;**kwargs:torch.utils.data.DataLoaderExtra parameters for , for examplebatch_size,shuffle( See this for details API).
2.3 Sampling practice
For visual aesthetics , The figure data used in this section is PyG Provided in KarateClub Data sets , This data set describes the social relationships of a karate club member , The node is 34 Members , If two members are still socializing outside the club , Then connect edges between corresponding nodes , The visualization of this dataset is as follows :

The following is the loading of the dataset 、 Visualization and neighbor sampling source code :
import torch
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.loader import NeighborLoader
def draw(graph):
nids = graph.n_id
graph = to_networkx(graph)
for i, nid in enumerate(nids):
graph.nodes[i]['txt'] = str(nid.item())
node_labels = nx.get_node_attributes(graph, 'txt')
# print(node_labels)
# {0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}
nx.draw_networkx(graph, labels=node_labels, node_color='#00BFFF')
plt.axis("off")
plt.show()
dataset = KarateClub()
g = dataset[0]
# print(g)
# Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
g.n_id = torch.arange(g.num_nodes)
for s in NeighborLoader(g, num_neighbors=[2, 2], input_nodes=torch.Tensor([14])):
draw(s)
break
In the above source code , The number of sampling layers set is 2 layer 、 Each node samples at most per layer 2 A neighbor , The initial node set for sampling is {14}, The corresponding sampling results are as follows :

As can be seen from the above figure , In the first iteration , Sampled nodes {14} Of the two 1 Jump neighbor {32,33}, And then, in the second iteration, we'll do {32,33} Samples were taken separately to obtain {2,8]} and {18,30}.
It should be noted that through NeighborLoader In the returned subgraph , The global node index is mapped to the local index corresponding to the subgraph . therefore , To map the nodes in the current mining graph to the corresponding nodes in the original graph , Sure Create an attribute in the original diagram to complete the mapping between the two , For example, in the source code of sampling practice :
g.n_id = torch.arange(g.num_nodes)
So since , After sampling, the nodes in the subgraph also contain n_id attribute , In this way, the nodes of the subgraph can be mapped back , The visualization of the graph in the above example takes advantage of this , The corresponding mapping is :
{
0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}
Conclusion
PyG The implementation of neighbor sampling in is far more than the above , See the following official website for details :
边栏推荐
猜你喜欢
随机推荐
MSF intranet penetration
(各种规律数的编程练习)输出范围内的素数,一个整数的分解质因数,两个数的最大公约数和最小公倍数以及水仙花数和完数等等
Contos7 installing SVN server
【JS】截取字符串
TypeError: iter() returned non-iterator of type ‘xxx‘
异常的相关介绍
【input】输入框事件总结
FPGA - 7 Series FPGA selectio -02- introduction to source language
[MySQL] database multi table operation customs clearance tutorial (foreign key constraint, multi table joint query)
Zongzi battle - guess who can win
Mysql database foundation: connection query
My college entrance examination experience and summary
PyG教程(6):自定义消息传递网络
EasyUI监听鼠标按下事件&监听键盘事件【简单详细】
How to access MySQL database through JDBC? Hand to hand login interface (illustration + complete code)
MySQL数据库基础:子查询
[[graduation season · advanced technology Er] - experience shared by senior students
Why should I use the source code of nametuple replace(‘,‘, ‘ ‘). Split() instead of split(‘,‘)
Binding method of remote data in the combox drop-down list of easyUI
User defined thread pool








