当前位置:网站首页>单目深度估计模型Featdepth实战中的问题和拓展
单目深度估计模型Featdepth实战中的问题和拓展
2022-07-25 09:22:00 【苹果姐】
关于Featdepth模型的原理和源码解读,可以参照以下两篇博客:
苹果姐:单目深度估计自监督模型Featdepth解读(上)——论文理解和核心源码分析
苹果姐:单目深度估计自监督模型Featdepth解读(下)——openMMLab框架分析和使用
博主在Featdepth实战中遇到了不少问题,如分布式多卡训练工具DDP使用中的问题,包括GPU分配上的bug、如何打印多卡全局loss、如何进行全局同步BN(SyncBN)、如何进行数据shuffle等。可能因为源码时间比较久,所以有不完善的地方。在此对于这些问题的解决方法和大家进行探讨。
一、GPU分配的bug问题
这个问题博主深入研究DDP机制后整理了以下博文供参考:
苹果姐:深入理解pytorch分布式并行处理工具DDP——从工程实战中的bug说起
虽然DDP中梯度的全局更新是自动的,但源码中自定义了DistOptimizerHook:
class DistOptimizerHook(OptimizerHook):
def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
runner.outputs['loss'].backward()
allreduce_grads(runner.model, self.coalesce, self.bucket_size_mb)
if self.grad_clip is not None:
self.clip_grads(runner.model.parameters())
runner.optimizer.step()
这里面调用了allreduce_grads函数,用coalesce参数配置了是否按照tensor的type分组,并变为contiguous 1D buffer(内存中连续)再进行reduce,可能可以进行内存优化。
关于剩下的三个问题涉及的原理也在上文中进行了分析,建议先阅读,以下介绍在featdepth中具体的解决方式。
二、log打印机制以及可以进行的扩展
featdepth中的log打印部分也是集成在mmcv框架里,使用注册hook的方式。在config文件里可以看到:
log_config = dict(interval=500,
hooks=[dict(type='TextLoggerHook'),])
这里设置了hook名称为TextLoggerHook,间隔是500,TextLoggerHook是mmcv内置的一种打印log的hook类型,可以用runner.register_training_hooks统一注册:
runner.register_training_hooks(cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config)
在函数内部又通过register_logger_hooks(logconfig)将TextLoggerHook以及interval一起注册。在mmcv的TextLoggerHook类源码中可以看到,它是从runner.log_buffer.output中取出需要打印的内容,再用runner.logger进行打印,具体代码就不贴了。这两个前者用来存储日志内容,后者用来打印日志。
在runner.train()和runner.val()中都会调用runner.runiter(),内部会把每一次迭代的结果中的log_vars存储在runner.log_buffer中,然后在Loggerhook中计算interval次数的平均值。
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
而log_vars是在自定义的batch_processor(在trainner.py)中保存的,内容是每一次迭代的各项loss值:
def batch_processor(model, data, train_mode):
data = change_input_variable(data)
model_out, losses = model(data)
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
'{} is not a tensor or list of tensors'.format(loss_name))
loss = sum(_value for _key, _value in log_vars.items())
log_vars['loss'] = loss
new_log_vars=OrderedDict()
for name in log_vars:
new_log_vars[str(name)] = log_vars[name].item()
outputs = dict(loss=loss,
log_vars=new_log_vars,
num_samples=len(data[('color', 0 , 0)].data))
return outputs
具体在打印log的地方,也就是runner.logger,只有0号GPU进行了打印。所以由此可以看出,源码中只能打印主节点GPU上的loss。而且由于源码中在训练阶段只传入了train_dataset,所以只能打印训练的日志,而验证集是在有真值的情况下通过计算rmse等指标来看,如果没有真值则不显示验证集loss。
由此可以进行两个扩展:一是打印全局平均loss,二是加入验证集的loss。
打印全局平均loss可以在batchprocessor中手动进行loss的全局同步,也就是all_reduce,这个同时作用于train和val模式。
for loss_name, loss_value in losses.items():
dist.all_reduce(loss_value.div_(torch.cuda.device_count()))
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
……
显示验证集的loss可以在data_loader中加入验证集:
data_loaders = [build_dataloader(dataset_train,
cfg.imgs_per_gpu,
cfg.workers_per_gpu,
dist=True),
build_dataloader(dataset_val,
cfg.imgs_per_gpu,
cfg.workers_per_gpu,
dist=True)
]
但目前还有一个问题:在模型结构(net.py)中将val模式设置为不输出loss:
def forward(self, inputs):
outputs = self.DepthDecoder(self.DepthEncoder(inputs["color_aug", 0, 0]))
if self.training:
outputs.update(self.predict_poses(inputs))
features = self.Encoder(inputs[("color", 0, 0)])
outputs.update(self.Decoder(features, 0))
loss_dict = self.compute_losses(inputs, outputs, features)
return outputs, loss_dict
return outputs
这里需要修改为:
def forward(self, inputs):
outputs = self.DepthDecoder(self.DepthEncoder(inputs["color_aug", 0, 0]))
outputs.update(self.predict_poses(inputs))
features = self.Encoder(inputs[("color", 0, 0)])
outputs.update(self.Decoder(features, 0))
loss_dict = self.compute_losses(inputs, outputs, features)
return outputs, loss_dict
并且要在get_dataset中将train和val模式的输入帧数都改成默认值(源代码val模式只输入一帧):
dataset = dataset(cfg.in_path,
filenames,
cfg.height,
cfg.width,
#cfg.frame_ids if training else [0], 这里需要修改
cfg.frame_ids,
is_train=training,
img_ext=img_ext,
gt_depth_path=cfg.gt_depth_path)
然后再在config文件中将validate设置为True即可。可以看到在runner中默认设置是train模式每一个interval打印一次,val模式每一个epoch打印一次。这样就可以看到所有的日志输出了。
三、全局同步BN问题
本问题也在上面博文中进行了说明,具体是在trainer.py的_dist_train()函数的DDP包装之前执行以下代码:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(local_rank)
四、数据shuffle问题
源码中的数据shuffle也是通过注册hook实现的:
runner.register_hook(DistSamplerSeedHook())
DistSamplerSeedHook的源码:
@HOOKS.register_module
class DistSamplerSeedHook(Hook):
def before_epoch(self, runner):
runner.data_loader.sampler.set_epoch(runner.epoch)
在每个epoch前设置了随机数种子进行了shuffle,无需重新设置。
其他实战中的问题后期将继续更新,请关注。
边栏推荐
- Server CUDA toolkit multi version switching
- In depth interpretation of C language random number function and how to realize random number
- [deep learning] convolutional neural network
- 解决QTCreator使用VS编译中文乱码错误
- chmod和chown对挂载的分区的文件失效
- 文件--初识
- Constant power wireless charging based on stm32
- [data mining] nearest neighbor and Bayesian classifier
- SurfaceView 闪屏(黑一下问题)
- [code source] daily question tree
猜你喜欢

Redis string structure command

OC--继承和多态and指针

Assignment 7.21 Joseph Ring problem and decimal conversion

Swift创作天气APP

【cf】Round 128 C. Binary String

Voice chat app source code - produced by NASS network source code

A picture explains SQL join left and right

OC -- category extension agreement and delegation

无向连通图邻接表的创建输出广度深度遍历

【数据挖掘】第四章 分类任务(决策树)
随机推荐
A picture explains SQL join left and right
UI - infinite rotation chart and column controller
cf #785(div2) C. Palindrome Basis
如何将其他PHP版本添加到MAMP
初识Opencv4.X----图像直方图绘制
Prim minimum spanning tree (diagram)
How many regions can a positive odd polygon be divided into
自定义Dialog 实现 仿网易云音乐的隐私条款声明弹框
Expect+sh realize automatic interaction
OC--类别 扩展 协议与委托
Customize the view to realize the background of redeeming lottery tickets [elementary]
How to import a large amount of data in MATLAB
Assignment 7.21 Joseph Ring problem and decimal conversion
Learning new technology language process
Use of map () function in JS
缺陷检测网络--混合监督(kolektor缺陷数据集复现)
@3-2 optimal threshold of CCF 2020-12-2 final forecast
Flutter rive multi state example
【降维打击】希尔伯特曲线
Create personal extreme writing process - reprint