当前位置:网站首页>标签平滑(label smoothing)
标签平滑(label smoothing)
2022-07-22 19:54:00 【Billie使劲学】
目录
标签平滑(label smoothing)出自GoogleNet v3
关于one-hot编码的详细知识请见:One-hot编码
1.标签平滑主要解决什么问题?
传统的one-hot编码会带来的问题:无法保证模型的泛化能力,使网络过于自信会导致过拟合。
全概率和0概率鼓励所属类别和其他类别之间的差距尽可能加大,而由梯度有界可知,这种情况很难adapt。会造成模型过于相信预测的类别。而标签平滑可以缓解这个问题。
2.标签平滑是怎么操作的?
标签平滑是把one-hot中概率为1的那一项进行衰减,避免过度自信,衰减的那部分的自信被平均分到每一个类别中。
例如:
一个4分类任务,label = (0,1,0,0)
labeling smoothing = (
,1-0.001+
,
,
)=(0.00025,0.99925,0.00025,0.00025)
其中,概率加起来等于1。
3.标签平滑公式
交叉熵(Cross Entropy):
其中,q为标签值,p为预测结果,k为类别。即q为one-hot编码结果。
labeling smothing:将q进行标签平滑变为q',让模型输出的p分布去逼近q'。
,其中u(k)为一个概率分布,这里采用均匀分布(
),则得到
其中,
为原分布q, ϵ ∈(0,1)是一个超参数。
由以上公式可以看出,这种方式使label有 ϵ 概率来自于均匀分布, 1−ϵ 概率来自于原分布。这就相当于在原label上增加噪声,让模型的预测值不要过度集中于概率较高的类别,把一些概率放在概率较低的类别。
故标签平滑后的交叉熵损失函数为:
那这个公式是怎么得来的呢?
将q'(k|x)带入交叉熵损失函数:

![=-\sum_{k=1}^{k}log(p_k)[(1-\varepsilon )\delta _{k,y}+\frac{\varepsilon }{k}]](http://img.inotgo.com/imagesLocal/202207/23/202207221953347716_11.gif)
![=-\sum_{k=1}^{k}log(p_k)(1-\varepsilon )\delta _{k,y}+[-\sum_{k=1}^{k}log(p_k)\frac{\varepsilon }{k}]](http://img.inotgo.com/imagesLocal/202207/23/202207221953347716_6.gif)
![=(1-\varepsilon )*[-\sum_{k=1}^{k}log(p_k)\delta _{k,y}]+\varepsilon *[-\sum_{k=1}^{k}log(p_k)\frac{1}{k}]](http://img.inotgo.com/imagesLocal/202207/23/202207221953347716_5.gif)

这样就得到了标签平滑公式。
4.代码实现
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, eps=0.1, reduction='mean', ignore_index=-100):
super(LabelSmoothingCrossEntropy, self).__init__()
self.eps = eps
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, output, target):
c = output.size()[-1]
log_pred = torch.log_softmax(output, dim=-1)
if self.reduction == 'sum':
loss = -log_pred.sum()
else:
loss = -log_pred.sum(dim=-1)
if self.reduction == 'mean':
loss = loss.mean()
return loss * self.eps / c + (1 - self.eps) * torch.nn.functional.nll_loss(log_pred, target,
reduction=self.reduction,
ignore_index=self.ignore_index)
边栏推荐
- 12306史上最奇葩验证码:正常用户可轻松识别 抢票软件被拒之门外
- 再学电商项目之谷粒商城之ES6新特性
- postman “status“: 500, “error“: “Internal Server Error“, “message“: “The request was rejecte“
- 常见的跨域问题
- 小程序wx.setStorageSync后,在用getStorageSync获取数据有时会获取不到
- Electromagnetic field and electromagnetic wave experiment 4. Be familiar with the application of CST Studio Software in the electromagnetic field
- How to do if the control panel program cannot be uninstalled? Compulsory uninstallation software tutorial
- Zhimeng dedecms forgot to manage the background password retrieval method
- 对线程池的了解与应用你掌握多少
- Gb28181 summary of common problems in the use and secondary development of livegbs streaming media service
猜你喜欢

Analysis of cache read and write strategy

常见的跨域问题

【JDBC】报错Exception in thread “main”com.mysql.jdbc.exceptions.jdbc4.CommunicationsException: Communica

Pikachu shooting range SQL injection search injection clearance steps

Common operators

Q6ui layout operation

Database system design: partition

Demo19- (to be updated)

STL container - vector simulation implementation

LUR caching algorithm
随机推荐
怎样删除c盘非系统文件 c盘爆红了可以删除的文件汇总
Gb28181 summary of common problems in the use and secondary development of livegbs streaming media service
电脑一拖二显示器分辨率怎么调? 两个显示器设置不同分辨率的技巧
thinkphp URL_ Mode =0 specific usage of normal mode
小黑啃leetcode:589. N 叉树的前序遍历
What is the difference between GPU and CPU? Introduction to the meaning of GPU in different computers
局域网SDN技术硬核内幕 - 14 三 从物到人——SDN走进园区网络
关注公众号免费领取小米移动电源是真的吗?微信朋友圈送小米移动电源
Pikachu shooting range SQL injection search injection clearance steps
MySql的DDL和DML和DQL的基本语法
奇瑞艾瑞泽8产品力超能打,“全优”可不是白叫的
Face algorithms
局域网SDN技术硬核内幕 - 3 前传 突破多核的瓶颈——虚拟化
小米金融今日(5月11日)正式上线 白送10000元体验金 附官方地址
软件卸载不掉 显示请等待当前程序完成卸载或更改的解决办法
百度钱包帮你还信用卡 跨行还款0手续费 实时到帐 新人奖励5元
【高并发基石】多线程、守护线程、线程安全、线程同步、互斥锁
小米活期宝和余额宝哪个好?小米活期宝与阿里余额宝区别详细对比介绍
常见的跨域问题
Analysis of cache read and write strategy