当前位置:网站首页>给RepVGG填坑?其实是RepVGG2的RepOptimizer开源
给RepVGG填坑?其实是RepVGG2的RepOptimizer开源
2022-06-23 09:33:00 【智源社区】
在神经网络结构设计中,我们经常会引入一些先验知识,比如ResNet的残差结构。然而我们还是用常规的优化器去训练网络。在本工作中,我们提出将先验信息用于修改梯度数值,称为梯度重参数化,对应的优化器称为RepOptimizer。我们着重关注VGG式的直筒模型,训练得到RepOptVGG模型,他有着高训练效率,简单直接的结构和极快的推理速度。

论文链接:https://arxiv.org/abs/2205.15242
官方仓库:https://github.com/DingXiaoH/RepOptimizers
与RepVGG的区别
- RepVGG加入了结构先验(如1x1,identity分支),并使用常规优化器训练。而RepOptVGG则是将这种先验知识加入到优化器实现中
- 尽管RepVGG在推理阶段可以把各分支融合,成为一个直筒模型。但是其训练过程中有着多条分支,需要更多显存和训练时间。而RepOptVGG可是 真-直筒模型,从训练过程中就是一个VGG结构
- 我们通过定制优化器,实现了结构重参数化和梯度重参数化的等价变换,这种变换是通用的,可以拓展到更多模型

将结构先验知识引入优化器
我们注意到一个现象,在特殊情况下,每个分支包含一个线性可训练参数,加一个常量缩放值,只要该缩放值设置合理,则模型性能依旧会很高。我们将这个网络块称为Constant-Scale Linear Addition(CSLA)我们先从一个简单的CSLA示例入手,考虑一个输入,经过2个卷积分支+线性缩放,并加到一个输出中,我们考虑等价变换到一个分支内,那等价变换对应2个规则:
初始化规则
融合的权重需为:

更新规则
针对融合后的权重,其更新规则为:

这部分公式可以参考附录A中,里面有详细的推导一个简单的示例代码为:
import torchimport numpy as npnp.random.seed(0)np_x = np.random.randn(1, 1, 5, 5).astype(np.float32)np_w1 = np.random.randn(1, 1, 3, 3).astype(np.float32)np_w2 = np.random.randn(1, 1, 3, 3).astype(np.float32)alpha1 = 1.0alpha2 = 1.0lr = 0.1conv1 = torch.nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)conv2 = torch.nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)conv1.weight.data = torch.nn.Parameter(torch.tensor(np_w1))conv2.weight.data = torch.nn.Parameter(torch.tensor(np_w2))torch_x = torch.tensor(np_x, requires_grad=True)out = alpha1 * conv1(torch_x) + alpha2 * conv2(torch_x)loss = out.sum()loss.backward()torch_w1_updated = conv1.weight.detach().numpy() - conv1.weight.grad.numpy() * lrtorch_w2_updated = conv2.weight.detach().numpy() - conv2.weight.grad.numpy() * lrprint(torch_w1_updated + torch_w2_updated)import torchimport numpy as npnp.random.seed(0)np_x = np.random.randn(1, 1, 5, 5).astype(np.float32)np_w1 = np.random.randn(1, 1, 3, 3).astype(np.float32)np_w2 = np.random.randn(1, 1, 3, 3).astype(np.float32)alpha1 = 1.0alpha2 = 1.0lr = 0.1fused_conv = torch.nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)fused_conv.weight.data = torch.nn.Parameter(torch.tensor(alpha1 * np_w1 + alpha2 * np_w2))torch_x = torch.tensor(np_x, requires_grad=True)out = fused_conv(torch_x)loss = out.sum()loss.backward()torch_fused_w_updated = fused_conv.weight.detach().numpy() - (alpha1**2 + alpha2**2) * fused_conv.weight.grad.numpy() * lrprint(torch_fused_w_updated)在RepOptVGG中,对应的CSLA块则是将RepVGG块中的3x3卷积,1x1卷积,bn层替换为带可学习缩放参数的3x3卷积,1x1卷积进一步拓展到多分支中,假设s,t分别是3x3卷积,1x1卷积的缩放系数,那么对应的更新规则为:

第一条公式对应输入通道==输出通道,此时一共有3个分支,分别是identity,conv3x3, conv1x1第二条公式对应输入通道!=输出通道,此时只有conv3x3, conv1x1两个分支第三条公式对应其他情况需要注意的是CSLA没有BN这种训练期间非线性算子(training-time nonlinearity),也没有非顺序性(non sequential)可训练参数。
边栏推荐
- AI: the Elephant in Room
- [geek challenge 2019] hardsql
- ionic5表单输入框和单选按钮
- Cesium加载正射影像方案
- Basic use of lua
- Redis learning notes - redis cli explanation
- Redis学习笔记—客户端通讯协议RESP
- [plugin:vite:import-analysis]Failed to resolve import “@/“ from ““.Does the file exist
- Redis learning notes RDB of persistence mechanism
- GPIO初识
猜你喜欢

三层架构与SSM之间的对应关系

Cesium loading orthophoto scheme
![[MRCTF2020]Ez_bypass](/img/cd/bd6fe5dfc3f1942a9959a9dab9e7e0.png)
[MRCTF2020]Ez_bypass
![[SUCTF 2019]CheckIn](/img/0e/75bb14e7a3e55ddc5126581a663bfb.png)
[SUCTF 2019]CheckIn
【NanoPi2试用体验】裸机第一步
[nanopi2 trial experience] the first step of bare metal

Pizza ordering design - simple factory model

设CPU有16根地址线,8根数据线,并用MREQ作为访存控制线号......存储器与CPU的连接
![[ciscn2019 North China Day2 web1]hack world](/img/bf/51a24fd2f9f0e13dcd821b327b5a00.png)
[ciscn2019 North China Day2 web1]hack world

A 32KB cache with direct mapping Memory exercises after class
随机推荐
位绑定
【CTF】bjdctf_ 2020_ babyrop
全局快门和卷帘快门的区别
Redis learning notes - redis and Lua
UEFI源码学习3.7 - NorFlashDxe
[GXYCTF2019]BabySQli
UEFI 学习3.6 - ARM QEMU上的ACPI表
Typora set up image upload service
三层架构与SSM之间的对应关系
【CTF】 2018_rop
Cesium loading orthophoto scheme
栈(Stack)的链式实现详解----线性结构
RGB与CMYK颜色模式
[SUCTF 2019]CheckIn
进入小公司的初级程序员要如何自我提高?
Redis学习笔记—数据类型:哈希(hash)
Redis learning notes master-slave copy
Cookie和Session入门
Redis学习笔记—客户端通讯协议RESP
[GYCTF2020]Blacklist