当前位置:网站首页>【南瓜书ML】(task4)神经网络中的数学推导
【南瓜书ML】(task4)神经网络中的数学推导
2022-07-24 21:45:00 【山顶夕景】
学习总结
文章目录
一、神经元模型和MLP基础
如CTR任务中,判断用户有点击过这个物品”等等。细胞体在接收到这些信号的时候,会做一个简单的判断,然后通过轴突输出一个信号,这个输出信号大小代表了用户对这个物品的感兴趣程度。

上图中的激活函数是 sigmoid 激活函数,它的数学定义是: f ( z ) = 1 1 + e − z f(z)=\frac{1}{1+\mathrm{e}^{-z}} f(z)=1+e−z1它的函数图像就是上图中的 S 型曲线,它的作用:
- 把输入信号从(-∞,+∞)的定义域映射到(0,1)的值域(因为在点击率预测,推荐问题中,往往是要预测一个从 0 到 1 的概率)。
- sigmoid 函数处处可导,方便之后的梯度下降学习过程,所以它成为了经常使用的激活函数。比较流行的还有 tanh、ReLU 等。
1.1 BP神经网络如何学习
多神经元构成的网络具有更强的拟合数据能力,如下图的一个由输入层、两神经元隐层和单神经元输出层组成的简单神经网络:
每个蓝色神经元的构造都和刚才的单神经元构造相同,h1和 h2神经元的输入是由 x1和 x2组成的特征向量,而神经元 o1的输入则是由 h1和 h2输出组成的输入向量。
神经网络的重要训练方法,前向传播(Forward Propagation)和反向传播(Back Propagation)。
(1)前向传播
前向传播的目的是在当前网络参数的基础上得到模型对输入的预估值,也就是我们常说的模型推断过程。比如说,我们要通过一位同学的体重、身高预测 TA 的性别,前向传播的过程就是给定体重值 71,身高值 178,经过神经元 h1、h2和 o1的计算,得到一个性别概率值,比如说 0.87,这就是 TA 可能为男性的概率。
(2)损失函数
如果这位同学的真实性别是男,那真实的概率值就是 1,根据公式 2 的绝对值误差定义,这次预测的损失就是|1-0.87| = 0.13。
l 1 ( y i , y ^ i ) = ∣ y i − y ^ i ∣ l_{1}\left(y_{i}, \hat{y}_{i}\right)=\left|y_{i}-\hat{y}_{i}\right| l1(yi,y^i)=∣yi−y^i∣
(3)梯度下降
发现了预测值和真实值之间的误差(Loss),我们就要用这个误差来指导权重的更新,让整个神经网络在下次预测时变得更准确。最常见的权重更新方式就是梯度下降法,它是通过求取偏导的形式来更新权重的。比如,我们要更新权重 w5,就要先求取损失函数到 w5 的偏导 ∂ L o 1 ∂ w 5 \frac{\partial L_{o 1}}{\partial w_{5}} ∂w5∂Lo1从数学角度来看,梯度的方向是函数增长速度最快的方向,那么梯度的反方向就是函数下降最快的方向,所以让损失函数减小最快的方向就是我们希望梯度 w5 更新的方向。这里我们再引入一个超参数α,它代表了梯度更新的力度,也称为学习率。现在可以写出梯度更新的公式了: w 5 t + 1 = w 5 t − α ∗ ∂ L o 1 ∂ w 5 w_{5}^{t+1}=w_{5}^{t}-\alpha * \frac{\partial L_{o 1}}{\partial w_{5}} w5t+1=w5t−α∗∂w5∂Lo1公式中的 w5当然可以换成其他要更新的参数,公式中的 t 代表着更新的次数。
对输出层神经元来说(图中的 o1),我们可以直接利用梯度下降法计算神经元相关权重(即图 5 中的权重 w5和 w6)的梯度,从而进行权重更新,但对隐层神经元的相关参数(比如 w1),我们又该如何利用输出层的损失进行梯度下降呢?
——“利用求导过程中的链式法则(Chain Rule)”。通过链式法则我们可以解决梯度逐层反向传播的问题。最终的损失函数到权重 w1的梯度是由损失函数到神经元 h1输出的偏导,以及神经元 h1输出到权重 w1的偏导相乘而来的。也就是说,最终的梯度逐层传导回来,“指导”权重 w1的更新。
∂ L o 1 ∂ w 1 = ∂ L o 1 ∂ h 1 ⋅ ∂ h 1 ∂ w 1 \frac{\partial L_{o 1}}{\partial w_{1}}=\frac{\partial L_{o 1}}{\partial h_{1}} \cdot \frac{\partial h_{1}}{\partial w_{1}} ∂w1∂Lo1=∂h1∂Lo1⋅∂w1∂h1
1.2 线性可分和非线性可分
要解决非线性可分问题,需要考虑用多层功能神经元。如下即可以解决异或问题的两层感知机:
二、误差逆传播BP算法
- BP算法的思想:首先将误差反向传播给隐含层神经元,调节隐含层到输出层的连接权重与输出层神经元的阈值;接着根据隐含层神经元的均方误差,来调节输入层到隐含层的连接权值与隐含层神经元的阈值。
- BP算法的目标:最小化训练集D上的累计误差: E = 1 m ∑ k = 1 m E k E=\frac{1}{m} \sum_{k=1}^{m} E_{k} E=m1k=1∑mEk
BP算法基本流程:
输入:
训练集 D = { ( x k , y k ) } k = 1 m D=\left\{\left(x_{k}, y_{k}\right)\right\}_{k=1}^{m} D={ (xk,yk)}k=1m;学习率 η \eta η
过程:
(1) 在 ( 0 , 1 ) (0,1) (0,1) 范围内随机初始化网络中所有连接权和阈值;
(2) repeat
(3) ——for all ( x k , y k ) ∈ D \left(x_{k}, y_{k}\right) \in D (xk,yk)∈D do
(4) ————根据当前参数和 y ^ j k = f ( β j − θ j ) \hat{y}_{j}^{k}=f\left(\beta_{j}-\theta_{j}\right) y^jk=f(βj−θj) 计算当前样本的输出 y ^ k \hat{y}_{k} y^k;
(5) ————根据 g j = y ^ j k ( 1 − y ^ j k ) ( y j k − y ^ j k ) g_{j}=\hat{y}_{j}^{k}\left(1-\hat{y}_{j}^{k}\right)\left(y_{j}^{k}-\hat{y}_{j}^{k}\right) gj=y^jk(1−y^jk)(yjk−y^jk) 计算输出层神经元的梯度项 g j g_{j} gj;
(6) ————根据 e h = b h ( 1 − b h ) ∑ j = 1 l w h j g j e_{h}=b_{h}\left(1-b_{h}\right) \sum_{j=1}^{l} w_{h j} g_{j} eh=bh(1−bh)∑j=1lwhjgj 隐藏层神经元的梯度项 e h e_{h} eh;
(7) ————更新连接权 w h j , v i h w_{h j}, v_{i h} whj,vih 与阈值 θ j , γ h \theta_{j}, \gamma_{h} θj,γh;
(8) ——end for
(9) until 达到停止条件
输出:连接权与阈值确定的多层前馈神经网络
防止过拟合:
- 早停(early stopping):将数据分成训练集和验证集,训练集用来计算梯度、更新权和阈值,验证集用来估计误差,若训练集误差降低但验证集误差升高,则停止训练,同事返回具有最小验证集误差的连接权和阈值。
- 正则化(regularization):在loss函数中增加一个用来描述网络复杂度的式子,如连接权和阈值的平方和: E = λ 1 m ∑ k = 1 m E k + ( 1 − λ ) ∑ i w i 2 , E=\lambda \frac{1}{m} \sum_{k=1}^{m} E_{k}+(1-\lambda) \sum_{i} w_{i}^{2}, E=λm1k=1∑mEk+(1−λ)i∑wi2,
- E k E_{k} Ek 表示第 k k k 个训练样例上的误差,
- w i w_{i} wi 表示连接权和阈值
- λ ∈ ( 0 , 1 ) \lambda \in(0,1) λ∈(0,1) 用于对经验误差与网络复杂度这两项进行折中, 常通过交叉验证法来估计.
三、全局最小 与 局部极小
- 基于梯度的搜索(梯度下降法)是使用最广泛的参数寻优方法。但是如果误差函数有多个局部极小值,误差函数若在当前点的梯度为0,则不能保证找到的解是全局最小。
- 试图跳出局部极小值:
- 模拟退火:在每一步都以一定的概率接受比当前解更差的结果。
- 随机梯度下降:在计算梯度时加入随机因素,所以即使陷入局部极小值点,计算出来的梯度仍可能不为0,就有机会跳出局部极小继续搜索最优参数。
Reference
[1] 陈希孺编著.概率论与数理统计[M].中国科学技术大学出版社,2009
[2] B 站视频教程:https://www.bilibili.com/video/BV1Mh411e7VU
[3] 线上南瓜书:https://datawhalechina.github.io/pumpkin-book/#/chapter1/chapter1
[4] 开源地址:https://github.com/datawhalechina/pumpkin-book
边栏推荐
- Today, there's a power failure for one day.... stop working for another day. Don't forget to study
- How much does it cost to build your own personal server
- 2022 Niuke multi school 7.23
- 【MLFP】《Face Presentation Attack with Latex Masks in Multispectral Videos》
- P2404 splitting of natural numbers
- String matching (Huawei)
- Drawing library matplotlibmatplotlib quick start
- [CCNA experiment sharing] routing between VLANs of layer 3 switches
- OSI的体系结构,以及各层协议
- Atcoder beginer contest 260 a~f problem solution
猜你喜欢

Binary search

Information system project manager must recite the core examination site (47) project subcontract

Lecun proposed that mask strategy can also be applied to twin networks based on vit for self supervised learning!
![[jzof] 06 print linked list from end to end](/img/c7/c2ac4823b5697279b81bec8f974ea9.png)
[jzof] 06 print linked list from end to end
![[image processing] pyefd.elliptic_ fourier_ How descriptors are used](/img/72/d2c825ddd95f541b37b98b2d7f6539.png)
[image processing] pyefd.elliptic_ fourier_ How descriptors are used

npm Warn config global `--global`, `--local` are deprecated. Use `--location=global` instead

String matching (Huawei)

Smarter! Airiot accelerates the upgrading of energy conservation and emission reduction in the coal industry

How to output position synchronization of motion control

How do test / development programmers survive the midlife crisis? You can see it at a glance
随机推荐
About the acid of MySQL, there are thirty rounds of skirmishes with mvcc and interviewers
Binary search
Mysql database commands
Image processing notes (1) image enhancement
What are the most problematic database accounts in DTS?
Practical skills!!
2022 Tsinghua summer school notes L2_ 1 basic composition of neural network
What should I pay attention to when choosing the self built database access method on ECs?
Is it safe to open an account on Alipay
Using skills and design scheme of redis cache (classic collection version)
Selenium test page content download function
Scientific computing toolkit SciPy Fourier transform
Gather relevant knowledge points and expand supplements
Diou and ciou loss of loss function
Can century model simulate soil respiration? Practice technology application and case analysis of century model
Uniqueness and ordering in set
Which bank outlet in Zhejiang can buy ETF fund products?
Lenovo Filez helps Zhongshui North achieve safe and efficient file management
PR 2022 22.5 Chinese version
What is a self built database on ECs?