当前位置:网站首页>阻止深度神经网络过拟合(Mysteries of Neural Networks Part II)
阻止深度神经网络过拟合(Mysteries of Neural Networks Part II)
2022-06-25 13:10:00 【小帅吖】
原文标题:Preventing Deep Neural Network from Overfitting
原文作者:Piotr Skalski
原文链接:https://medium.com/towards-data-science/preventing-deep-neural-network-from-overfitting-953458db800a
多亏大量的参数(数万甚至数百万)神经网络拥有很大的自由度能够适应大量复杂的数据集。这种独特的能力使他们能够接管许多在“传统”机器学习时代难以取得任何进展的领域,如图像识别、目标检测或自然语言处理。然而,有时他们最大的优势变成了潜在的弱点。 缺乏对模型学习过程的控制可能会导致过拟合——当我们的神经网络与训练集紧密拟合时,很难对新数据进行泛化和预测。了解这一问题的根源和防止它发生的方法,是一个成功的神经网络设计的必要。
How do you know your NN is overfitting?
Train, Dev and Test sets
发现我们的模型是过拟合是困难的。 我们训练过的模型已经投入生产,然后我们开始意识到有些地方出了问题,这种情况并不少见。 事实上,只有面对新的数据,才能确保一切正常工作。 但是,在训练过程中,我们要尽可能的还原真实的情况。 出于这个原因,将数据集分为三部分是一个很好的实践——训练集、验证集(也称为交叉验证或保留)和测试集。 我们的模型通过只看到这些部分中的第一部分来学习。 验证集是用来跟踪我们的进展和得出结论,以优化模型。 同时,我们在训练过程的最后使用测试集来评估我们的模型的性能。 使用全新的数据可以让我们对我们的算法工作得如何有一个不偏不倚的看法。
确保您的交叉验证和测试集来自相同的分布,以及它们准确地反映我们期望在未来收到的数据,这是非常重要的。 只有这样,我们才能确信我们在学习过程中做出的决定会让我们更接近一个更好的解决方案。 我知道你在想什么……“我应该如何划分我的数据集?” 直到最近,最常被推荐的分割方式之一是60/20/20,但在大数据时代,当我们的数据集可以计数数百万条目时,那些固定的比例不再合适。 简而言之,一切都取决于我们所处理的数据集的大小。 如果我们有数百万个条目可供支配,也许将它们除以98/1/1的比率会是更好的主意。 我们的验证集和测试集应该足够大,使我们对模型的性能有很高的信心。 根据数据集的大小划分数据集的推荐方法如图所示。
Bias and Variance
多亏我们对数据进行了适当的准备,我们给自己提供了评估模型性能的工具。 然而,在我们开始得出任何结论之前,我们应该熟悉两个新概念——偏差和方差。 为了让我们更好地理解这个复杂的问题,我们将用一个简单的例子,希望它能让我们建立一个有价值的直觉。 我们的数据集由两类点组成,位于二维空间中,如下图所示。
由于这是一个简单的演示,这次我们将跳过测试集,只使用训练集和交叉验证集。 接下来,我们将准备三个模型:第一个是简单线性回归,另外两个是由几个紧密相连的层构建的神经网络。 如下图所示,我们可以看到使用这些模型定义的分类边界。 在右上角的第一个模型非常简单,因此有很高的偏差,即它不能找到所有特征和结果之间的显著联系。 这是可以理解的——我们的数据集中有很多噪声,因此简单的线性回归不能有效地处理它。 神经网络的表现要好得多,但第一个(如图左下角所示)与数据的吻合过于紧密,这使得它在验证集上的表现明显更差。 这意味着它有很高的方差——它适合噪声而不是预期的输出。 在最后一个模型中,通过使用正则化,这种不良影响得到了缓解。
top right corner — linear regression; bottom left corner — neural network; bottom right corner — neural network with regularisation
我给出的这个例子很简单——我们只有两个特性,并且在任何时候我们都可以创建一个图表并直观地检查我们的模型的行为。 当我们在一个多维空间中操作时,如果数据集包含几十个特征,我们该怎么办? 然后我们比较使用训练集和交叉验证集计算的误差值。 当然,我们应该追求的最佳情况是,在这两个集合中都有较低的错误率。 主要的问题是定义什么是低错误率——在某些情况下,它可以是1%,在其他情况下,它可以高达10%或更多。 当训练神经网络时,它有助于建立一个基准来比较我们的模型的性能。 通常这将是人类执行这项任务的效率水平。 然后我们尽量确保我们的算法在训练过程中有一个接近我们的参考水平的误差。 如果我们已经实现了这个目标,但是当我们在保持集上验证它时,错误率显著增加,这可能意味着我们过度拟合(高方差)。 另一方面,如果我们的模型在训练集和交叉验证上的表现很差,那么它可能太弱,有很高的偏差。 当然,这个问题要复杂得多,涉及的范围也非常广泛,可以单独写一篇文章来讨论这个问题,但是这些基本信息应该足以理解下面的分析。
Ways to prevent overfitting
有很多方法可以帮助我们当我们遇到过拟合的问题。其中一些,比如获取更多的数据,是相当通用的,每次都很有效。 而其他的,比如正则化,则需要大量的技巧和经验。 对我们的NN施加太多的限制可能会损害其有效学习的能力。 现在让我们来看看一些最流行的减少过拟合的方法,并讨论它们起作用的原因。
L1 and L2 Regularizations
我们遇到过拟合问题时首先需要尝试的便是通过正则化方法进行解决。它涉及到在损失函数中添加一个额外的元素,这将惩罚我们的模型过于复杂,或者简单地说,在权值矩阵中使用过高的值。 通过这种方式,我们试图限制它的灵活性,但也鼓励它基于多种特性构建解决方案。 这种方法的两个流行版本是L1 -最小绝对偏差(LAD)和L2 -最小二乘误差(LS)。 描述这些规则的方程如下所示。
在大多数情况下,使用L1是可取的,因为它将不太重要的特征的权值减少到零,经常从计算中完全消除它们。 在某种程度上,它是自动特征选择的内置机制。 此外,L2在具有大量异常值的数据集上的表现不是很好,使用值平方的结果在模型中最小化异常值的影响,以牺牲更流行的案例。
现在让我们看看我们在偏差和方差的例子中使用的两个神经网络。 其中之一,正如我之前提到的,我使用正则化来消除过拟合。 我发现在三维空间中可视化权值矩阵,并比较有和没有正则化的模型之间得到的结果是一个有趣的想法。 我还使用正则化对许多模型进行了模拟,改变λ值以验证其对权重矩阵中包含的值的影响。 (说实话,我这么做也是因为我觉得它看起来超级酷,而且我不想让你睡着。) 矩阵的行和列索引对应于横轴值,权值被解释为垂直坐标。
Lambda factor and its effect
前面提到的公式在L1和L2的两个版本的正则化,我引入了超参数λ也称为正则化率。在选择其值时,我们试图在模型的简单性和拟合它与训练数据之间找到最佳平衡点。增加λ值也会增加正则化效果。上图所示。我们可以立即注意到,没有调节的模型得到的平面,以及具有很低λ系数值的模型非常“湍流”。有许多具有显著值的峰。在应用超参数值较高的L2正则化后,图呈扁平化趋势。最后,我们可以看到在0.1或1附近设置lambda值会导致模型中权重值的急剧下降。我鼓励您检查用于创建这些可视化的源代码。
Dropout
另外一种流行的方法放弃了目前非常流行的神经网络正则化方法。这个想法实际上非常简单——我们神经网络的每个单元(除了属于输出层的单元)在计算中被暂时忽略的概率p。超参数p被称为dropout rate,通常它的默认值设置为0.5。然后,在每次迭代中,我们根据指定的概率随机选择我们丢弃的神经元。因此,每次我们使用的神经网络都更小。下面的可视化显示了一个遭受dropout的神经网络示例。我们可以看到,在每次迭代中,来自第二层和第四层的随机神经元是如何失活的。
这种方法的有效性是相当令人惊讶和违反直觉的。毕竟,在现实世界中,如果工厂经理每天随机挑选员工并把他们打发回家,那么工厂的生产率就不会提高。让我们从单个神经元的角度来看这个问题。由于在每次迭代中,任何输入值都可能被随机消除,神经元试图平衡风险,而不偏爱任何特征。这样,权值矩阵中的值分布更加均匀。模型希望避免它提出的解决方案不再有意义的情况,因为它不再有来自非活动特性的信息流。
Early Stopping
下图显示了在学习过程的后续迭代中,测试集和交叉验证集上计算的精度值的变化。我们马上就能看到,我们最后得到的模型并不是我们所能创造的最好的。老实说,这比我们经历了150个时代后的情况要糟糕得多。为什么不在模型开始过拟合之前中断学习过程呢?这一发现启发了一种流行的过拟合减少方法,即提早停止。
在实践中,每隔几次迭代对我们的模型进行抽样,并检查它与验证集的工作情况是非常方便的。每一个比以前所有模型性能更好的模型都被保存。我们还设置了一个限制,最大迭代次数即在此范围内没有过程被记录。当超过该值时,学习将停止。虽然早期停止可以显著改善我们的模型的性能,但在实践中,它的应用极大地复杂化了我们模型的优化过程。它很难与其他常规技术相结合。
Conclusions
认识到我们的神经网络是过拟合的能力,以及我们可以应用解决方案来防止它发生的知识是基本的。然而,这些是非常广泛的主题,不可能在一篇文章中充分详细地描述它们。因此,我的目标是提供一些基本的直觉,让大家知道规则化或退出等技巧是如何工作的。这些话题对我来说很难理解,我希望我能帮助你解决它们。
边栏推荐
- 關於一道教材題的講解
- 关于数据在内存中存储的相关例题
- Prototype and prototype chain - constructor and instanceof
- C# 切换中英文输入法
- [pit avoidance means "difficult"] actionref current. Reload() does not take effect
- Detailed explanation of string operation functions and memory functions
- 解决报错:Creating window glfw ERROR: GLEW initalization error: Missing GL version
- Nova中的api
- The priority of catch() and then (..., ERR) of promise
- Cesium learning notes
猜你喜欢
Drago Education - typescript learning
Custom vertical table
“移动云杯”算力网络应用创新大赛火热报名中!
Related examples of data storage in memory
戴尔电脑cpu温度过高怎么办
Knowledge of initial C language 2.0
What if the CPU temperature of Dell computer is too high
Rust, the best choice for programmers to start a business?
À propos du stockage des données en mémoire
DE2-115 FPGA开发板的VGA显示
随机推荐
Introduction to mongodb chapter 01 introduction to mongodb
Class usage and inheritance in ES6
用NumPy实现神经网络(Mysteries of Neural Networks Part III)
Cesium learning notes
Restful and RPC
关于数据在内存中存储的相关例题
请问通达信股票开户是安全的吗?
Insight into heap and stack stored in new string() /string() in JS
Discuz copy today's headlines template /discuz news and information business GBK template
Application of tactile intelligent sharing-rk3568 in financial self-service terminal
Nr-arfcn and channel grid, synchronous grid and GSCN
Always maintain epidemic prevention and control and create a safe and stable social environment
Django framework - caching, signaling, cross site request forgery, cross domain issues, cookie session token
Knowledge of initial C language 2.0
leetcode:456. 132 mode [monotone stack]
New Gospel of drug design: Tencent, together with China University of science and technology and Zhejiang University, developed an adaptive graph learning method to predict molecular interactions and
Some knowledge about structure, enumeration and union
C# 切换中英文输入法
“移动云杯”算力网络应用创新大赛火热报名中!
Openstack -- creating virtual machines for Nova source code analysis