当前位置:网站首页>200 行代码,深入分析动态计算图的原理及实现
200 行代码,深入分析动态计算图的原理及实现
2022-07-23 07:20:00 【林地宁宁】
200 行代码,深入分析动态计算图的原理及实现
原文地址:CSDN 博客
文章目录
1. 前言
机器学习这几年可是大红大紫,各行各业的人都往这里涌入,硬是在机器学习这一领域里挤出了一片人口红海。而在机器学习领域,神经网络由于自己下限低、上限高的特点,赢得了不少人的青睐。
在神经网络中,却有一件我们经常使用,经常耳闻,但又不太熟悉的东西——BP 算法。入门“炼丹”的小萌新往往会对这个一头雾水,久经沙场的老油条对这个也可能不了解细节。
我在查阅许多文章后,发现大多数文章对 BP 算法的介绍往往是点到为止,更深入者也就在数学公式推导层面止步,涉及到代码层面的博主鲜少,更很少提及 BP 算法在神经网络中的更广泛实现——计算图机制。
于是,秉持着“科普”的原则,笔者就撰写了这篇有关于 BP 算法以及计算图原理的文章,并在其中以笔者自己的代码实现,详细地讲解计算图的工作机制,并最终与成熟的计算框架进行比较。
2. BP 算法
BP 算法,又名反向传播算法,是目前深度学习的理论基石。其原始论文于 1986 年由 D. Rumelhart 发表在 Nature 上1。在其论文中,就已经使用 MSE(Mean Square Error) 均方误差作为训练目标,并使用多层的 MLP 感知机作为模型,进行亲戚关系的分类。

当前时代的神经网络,早已比当时的网络来的更加庞大,几百个万的模型参数比比皆是,GPT-3 甚至已经上千亿的模型,而其最基本的算法,却来自于 40 年前,让人感到不可思议。
对于 BP 算法的理解其实非常简单。假设神经网络的的损失是 L L L, x \bm{x} x 是输入向量, W i j \bm{W}_{ij} Wij 是第 i i i 层的第 j j j 个参数,那么根据梯度下降的原理,我们需要得到 L L L 对 W i j \bm{W}_{ij} Wij 偏微分值:
∇ W i j = ∂ L ∂ W i j \nabla\bm{W}_{ij}=\frac{\partial L}{\partial \bm{W}_{ij}} ∇Wij=∂Wij∂L
设 η \eta η 为学习率,则最终的参数更新算法为:
W i j t + 1 = W i j t − η ⋅ ∇ W i j \bm{W}_{ij}^{t+1}=\bm{W}_{ij}^{t}-\eta\cdot\nabla\bm{W}_{ij} Wijt+1=Wijt−η⋅∇Wij
然后问题来了:怎么计算 ∂ L ∂ W i j \frac{\partial L}{\partial \bm{W}_{ij}} ∂Wij∂L?
许多的博文都对这个问题作出众多的解释,大部分人会选择使用数学推导的形式阐述,最终结果或许可能如下:
这串花里胡哨的东西,对数学系的同学来说刚刚好,对笔者来说可不好。讲到底 BP 算法就是一个偏导数的链式法则应用,写这么复杂真的有用吗?
∂ y ∂ x 1 = ∂ y ∂ x n ⋅ ∂ y ∂ x n − 1 ⋅ ⋯ ⋅ ∂ x 2 ∂ x 1 \frac{\partial y}{\partial x_1}=\frac{\partial y}{\partial x_n}\cdot \frac{\partial y}{\partial x_{n-1}}\cdot\dots\cdot\frac{\partial x_2}{\partial x_1} ∂x1∂y=∂xn∂y⋅∂xn−1∂y⋅⋯⋅∂x1∂x2
看吧!如果我把上面这串链式法则的公式, y y y 换成 L L L, x 1 x_1 x1 换为 W i j \bm{W}_{ij} Wij,剩下的 x i x_i xi 换为神经网络中的一些其他变量,不就把 BP 算法拆成了许多更小的偏导数的乘积吗?
对于 BP 算法的数学机理,了解到这已经足够。下一节,笔者将以程序员的角度,带大家看 BP 算法的另一个视角——计算图机制。
3. 计算图
本章中通过一个实际的例子,给出计算图的详细说明,并引出了计算图反向传播机制的定理。
3.1 计算图定义
计算图是描述计算过程的数据结构,而且通常是 DAG 图(有向无环图)。
在计算图中,每一个节点表示一个变量(值),每一条边表示数据的流动方向,并且每一条边的值被定义为边的首尾节点的偏导数值。例如:
这幅图表示以下的三个算式:
c = a + b d = b + 1 e = c × d \begin{aligned} c&=a+b\\ d&=b+1\\ e&=c\times d \end{aligned} cde=a+b=b+1=c×d
在这副计算图中,每个节点都表示着一个变量值,每条边表示数据的流动。在每条边上,笔者提前算出了每条边的末尾节点对起始节点的偏导数,例如边 (b,d) 的偏导数就是 ∂ d ∂ b = 1 \frac{\partial d}{\partial b}=1 ∂b∂d=1。
3.2 计算图机制
拥有计算图的定义后,下面来详细介绍一下计算图是如何对应 BP 算法的。
3.2.1 前向传播
对应于 BP 算法的前向传播(Forward Pass)过程,计算图的前向传播其实相同,就是把计算图的每个节点的值都计算出来。
例如,在上面的示例图中,若设 a = 2 , b = 1 a=2,b=1 a=2,b=1,那么前向传播的过程就把其他的节点值都算出来:
c = a + b = 3 d = b + 1 = 2 e = c × d = 5 \begin{aligned} c&=a+b=3\\ d&=b+1=2\\ e&=c\times d=5 \end{aligned} cde=a+b=3=b+1=2=c×d=5
前向传播没有理解上的难点,大家一眼就能明白,而难点在于反向传播的过程中。
3.2.2 反向传播
在反向传播的机制中,笔者并不打算引入过于复杂的数学公式来证明,而是选择用更加浅显易懂的大白话,说明计算图在反向传播过程中的工作原理。

对于示例图中,如果想求 ∂ e ∂ b \frac{\partial e}{\partial b} ∂b∂e,该怎么办?首先,从节点 b b b 开始,可以发现 b b b 通过作用于 c c c 和 d d d,进而对节点 e e e 造成了影响。
这个连环影响的现象表达成数学的形式,即为
Δ e = ∂ e ∂ c ⋅ Δ c + ∂ e ∂ d ⋅ Δ d = ∂ e ∂ c ⋅ ( ∂ c ∂ b ⋅ Δ b ) + ∂ e ∂ d ⋅ ( ∂ d ∂ b ⋅ Δ b ) \begin{aligned} \Delta e&=\frac{\partial e}{\partial c}\cdot\Delta c + \frac{\partial e}{\partial d}\cdot\Delta d \\ &=\frac{\partial e}{\partial c}\cdot(\frac{\partial c}{\partial b}\cdot \Delta b) + \frac{\partial e}{\partial d}\cdot(\frac{\partial d}{\partial b}\cdot \Delta b) \end{aligned} Δe=∂c∂e⋅Δc+∂d∂e⋅Δd=∂c∂e⋅(∂b∂c⋅Δb)+∂d∂e⋅(∂b∂d⋅Δb)
上式左右两侧同时除以 Δ b \Delta b Δb,则可以不严谨的得到:
∂ e ∂ b = ∂ e ∂ c ⋅ ∂ c ∂ b + ∂ e ∂ d ⋅ ∂ d ∂ b \frac{\partial e}{\partial b}=\frac{\partial e}{\partial c}\cdot\frac{\partial c}{\partial b} + \frac{\partial e}{\partial d}\cdot\frac{\partial d}{\partial b} ∂b∂e=∂c∂e⋅∂b∂c+∂d∂e⋅∂b∂d
仔细地观察这个式子,对比下图可以发现:式子的前半部分 ∂ e ∂ c ⋅ ∂ c ∂ b \frac{\partial e}{\partial c}\cdot\frac{\partial c}{\partial b} ∂c∂e⋅∂b∂c,正好是路线 A 的边上梯度值的乘积;同理,式子的后半部分 ∂ e ∂ d ⋅ ∂ d ∂ b \frac{\partial e}{\partial d}\cdot\frac{\partial d}{\partial b} ∂d∂e⋅∂b∂d,也是路线 B 的边上梯度值的乘积。

从这里例子,可以总结出计算图的最终定理。
定理(计算图反向传播机制):计算图上任意两点 x x x 和 y y y,且 y y y 在 x x x 之后,则 ∂ y ∂ x \frac{\partial y}{\partial x} ∂x∂y 的值为点 x x x 到点 y y y 上所有的不重复路径上的边值乘积的总和。
如果觉得这个定理有点难懂,那么其详细的计算过程如下:
- 找到所有从点 x x x 到 y y y 的不重复路径,记作集合 P \mathcal{P} P
- 对任意 p i ∈ P p_i \in \mathcal{P} pi∈P,计算路径 p i p_i pi 上所有边值乘积 M i M_i Mi
- 则 ∂ y ∂ x = ∑ p i ∈ P M i \frac{\partial y}{\partial x}=\sum^{p_i\in \mathcal{P}} M_i ∂x∂y=∑pi∈PMi
对应到这个例子,就是说:从路线 A,得到其路径上的乘积为 d d d;从路线 B,得到其路径上的乘积为 c c c。那么最终的结果为
∂ e ∂ b = d + c = a + 2 b + 1 \frac{\partial e}{\partial b}=d+c=a+2b+1 ∂b∂e=d+c=a+2b+1
由于在前向传播的过程中,所有的变量值我们都已经确定,所以算出 ∂ e ∂ b \frac{\partial e}{\partial b} ∂b∂e 的过程也就迎刃而解了。
有兴趣的同学可以试着验证其他的变量,看它们是否符合此规律。此外,笔者更推荐对其他的计算图检查,可以加深对这条规则的理解。
4. 代码实现
下面就是代码实现的部分咯,觉得麻烦的小伙伴可以跳过不看哦,但还是希望能给我的代码点个 star 收藏一下,十分感激!ヾ(≧▽≦*)o
Github 仓库:toy_computational_graph
4.1 Operation 定义
在个人的 200 行代码的实现中,大部分代码用于实现加减乘除的操作,事实上真正涉及反向传播的代码可能不足 30 行。下面是关于 Operation 的基类定义:
class Operation(ABC):
def __init__(self):
super().__init__()
# 反向传播过程中所需要的上下文 ctx
self.ctx: Optional[Dict] = None
# 记录输入的节点
self.inputs: List[Value] = []
def __call__(self, *args) -> Scalar:
self.inputs = list(args)
self.ctx = dict()
ret = self.forward(args, ctx=self.ctx)
ret.op = self
return ret
@staticmethod
@abstractmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
# 进行前向传播,并将反向传播的必要信息存放于 ctx 中
pass
@staticmethod
@abstractmethod
def backward(grad_output: float, ctx=None) -> List[float]:
# 反向传播的过程,返回每条输入边的累积梯度值
# grad_output 是从更加往后的节点传播到此处的累积梯度乘积
pass
可见,每个 Operation 其实就有以下功能:
- 记录输入节点
- 记录前向传播过程中产生的上下文
- 前向传播
- 反向传播
根据这个基类,最终派生出了加减乘除操作的实现类:
class AddOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
return Scalar(x.value + y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
return [grad_output, grad_output]
class SubOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
return Scalar(x.value - y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
return [grad_output, -grad_output]
class MulOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
ctx["x"] = x
ctx["y"] = y
return Scalar(x.value * y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
x, y = ctx["x"].value, ctx["y"].value
return [grad_output * y, grad_output * x]
class DivOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
assert y.value != 0, "Division by zero"
ctx["x"] = x
ctx["y"] = y
return Scalar(x.value / y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
x, y = ctx["x"].value, ctx["y"].value
return [grad_output / y, -x * grad_output / (y ** 2)]
代码简短而且清爽,适合读者学习。
4.2 数值类型
由于这个 codebase 体量不大,因此只允许使用 float 的包装类 Scalar 作为数值类型。其中 Value 类是 Scalar 类的基类,其定义并实现了反向传播的机制,如下:
class Value:
def __init__(self, op: Optional[Operation]):
self.op = op
self.grad = 0.
def zero_grad(self):
# 梯度清零,类似于 PyTorch
self.grad = 0.
def backward(self, grad_output: Optional[float] = None):
# 反向传播的实际执行,就是从此节点,迭代地把累积梯度乘积向更前的节点传播
# 等节点根据所传入的累积梯度乘积,更新完自身的梯度值后,就继续进行此过程
# 注:在保证 DAG 的前提下,此过程相等于遍历图上的所有不同路径
grad_output = grad_output if grad_output is not None else 1.
self.grad += grad_output
if self.op is not None:
prev_grads = self.op.backward(grad_output, ctx=self.op.ctx)
for input, prev_grad in zip(self.op.inputs, prev_grads):
input.backward(prev_grad)
至于 Scalar 类,只是实现了 __add__ 之类的加减乘除的 Dunder 函数的封装类,大致如下:
class Scalar(Value):
def __init__(self, value: numbers.Number, op: Optional[Operation] = None):
super().__init__(op)
self._value = float(value)
def __add__(self, other):
from operation import AddOperation
if isinstance(other, Scalar):
op = AddOperation()
return op(self, other)
elif isinstance(other, numbers.Number):
op = AddOperation()
return op(self, Scalar(other))
else:
raise TypeError("unsupported type")
... ...
由于 Scalar 类并不包括太多实际操作,因此完整代码供有兴趣的读者自行查看。
4.3 运行结果
详细代码可以查看代码仓库中的 example.py,结果如下:
example1:
x=10.0, y=2.0, r=x+2*y=14.0
=> x.grad=1.0, y.grad=2.0
example2:
x=10.0, r=x*x=100.0
=> x.grad=20.0
example3:
x=10.0, r=x*(x+1)=110.0
=> x.grad=21.0
example4:
x=8.0, y=4.0, r=x/y=2.0
=> x.grad=0.25, y.grad=-0.5
example5:
x=3.0, r=1/(x*x+1)=0.1
=> x.grad=-0.06
example6:
x=8.0, y=3.0, r=(x*x+1)/(y*y-1)=8.125
=> x.grad=2.0, y.grad=-6.09375
以上六个例子的运算结果均正确。
5. 杂谈
事实上,我这个 demo 和 PyTorch 一样,采用的是动态计算图的形式,即计算图是在运算的过程中实时产生。相反的,Tensorflow 就是采用静态计算图,其计算图需要在一开始就进行编译并固定。
相较于我这个毫无优化的 demo,PyTorch 对于计算图的优化则是出神入化。首先在这个计算图的迭代过程中,明显可以发现,不同路径之间的乘积是可以并行计算的。
同时,从计算图机制的定理中可以发现,由于各个路径上的梯度最终是相加起来的,因此并行下最好的实现方式就是将各个变量的梯度都初始化为 0,否则梯度相加后会出错。这也是为什么 PyTorch 训练时,会需要 zero_grad() 这一步。当然,笔者的实现中也仿效了这一设计。
6. 总结
本文从程序员的角度,总结出了计算图机制下的运行定理,并给出了约 200 行的代码实现,希望能够帮助所有正在入门机器学习的人。
如果您觉得本文有价值,还希望您能给我的文章点个赞、收藏和关注的三连,我们下期再见!ヾ( ̄▽ ̄)ByeBye
最后的最后,附上本文代码的 repo 地址:toy_computational_graph,希望读者能点几个 star 支持一下!
边栏推荐
- [play with FPGA in simple terms to learn 10 ----- simple testbench design]
- 回溯法解决 八皇后问题
- 微信小程序--动态设置导航栏颜色
- LeetCode_52_N皇后Ⅱ
- Learn about canvas
- 学会用canvas构建折线图、柱状图、饼状图
- Qt Creator .pro文件根据kit添加对应库
- Vs2019:constexpr function "qcountleadingzerobits" cannot generate constant expressions
- 专题讲座5 组合数学 学习心得(长期更新)
- 图形管线(一)后处理阶段 alpha测试 模版测试 深度测试 混合
猜你喜欢

Special lecture 5 combinatorial mathematics learning experience (long-term update)

Hardware system architecture of 4D millimeter wave radar

Detailed explanation of decision tree

Remove title block

ES6 - weekly examination questions

魔兽地图编辑器触发器笔记

【STM32】串口通信基础知识

学会用canvas构建折线图、柱状图、饼状图

微服务重点

Kept dual machine hot standby
随机推荐
【Ardunio】2种方法控制舵机
Learn about canvas
Smart canteen data analysis system
How to deal with the new development mode when doing testing?
Special lecture 5 combinatorial mathematics learning experience (long-term update)
How to ensure the reliable transmission of messages? What if the message is lost
Chapter II relational database after class exercises
C#:readonly与const
[图形学]ASTC纹理压缩格式
类和对象(上)
LeetCode_ 2342_ Maximum sum of digits and equal pairs
基于OpenCV实现对图片及视频中感兴趣区域颜色识别
【STM32】串口通信基础知识
面试官:有了解过ReentrantLock的底层实现吗?说说看
These five points should be considered in the production of enterprise science and technology exhibition hall
图像处理5:膨胀
LeetCode_2341_数组能形成多少数对
[cocos creator] spin animation, monitoring and playback end
养老机构智能视频监控解决方案,用新技术助力养老院智慧监管
挖财开户风险性大吗,安全吗?