当前位置:网站首页>机器学习:线性回归
机器学习:线性回归
2022-06-24 19:28:00 【翁炜强】
低级API实现:
1.随机初始化数据
import matplotlib.pyplot as plt
import tensorflow as tf
TRUE_W=3.0
TRUE_b=2.0
NUM_SAMPLES=100
#初始化随机数据
X=tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
noise=tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
y=X*TRUE_W+TRUE_b+noise #添加噪声
plt.scatter(X,y)
2.
定义一元回归模型并拟合曲线:
𝑓(𝑤,𝑏,𝑥)=𝑤∗𝑥+𝑏
class Model(object): #object模型的主体
def __init__(self):
self.W = tf.Variable(tf.random.uniform([1])) # 随机初始化参数
self.b = tf.Variable(tf.random.uniform([1]))
def __call__(self, x):
return self.W * x + self.b # w*x + b
model = Model() # 实例化模型
plt.scatter(X, y)
plt.plot(X, model(X), c='r')

可见拟合效果不是很好 因此继续训练模型
3.利用损失函数 去 进行梯度下降迭代 得到好的拟合结果
损失函数:

更新参数:
𝑏←b−𝑙𝑟∗∂loss(𝑤,𝑏)
w←w−𝑙𝑟∗∂loss(𝑤,𝑏)
lr指是学习率
最后迭代十次
def loss_fn(model,x,y):
y_=model(x)
return tf.reduce_mean(tf.square(y_ -y))
EPOCHS =10
LEARNING_RATE=0.1
for epoch in range (EPOCHS): #迭代次数
with tf.GradientTape() as tape:
loss=loss_fn(model,X,y)#计算损失
dW,db=tape.gradient(loss,[model.W,model.b]) #计算梯度
model.W.assign_sub(LEARNING_RATE*dW)
model.b.assign_sub(LEARNING_RATE*db)
#输出计算结果
print(f'Epoch[{epoch}/{EPOCHS}], loss[{loss}], W/b[{model.W.numpy()}/{model.b.numpy()}]')
plt.scatter(X, y)
plt.plot(X, model(X), c='r')得到以下结果:
高阶API实现:
使用tensorflow现有库中的keras
model = tf.keras.Sequential() # 新建顺序模型
model.add(tf.keras.layers.Dense(units=1, input_dim=1)) # 添加线性层
model.compile(optimizer='sgd', loss='mse') # 定义损失函数和优化方法
model.fit(X, y, epochs=10, batch_size=32) # 训练模型
边栏推荐
猜你喜欢

力扣每日一题-第26天-496.下一个更大元素Ⅰ

如何做到全彩户外LED显示屏节能环保

C语言实现DNS请求器
![[camera Foundation (II)] camera driving principle and Development & v4l2 subsystem driving architecture](/img/b5/23e3aed317ca262ebd8ff4579a41a9.png)
[camera Foundation (II)] camera driving principle and Development & v4l2 subsystem driving architecture

(待补充)GAMES101作业7提高-实现微表面模型你需要了解的知识

多路转接select

memcached全面剖析–2. 理解memcached的內存存儲

Memcached comprehensive analysis – 3 Deletion mechanism and development direction of memcached

福建省发改委福州市营商办莅临育润大健康事业部指导视察工作

Byte software testing basin friends, you can change jobs. Is this still the byte you are thinking about?
随机推荐
Transport layer UDP & TCP
Vscode netless environment rapid migration development environment (VIP collection version)
MySQL optimizes query speed
煮茶论英雄!福建省发改委、市营商办领导一行莅临育润大健康事业部交流指导
Multi task model of recommended model: esmm, MMOE
直击“三夏”生产:丰收喜报频传 夏播紧锣密鼓
自己总结的wireshark抓包技巧
Decoration home page custom full screen video playback effect GIF dynamic picture production video tutorial playback code operation settings full screen center Alibaba international station
LeetCode-513. 找树左下角的值
TCP Jprobe utilization problem location
Li Kou daily question - day 26 -496 Next larger element I
Shengzhe technology AI intelligent drowning prevention service launched
优雅的自定义 ThreadPoolExecutor 线程池
在每个树行中找最大值[分层遍历之一的扩展]
memcached全面剖析–3. memcached的删除机制和发展方向
ping: www.baidu. Com: unknown name or service
02---纵波不可能产生的现象
【Camera基础(一)】Camera摄像头工作原理及整机架构
Datakit 代理实现局域网数据统一汇聚
Volcano成Spark默认batch调度器
