当前位置:网站首页>Mathematical derivation in [pumpkin Book ml] (task4) neural network
Mathematical derivation in [pumpkin Book ml] (task4) neural network
2022-07-24 21:55:00 【Evening scenery at the top of the mountain】
Learning summary
List of articles
One 、 Neuron model and MLP Basics
Such as CTR Tasks , Judge whether the user has clicked this item ” wait . When the cell body receives these signals , Will make a simple judgment , Then output a signal through the axon , The size of the output signal represents the user's interest in the item .

The activation function in the above figure is sigmoid Activation function , Its mathematical definition is : f ( z ) = 1 1 + e − z f(z)=\frac{1}{1+\mathrm{e}^{-z}} f(z)=1+e−z1 Its function image is the one in the above figure S Type curve , Its function :
- Turn the input signal from (-∞,+∞) The definition field of is mapped to (0,1) Range of values ( Because in Click through rate forecast , Recommended questions , It is often necessary to predict one from 0 To 1 Probability ).
- sigmoid Functions are derivable everywhere , Facilitate the subsequent gradient descent learning process , So it has become a frequently used activation function . Other popular ones are tanh、ReLU etc. .
1.1 BP How neural networks learn
The network composed of multiple neurons has stronger data fitting ability , As shown in the figure below, an input layer 、 A simple neural network consisting of two neuron hidden layers and a single neuron output layer :
The structure of each blue neuron is the same as that of the single neuron ,h1 and h2 The input of neurons is made by x1 and x2 The eigenvectors of the components , And neurons o1 The input of is made by h1 and h2 The input vector composed of outputs .
Important training methods of Neural Networks , Forward propagation (Forward Propagation) And back propagation (Back Propagation).
(1) Forward propagation
The purpose of forward propagation is to get the estimated value of the model on the input based on the current network parameters , That is what we often call the model inference process . for instance , We need to pass the weight of a classmate 、 Height prediction TA Gender of , The process of forward propagation is a given weight value 71, Height value 178, Through neurons h1、h2 and o1 The calculation of , Get a sex probability value , for instance 0.87, This is it. TA Probability of being male .
(2) Loss function
If the real gender of this classmate is male , The real probability value is 1, According to the formula 2 Definition of absolute value error of , The loss of this prediction is |1-0.87| = 0.13.
l 1 ( y i , y ^ i ) = ∣ y i − y ^ i ∣ l_{1}\left(y_{i}, \hat{y}_{i}\right)=\left|y_{i}-\hat{y}_{i}\right| l1(yi,y^i)=∣yi−y^i∣
(3) gradient descent
The error between the predicted value and the real value is found (Loss), We need to use this error to guide the weight update , Make the whole neural network more accurate in the next prediction . The most common way to update weight is gradient descent , It updates the weight by calculating the partial derivative . such as , We need to update the weight w5, We must first find the loss function to w5 Partial derivative of ∂ L o 1 ∂ w 5 \frac{\partial L_{o 1}}{\partial w_{5}} ∂w5∂Lo1 From a mathematical point of view , The direction of the gradient is the direction where the function grows fastest , that The opposite direction of the gradient is the direction in which the function drops fastest , So let the loss function decrease in the fastest direction Is that we want the gradient w5 Update direction . Here we introduce another super parameter α, It represents the strength of gradient updating , Also known as Learning rate . Now you can write the formula of gradient update : w 5 t + 1 = w 5 t − α ∗ ∂ L o 1 ∂ w 5 w_{5}^{t+1}=w_{5}^{t}-\alpha * \frac{\partial L_{o 1}}{\partial w_{5}} w5t+1=w5t−α∗∂w5∂Lo1 Formula w5 Of course, it can be replaced by other parameters to be updated , Formula t Represents the number of updates .
For output layer neurons ( In the picture o1), We can directly use the gradient descent method to calculate the neuron correlation weight ( This is the picture 5 Weight in w5 and w6) Gradient of , So as to update the weight , But for the relevant parameters of hidden layer neurons ( such as w1), How can we use the loss of the output layer for gradient descent ?
——“ Use the chain rule in the derivation process (Chain Rule)”. Through the chain rule, we can solve the problem of gradient layer by layer back propagation . The final loss function to weight w1 The gradient of is from loss function to neuron h1 Partial derivative of output , And neurons h1 Output to weight w1 Multiplied by the partial derivatives of . in other words , The final gradient is conducted back layer by layer ,“ To guide the ” The weight w1 Update .
∂ L o 1 ∂ w 1 = ∂ L o 1 ∂ h 1 ⋅ ∂ h 1 ∂ w 1 \frac{\partial L_{o 1}}{\partial w_{1}}=\frac{\partial L_{o 1}}{\partial h_{1}} \cdot \frac{\partial h_{1}}{\partial w_{1}} ∂w1∂Lo1=∂h1∂Lo1⋅∂w1∂h1
1.2 Linear separable and nonlinear separable
To solve the nonlinear separable problem , We need to consider using multilayer functional neurons . The following is a two-layer perceptron that can solve the XOR problem :
Two 、 Error back propagation BP Algorithm
- BP The idea of algorithm : First, the error is propagated back to the hidden layer neurons , Adjust the connection weight from the hidden layer to the output layer and the threshold of the neurons in the output layer ; Then, according to the mean square error of hidden layer neurons , To adjust the connection weight from the input layer to the hidden layer and the threshold of the hidden layer neuron .
- BP The goal of the algorithm : Minimize the training set D Cumulative error on : E = 1 m ∑ k = 1 m E k E=\frac{1}{m} \sum_{k=1}^{m} E_{k} E=m1k=1∑mEk
BP The basic flow of the algorithm :
Input :
Training set D = { ( x k , y k ) } k = 1 m D=\left\{\left(x_{k}, y_{k}\right)\right\}_{k=1}^{m} D={ (xk,yk)}k=1m; Learning rate η \eta η
The process :
(1) stay ( 0 , 1 ) (0,1) (0,1) All connection weights and thresholds in the network are initialized randomly in the range ;
(2) repeat
(3) ——for all ( x k , y k ) ∈ D \left(x_{k}, y_{k}\right) \in D (xk,yk)∈D do
(4) ———— According to the current parameters and y ^ j k = f ( β j − θ j ) \hat{y}_{j}^{k}=f\left(\beta_{j}-\theta_{j}\right) y^jk=f(βj−θj) Calculate the output of the current sample y ^ k \hat{y}_{k} y^k;
(5) ———— according to g j = y ^ j k ( 1 − y ^ j k ) ( y j k − y ^ j k ) g_{j}=\hat{y}_{j}^{k}\left(1-\hat{y}_{j}^{k}\right)\left(y_{j}^{k}-\hat{y}_{j}^{k}\right) gj=y^jk(1−y^jk)(yjk−y^jk) Calculate the gradient term of neurons in the output layer g j g_{j} gj;
(6) ———— according to e h = b h ( 1 − b h ) ∑ j = 1 l w h j g j e_{h}=b_{h}\left(1-b_{h}\right) \sum_{j=1}^{l} w_{h j} g_{j} eh=bh(1−bh)∑j=1lwhjgj Gradient term of hidden layer neurons e h e_{h} eh;
(7) ———— Update connection rights w h j , v i h w_{h j}, v_{i h} whj,vih And threshold θ j , γ h \theta_{j}, \gamma_{h} θj,γh;
(8) ——end for
(9) until Stop condition reached
Output : Multilayer feedforward neural network with connection weight and threshold determination
Prevent over fitting :
- Stop early (early stopping): Divide the data into training set and verification set , The training set is used to calculate the gradient 、 Update weights and thresholds , The verification set is used to estimate the error , If the training set error decreases but the verification set error increases , Then stop training , Colleagues return the connection weight and threshold with the minimum verification set error .
- Regularization (regularization): stay loss Add an expression to the function to describe the network complexity , Such as the sum of squares of connection weight and threshold : E = λ 1 m ∑ k = 1 m E k + ( 1 − λ ) ∑ i w i 2 , E=\lambda \frac{1}{m} \sum_{k=1}^{m} E_{k}+(1-\lambda) \sum_{i} w_{i}^{2}, E=λm1k=1∑mEk+(1−λ)i∑wi2,
- E k E_{k} Ek It means the first one k k k Errors on training samples ,
- w i w_{i} wi Indicates connection rights and thresholds
- λ ∈ ( 0 , 1 ) \lambda \in(0,1) λ∈(0,1) It is used to make a compromise between empirical error and network complexity , It is often estimated by cross validation .
3、 ... and 、 Global minimum And Local minimum
- Gradient based search ( Gradient descent method ) It is the most widely used parameter optimization method . But if the error function has multiple local minima , If the gradient of the error function at the current point is 0, We can't guarantee that the solution we find is global minimum .
- Try to jump out of the local minimum :
- Simulated annealing : At each step, accept a worse result than the current solution with a certain probability .
- Stochastic gradient descent : When calculating the gradient, add random factors , So even if it falls into a local minimum , The calculated gradient may still not be 0, There is a chance to jump out of the local minimum and continue to search for the optimal parameters .
Reference
[1] Edited by Chen Xiru . Probability theory and mathematical statistics [M]. China University of science and Technology Press ,2009
[2] B Stop video tutorial :https://www.bilibili.com/video/BV1Mh411e7VU
[3] Online pumpkin book :https://datawhalechina.github.io/pumpkin-book/#/chapter1/chapter1
[4] Open source address :https://github.com/datawhalechina/pumpkin-book
边栏推荐
- Drawing library Matplotlib styles and styles
- 【类的组合(在一个类中定义一个类)】
- Using skills and design scheme of redis cache (classic collection version)
- ESP32C3 LED PWM使用和ESP32差异说明
- Thank Huawei for sharing the developer plan
- 图像处理笔记(1)图像增强
- Alibaba cloud and parallel cloud launched the cloud XR platform to support the rapid landing of immersive experience applications
- Makefile基础知识--扩展
- Thread pool learning
- What is a database password?
猜你喜欢

Leetcode skimming -- bit by bit record 017

Sqlserver BCP parameter interpretation, character format selection and fault handling summary
![[Matplotlib drawing]](/img/ac/dea6fa0aff6f02477fba48c929fadc.png)
[Matplotlib drawing]

Redefine analysis - release of eventbridge real-time event analysis platform

MySQL forced indexing

Using skills and design scheme of redis cache (classic collection version)
![[CCNA experiment sharing] routing between VLANs of layer 3 switches](/img/71/2f28c6b6b62f273fad1b3d71e648a1.jpg)
[CCNA experiment sharing] routing between VLANs of layer 3 switches

【南瓜书ML】(task4)神经网络中的数学推导

How to gracefully realize regular backup of MySQL database (glory Collection Edition)

Brand new: the latest ranking of programming languages in July
随机推荐
Documentary of the second senior brother
How does novice Xiaobai build a personal server?
Drawing library Matplotlib drawing
PR 2022 22.5 Chinese version
Leetcode skimming -- bit by bit record 017
Feeding Program Source Code to ZK VMs
Scientific computing toolkit SciPy data interpolation
Big country "grain" policy | wheat expert Liu Luxiang: China's rations are absolutely safe, and the key to increasing grain potential lies in science and technology
【南瓜书ML】(task4)神经网络中的数学推导
Wechat applet monitoring real-time geographical location change event interface application
How to output position synchronization of motion control
一种自动化九点标定工具原理(包涵部分源码)
Unity & facegood audio2face drives face blendshape with audio
图像处理笔记(1)图像增强
Understand MySQL index and b+tree in an easy to understand way (supreme Collection Edition)
About the acid of MySQL, there are thirty rounds of skirmishes with mvcc and interviewers
A simple method -- determine whether the dictionary has changed
Makefile基础知识--扩展
Build Tencent cloud website server at low cost (build your own website server)
Mysql database commands