当前位置:网站首页>梯度下降法介紹-黑馬程序員機器學習講義
梯度下降法介紹-黑馬程序員機器學習講義
2022-06-24 04:35:00 【黑馬程序員官方】
學習目標
- 知道全梯度下降算法的原理
- 知道隨機梯度下降算法的原理
- 知道隨機平均梯度下降算法的原理
- 知道小批量梯度下降算法的原理
上一節中給大家介紹了最基本的梯度下降法實現流程,常見的梯度下降算法有:
- 全梯度下降算法(Full gradient descent),
- 隨機梯度下降算法(Stochastic gradient descent),
- 小批量梯度下降算法(Mini-batch gradient descent),
- 隨機平均梯度下降算法(Stochastic average gradient descent)
它們都是為了正確地調節權重向量,通過為每個權重計算一個梯度,從而更新權值,使目標函數盡可能最小化。其差別在於樣本的使用方式不同。
1 全梯度下降算法(FG)
計算訓練集所有樣本誤差,對其求和再取平均值作為目標函數。
權重向量沿其梯度相反的方向移動,從而使當前目標函數减少得最多。
因為在執行每次更新時,我們需要在整個數據集上計算所有的梯度,所以批梯度下降法的速度會很慢,同時,批梯度下降法無法處理超出內存容量限制的數據集。
批梯度下降法同樣也不能在線更新模型,即在運行的過程中,不能增加新的樣本。
其是在整個訓練數據集上計算損失函數關於參數θ的梯度:

2 隨機梯度下降算法(SG)
由於FG每迭代更新一次權重都需要計算所有樣本誤差,而實際問題中經常有上億的訓練樣本,故效率偏低,且容易陷入局部最優解,因此提出了隨機梯度下降算法。
其每輪計算的目標函數不再是全體樣本誤差,而僅是單個樣本誤差,即每次只代入計算一個樣本目標函數的梯度來更新權重,再取下一個樣本重複此過程,直到損失函數值停止下降或損失函數值小於某個可以容忍的閾值。
此過程簡單,高效,通常可以較好地避免更新迭代收斂到局部最優解。其迭代形式為

其中,x(i)錶示一條訓練樣本的特征值,y(i)錶示一條訓練樣本的標簽值
但是由於,SG每次只使用一個樣本迭代,若遇上噪聲則容易陷入局部最優解。
3 小批量梯度下降算法(mini-batch)
小批量梯度下降算法是FG和SG的折中方案,在一定程度上兼顧了以上兩種方法的優點。
每次從訓練樣本集上隨機抽取一個小樣本集,在抽出來的小樣本集上采用FG迭代更新權重。
被抽出的小樣本集所含樣本點的個數稱為batch_size,通常設置為2的幂次方,更有利於GPU加速處理。
特別的,若batch_size=1,則變成了SG;若batch_size=n,則變成了FG.其迭代形式為

4 隨機平均梯度下降算法(SAG)
在SG方法中,雖然避開了運算成本大的問題,但對於大數據訓練而言,SG效果常不盡如人意,因為每一輪梯度更新都完全與上一輪的數據和梯度無關。
隨機平均梯度算法克服了這個問題,在內存中為每一個樣本都維護一個舊的梯度,隨機選擇第i個樣本來更新此樣本的梯度,其他樣本的梯度保持不變,然後求得所有梯度的平均值,進而更新了參數。
如此,每一輪更新僅需計算一個樣本的梯度,計算成本等同於SG,但收斂速度快得多。
5 小結
- 全梯度下降算法(FG)【知道】
- 在進行計算的時候,計算所有樣本的誤差平均值,作為我的目標函數
- 隨機梯度下降算法(SG)【知道】
- 每次只選擇一個樣本進行考核
- 小批量梯度下降算法(mini-batch)【知道】
- 選擇一部分樣本進行考核
- 隨機平均梯度下降算法(SAG)【知道】
- 會給每個樣本都維持一個平均值,後期計算的時候,參考這個平均值
边栏推荐
- I have an agreement with IOT
- 提pr,push 的时候网络超时配置方法
- 一文简述:供应链攻击知多少
- 2. in depth tidb: entry code analysis and debugging tidb
- 外网访问svn服务器(外网访问部署在云上的svn服务器)
- Abnova荧光原位杂交(FISH)探针解决方案
- Worthington胰蛋白酶的物化性质及特异性
- High availability architecture design to deal with network failure of operators
- How to restart the ECS? What are the differences between ECS restart and normal computers?
- How does the VPS server upload data? Is the VPS server free to use?
猜你喜欢
Summary of Android interview questions in 2020 (intermediate)

ServiceStack. Source code analysis of redis (connection and connection pool)

MySQL - SQL execution process

What is the data center

SAP MTS/ATO/MTO/ETO专题之八:ATO模式2 D+空模式策略用85

Abnova膜蛋白脂蛋白体解决方案

Jointly build Euler community and share Euler ecology | join hands with Kirin software to create a digital intelligence future

开源之夏2022中选结果公示,449名高校生将投入开源项目贡献

Abnova fluorescence in situ hybridization (FISH) probe solution

apipost接口断言详解
随机推荐
What is Ping? How can the server disable Ping?
Apipost interface assertion details
Advanced authentication of uni app [Day12]
apipost接口断言详解
Openeuler kernel technology sharing issue 20 - execution entity creation and switching
Kubernetes resource topology aware scheduling optimization
event
Training course of mixed accuracy from simple to deep
How to select a telemedicine program system? These four points are the key!
Database answers build standard, answer as required
Easyanticheat uses to inject unsigned code into a protected process (1)
getAttribute 返回值为null
How to install software on ECs is it expensive to rent ECS
博士申请 | 香港科技大学(广州)刘浩老师招收数据挖掘方向全奖博士/硕士
15+ urban road element segmentation application, this segmentation model is enough
Submit sitemap to Baidu
Worthington脱氧核糖核酸酶I特异性和相关研究
ARM 架构、ARM7、ARM9、STM32、Cortex M3 M4 、51、AVR 有啥区别
How to modify the channel name registered by the camera in the easygbs national standard platform?
What does VPS server mean? What is the difference between a VPS server and an ECS?