当前位置:网站首页>PyG教程(3):邻居采样
PyG教程(3):邻居采样
2022-06-21 06:24:00 【斯曦巍峨】
一.为什么需要邻居采样?
在GNN领域,大图是非常常见的,但由于GPU显存的限制,大图是无法放到GPU上进行训练的。为此,可以采用邻居采样,这样一来可以将GNN扩展到大图上。在PyG中,邻居采样的方式有很多种,具体详解torch_geometric.loader。本文以GraphSage中的邻居采样为例进行介绍,其在PyG中实现为NeighborLoader。
NeighborSampler也是PyG中关于GraphSage中邻居采样的实现,但已经被弃用,在未来版本中会被删除。
二.NeighborLoader详解
2.1 GraphSage邻居采样原理
假设采样的层数为 K K K,每层采样的邻居数为 S k S_k Sk,GraphSage中邻居采样是这样进行的:
- 步骤一:首先给定要采样邻居的小批量节点集 B \mathcal{B} B;
- 步骤二:对 B \mathcal{B} B的 1 1 1跳(hop)邻居进行采样,然后得到 B 1 \mathcal{B}_1 B1,然后对 B 1 \mathcal{B}_1 B1的 1 1 1跳邻居进行采样(即最初结点集的 2 2 2跳邻居)得到 B 2 \mathcal{B}_2 B2,如此往复进行 K K K次,得到最初小批量节点集相关的一个子图。
下图左是GraphSage中给出的一个2层邻居采样的示例,其中每层采样的邻居数 S k S_k Sk是相等的(图中为3)。

2.2 API介绍
PyG中,GraphSage的邻居采样实现为torch_geometric.loader.NeighborLoader,其初始化函数参数为:
def __init__(
self,
data: Union[Data, HeteroData],
num_neighbors: NumNeighbors,
input_nodes: InputNodes = None,
replace: bool = False,
directed: bool = True,
transform: Callable = None,
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
)
常用参数说明如下:
data:要采样的图对象,可以为异构图HeteroData,也可以为同构图Data;num_neighbors:每个节点每次迭代(每层)采样的最大邻居数,List[int]类型,例如[2,2]表示采样2层,每层中每个节点最多采样2个邻居;input_nodes:从原始图中采样得到的子图中需要包含的原始图中节点索引,即2.1节中最初的 B \mathcal{B} B,torch.Tensor()类型;directed:如果设置为False,将包括所有采样节点之间的所有边;**kwargs:torch.utils.data.DataLoader的额外参数,例如batch_size,shuffle(具体详见该API)。
2.3 采样实践
为了可视化的美观性,本小节采用的图数据是PyG中提供的KarateClub数据集,该数据集描述了一个空手道俱乐部会员的社交关系,节点为34名会员,如果两位会员在俱乐部之外仍保持社交关系,则在对应节点间连边,该数据集的可视化如下所示:

下面是对该数据集的加载、可视化以及邻居采样的源码:
import torch
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.loader import NeighborLoader
def draw(graph):
nids = graph.n_id
graph = to_networkx(graph)
for i, nid in enumerate(nids):
graph.nodes[i]['txt'] = str(nid.item())
node_labels = nx.get_node_attributes(graph, 'txt')
# print(node_labels)
# {0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}
nx.draw_networkx(graph, labels=node_labels, node_color='#00BFFF')
plt.axis("off")
plt.show()
dataset = KarateClub()
g = dataset[0]
# print(g)
# Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
g.n_id = torch.arange(g.num_nodes)
for s in NeighborLoader(g, num_neighbors=[2, 2], input_nodes=torch.Tensor([14])):
draw(s)
break
在上述源码中,设置的采样层数为2层、每个节点每层采样最多采样2个邻居,采样的初始节点集为{14},其对应的采样结果如下所示:

从上图可以看出,在第一次迭代中,采样了节点{14}的两个1跳邻居{32,33},然后在第二次迭代中对{32,33}分别进行采样得到{2,8]}和{18,30}。
需要注意是通过NeighborLoader返回的子图中,全局节点索引会映射到到与该子图对应的局部索引。因此,若要将当前采样子图中的节点映射会原来图中对应的节点,可以在原始图中创建一个属性来完成两者之间的映射,例如采样实践源码中的:
g.n_id = torch.arange(g.num_nodes)
如此以来,采样后子图中的节点同样包含n_id属性,这样就可以将子图的节点映射回去了,上述示例中对图进行可视化便利用了这一点,其对应的映射为:
{
0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}
结语
PyG中对于邻居采样的实现远远不止上述这一种,具体参见如下官网资料:
边栏推荐
- [data mining] final review Chapter 2
- Aurora8b10b IP use-04-ip routine application example
- 我的高考经历与总结
- 第8期:云原生—— 大学生职场小白该如何学
- leetcode 675. Cutting down trees for golf competitions - (day29)
- Aurora8B10B IP使用 -05- 收发测试应用示例
- 端口占用解决
- [data mining] final review Chapter 1
- 【笔记自用】myeclipse连接MySQL数据库详细步骤
- nametuple的源码为什么要使用.replace(‘,‘, ‘ ‘).split()而不是.split(‘,‘)
猜你喜欢

利用burp进行爆破(普通爆破+验证码爆破)

Deeply understand the gradient disappearance of RNN and why LSTM can solve the gradient disappearance

Course design of simulated bank deposit and withdrawal management system in C language (pure C language version)

Latest analysis on operation of refrigeration and air conditioning equipment in 2022 and examination question bank for operation of refrigeration and air conditioning equipment

Construction and protection of small-scale network examination

机器学习之数据归一化(Feature Scaling)

contos7 安装svn服务端

Issue 7: roll inside and lie flat. How do you choose

Aurora8b10b IP usage-01-introduction and port description

FPGA - 7 Series FPGA selectio -03- ILOGIC of logic resources
随机推荐
Aurora8B10B IP使用 -02- IP功能设计技巧
[data mining] final review Chapter 2
TypeError: iter() returned non-iterator of type ‘xxx‘
Pychart sets the default interpreter for the project
创新项目实训:数据爬取
Sqlmap tool
How powerful are spectral graph neural networks
Idea usage record
Memorizing Normality to Detect Anomaly: Memory-augmented Deep Autoencoder for Unsupervised Anomaly D
leetcode 410. Maximum value of split array - (Day30)
What is the shortcut button 4 click of pychart?
[MySQL] database multi table operation customs clearance tutorial (foreign key constraint, multi table joint query)
That's great. MySQL's summary is too comprehensive
机器学习之数据归一化(Feature Scaling)
Leetcode 75 - three implementation methods of color classification [medium]
C语言实现模拟银行存取款管理系统课程设计(纯C语言版)
Sqlmap命令大全
520泡泡的源码
第7期:内卷和躺平,你怎么选
Basic use of JPA