当前位置:网站首页>【CVPR2022 oral】Balanced Multimodal Learning via On-the-fly Gradient Modulation
【CVPR2022 oral】Balanced Multimodal Learning via On-the-fly Gradient Modulation
2022-07-13 17:48:00 【AI前沿理论组@OUC】

论文:https://arxiv.org/abs/2203.15332
代码:https://github.com/GeWu-Lab/OGM-GE_CVPR2022
这是一个来自人民大学GeWu-Lab的工作,被CVPR2022接收并选为Oral Presentation,相关代码已经开源。
1、研究动机
使用多模态数据进行分类有助于提高分类性能,但是,实际上现有的方法并没有有效的挖掘多个模态数据的性能。(如下图所示,在多模态模型中特定模态编码器的性能反而不如单模态模型,这说明现有的模型对于单模态特征表示的挖掘不足)

作者认为,这个现象的主要原因是:两个模态有一个起主导作用,另一个起辅助作用。起主导作用的模态会对另一个模态的优化起抑制作用。因此,为了解决这一问题,作者提出了OGM-GE(on-the-fly gradient modulation, generalization enhancement)。
作者举了一个非常生动的例子,在速度滑冰团体追逐比赛中,是以团队最后一名完成的队员记录团队成绩:通常情况下优势模态队员与较慢的弱势模态队员缺乏联系与互助,即使率先到达终点(训练收敛)也不得不停下来等待弱势模态队员(如图2的上半部分所示);但是通过OGM-GE方法的调控,让优势队员慢下来带带弱势队员。弱势队员由于受到帮助(比如滑冰比赛中的借风等因素)速度有所提高,这样团队总体成绩便能提高(如图2的下半部分所示)。

2、主要方法
为了解决多模态分类中的优化不平衡问题,作者根据模态间的效果差异自适应地调制梯度,并结合高斯噪声的泛化性增强能力,提出了具有较强适用性的OGM-GE优化方法,整个方法的算法框架下图所示。

首先,通过监控它们对学习目标的贡献的差异来自适应地调节每种模态的梯度。为了有效地衡量多模态模型中单模态表征能力的差异,作者设计了模态间差异比率这一指标;在训练进程中,根据该模态间的差异比率,动态地赋予不同模态不同的梯度比例系数,这样在整个训练过程中,都可以自适应地调控模型的优化进程。
其次,通过分析,作者发现对梯度添加小于1的系数进行调控会削弱优化过程中随机梯度噪声的强度,进而可能潜在地影响模型的泛化能力。因此,在梯度调控的基础上增加了噪声补强策略,即在梯度上额外增加高斯噪声,以恢复(甚至在原有基础上增强)随机梯度噪声强度,从而提升模型的泛化性能。
(1)on-the-fly gradient modulation
两个模态就有两个 encoder ,用 φ u \varphi^{u} φu 表示 ,其中 u ∈ { a , v } u \in \{a,v\} u∈{ a,v},分别对应audio和video两个模态。其中encoder的参数为 θ u \theta^u θu。使用梯度下降法优化时为:
θ t + 1 u = θ t u − η ∇ θ u L ( θ t u ) . \theta_{t+1}^{u} =\theta_{t}^{u}-\eta \nabla_{\theta^{u}} L(\theta_{t}^{u}). θt+1u=θtu−η∇θuL(θtu).
因此,作者的想法是,可以自适应的调节各个模态的优化速度,因此定义一个差异比率 ρ t u \rho^u_t ρtu:
s i a = ∑ k = 1 M 1 k = y i ⋅ softmax ( W t a ⋅ φ t a ( θ a , x i a ) + b 2 ) k , s_{i}^{a}= \sum_{k=1}^M 1_{k=y_{i}} \cdot \text{softmax} (W^{a}_{t} \cdot \varphi^{a}_{t}(\theta^{a},x_{i}^{a})+\frac{b}{2})_{k}, sia=k=1∑M1k=yi⋅softmax(Wta⋅φta(θa,xia)+2b)k,
s i v = ∑ k = 1 M 1 k = y i ⋅ softmax ( W t v ⋅ φ t v ( θ v , x i v ) + b 2 ) k , s_{i}^{v}= \sum_{k=1}^M 1_{k=y_{i}} \cdot \text{softmax} (W^{v}_{t} \cdot \varphi^{v}_{t}(\theta^{v},x_{i}^{v})+\frac{b}{2})_{k}, siv=k=1∑M1k=yi⋅softmax(Wtv⋅φtv(θv,xiv)+2b)k,
ρ t v = ∑ i ∈ B t s i v ∑ i ∈ B t s i a . \rho^{v}_{t}=\frac{\sum_{i \in B_{t}} s_{i}^{v} } {\sum_{i \in B_{t}} s^{a}_i}. ρtv=∑i∈Btsia∑i∈Btsiv.
因为有两个模态 audio 和 video, ρ t a \rho^a_t ρta定义为 ρ t v \rho^v_t ρtv的倒数。作者使用 ρ t u \rho^u_t ρtu动态监测多个模态之间的贡献差异,通过下面公式来自适应的调节梯度:
k t u = { 1 − tanh ( α ⋅ ρ t u ) ρ t u > 1 1 others, k^{u}_{t}=\left\{\begin{array}{cl} 1-\tanh (\alpha \cdot \rho^{u}_{t}) & \text { }\rho ^{u}_{t}>1 \\ 1 & \text { others, }\end{array}\right. ktu={ 1−tanh(α⋅ρtu)1 ρtu>1 others,
其中, α \alpha α 是一个超参数。作者将 k t u k^u_t ktu用在了SGD优化方法中,在迭代中,网络参数更新过程如下:
θ t + 1 u = θ t u − η ⋅ k t u g ~ ( θ t u ) . \theta^{u}_{t+1} =\theta_{t}^{u}-\eta \cdot k_{t}^{u}\tilde{g}(\theta_{t}^{u}). θt+1u=θtu−η⋅ktug~(θtu).
可以看出,通过 k t u k^u_t ktu 降低了主导模态的优化,而其它模态不受影响,从而缓解模态不平衡问题。
(2)generalization enhancement
定理:SGD 中的噪声与其泛化能力密切相关,SGD噪声越大,泛化能力越好。SGD 噪声的协方差与学习率和批量大小的比率成正比。
根据定理,梯度协方差的值越高,通常会带来更好的泛化能力。但经过作者的推导,OGM 方法会使 SGD 的泛化性能下降。所以有必要开发一种控制 SGD 噪声以提高泛化能力的方法。
作者引入了随机采样的高斯噪声 h ( θ t u ) h(\theta^u_t) h(θtu)到梯度中,迭代公式变为:
θ t + 1 u = θ t u − η ∇ θ u L ′ ( θ t u ) + η ξ t ′ , ξ t ′ ∼ N ( 0 , ( k t u ) 2 ⋅ Σ s g d ( θ t u ) ) , \theta_{t+1}^{u} =\theta_{t}^{u}-\eta \nabla_{\theta^{u}} L^{\prime}(\theta_{t}^{u})+\eta \xi_{t}^{\prime}, \\\xi_{t}^{\prime} \sim \mathcal{N}(0,(k_{t}^{u})^{2}\cdot\Sigma^{sgd}(\theta_{t}^{u})), θt+1u=θtu−η∇θuL′(θtu)+ηξt′,ξt′∼N(0,(ktu)2⋅Σsgd(θtu)),
整个算法流程如下:

代码如下:
a, v, out = model(spec, image)
out_v = (torch.mm(v, torch.transpose(model.fusion_module.fc_out.weight[:, :512], 0, 1)) +
model.fusion_module.fc_out.bias / 2)
out_a = (torch.mm(a, torch.transpose(model.fusion_module.fc_out.weight[:, 512:], 0, 1)) +
model.fusion_module.fc_out.bias / 2)
score_v = sum([softmax(out_v)[i][label[i]] for i in range(out_v.size(0))])
score_a = sum([softmax(out_a)[i][label[i]] for i in range(out_a.size(0))])
ratio_v = score_v / score_a
ratio_a = 1 / ratio_v
if ratio_v > 1:
coeff_v = 1 - tanh(args.alpha * relu(ratio_v))
coeff_a = 1
else:
coeff_a = 1 - tanh(args.alpha * relu(ratio_a))
coeff_v = 1
for name, parms in model.named_parameters():
layer = str(name).split('.')[1]
if 'audio' in layer and len(parms.grad.size()) == 4:
parms.grad *= coeff_a
parms.grad += torch.zeros_like(parms.grad).normal_(0, parms.grad.std().item() + 1e-8)
if 'visual' in layer and len(parms.grad.size()) == 4:
parms.grad *= coeff_v
parms.grad += torch.zeros_like(parms.grad).normal_(0, parms.grad.std().item() + 1e-8)
3、实验结果
作者首先将 OGM-GE 方法应用于几种常见的融合方法: baseline、concatenation 和 summation,以及专门设计的融合方法:FiLM 。性能如下表所示。从表中可以看出,两个模态性能不平衡,audio模态性能要明显优于 visual 模态。结a合OGM-GE方法以后,可以看到模型的性能有显著提升。

和其它modulation strategy 的对比。 作者和 modality-dropout 和 gradient-blending 进行了对比(如表2所示),可以看出所有 modulation方法都取得了性能提升,

消融实验非常有趣,在VGGSound数据集上,一开始OGM-GE方法的性能没有不使用时性能好,但是最后OGM-GE方法又实现了性能显著的提升。这是因为一开始我们方法在后期很好的挖掘了另一个模态的信息,实现了性能的提升。在图3中,展示了 ρ a \rho^a ρa的变化,可以看出,使用了OGM-GE训练的模型(蓝线),模态间的不平衡比例要显著小于直接训练的模型(黄线)。但由于不同模态间的信息量等存在天然差异等因素,该比例可能无法逼近1,只能在合理范围内减小。


边栏推荐
- 消息转发机制--拯救你的程序崩溃
- antd setFieldsValue警告问题 Cannot use `setFieldsValue` until you use `getFieldDecorator` or
- unity实验-模拟太阳系星体运动
- El button display and disable
- JS array de duplication
- Web API——获取元素、事件基础
- 一个简单的英文自然语言处理流程
- Vue+axios+mysql realizes paging query, condition query and batch deletion
- Expanding knowledge -- hijacking technology of JS
- Solutions to Oracle database error codes
猜你喜欢

Moss privacy computing all-in-one machine has passed 83 evaluations of Shenzhen Guojin evaluation center

unity实验-模拟太阳系星体运动

About the installation and use of visual studio 2022

36.js-- prototype chain 2-- (mainly written test questions)

ES6 -- Deconstruction assignment (key)

【MIT Missing Semester 2】Shell Tools

JS downloads files according to binary data
利用Spark预测回头客实验报告

Unity experiment - simulating the motion of stars in the solar system

How did the situation that NFT trading market mainly uses eth standard for trading come into being?
随机推荐
Solutions to Oracle database error codes
Tkmapper uses weekend splicing conditions to query conditions
Class loader + reflection +properties
最详细的window10虚拟机安装,手把手安装虚拟机,解决家庭版window找不到Hyper-V选项
《代码整洁之道》读后笔记
网络通信安全部分笔记二
Es6--string (string)
使用base64对图片进行编码、对byte[]进行编码
36.js-- prototype chain 2-- (mainly written test questions)
解读AFNetworking4.0请求原理
MySQL-约束
sniffer Pro對ARP協議的分析、捕獲與模擬攻擊
ES6 -- arrow function
_button.enable=NO不起作用
tkMapper之使用Weekend拼接条件进行条件查询
MSF利用永恒之蓝渗透win2003
曾入选顶会的技术完成产品化 蚂蚁链推出版权AI计算引擎
sniffer Pro对ARP协议的分析、捕获与模拟攻击
Unity experiment - gravity hitting the wall
消息转发机制--拯救你的程序崩溃