当前位置:网站首页>MAML (Model-Agnostic Meta-Learning) 解读
MAML (Model-Agnostic Meta-Learning) 解读
2022-06-22 10:47:00 【千羽QY】
论文地址:proceedings.mlr.press/v70/finn17a/finn17a.pdf
5.1 简介
Model-Agnostic:可适用于任何梯度下降的模型,可用于不同的学习任务(如分类、回归、策略梯度RL)。
Meta-Learning:在大量的学习任务上训练模型,从而让模型仅用小数量的训练样本就可以学习新任务(加速fine-tune)。不同的任务有不同的模型。
需要考虑将先前的经验与少量新信息融合,同时避免过拟合。
方法的核心是训练模型的初始参数,从而使模型仅在少量新任务样本上通过几步梯度更新就达到最好性能。
5.2 方法
首先通过如下算法1初始化网络权重,然后在新任务上微调训练。
对上述算法1的解读:
MAML的目的是:学习网络的初始化权重,从而使网络只在新任务上训练一步或几步就能达到很好的效果。
和深度学习的核心一样,算法1中的训练任务和初始化权重后的微调的测试任务都是样本,和一般深度学习中的训练集和测试集的目的和概念一样。
对上述描述进行公式化:假设网络的初始化权重为 θ \theta θ,网络在不同新任务 τ i \tau_i τi的训练集上经过一步梯度更新后的权重为 θ i ′ \theta'_i θi′,使用更新后的权重 θ i ′ \theta'_i θi′在新任务 τ i \tau_i τi的测试集上计算损失 L i ( θ i ′ ) \mathcal{L}_i (\theta'_i) Li(θi′),MAML的目的是使不同新任务 τ i \tau_i τi上的损失之和最小,公式如下:
L = m i n ∑ τ i ∼ p ( τ ) L i ( θ i ′ ) L = min ~ \sum_{\tau_i \sim p(\tau)} \mathcal{L}_i (\theta'_i) L=min τi∼p(τ)∑Li(θi′)
以上述为总损失函数,对网络权重 θ \theta θ进行梯度下降,如下:
θ ← θ − β ∇ θ ∑ τ i ∼ p ( τ ) L i ( θ i ′ ) = θ − β ∑ τ i ∼ p ( τ ) ∇ θ L i ( θ i ′ ) \theta \leftarrow \theta - \beta \nabla_{\theta} \sum_{\tau_i \sim p(\tau)} \mathcal{L}_i (\theta'_i) \\ ~~ = \theta - \beta \sum_{\tau_i \sim p(\tau)} \nabla_{\theta} \mathcal{L}_i (\theta'_i) θ←θ−β∇θτi∼p(τ)∑Li(θi′) =θ−βτi∼p(τ)∑∇θLi(θi′)
计算 ∇ θ L i ( θ i ′ ) \nabla_{\theta} \mathcal{L}_i (\theta'_i) ∇θLi(θi′):
借用李宏毅老师讲义中的公式, ϕ = θ \phi=\theta ϕ=θ, θ ^ = θ i ′ \hat{\theta}=\theta'_i θ^=θi′, ∇ θ L i ( θ i ′ ) = ∇ ϕ l ( θ ^ ) \nabla_{\theta} \mathcal{L}_i (\theta'_i) = \nabla_{\phi} l(\hat\theta) ∇θLi(θi′)=∇ϕl(θ^), ∇ ϕ l ( θ ^ ) \nabla_{\phi} l(\hat\theta) ∇ϕl(θ^)可以分解为如下公式,

其中, θ ^ \hat{\theta} θ^由 ϕ \phi ϕ计算得到,如下:
通过如下公式计算 ∇ ϕ l ( θ ^ ) \nabla_{\phi} l(\hat\theta) ∇ϕl(θ^)中的每一项导数:
计算二阶导数非常耗时,所以MAML论文中提出使用一阶导数近似方法,即假设二阶导数都为0,对公式简化如下:

简化后, ∇ ϕ l ( θ ^ ) → ∇ θ ^ l ( θ ^ ) \nabla_{\phi} l(\hat\theta) \rightarrow \nabla_{\hat\theta} l(\hat\theta) ∇ϕl(θ^)→∇θ^l(θ^),原梯度下降公式转化为:
θ ← θ − β ∑ τ i ∼ p ( τ ) ∇ θ i ′ L i ( θ i ′ ) \theta \leftarrow \theta - \beta \sum_{\tau_i \sim p(\tau)} \nabla_{\theta'_i} \mathcal{L}_i (\theta'_i) θ←θ−βτi∼p(τ)∑∇θi′Li(θi′)
即,直接对每个更新后的 θ i ′ \theta'_i θi′计算梯度,将梯度作用到更新前的 θ \theta θ上。
问题:
1、为什么循环随机采样多个任务进行学习?
答:构建足够多的不同任务,使网络得到充分训练,从而在面向新任务时只通过几步更新就能达到较好的效果。
2、为什么第一次计算梯度与第二次计算梯度使用相同任务下的不同样本,即support set和query set?
答:前者是训练集,用于计算得到 θ i ′ \theta'_i θi′,后者是测试集,用于计算损失。
3、相比于先在一大堆任务上预训练(每次只计算一次梯度),再在新任务上微调,优势是什么?
答:预训练的目的是使网络在所有任务上的性能达到最优,将这个最优模型用于新任务微调时,可能陷入局部最优值等问题;而MAML的目的是使模型在新任务上训练几步后的性能达到最优,考虑的是未来的最优值,因此不会在某些任务上达到最优,而在其他任务上陷入次优。
更多细节请参考:
https://zhuanlan.zhihu.com/p/57864886
https://www.bilibili.com/video/BV1w4411872t?p=7&vd_source=383540c0e1a6565a222833cc51962ed9
边栏推荐
- iNFTnews | 观点:市场降温或是让NFT应用走向台前的机会
- CVPR 2022 Oral | 以运动为导向的点云单目标跟踪新范式
- 【毕业季·进击的技术er】青春不散场
- nodejs基础快速复习
- Batch create / delete files, folders, modify file name suffixes
- Niuke.com Huawei question bank (31~40)
- Super simple C language Snake does not flash screen double buffer
- Denso China adopts Oracle HCM cloud technology solution to accelerate the digital transformation of human resources
- 在 Laravel 中使用计算列
- 每日一题day5-1636. 按照频率将数组升序排序
猜你喜欢

Gartner表示:云数据库发展强劲,但本地数据库仍然充满活力

字节二面:Redis主节点的Key已过期,但从节点依然读到过期数据是为什么?怎么解决?

SQL statement of final examination for College Students

Kirin software and Geer software focus on the development of network data security

The data intelligence infrastructure upgrade window is approaching? See Chapter 9 how Yunji dingodb breaks through data pain points

iNFTnews | 观点:市场降温或是让NFT应用走向台前的机会

2022年深入推进IPv6部署和应用,该如何全面实现安全升级改造?

Laravel 中类似 WordPress 的钩子和过滤器

将有色液体图像转换成透明液体,CMU教机器人准确掌控向杯中倒多少水

In 2022, IPv6 deployment and application will be further promoted. How can we comprehensively realize security upgrading and transformation?
随机推荐
Leetcode algorithm The penultimate node in the linked list
MySQL daily experience [02]
MySQL uses SQL statements to modify field length and field name
代码签名证书一旦泄露 危害有多大
Backbone! Youxuan software was selected as one of the top 100 digital security companies in China in 2022
PHP开发的网站,如何实现批量打印快递单的功能?
大学生期末考试SQL语句
线程死锁的理解
Learn to view object models with VisualStudio developer tools
Quel est le risque de divulgation d'un certificat de signature de code?
世界上第一个“半机械人”去世,改造自己只为“逆天改命”
Use of libevent
在 Laravel 中使用计算列
数据智能基础设施升级窗口将至?看九章云极 DingoDB 如何击破数据痛点
从MVC原理开始手敲一个MVC框架,带你体会当大神的乐趣
How harmful is the code signature certificate once it is leaked
laravel 开发 文章URL 生成器
TCP 3次握手的通俗理解
Solve the problem that the chrome icon of Google browser in win7 taskbar is missing and abnormally blank
符合我公司GIS开源解决方案的探讨