当前位置:网站首页>Pytorch遇到的坑:为什么模型训练时,L1loss损失无法下降?
Pytorch遇到的坑:为什么模型训练时,L1loss损失无法下降?
2022-06-25 06:39:00 【一个菜鸟的奋斗】
最近在用L1loss做一个回归模型的训练,发现模型训练过程中loss及其不稳定,且训练效果很差,终于找到原因了!
原代码如下:
criterion = nn.L1Loss()
def train():
print('Epoch {}:'.format(epoch + 1))
model.train()
# switch to train mode
for i, sample_batched in enumerate(train_dataloader):
input, target = sample_batched['geno'], sample_batched['pheno']
# compute output
output = model(input.float().cuda())
loss = criterion(output, target.float().cuda())
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
以上代码问题出在:
loss = criterion(output, target.float().cuda())
我输入的batchsize是4,因此output的size是[4,1],也就是一个二维的数据;target的size是[4]。loss输出的结果是一个正确的数值。这也是我没发现问题的原因!我们看一下pytorch库里l1_loss的代码:
def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
Function that takes the mean element-wise absolute value difference.
See :class:`~torch.nn.L1Loss` for details.
"""
if not torch.jit.is_scripting():
tens_ops = (input, target)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
return handle_torch_function(
l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce,
reduction=reduction)
if not (target.size() == input.size()):
warnings.warn("Using a target size ({}) that is different to the input size ({}). "
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size.".format(target.size(), input.size()),
stacklevel=2)
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
if target.requires_grad:
ret = torch.abs(input - target)
if reduction != 'none':
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
else:
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
ret = torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
return ret
代码里的warning,要求input和target的size必须一致,不然会出现不对的结果。我自己代码里把warning给ignore了,所以这个warning一直没看到!这里提醒大家,一定不要随意ignore warning,并且要好好看warning,不要只看error。。。。
我把代码改成以下,就没有问题了:
loss = criterion(output.squeeze(), target.float().cuda())
既然问题解决了,得知道为啥size不匹配会导致模型出错呀,不然找了那么久的bug不是白瞎了= =
我们先尝试错误输入,输入的size是[4,1],target的size是[4]:
input = tensor([[-0.3704, -0.2918, -0.6895, -0.6023]], device='cuda:0', grad_fn=<PermuteBackward>) target = tensor([ 63.6000, 127.0000, 102.2000, 115.4000], device='cuda:0') expanded_input, expanded_target = torch.broadcast_tensors(input, target) ret = torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
返回 expanded_input:
tensor([[-0.3704, -0.2918, -0.6895, -0.6023],
[-0.3704, -0.2918, -0.6895, -0.6023],
[-0.3704, -0.2918, -0.6895, -0.6023],
[-0.3704, -0.2918, -0.6895, -0.6023]], device='cuda:0',
grad_fn=<PermuteBackward>)
返回 expanded_target:
tensor([[ 63.6000, 63.6000, 63.6000, 63.6000],
[127.0000, 127.0000, 127.0000, 127.0000],
[102.2000, 102.2000, 102.2000, 102.2000],
[115.4000, 115.4000, 115.4000, 115.4000]], device='cuda:0')
返回ret:
tensor(102.5385, device='cuda:0', grad_fn=<PermuteBackward>)
接下来是正确输入,输入的size是[4],target的size是[4]:
input = tensor([-0.3704, -0.2918, -0.6895, -0.6023], device='cuda:0',
grad_fn=<PermuteBackward>)
target = tensor([ 63.6000, 127.0000, 102.2000, 115.4000], device='cuda:0')
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
ret = torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
返回 expanded_input:
tensor([[-0.3704, -0.2918, -0.6895, -0.6023],
[-0.3704, -0.2918, -0.6895, -0.6023],
[-0.3704, -0.2918, -0.6895, -0.6023],
[-0.3704, -0.2918, -0.6895, -0.6023]], device='cuda:0',
grad_fn=<PermuteBackward>)
返回ret:
tensor(102.5385, device='cuda:0', grad_fn=<PermuteBackward>)
经过mean求平均之后,返回的ret值是一样的,唯一不同的是expanded_input。这个中间值不一样,是否会导致梯度变化?为了验证这个想法,我们在代码中输出input的梯度值。
for name, parms in model.named_parameters(): print('name:', name) print('grad_requirs:', parms.requires_grad) print('grad_value:', parms.grad)
以下为错误输入,输入的size是[4,1],target的size是[4]:
===
name: module.linear1.bias
grad_requirs: True
grad_value: tensor([-0.1339, 0.0000, 0.0505, 0.0219, -0.1498, 0.0265, -0.0604, -0.0385,
0.0471, 0.0000, 0.0304, 0.0000, 0.0000, 0.0406, 0.0066, 0.0000,
-0.0259, -0.1544, 0.0000, -0.0208, 0.0050, 0.0000, 0.0625, -0.0474,
0.0000, 0.0858, -0.0116, 0.0777, 0.0000, -0.0828, 0.0000, -0.1265],
device='cuda:0')
===
name: module.linear2.weight
grad_requirs: True
grad_value: tensor([[-0.9879, -0.0000, -1.0088, -0.1680, -0.7312, -0.0066, -0.3093, -0.7478,
-0.3104, -0.0000, -0.1615, -0.0000, -0.0000, -0.3162, -0.1047, -0.0000,
-0.4030, -0.3385, -0.0000, -0.1738, -0.0831, -0.0000, -0.3490, -0.1129,
-0.0000, -0.8220, -0.0279, -0.3754, -0.0000, -0.3566, -0.0000, -0.5950]],
device='cuda:0')
===
name: module.linear2.bias
grad_requirs: True
grad_value: tensor([-1.], device='cuda:0')
===
以下为正确输入,输入的size是[4],target的size是[4]得到的梯度:
===
name: module.linear1.bias
grad_requirs: True
grad_value: tensor([-0.1351, 0.0000, 0.0000, 0.0000, -0.0377, 0.0000, -0.0809, -0.0394,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0202, 0.0098, -0.0365,
-0.0263, -0.2063, -0.1533, -0.0626, 0.0050, 0.0000, 0.0000, -0.0950,
0.0000, 0.0000, -0.0348, 0.0000, 0.0000, -0.1108, -0.0402, -0.1693],
device='cuda:0')
===
name: module.linear2.weight
grad_requirs: True
grad_value: tensor([[-7.4419, 0.0000, 0.0000, 0.0000, -1.9245, 0.0000, -2.7927, -2.4551,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.0309, -0.4843, -0.0211,
-1.7046, -7.7090, -0.1696, -0.9997, -0.0862, 0.0000, 0.0000, -2.0397,
0.0000, 0.0000, -0.3125, 0.0000, 0.0000, -3.9532, -0.0643, -6.5799]],
device='cuda:0')
===
name: module.linear2.bias
grad_requirs: True
grad_value: tensor([-1.], device='cuda:0')
===
果然,梯度值不一样!!!经验教训:每一行代码都要深入理解其作用的机理,不要想当然!
边栏推荐
- Tuwei Digital Isolator and interface chip can perfectly replace imported brands Ti and ADI
- 【批处理DOS-CMD命令-汇总和小结】-上网和网络通信相关命令(ping、telnet、nslookup、arp、tracert、ipconfig)
- Classic paper in the field of character recognition: aster
- GUI pull-down menu of unity3d evil door implementation dropdown design has no duplicate items
- Rotation vector (rotation matrix) and Euler angle
- 【批處理DOS-CMD命令-匯總和小結】-外部命令-cmd下載命令、抓包命令(wget)
- Reading sensor data with GPIO analog SPI interface
- lebel只想前面有星号,但是不想校验
- [batch dos-cmd command - summary and summary] - CMD window setting and operation commands (CD, title, mode, color, pause, CHCP, exit)
- Authentique Photoshop 2022 expérience d'achat partage
猜你喜欢
Construction of occupancy grid map
Ltpowercad II and ltpowerplanner III
正版photoshop2022购买体验经历分享
Several good weather plug-ins
【LeetCode】two num·两数之和
[batch dos-cmd command - summary and summary] - commands related to Internet access and network communication (Ping, Telnet, NSLOOKUP, ARP, tracert, ipconfig)
One year's time and University experience sharing with CSDN
Why "New Year's Eve", the original memory burst!
[batch dos-cmd command - summary and summary] - CMD extended command and function (CMD /e:on, CMD /e:off)
Chuantu microelectronics high speed and high performance rs-485/422 transceiver series
随机推荐
我的处女作杀青啦!
What is the difference between norflash and nandflash
One year's time and University experience sharing with CSDN
Sichuan earth microelectronics 8-channel isolated digital input receiver
Can I open a stock account with a compass? Is it safe?
Ca-is1200u current detection isolation amplifier has been delivered in batch
From perceptron to transformer, a brief history of deep learning
Common functions of OrCAD schematic
14 BS object Node name Name attrs string get node name attribute content
栅格地图(occupancy grid map)构建
[Batch dos - cmd Command - Summary and Summary] - cmd extension Command, extension Function (CMD / E: on, CMD / E: off)
【LeetCode】two num·两数之和
Distributed quorum NWR of the alchemy furnace of the Supreme Master
Three years of continuous decline in revenue, Tiandi No. 1 is trapped in vinegar drinks
VectorDraw Developer Framework 10.10
Unity3D邪门实现之GUI下拉菜单Dropdown设计无重复项
鸿蒙页面菜单的选择
smartBugs安装小问题总结
14 bs对象.节点名称.name attrs string 获取节点名称 属性 内容
Construction of occupancy grid map