当前位置:网站首页>PyG教程(7):剖析邻域聚合
PyG教程(7):剖析邻域聚合
2022-06-22 05:56:00 【斯曦巍峨】
一.前言
上篇文章《PyG教程(6):自定义消息传递网络》主要介绍了消息传递GNN的大致框架。本文主要聚焦于消息传播中的邻域聚合,本文将介绍PyG是如何将节点的邻居的消息聚合到节点本身的。
二.PyG中的邻域聚合
PyG中邻域聚合是通过aggregate(inputs, index)函数来完成的,该函数的第一个参数inputs为消息构建函数message()构建的消息,该函数还存在一个参数index,这个参数对于消息聚合是十分关键的,它指示了inputs中每条消息属于哪个节点的邻域。下图便很好的解释了PyG中的消息聚合:

上述栗子中展示的是包含4个顶点、8条边的graph,其中input为在8条边上传播的消息、index为各条边上消息的归属,即目标节点的索引。通过index,可以将属于同一个节点邻域的消息聚合到一起,常见的聚合包括sum、mean、mean、mul和min等。
在PyG中通过scatter函数来实现上述过程,查看MessagePassing的源码,可以看到其aggregate函数的定义如下:
def aggregate(self, inputs: Tensor, index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
r""" 注释太长略 """
if ptr is not None:
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
return segment_csr(inputs, ptr, reduce=self.aggr)
else:
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
reduce=self.aggr)
aggregate函数中scatter函数源码为:
def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
r""" 注释太长略 """
if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size)
if reduce == 'mul':
return scatter_mul(src, index, dim, out, dim_size)
elif reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size)
elif reduce == 'min':
return scatter_min(src, index, dim, out, dim_size)[0]
elif reduce == 'max':
return scatter_max(src, index, dim, out, dim_size)[0]
else:
raise ValueError
其中便包含了前面提到的5种聚合方式。对于这些聚合方式,只需要在继承MessagePassing类时,通过super().__init__来向该类传递参数aggr参数的值即可。
三.torch_scatter模块
若用户需要自定义消息聚合,则在重写的aggregate()函数中,同样可以使用MessagePassing中的scatter函数,只需要导入torch_scatter模块即可。
在torch_scatter模块中也实现了scatter函数,其声明如下:
scatter: (src: Tensor, index: Tensor, dim: int = -1, out: Tensor | None = None, dim_size: int | None = None, reduce: str = "sum") -> Tensor
常用参数说明:
| 参数 | 说明 |
|---|---|
src | 每条边上的源节点生成的消息 |
index | 指示每条边上消息需要聚合到哪个节点上 |
dim | 指示沿着那个维度(轴)应用index进行聚合 |
reduce | 聚合操作,包括sum、mul、 mean、 min 和 max |
注意,torch_scatter也为上述的几种聚合单独提供了API:
torch_scatter.scatter_add()
torch_scatter.scatter_max()
torch_scatter.scatter_mean()
torch_scatter.scatter_min()
torch_scatter.scatter_mul()
为了方便理解,下面给出一个栗子,假设存在一个包含3个顶点、6条边的图:

假设0、1、2三个顶点生成的消息分别为1、2、3,则图中6条边的消息inputs和相应的index构造如下:
inputs = torch.tensor([[1], [1], [2], [2], [3], [3]])
index = torch.tensor([1, 2, 0, 2, 0, 1])
应用torch_scatter.scatter()函数的结果如下:
out = torch_scatter.scatter(src=inputs, index=index, dim=0, reduce="sum")
print(out)
""" tensor([[5], [4], [3]]) """
可以看到节点0接受来自节点1,2的消息得到2+3=5,节点1接受来自节点0,2的消息得到1+3=4,而节点2接受来自节点0,1的消息得到1+2=3。
四.结语
参考资料:
通过本文可以加深对PyG中消息聚合过程的理解,这将有助于更好的自定义GNN模型。以上便是本文的全部内容,若有任何错误,请批评指正。
边栏推荐
- MFC tabctrl control to modify label size
- Machine learning concept sorting (no formula)
- Logback自定义Pattern参数解析
- idea插件EasyCode的使用
- Vulkan pre rotation processing equipment direction
- Unity development - scene asynchronous loading
- 组合逻辑块的测试平台
- R语言观察日志(part24)--writexl包
- 3D asset optimization and vertex data management for performance optimization
- 单细胞论文记录(part13)--SpaGCN: Integrating gene expression, spatial location and histology to ...
猜你喜欢

单细胞论文记录(part11)--ClusterMap for multi-scale clustering analysis of spatial gene expression

3D asset optimization and vertex data management for performance optimization

Creating GLSL Shaders at Runtime in Unity3D

System identification of automatic control principle

Conversion between gray code and binary

【CPU设计实战】数字逻辑电路设计基础(一)

Single cell thesis record (part13) -- spagcn: integrating gene expression, spatial location and history to

Shengxin visualization (Part1) -- histogram

生信可视化(part2)--箱线图

What about computer jam?
随机推荐
基于卫星测深的牙买加沿岸水深测量
[soft test] senior system architecture designer learning experience sharing
TiDB 社区线下交流会,天津 & 石家庄的小伙伴看过来~
单细胞文献学习(part3)--DSTG: deconvoluting spatial transcriptomics data through graph-based AI
W800芯片平台进入OpenHarmony主干
Server PHP related web page development environment construction
Single cell paper record (Part11) -- clustermap for multi-scale clustering analysis of spatial gene expression
【云计算重点复习】
TCP connection details
牛客-TOP101-BM27
触 发 器
401-字符串(344. 反转字符串、541. 反转字符串II、题目:剑指Offer 05.替换空格、151. 颠倒字符串中的单词)
关于MNIST线性模型矩阵顺序问题
文献记录(part106)--GRAPH AUTO-ENCODER VIA NEIGHBORHOOD WASSERSTEIN RECONSTRUCTION
汇顶科技GR551x系列开发板已支持OpenHarmony
System identification of automatic control principle
Surfer格网文件裁剪
使用Systemverilog描述状态机
MFC TabCtrl 控件修改标签尺寸
DataBricks从开源到商业化踩过的坑