当前位置:网站首页>Machine learning: linear regression
Machine learning: linear regression
2022-06-24 21:57:00 【Weng Weiqiang】
low-level API Realization :
1. Random initialization data
import matplotlib.pyplot as plt
import tensorflow as tf
TRUE_W=3.0
TRUE_b=2.0
NUM_SAMPLES=100
# Initialize random data
X=tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
noise=tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
y=X*TRUE_W+TRUE_b+noise # Add noise
plt.scatter(X,y)
2.
Define the univariate regression model and fit the curve :
𝑓(𝑤,𝑏,𝑥)=𝑤∗𝑥+𝑏
class Model(object): #object Body of model
def __init__(self):
self.W = tf.Variable(tf.random.uniform([1])) # Random initialization parameters
self.b = tf.Variable(tf.random.uniform([1]))
def __call__(self, x):
return self.W * x + self.b # w*x + b
model = Model() # Instantiation model
plt.scatter(X, y)
plt.plot(X, model(X), c='r')

It can be seen that the fitting effect is not very good So continue training the model
3. Using the loss function Go to Perform gradient descent iteration Get good fitting results
Loss function :

Update parameters :
𝑏←b−𝑙𝑟∗∂loss(𝑤,𝑏)
w←w−𝑙𝑟∗∂loss(𝑤,𝑏)
lr It refers to the learning rate
The last iteration is ten times
def loss_fn(model,x,y):
y_=model(x)
return tf.reduce_mean(tf.square(y_ -y))
EPOCHS =10
LEARNING_RATE=0.1
for epoch in range (EPOCHS): # The number of iterations
with tf.GradientTape() as tape:
loss=loss_fn(model,X,y)# Calculate the loss
dW,db=tape.gradient(loss,[model.W,model.b]) # Calculate the gradient
model.W.assign_sub(LEARNING_RATE*dW)
model.b.assign_sub(LEARNING_RATE*db)
# Output calculation results
print(f'Epoch[{epoch}/{EPOCHS}], loss[{loss}], W/b[{model.W.numpy()}/{model.b.numpy()}]')
plt.scatter(X, y)
plt.plot(X, model(X), c='r')The following results are obtained :
Higher order API Realization :
Use tensorflow In an existing library keras
model = tf.keras.Sequential() # Create a new sequence model
model.add(tf.keras.layers.Dense(units=1, input_dim=1)) # Add a linear layer
model.compile(optimizer='sgd', loss='mse') # Define loss function and optimization method
model.fit(X, y, epochs=10, batch_size=32) # Training models
边栏推荐
- 虚拟机CentOS7中无图形界面安装Oracle(保姆级安装)
- (to be added) games101 job 7 improvement - knowledge you need to know to realize micro surface model
- 如何化解35岁危机?华为云数据库首席架构师20年技术经验分享
- Li Kou daily question - day 25 -496 Next larger element I
- 降低pip到指定版本(通過PyCharm昇級pip,在降低到原來版本)
- 【论】Deep learning in the COVID-19 epidemic: A deep model for urban traffic revitalization index
- 好想送对象一束花呀
- Sslhandshakeexception: no subject alternative names present - sslhandshakeexception: no subject alternative names present
- 985测试工程师被吊打,学历和经验到底谁更重要?
- leetcode:1504. 统计全 1 子矩形的个数
猜你喜欢

面试官:你说你精通Redis,你看过持久化的配置吗?

2022国际女性工程师日:戴森设计大奖彰显女性设计实力

滤波数据分析

leetcode-201_2021_10_17

【吴恩达笔记】多变量线性回归

《各行业零代码企业应用案例集锦》正式发布

CV2 package guide times could not find a version that satisfies the requirement CV2 (from versions: none)

【论】Deep learning in the COVID-19 epidemic: A deep model for urban traffic revitalization index

How to achieve energy conservation and environmental protection of full-color outdoor LED display

A deep learning model for urban traffic flow prediction with traffic events mined from twitter
随机推荐
LeetCode-513. 找树左下角的值
多线程收尾
[untitled]
A deep learning model for urban traffic flow prediction with traffic events mined from twitter
EasyBypass
队列实现原理和应用
Excel布局
即构「畅直播」上线!提供全链路升级的一站式直播服务
LeetCode-513. Find the value in the lower left corner of the tree
【吴恩达笔记】多变量线性回归
Datakit 代理实现局域网数据统一汇聚
Li Kou daily question - day 25 -496 Next larger element I
好想送对象一束花呀
Slider controls the playback progress of animator animation
Summary of papers on traveling salesman problem (TSP)
平衡二叉搜索树
【吴恩达笔记】机器学习基础
Based on asp Net development of fixed assets management system source code enterprise fixed assets management system source code
使用Adb连接设备时提示设备无权限
【无标题】
