当前位置:网站首页>机器学习:线性回归
机器学习:线性回归
2022-06-24 19:28:00 【翁炜强】
低级API实现:
1.随机初始化数据
import matplotlib.pyplot as plt
import tensorflow as tf
TRUE_W=3.0
TRUE_b=2.0
NUM_SAMPLES=100
#初始化随机数据
X=tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
noise=tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
y=X*TRUE_W+TRUE_b+noise #添加噪声
plt.scatter(X,y)
2.
定义一元回归模型并拟合曲线:
𝑓(𝑤,𝑏,𝑥)=𝑤∗𝑥+𝑏
class Model(object): #object模型的主体
def __init__(self):
self.W = tf.Variable(tf.random.uniform([1])) # 随机初始化参数
self.b = tf.Variable(tf.random.uniform([1]))
def __call__(self, x):
return self.W * x + self.b # w*x + b
model = Model() # 实例化模型
plt.scatter(X, y)
plt.plot(X, model(X), c='r')

可见拟合效果不是很好 因此继续训练模型
3.利用损失函数 去 进行梯度下降迭代 得到好的拟合结果
损失函数:

更新参数:
𝑏←b−𝑙𝑟∗∂loss(𝑤,𝑏)
w←w−𝑙𝑟∗∂loss(𝑤,𝑏)
lr指是学习率
最后迭代十次
def loss_fn(model,x,y):
y_=model(x)
return tf.reduce_mean(tf.square(y_ -y))
EPOCHS =10
LEARNING_RATE=0.1
for epoch in range (EPOCHS): #迭代次数
with tf.GradientTape() as tape:
loss=loss_fn(model,X,y)#计算损失
dW,db=tape.gradient(loss,[model.W,model.b]) #计算梯度
model.W.assign_sub(LEARNING_RATE*dW)
model.b.assign_sub(LEARNING_RATE*db)
#输出计算结果
print(f'Epoch[{epoch}/{EPOCHS}], loss[{loss}], W/b[{model.W.numpy()}/{model.b.numpy()}]')
plt.scatter(X, y)
plt.plot(X, model(X), c='r')得到以下结果:
高阶API实现:
使用tensorflow现有库中的keras
model = tf.keras.Sequential() # 新建顺序模型
model.add(tf.keras.layers.Dense(units=1, input_dim=1)) # 添加线性层
model.compile(optimizer='sgd', loss='mse') # 定义损失函数和优化方法
model.fit(X, y, epochs=10, batch_size=32) # 训练模型
边栏推荐
- Bld3 getting started UI
- Make tea and talk about heroes! Leaders of Fujian Provincial Development and Reform Commission and Fujian municipal business office visited Yurun Health Division for exchange and guidance
- Antdb database online training has started! More flexible, professional and rich
- Intelligent fish tank control system based on STM32 under Internet of things
- Please open online PDF carefully
- Implementation of adjacency table storage array of graph
- VSCode无网环境快速迁移开发环境(VIP典藏版)
- Byte software testing basin friends, you can change jobs. Is this still the byte you are thinking about?
- leetcode_191_2021-10-15
- BBR bandwidth per second conversion logic
猜你喜欢

Understanding openstack network

【吴恩达笔记】机器学习基础

Pattern recognition - 1 Bayesian decision theory_ P1

【论】A deep-learning model for urban traffic flow prediction with traffic events mined from twitter

【论】Deep learning in the COVID-19 epidemic: A deep model for urban traffic revitalization index

Wireshark packet capturing skills summarized by myself

EditText controls the soft keyboard to search

123. the best time to buy and sell shares III

Byte software testing basin friends, you can change jobs. Is this still the byte you are thinking about?
![[product design and R & D collaboration tool] Shanghai daoning provides you with blue lake introduction, download, trial and tutorial](/img/0f/e0b261496d04ca3da8a7d7d19e5bf1.png)
[product design and R & D collaboration tool] Shanghai daoning provides you with blue lake introduction, download, trial and tutorial
随机推荐
XTransfer技术新人进阶秘诀:不可错过的宝藏Mentor
如何做到全彩户外LED显示屏节能环保
架构实战营 第 6 期 毕业设计
SYSCALL_ Define5 setsockopt code flow
Tutorial on obtaining JD cookies by mobile browser
[精选] 多账号统一登录,你如何设计?
The most important thing at present
Li Kou daily question - day 26 -496 Next larger element I
leetcode-201_2021_10_17
Docking of arkit and character creator animation curves
Slider controls the playback progress of animator animation
自己总结的wireshark抓包技巧
Call process of package receiving function
VSCode无网环境快速迁移开发环境(VIP典藏版)
【产品设计研发协作工具】上海道宁为您提供蓝湖介绍、下载、试用、教程
VirtualBox virtual machine installation win10 Enterprise Edition
memcached全面剖析–5. memcached的应用和兼容程序
Tdengine can read and write through dataX
Volcano成Spark默认batch调度器
TCP specifies the source port
