当前位置:网站首页>Pytorch learning -- using gradient descent method to realize univariate linear regression
Pytorch learning -- using gradient descent method to realize univariate linear regression
2022-07-24 11:09:00 【practical_ sharp】
Univariate linear regression
The one variable linear model is very simple , Suppose we have variables x i x_i xi And the target y i y_i yi, Every i Corresponds to a data point , Hope to build a model 
y ^ i \hat{y}_i y^i It's what we predicted , Hope to pass y ^ i \hat{y}_i y^i To fit the target y i y_i yi, Generally speaking, it is to find the function fitting y i y_i yi To minimize the error , Is to minimize the 
gradient
A gradient is mathematically a derivative , If it's a multivariate function , So the gradient is the partial derivative . Like a function f(x, y), that f The gradient is
( ∂ f ∂ x , ∂ f ∂ y ) (\frac{\partial f}{\partial x},\ \frac{\partial f}{\partial y}) (∂x∂f, ∂y∂f)
Can be called grad f(x, y) perhaps ∇ f ( x , y ) \nabla f(x, y) ∇f(x,y). Specific point ( x 0 , y 0 ) (x_0,\ y_0) (x0, y0) The gradient is ∇ f ( x 0 , y 0 ) \nabla f(x_0,\ y_0) ∇f(x0, y0).
What's the point of gradients ? Geometrically speaking , The gradient value of a point is where the function changes the fastest , say concretely , For the function f(x, y), At point ( x 0 , y 0 ) (x_0, y_0) (x0,y0) It's about , Along the gradient ∇ f ( x 0 , y 0 ) \nabla f(x_0,\ y_0) ∇f(x0, y0) The direction of , The function increases the fastest , That is, along the direction of the gradient , We can find the maximum point of the function faster , Or vice versa, in the opposite direction of the gradient , We can find the minimum point of the function faster .
Gradient descent method
With an understanding of gradients , We can understand the principle of gradient descent method . Above, we need to minimize this error , That is, we need to find the minimum point of this error , Then we can find the minimum point along the opposite direction of the gradient .
We can look at an intuitive explanation . Let's say we're somewhere on a mountain , Because we don't know how to get down the mountain , So I decided to go step by step , That is, every time you get to a position , Find the gradient of the current position , In the negative direction of the gradient , That is, the steepest position at present, take a step down , Then continue to solve the current position gradient , Take the steepest and easiest place to go down the mountain . Step by step , Until I feel that we have reached the foot of the mountain . Of course, go on like this , Maybe we can't go to the foot of the mountain , But to a certain part of the low peak .
Analogy to our problem , It's going in the opposite direction of the gradient , We keep changing w and b Value , Finally find the best group w and b To minimize the error .
At the time of the update , We need to decide the magnitude of each update , For example, in the case of going down the mountain , We need the length of each step down , This length is called the learning rate , use η \eta η Express , This learning rate is very important , Different learning rates will lead to different results , Too little learning rate will lead to a very slow decline , Too much learning rate will lead to obvious beating .
Finally, our updated formula is
w : = w − η ∂ f ( w , b ) ∂ w b : = b − η ∂ f ( w , b ) ∂ b w := w - \eta \frac{\partial f(w,\ b)}{\partial w} \\ b := b - \eta \frac{\partial f(w,\ b)}{\partial b} w:=w−η∂w∂f(w, b)b:=b−η∂b∂f(w, b)
By constantly iterating and updating , Finally, we can find an optimal set of w and b, This is the principle of gradient descent method .
PyTorch Realize linear regression of one variable
Experimental data to be fitted
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
[9.779], [6.182], [7.59], [2.167], [7.042],
[10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
[3.366], [2.596], [2.53], [1.221], [2.827],
[3.465], [1.65], [2.904], [1.3]], dtype=np.float32)
plt.plot(x_train, y_train, 'bo')
plt.show()
Display images :
Model definition and initialization
# convert to Tensor
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
# Defining parameters w and b
w = Variable(torch.randn(1), requires_grad=True) # Random initialization
b = Variable(torch.zeros(1), requires_grad=True) # Use 0 To initialize
# Build a linear regression model
x_train = Variable(x_train)
y_train = Variable(y_train)
def linear_model(x):
return x * w + b
y_ = linear_model(x_train)
plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')
plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')
plt.show()

At this time, we need to calculate our error function , That is to say 
# Calculation error
def get_loss(y_, y):
return torch.mean((y_ - y_train) ** 2)
loss = get_loss(y_, y_train)
# Print it and have a look loss Size
print(loss)

Define the error function , Next we need to calculate w and b The gradient of
# Automatic derivation
loss.backward()
# see w and b Gradient of
print(w.grad)
print(b.grad)

Update parameters once
# Update parameters once
w.data = w.data - 1e-2 * w.grad.data
b.data = b.data - 1e-2 * b.grad.data
y_ = linear_model(x_train)
plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')
plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')
plt.show()

As you can see from the above example , After the update, the red line ran under the blue line , There is no particularly good fitting of the true value of blue , So we need to update several times
Loop update
for e in range(10): # Conduct 10 Secondary update
y_ = linear_model(x_train)
loss = get_loss(y_, y_train)
w.grad.zero_() # Remember to zero the gradient
b.grad.zero_() # Remember to zero the gradient
loss.backward()
w.data = w.data - 1e-2 * w.grad.data # to update w
b.data = b.data - 1e-2 * b.grad.data # to update b
print('epoch: {}, loss: {}'.format(e, loss.item()))
y_ = linear_model(x_train)
plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')
plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')
plt.show()
tensor(13.2937, grad_fn=<MeanBackward0>)
tensor([-47.0074])
tensor([-6.9313])
epoch: 0, loss: 0.46915704011917114
epoch: 1, loss: 0.23152294754981995
epoch: 2, loss: 0.22683243453502655
epoch: 3, loss: 0.22645452618598938
epoch: 4, loss: 0.22615781426429749
epoch: 5, loss: 0.2258642166852951
epoch: 6, loss: 0.22557203471660614
epoch: 7, loss: 0.22528137266635895
epoch: 8, loss: 0.22499223053455353
epoch: 9, loss: 0.2247045636177063
after 10 Secondary update , We found that the prediction result of red has better fitted the true value of blue .
边栏推荐
- 《Nature》论文插图复刻第3期—面积图(Part2-100)
- In idea, system.getproperty ("user.dir") identifies the path of the module: the setting of the working directory
- 【Golang】golang实现发送微信服务号模板消息
- MySQL engine
- 只会“点点点”,凭什么让开发看得起你?
- Mockito3.8 how to mock static methods (how to mock PageHelper)
- SQL optimization skills and precautions
- Zero basic learning canoe panel (7) -- input/output box
- Read the triode easily. It turns out that it works like this
- [FPGA]: IP core ibert
猜你喜欢

Altium one key automatic BOM

爬虫与反爬:一场无休止之战

Cub school learning - Kernel Development

TwinCAT3各版本下载路径

Only "a little bit", why do developers look up to you?

Neo4j installation tutorial

2018 arXiv | Objective-Reinforced Generative Adversarial Networks (ORGAN) for Sequence Generation Mo

read_csv 报错:‘gbk‘ codec can‘t decode byte 0xb4 in position 274: illegal multibyte sequence

乘势而上,OceanBase推动数字支付精益增长

High speed ADC test experience
随机推荐
Druid encryption command
[interview: Basics 03: selection sort]
【Golang】golang实现post请求发送form类型数据函数
UNIX C language POSIX thread creation, obtaining thread ID, merging thread, separating thread, terminating thread, thread comparison
"Low power Bluetooth module" master-slave integrated Bluetooth sniffer - help smart door lock
Build resume editor based on Nocode
【C】 Recursive and non recursive writing of binary tree traversal
【直播报名】Location Cache 模块浅析及 OCP 监控、报警详解
[FPGA]: IP core ibert
BBR 与 queuing
Talk about new congestion control
MySQL paging
Decomposition of kubernets principle
Xilinx FPGA soft core development process
Robot Framework官方教程(一)入门
Taking advantage of the momentum, oceanbase promotes the lean growth of digital payment
【C】 Understanding C language variable scope and life cycle from memory
轻松读懂三极管,原来它是这样工作的
基于NoCode构建简历编辑器
Capture and handling of JDBC exception sqlexception