当前位置:网站首页>Semi-supervised Learning入门学习——Π-Model、Temporal Ensembling、Mean Teacher简介
Semi-supervised Learning入门学习——Π-Model、Temporal Ensembling、Mean Teacher简介
2022-06-27 08:43:00 【umbrellalalalala】
知乎同名账号同步发布。
通过两篇论文简单入门学习半监督的思想。论文:
arxiv1610(ICLR17)Temporal Ensembling for Semi-Supervised Learning
arxiv1703(NIPS17)Mean teachers are better role models- Weight-averaged consistency targets improve semi-supervised deep learning results
简要介绍Π-Model、Temporal Ensembling、Mean Teacher。欢迎交流,喜欢的话请点赞关注,之后持续更新。
多个网络的集成通常比单个网络更强。
过去也通过dropout、dropconnect、stochastic depth等技术间接证明了这一点,以及在swapout network中,训练聚焦于一个特定的网络子集。这些技术使得网络的训练结果可被视为ensemble of trained sub-networks。
作者是将上述观点扩展到了将单个网络的不同epochs的输出(结合不同的正则化和对输入数据的增强)进行集成:
We extend this idea by forming ensemble predictions during training, using the outputs of a single network on different training epochs and under different regularization and input augmentation conditions.
训练仍旧在单个网络上进行,但由于dropout,不同epoch的预测结果对应于大量单个子网络的集成的预测。
总之可以将单个网络的不同epoch的预测结果集成起来,而这种集成预测(ensemble prediction)能够用于半监督。如果将集成预测结果和正在被训练的当前网络的输出相比较,那么集成预测的结果应该更接近于无标签数据对应的未知的标签。所以可以将集成预测的结果作为无标签的输入对应的label(可以将集成预测结果当做是一个伪标签)。
(但这里我有个疑问,因为大前提是认为多个网络的集成比单个网络更强大,虽然不同epoch的集成可以视为不同网络的集成,但是后来的epoch按理来说要比先前的epoch更优秀,所以把当前epoch和之前所有epoch做集成,我认为可能并不会更强大,因为之前所有的epoch都是更菜的epoch。所以我觉得这样做有用的原因可能不是因为这样集成会更强,而是这样集成会更菜,从而可以视为一种正则化)
作者的方法非常依赖于dropout正则化和丰富多样的input数据增强,如果两者都没有,那么用上述方法推断出的unlabeled data的label(伪标签)就没有太多的可信度。
作者提出的方法是self-ensembling,进一步作者发现,在全标签的情况下,这个方法也能够提升分类效果,并提供了对错误标签的容忍度。
两种方式来实现self-ensembling,Π-model和temporal ensembling。
作者用分类问题来阐述方法,N个input数据,M个有label,类别一共有C类。
一、Π-model简介
模型流程图如上图所示,看伪代码就足够理解它了:
x表示input,y表示label,z表示预测值。有两个z,表示对input做不同的数据增强、经过不同dropout网络生成的两个预测结果,一个z和label y做交叉熵损失,然后两个z之间做均方差损失。将两个损失加权求和(有标签数据两个损失项都用,无标签数据只用第二个损失项),网络参数用ADAM进行优化。
关于将 z i z_i zi和 z i ~ \tilde{z_i} zi~之间最小化(loss的第二项),paper中给了一些说法:
1,让两个z之间的dark knowledge尽可能接近,这是一个比要求只有最终分类保持不变强得多的要求。
2,因为dropout的存在,在训练过程中网络的输出是一个随机变量,对于相同input和相同网络而言,会产生不同的输出结果(即对于同一个x,产生的两个z不相同)。数据增强亦是如此,也会造成两个z之间的difference,这个difference可以视为是分类问题的错误(an error in classification),由于两个z对应的输入x是同一个,所以最小化两个z之间的difference就是一个合理的目标。
*对于无标签的数据,学者提出了一个一致性约束假设:
对于无标签数据,对模型和数据加一定扰动,预测结果一致。
内容来源:https://blog.csdn.net/u011345885/article/details/111758193
上述的dropout就是对模型的扰动,数据增强就是对数据的扰动。
关于将两个z之间的差异最小化,还有两种思想:
1,一致性正则化:指的是根据上述一致性约束来构造loss;
2,伪标签:指的是将其中一个z视为伪标签,让另一个z逼近这个伪标签。
(如果是用思想2来看待π model的无监督loss term,即伪标签的思想,那么temporal ensembling就是改进了伪标签,后文会讲述)
关于权重 w w w,它的公式是 w ( t ) = e x p [ − 5 ( 1 − T ) 2 ] w(t)=exp[-5(1-T)^2] w(t)=exp[−5(1−T)2],在前80个epochs,T线性地由0变成1,从而w的值从较小的正数逐渐变为1.所以一开始,训练主要取决于loss中的有监督分量,即仅取决于标记数据。需要注意的是loss中的无监督分量要上升得足够慢,否则网络容易陷入到退化解中,无法获得有意义的分类。
(作者在给出w(t)公式的附录里也给出了其他训练细节:除了上述权重在前80个epochs的变化外,学习率和Adam β 1 \beta_1 β1都需要衰减,batchsize是100,网络一共训练300个epochs。)
二、Temporal Ensembling简介
如果将Π-model中的 z ~ i \tilde z_i z~i视为 z i z_i zi的伪标签的话,那么这个伪标签是不太好的,temporal ensembling就是针对这一点进行了改进,具体见paper中的图片:
注意上图的角标 i i i不是指时序,而是指一共 N N N个数据中的第 i i i个,生成的 z i z_i zi要参与生成下个epochs的第 i i i个数据的伪标签。
(注意是每个epoch,而不是每个batch,来改变一次伪标签,这种改变其实非常缓慢。之后的工作比如说mean teacher也指出,这个方法对于大数据集来说是非常难处理的)
注意上述伪代码, z ~ \tilde{z} z~表示 N N N个数据的伪标签,每个伪标签 z i ~ \tilde{z_i} zi~是一个 C C C维向量,作者的意思是在minibatch的循环就能够完成对 z ~ \tilde{z} z~的更新(每次循环更新一个 z i ~ \tilde{z_i} zi~),但是为了表述清晰,伪代码将更新写在了epoch的循环中。
和Π-model不同就在于end for语句后面的两行,作者将α设置为0.6,其中这两行的第二行作者称其为:对startup bias的纠正,作者说这个和Adam是类似的:
A similar bias correction has been used in, e.g., Adam (Kingma & Ba, 2014) and mean-only batch normalization (Salimans & Kingma, 2016).
实际上就是由于 Z Z Z是采用 α Z + ( 1 − α ) z \alpha Z + (1 - \alpha)z αZ+(1−α)z的公式计算,是历史值 Z Z Z和新的值 z z z累加(注意累加这个词,后面几句还会提及)所得。最开始的时候 Z Z Z为0,所以计算所得的 Z Z Z的值就是 ( 1 − α ) z (1 - \alpha)z (1−α)z,这个时候 t t t为1,则 Z / ( 1 − α t ) Z/(1 - \alpha^t) Z/(1−αt)就是 z z z本身,也就是将值放大到了本来该有的样子。随着 t t t的增大, ( 1 − α t ) (1 - \alpha^t) (1−αt)的值越来越接近1,则除以它的放大效果就会减弱。也就是这个分母是一开始起作用,解决的是累加在开始阶段导致值偏小(因为如刚才所说,累加是历史值 Z Z Z和新的值 z z z加权求和,在开始阶段,历史值 Z Z Z很小,甚至在 t t t为1的时候历史值为0),所以才说它解决的是startup bias,即开始阶段的偏差(偏小),用除以一个小于1的数将它放大。
由于采用的是类似于滑动平均的思想去构造伪标签,所以在第一个epochs需要单独设置一些参数,第一个epochs中的w(t)设置为0,表示loss中只有有监督分量。
temporal ensembling相对于Π-model的好处:
1,训练更快,因为一个epochs不再用计算两个output z;
2,训练结果比Π-model更不noisy(具体啥意思作者没说,应该就是结果更稳定吧)。
Second, the training targets z ~ \tilde z z~ can be expected to be less noisy than with Π-model.
参考资料
https://blog.csdn.net/u011345885/article/details/111758193
三、mean teachers
作者指出了temporal ensembling的缺点:每个epoch更新一次伪标签,如果面对的是很大的数据集,那么这种更新方式会变得很缓慢,这是很有问题的。为了克服这个问题,作者提出的方法是对模型的权重进行滑动平均,而非对伪标签的生成进行滑动平均。
To overcome this problem, we propose Mean Teacher, a method that averages model weights instead of label predictions.
作者解释了已有的模型:模型拥有双重的身份——学生和教师。作为学生,模型一如既往地进行学习;作为教师,模型生成一个target(即伪标签)。(注意mean teacher每个batch更新一次teacher model的参数,详见后文)
由于模型它自己生成target,这样可能会造成不正确的事情,所以提升target或者说伪标签的质量就是需要考虑的事情。作者认为有两种提升target的方式:
- 仔细选择对数据或者模型的扰动,而不只是施加加性或者乘性噪声。(补充paper中描述:对于两个相似的数据点,一个良好的模型应当给出相同的预测结果)
- 仔细选择一个teacher model,而不是直接将student model复制过来作为teacher model本身。
第一种方式已经被下述方法使用了:
Miyato, Takeru, Maeda, Shin-ichi, Koyama, Masanori, and Ishii, Shin. Virtual Adversarial Training: a Regularization Method for Supervised and Semi-supervised Learning. arXiv:1704.03976 [cs, stat], April 2017. arXiv: 1704.03976.
链接:https://ieeexplore.ieee.org/document/8417973
(TPAMI, 421 paper citations (20220507))
作者采用第二种方法,具体做法是,对于每个batch,用反向传播的方式更新student model的参数,然后用EMA的方式更新teacher model的参数:
student model和teacher model都有分类的能力,但在训练结束后,teacher model可能有更好的正确率。
所谓EMA的方式,就是用student model的参数,按照下述公式来更新teacher model的参数(注意最开始是直接将student model的参数复制给teacher model):
在ramp-up阶段,α的值设定为0.99,之后的训练过程中设置为0.999( α \alpha α越大,student model的参数对teacher model的参数影响越小)。这是因为初始时 student 模型训练的很快,而 teacher 需要忘记之前的、不正确的 student 权重;在 student 提升很慢的时候, teacher 记忆越长越好。
其实如果讲述个人想法,我觉得teacher model的存在会让student model的更新速度变慢(因为上图中的consistency cost),从而起到一个正则化的作用。但也可以用之前temporal ensembling那里的思想,认为多个模型的集成要强于单个模型(由于teacher model的模型参数是每个阶段的student model模型参数的滑动平均,所以可以视为不同阶段的student model的集成。而且 α \alpha α非常接近1,所以teacher model参数的更新速度非常的慢)。
但我个人觉得正则化的思想更能说服我,并且喊了其他大佬的博客中,出现了一个词叫“consistency regularization”即一致性正则,所以应该可以用正则化的观点来看待这几个方法:
Π-Model、Temporal Ensembling 和 Mean Teacher 三者都是利用一致性正则(consistency regularization)来进行半监督学习(semi-supervised learning)。
来源:https://blog.csdn.net/chanbo8205/article/details/108846097
再回归一下标题,自监督的简单应用的例子也在上图,就是数字识别问题,工作流程刻画的较为清晰,就不赘述了。
作者的实验和若干方法进行了对比,其中将Π-model复制过来当作baseline,然后将它修改成使用weight-averaged consistency的形式(即上图的consistency cost),记为Π (ours)。看一下作者做的实验和实验结果:
可以看出在label较少的时候,mean teacher要好很多,部分情况mean teacher不是最佳。值得注意的是原始Π-model在label不全的时候比Π (ours)好,这似乎无法说明在Π-model的情况下,令模型使用weight-averaged consistency会得到更好的效果。
先看到这里,其他实验部分就不赘述了。如有不通意见,欢迎评论区交流。
边栏推荐
猜你喜欢
Persistence mechanism of redis
A classic interview question covering 4 hot topics
Refer to | the computer cannot access the Internet after the hotspot is turned on in win11
多网络设备存在时,如何配置其上网优先级?
This, constructor, static, and inter call must be understood!
VIM from dislike to dependence (20) -- global command
2022.06.26(LC_6100_统计放置房子的方式数)
Analysis of orthofinder lineal homologous proteins and result processing
100%弄明白5种IO模型
How Oracle converts strings to multiple lines
随机推荐
Redis transactions
Understanding mvcc in MySQL transactions is super simple
Analysis of orthofinder lineal homologous proteins and result processing
How Oracle converts strings to multiple lines
Linux下Redis的安装
Rough reading DS transunet: dual swing transformer u-net for medical image segmentation
See how much volatile you know
Lvgl description 3 about the use of lvgl Guide
Digital ic-1.9 understands the coding routine of state machine in communication protocol
Object含有Copy方法?
Matlab tips (18) matrix analysis -- entropy weight method
vim 从嫌弃到依赖(19)——替换
Order by injection of SQL injection
I'm almost addicted to it. I can't sleep! Let a bug fuck me twice!
Five basic types of redis
Fake constructor???
Several cases that do not initialize classes
MATLAB小技巧(18)矩阵分析--熵权法
[MySQL basic] general syntax 1
oracle怎样将字符串转为多行