当前位置:网站首页>tensorflow2的GradientTape求梯度
tensorflow2的GradientTape求梯度
2022-06-23 11:09:00 【河北一帆】
上一篇博文使用tensorflow2创建神经网络_河北一帆的博客-CSDN博客中创建的神经网络,使用optimizer的minimize方法进行损失函数优化,其原理是对每层的权重和偏置求梯度,进行梯度下降更新。
tensorflow2中提供了GradientTape方式能够非常方便的实现求梯度,通过求梯度,自己实现minimize方法。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
class NN:
def __init__(self):
# 创建一个隐藏层的神经网络模型
self.model = tf.keras.Sequential()
self.model.add(tf.keras.layers.Dense(10, activation='relu'))
self.model.add(tf.keras.layers.Dense(1, activation='relu'))
self.model.build((None, 2)) # 2表示 input是2维
def output(self, x):
return self.model(x)
def train(self, x, y):
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(self.model.trainable_variables)
loss = tf.reduce_mean(tf.square(y - self.output(x))) # 均方误差损失函数
gradients = tape.gradient(loss, self.model.trainable_variables)
optimizer = tf.keras.optimizers.Adam()
optimizer.apply_gradients(zip(gradients, self.model.trainable_weights))
print(loss)
if __name__ == "__main__":
nn = NN()
# 训练集
x = np.float32(np.random.uniform(0, 10, size=(100, 1)))
y = np.float32(np.random.uniform(0, 10, size=(100, 1)))
z = 3 * x + y
xx = np.hstack((x, y))
for i in range(1000):
nn.train(xx, z)
# 测试集
x_verify = np.float32(np.random.uniform(0, 10, size=(100, 1)))
y_verify = np.float32(np.random.uniform(0, 10, size=(100, 1)))
z_verify = 3 * x_verify + y_verify
xx_verify = np.hstack((x_verify, y_verify))
z_ = nn.output(xx_verify).numpy() # Tensor张量转成numpy array
plt.plot(z_verify)
plt.plot(z_)
plt.show()
self.model.trainable_variables是两层网络的权重和偏置,[w1, b1, w2, b2] = self.model.trainable_variables可以打断点观察。
gradients = tape.gradient(loss, self.model.trainable_variables)即对网络权重和偏置求偏导
网络权重更新时
公式中的alpha,可以使用Adam优化算法变步长计算,以下两行实现的就是该功能
optimizer = tf.keras.optimizers.Adam()
optimizer.apply_gradients(zip(gradients, self.model.trainable_weights))也可使用SGD方法固定Learning rate即
的值
边栏推荐
- Which securities company has the lowest Commission for opening a mobile account? Is it safe to open an account online now?
- Noi OJ 1.3 09: circle related computing C language
- TTY drive frame
- Share a mobile game script source code
- 新派科技美学、原生物联网操作系统重塑全屋智能
- 每日一题day7-1652. 拆炸弹
- 一年多时间时移世易,中国芯片不断突破,美国芯片却难以卖出
- Opencloudos uses snap to install NET 6
- 从0到1,IDE如何提升端侧研发效率?| DX研发模式
- The simplest DIY actuator cluster control program based on 51 single chip microcomputer, pca9685, IIC and PTZ
猜你喜欢

直播带货app源码搭建中,直播CDN的原理是什么?

Picture storage -- Reference

ESP32-CAM无线监控智能网关的设计与实现

单向链表实现--计数

程序中创建一个子进程,然后父子进程各自独自运行,父进程在标准输入设备上读入小写字母,写入管道。子进程从管道读取字符并转化为大写字母。读到x结束

Install the typescript environment and enable vscode to automatically monitor the compiled TS file as a JS file

The simplest DIY actuator controller based on 51 single chip microcomputer

安卓安全/逆向面试题

今天14:00 | 12位一作华人学者开启 ICLR 2022

“互联网+”大赛命题火热对接中 | 一图读懂百度38道命题
随机推荐
最简单DIY基于STM32的远程控制电脑系统①(电容触摸+按键控制)
torch权重转mindspore
MAUI使用Masa blazor组件库
1154. 一年中的第几天
Groovy之Map操作
Force buckle 1319 Number of connected network operations
强化责任意识和底线思维 全力筑牢抗洪抢险“安全堤”
运行时应用自我保护(RASP):应用安全的自我修养
Simplest DIY steel patriot machine gun controller based on Bluetooth, 51 MCU and steering gear
经济小常识
直播带货app源码搭建中,直播CDN的原理是什么?
Win10 无线网络,系统搜索不到WLAN的,解决办法 (以及 VMnet1,8)
Simplest DIY mpu6050 gyroscope attitude control actuator program based on stm32f407 Explorer development board
Noi OJ 1.3 14: elephant drinking water C language
如何用 Redis 实现一个分布式锁
Noi OJ 1.4 04: odd even ASCII value judgment C language
Rancher 2.6 全新 Monitoring 快速入门
File has not been synchronized when NFS is mounted
Noi OJ 1.2 06: round floating point numbers to zero
The simplest DIY actuator cluster control program based on 51 single chip microcomputer, pca9685, IIC and PTZ