当前位置:网站首页>pytorch中copy_()、detach()、data()和clone()操作区别小结
pytorch中copy_()、detach()、data()和clone()操作区别小结
2022-06-22 00:45:00 【非晚非晚】
1. clone
b = a.clone()
创建一个tensor与源tensor有相同的shape,dtype和device,不共享内存地址,但新tensor(b)的梯度会叠加在源tensor(a)上。需要注意的是,b = a.clone()之后,b并非叶子节点,所以不可以访问它的梯度。
import torch
a = torch.tensor([1.,2.,3.],requires_grad=True)
b = a.clone()
print('===========================不共享地址=========================')
print(type(a), a.data_ptr())
print(type(b), b.data_ptr())
print('===========================clone后分别输出=========================')
print('a: ', a) # a: tensor([1., 2., 3.], requires_grad=True)
print('b: ', b) #b: tensor([1., 2., 3.], grad_fn=<CloneBackward0>)
c = a ** 2
d = b ** 3
print('===========================反向传播=========================')
c.sum().backward() # 2* a
print('a.grad: ', a.grad) #a.grad: tensor([2., 4., 6.])
d.sum().backward() # 3b**2
print('a.grad: ', a.grad) #a.grad: tensor([ 5., 16., 33.]) ,会将b梯度累加上去
#print('b.grad: ', b.grad) # b.grad: None , 已经不属于计算图的叶子,不可以访问b.grad
输出:
===========================不共享地址=========================
<class 'torch.Tensor'> 93899916787840
<class 'torch.Tensor'> 93899917014528
===========================clone后分别输出=========================
a: tensor([1., 2., 3.], requires_grad=True)
b: tensor([1., 2., 3.], grad_fn=<CloneBackward0>)
===========================反向传播=========================
a.grad: tensor([2., 4., 6.])
a.grad: tensor([ 5., 16., 33.])
2. copy_
b = torch.empty_like(a).copy_(a)
copy_()函数是需要一个目标tensor,也就是说需要先构建b,然后将a拷贝给b,而clone操作则不需要。
copy_()函数完成与clone()函数 类似的功能,但也存在区别。调用copy_()的对象是目标tensor,参数是复制操作from的tensor,最后会返回目标tensor;而clone()的调用对象为源tensor,返回一个新tensor。当然clone()函数也可以采用torch.clone()调用,将源tensor作为参数。
import torch
a = torch.tensor([1., 2., 3.],requires_grad=True)
b = torch.empty_like(a).copy_(a)
print('====================copy_内存不一样======================')
print(a.data_ptr())
print(b.data_ptr())
print('====================copy_打印======================')
print(a)
print(b)
c = a ** 2
d = b ** 3
print('===================c反向传播=======================')
c.sum().backward()
print(a.grad) # tensor([2., 2., 2.])
print('===================d反向传播=======================')
d.sum().backward()
print(a.grad) # 源tensor梯度累加了
#print(b.grad) # None
输出:
====================copy_内存不一样======================
94358408685568
94358463065088
====================copy_打印======================
tensor([1., 2., 3.], requires_grad=True)
tensor([1., 2., 3.], grad_fn=<CopyBackwards>)
===================c反向传播=======================
tensor([2., 4., 6.])
===================d反向传播=======================
tensor([ 5., 16., 33.])
3. detach
detach()函数返回与调用对象tensor相关的一个tensor,此新tensor与源tensor共享数据内存(那么tensor的数据必然是相同的),但其requires_grad为False,并且不包含源tensor的计算图信息。
import torch
a = torch.tensor([1., 2., 3.],requires_grad=True)
b = a.detach()
print('=========================共享内存==============================')
print(a.data_ptr())
print(b.data_ptr())
print('=========================原值与detach==============================')
print(a)
print(b)
c = a * 2
d = b * 3 #不可以反向传播
print('=========================原值反向传播==============================')
c.sum().backward()
print(a.grad)
print('=========================detach不可以反向传播==============================')
# d.sum().backward()
输出:
=========================共享内存==============================
94503766034432
94503766034432
=========================原值与detach==============================
tensor([1., 2., 3.], requires_grad=True)
tensor([1., 2., 3.])
=========================原值反向传播==============================
tensor([2., 2., 2.])
=========================detach不可以反向传播==============================
由于b已经从计算图脱离出来,pytorch自然也不跟踪其后续计算过程了。如果想要让b重新加入计算图,只需要b.requires_grad_()。
pytorch可以继续跟踪b的计算,但梯度不会从b流回a,梯度被截断。但由于b与a共享内存,a与b的值会一直相等。
4. data
data方法是得到一个tensor的数据信息,其返回的信息与上面提到的detach()返回的信息是相同的,也具有 内存相同,不保存梯度 信息的特点。但是data有时候不安全,因为它们共享内存,如果改变一个则另一个也跟着改变,而使用detach时候使用反向传播会报错。
import torch
import pdb
x = torch.FloatTensor([[1., 2.]]) #默认x.requires_grad == False,只有float类型可以反向传播
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True
d = torch.matmul(x, w1) #相乘后,d的requires_grad = True(相加操作也是True)
d_ = d.data # d和d_会共享内存,d_的requires_grad = False
# d_ = d.detach() #d和d_也会共享内存,但是不能反向传播
f = torch.matmul(d, w2)
d_[:] = 1 #d_修改了值,所以d的值也跟着改变
f.backward() #使用data会获取错误的值,使用detach则报错
参考:
边栏推荐
- HDOJ - Is It A Tree?
- Special survey of moving average strategy
- How to use the low code platform of the Internet of things for report management?
- matplotlib 制作不等间距直方图
- 3 minutes, take you to play with chat robot automation [top template]
- SparkRDD 案例:计算总成绩
- Find find files with different extensions
- Promise me not to use if (obj! = null) to judge empty
- Some introduction and transplantation of lvgl
- Today's content
猜你喜欢
![3 minutes, take you to play with chat robot automation [top template]](/img/71/4d848b46a52b71a351a086db248a95.png)
3 minutes, take you to play with chat robot automation [top template]
![[environment stepping on the pit] open the picture with OpenCV and report an error](/img/6d/4679cfebf2dfd43566c976d435ea84.png)
[environment stepping on the pit] open the picture with OpenCV and report an error

想加入大厂?看这篇文章也许会帮助到你

動態規劃-01背包,分割等和子集,最後一塊石頭的重量
![[ÑÖÏ Simulation Competition] fading (matrix acceleration, cyclic convolution, Gauss elimination)](/img/4a/9dfcb699e36f67e14c036e3ae26417.png)
[ÑÖÏ Simulation Competition] fading (matrix acceleration, cyclic convolution, Gauss elimination)

Special survey of moving average strategy

ShardingSphere-proxy-5.0.0分布式哈希取模分片实现(四)

Sparkrdd case: calculate total score

Brief introduction to jpom: simple and light low intrusive online construction, automatic deployment, daily operation and maintenance, and project monitoring software

03 fastjson resolving circular references
随机推荐
Install tensorflow and transformer on Orange Pie orangepi4b
消息队列之发送 Webhook 实现跨应用异步回调
3746. academic circle of cattle II
Shardingsphere-proxy-5.0.0 implementation of distributed hash modulo fragmentation (4)
SSO and oauth2 solutions
Error 4 opening dom ASM/Self in 0x8283c00
Simple sorting of RNN
4273. linked list consolidation
3746. 牛的学术圈 II
Virtual variables and formatting characters in debugging
站在数字化风口,工装企业如何“飞起来”
Broadening - simple strategy test
==And equals
颜值、空间、安全、动力齐升级,新捷途X70S 8.79万元起售上市
【TensorRT】Video Swin-Transformer部署相关
Navicat连接不到MySQL
3371. comfortable cow
利用SSM框架实现用户登陆
How to use the low code platform of the Internet of things for report management?
Jpom 简介: 简而轻的低侵入式在线构建、自动部署、日常运维、项目监控软件