当前位置:网站首页>Learning to Pre-train Graph Neural Networks(图预训练与微调差异)
Learning to Pre-train Graph Neural Networks(图预训练与微调差异)
2022-07-25 11:08:00 【上杉翔二】
博主曾经整理过一篇图预训练的文章,此后有很多在Graph上做Pretraning的文章层出不穷,但基本上万变不离其宗,都是在node-level和graph-level上做自监督学习。
为什么自监督策略有效?
- 多层结构,下层固定了再去train上层
- 多任务,可以去bias更泛化
- 同领域预训练,学到更多相关信息
但是预训练和微调之间总有差距,如何解决这个gap变成为一个棘手的问题,本篇博文博主将整理几种解决方案。
Learning to Pre-train Graph Neural Networks
这篇文章来自AAAI 2021。其核心的思想其实就是:如何缓解GNN预训练和微调之间的优化误差?
首先作者论证了 GNN 预训练是一个两阶段的流程:
- Pre-traning。先在大规模图数据集上进行预训练。即对参数theta进行更新使其最小化: θ 0 = a r g m i n θ L p r e ( f θ ; D p r e ) \theta_0=argmin_{\theta} L^{pre}(f_{\theta};D^{pre}) θ0=argminθLpre(fθ;Dpre)
- Fine-tuning。在下游数据上进行微调。用上一步训练好了的 θ 0 \theta_0 θ0上进行微调,即做梯度下降: θ 1 = θ 0 − η ∇ θ 0 L f i n e ( f θ 0 ; D t r ) \theta_1=\theta_0-\eta \nabla_{\theta_0} L^{fine}(f_{\theta_0};D^{tr}) θ1=θ0−η∇θ0Lfine(fθ0;Dtr)
作者认为发现这两个步骤之间是存在一些差异的,即在fine-turning虽然是用到了 θ 0 \theta_0 θ0,但 θ 0 \theta_0 θ0是固定的,它的得到是对fine-tuning的数据不可见的,即不会考虑到下游要怎么微调。这样就会造成Pre-traning和Fine-tuning之间的优化偏差,而这一差异在一定程度上影响预训练模型的迁移效果。
因此,作者提出了一种自监督预训练策略L2P-GNN,关键的两点博主认为是:
- 在pre–traning中做Fine-tuning。即既然有gap,那么在pre–traning的过程中就做类似Fine-tuning的事情就好。有些类似借用元学习的思想,学习如何去learn。
- 在node-level和graph-level上做自监督学习。

模型架构如上图,比较重要的是task construction和dual adaptation这两部分。
Task Construction
为了在pre–traning的过程中就做类似Fine-tuning的事情,作者的思路就是提前把数据集也划分成training和testing就好。对于需要Pre-training的多个task,每个task都会这样划分,对应图中的support set和query set。
而为了模拟在下游训练集合上的微调,就直接在支持集上训练损失函数得到可迁移先验知识,然后适配其在查询集上的表现即可。
Dual Adaptation
为了缩小预训练和微调过程之间的差距,在预训练过程中优化模型快速适应新任务的能力是至关重要的。为了将局部信息和全局信息都编码到先验信息中,所以作者提出双重适应在node和graph两个层面进行更新。
- 节点级适应.。这个与之前文章的方法一致,也是进行采样然后计算: L n o d e ( ψ ; S G c ) = ∑ − l n ( σ ( h u T h v ) ) − l n ( σ ( h u T h v ′ ) ) L^{node}(\psi;S^c_G)=\sum -ln(\sigma(h^T_uh_v))-ln(\sigma(h^T_uh_v')) Lnode(ψ;SGc)=∑−ln(σ(huThv))−ln(σ(huThv′)) 此时就进行节点级的参数更新: ψ ′ = ψ − α ∂ ∑ L n o d e ( ψ ; S G c ) ∂ ψ \psi'=\psi-\alpha \frac{\partial \sum L^{node}(\psi;S^c_G)}{\partial \psi} ψ′=ψ−α∂ψ∂∑Lnode(ψ;SGc)
- 图级适应。同样的,用采用子图的方法计算(图的表示通过pooling得到): L g r a p h ( ω ; S G ) = ∑ − l o g ( σ ( h S G c T h G ) ) − l n ( σ ( h S G c T h G ′ ) ) L^{graph}(\omega;S_G)=\sum -log(\sigma(h^T_{S^c_G}h_G))-ln(\sigma(h^T_{S^c_G}h_G')) Lgraph(ω;SG)=∑−log(σ(hSGcThG))−ln(σ(hSGcThG′))然后图级别的参数更新: ω ′ = ω − β ∂ L g r a p h ( ω ; S G ) ∂ ω \omega'=\omega-\beta \frac{\partial L^{graph}(\omega;S_G)}{\partial \omega} ω′=ω−β∂ω∂Lgraph(ω;SG)
- 先验知识的优化。经过节点级和图级自适应后,已经将全局先验知识适配 θ \theta θ为了任务特定的知识 θ ′ = { ψ ′ , ω ′ } \theta'=\{\psi',\omega'\} θ′={ ψ′,ω′}。然后用它来反向传播得到优化 θ \theta θ:
θ ← θ − γ ∂ ∑ L ( θ ′ ; Q G ) ∂ θ \theta \leftarrow \theta-\gamma \frac{\partial \sum L(\theta';Q_G)}{\partial \theta} θ←θ−γ∂θ∂∑L(θ′;QG) L ( θ ′ ; Q G ) = 1 k ∑ L n o d e ( ψ ; S G c ) + L g r a p h ( ω ; S G ) L(\theta';Q_G)=\frac{1}{k}\sum L^{node}(\psi;S^c_G)+L^{graph}(\omega;S_G) L(θ′;QG)=k1∑Lnode(ψ;SGc)+Lgraph(ω;SG)
paper:https://yuanfulu.github.io/publication/AAAI-L2PGNN.pdf
code:https://github.com/rootlu/L2P-GNN

Adaptive Transfer Learning on GNN
来自KDD2021。传统的预训练方案并没有设计下游的自适应学习,无法做到上下游一致。因此作者借助元学习设计了一个权重模型adaptive auxilizry loss weighting model来控制上游self-supervised任务和下游target task之间的一致性。
- 传统方法。在大量无标签数据上进行自监督任务学习+用自监督任务学习到的节点表征来辅助目标任务的学习。
- 作者的transfer方法。用joint loss来微调在参数,这样便会自适应保留pre-training阶段的有效信息,即通过计算辅助任务与目标任务梯度之间的余弦相似度similarity来学习Adaptive Auxiliary Loss Weighting,以量化辅助任务与目标任务的一致性。
paper:https://arxiv.org/abs/2107.08765
边栏推荐
猜你喜欢

varest蓝图设置json

brpc源码解析(四)—— Bthread机制

The most efficient note taking method in the world (change your old version of note taking method)

Chapter 4 linear equations
![[imx6ull notes] - a preliminary exploration of the underlying driver of the kernel](/img/0f/a0139be99c61fde08e73a5be6d6b4c.png)
[imx6ull notes] - a preliminary exploration of the underlying driver of the kernel

LeetCode 50. Pow(x,n)

Onenet platform control w5500 development board LED light

WIZnet嵌入式以太网技术培训公开课(免费!!!)

软件测试阶段的风险

Javescript loop
随机推荐
Javescript loop
A beautiful gift for girls from programmers, H5 cube, beautiful, exquisite, HD
Eigenvalues and eigenvectors of matrices
微星主板前面板耳机插孔无声音输出问题【已解决】
硬件连接服务器 tcp通讯协议 gateway
基于Caffe ResNet-50网络实现图片分类(仅推理)的实验复现
30 sets of Chinese style ppt/ creative ppt templates
The principle analysis of filter to solve the request parameter garbled code
硬件外设=maixpy3
Objects in JS
Layout management ==pyqt5
Experimental reproduction of image classification (reasoning only) based on caffe resnet-50 network
Return and finally? Everyone, please look over here,
W5500 adjusts the brightness of LED light band through upper computer control
软件缺陷的管理
winddows 计划任务执行bat 执行PHP文件 失败的解决办法
相似矩阵,可对角化条件
什么是全局事件总线?
SQL language (V)
创新突破!亚信科技助力中国移动某省完成核心账务数据库自主可控改造