当前位置:网站首页>PyTorch的自动求导
PyTorch的自动求导
2022-07-24 15:18:00 【强强学习】
文章目录
1. 基本概念
1.1 requires_grad
如果需要对某个张量进行求导,则在初始化时必须赋值requires_grad=True,例如
a = torch.tensor(2.3, requires_grad=True)
1.2 计算图
在计算图中,数据用椭圆表示,加减乘除等操作用矩形表示。通过计算图将数据和操作表示成了二叉树结构。
1.3 叶子节点
在计算图中,由用户自己创建的数据就叫叶子节点,比如上图的w, x, b,也可以说未经过计算得出来的数据就是叶子节点。可以用 a.is_leaf(注意是属性,而不是方法)判断是否为叶子节点。
print(a.is_leaf)
1.4 grad_fn
z.grad_fn输出的是<AddBackward0 at 0x8ea34d2cd342>,表示的是z对应的直接操作运算。
tensor由某个操作获得,在PyTorch每个操作的反向传播函数是已经被定义好的,比如z是由add即加操作得到的,那么z.grad_fn得到的就是add函数的反向传播函数(求导函数)。注意我们的得到的是AddBackward0 后面有个0,说明一个计算图中可以出现很多次add,每个add的反向传播函数是不一样的。
1.5 next_functions
z.grad_fn.next_functions 输出的是
(<AccumulateGrad at 0x7fb73c7cdad0>, 0L))
((<MulBackward0 at 0x7fb73c7cd7d0>, 0L),
z是由add操作得到的,那么add操作的输入是b和 y,输出的就是b.grad_fn和y.grad_fn。
AccumulateGrad是什么?a.grad是什么?为什么梯度要置0?
对于y.grad_fn我们可以知道MulBackward就是乘积操作对应的反向传播函数。
b只是一个叶子结点,是一个Tensor,它的grad_fn即Accumlate_Grad表示这个b的导数是可积累的。比如你第一次方向传播一次,我们得出b的导数为3,即a.grad为3,但是你再求导一次,就会发现a.grad为6,这就是所谓的可累加。所以在Pytorch里面,每一个batch即每一次反向传播前都会把梯度下降即grad都置为0。
1.6 retain_graph=True backward()
z.backward(retain_graph=True)
z.backward()表示从z求出来的是z对各个变量的导数。
retain_graph=True表示保存中间变量。比如我们计算z对w的导数发现导数就是y,注意这个y在我们上面举例的计算图中不是我们自己指定的,是中间求出来的,我们第一次z.backward()求z对w的导数会取到y的值。但是如果我们这次传播完立刻在想传播一次,那么就会报错,因为一次梯度玩会自动把中间的计算东西释放掉,也就是第二次传播时候就没有y了,除非你再前向传播一次。所以我们可以提前指定这个保证第一次传播完中间变量仍然存在。
1.7 hook函数
非叶子节点的导求出来后会被释放,如果想看其导数,可以用autograd.grad或者hook函数。
2. 总结
关于autograd,我们需要知道的就是我们可以在创建tensor的时候指定 requires_grad = True 使得可求导,然后在最终函数用 z.backward()。用a.grad查看导数(梯度)。
边栏推荐
- 【Flutter -- 布局】流式布局(Flow和Wrap)
- 异或程序
- spark:指定日期输出相应日期的日志(入门级-简单实现)
- Exomiser对外显子组变体进行注释和优先排序
- Jmeter-调用上传文件或图片接口
- Various searches (⊙▽⊙) consolidate the chapter of promotion
- spark:获取日志中每个时间段的访问量(入门级-简单实现)
- Unity 使用NVIDIA FleX for Unity插件实现制作软体、水流流体、布料等效果学习教程
- Huawei camera capability
- 2022 RoboCom 世界机器人开发者大赛-本科组(省赛)---第一题 不要浪费金币 (已完结)
猜你喜欢

Outlook tutorial, how to create tasks and to DOS in outlook?

2022 robocom world robot developer competition - undergraduate group (provincial competition) -- question 3 running team robot (finished)
DS binary tree - parent and child nodes of binary tree

各种Normalization的直观理解

Problem handling of repeated restart during Siemens botu installation

Leetcode high frequency question 56. merge intervals, merge overlapping intervals into one interval, including all intervals

Here comes the problem! Unplug the network cable for a few seconds and plug it back in. Does the original TCP connection still exist?

Decrypt "sea Lotus" organization (domain control detection and defense)

【Bug解决】Win10安装pycocotools报错

Isprs2018/ cloud detection: cloud/shadow detection based on spectral indexes for multi/hyp multi / hyperspectral optical remote sensing imager cloud / shadow detection
随机推荐
Unity 使用NVIDIA FleX for Unity插件实现制作软体、水流流体、布料等效果学习教程
onBlur和onChange冲突解决方法
pip 安装报错 error in anyjson setup command: use_2to3 is invalid.
JS data transformation -- Transformation of tree structure and tile structure
Rest style
DDD based on ABP -- Entity creation and update
Summary of feature selection: filtered, wrapped, embedded
VAE(变分自编码器)的一些难点分析
The first n rows sorted after dataframe grouping nlargest argmax idmax tail!!!!
How to set packet capturing mobile terminal
[USENIX atc'22] an efficient distributed training framework whale that supports the super large-scale model of heterogeneous GPU clusters
Performance test - analyze requirements
4279. 笛卡尔树
维护香港服务器安全的9个关键措施
循环结构practice
JMeter - call the interface for uploading files or pictures
Learning rate adjustment strategy in deep learning (1)
spark:获取日志中每个时间段的访问量(入门级-简单实现)
C # exit login if there is no operation
Performance test - Preparation of test plan