当前位置:网站首页>The model defined (modified) in pytoch loads some required pre training model parameters and freezes them
The model defined (modified) in pytoch loads some required pre training model parameters and freezes them
2022-06-26 05:34:00 【Little beaver flower made by Rua】
Part of this article refers to https://zhuanlan.zhihu.com/p/34147880
One . This method is more versatile , Load the parameters of the pre training model according to the parameters of your own model , Assignment with the same name . If you add some layers to the original model, it will not be loaded
dict_trained=torch.load(self.args.load_path, map_location=torch.device('cpu'))
dict_new=model.state_dict()
# 1. filter out unnecessary keys
dict_trained = {
k: v for k, v in dict_trained.items() if k in dict_new}
# 2. overwrite entries in the existing state dict
model_dict.update(dict_trained)
model.load_state_dict(dict_new)
Two . This is a lot more complicated , Make the changes you want , Such as my , This model adds four layers ’dense’, ‘unary_affine’, ‘binary_affine’, ‘classifier’, adopt j+=8, Skip their weight and bias, This can be referred to as weight attenuation . At the same time, the original model parameters are ’crf’ Partially not loaded .
dict_trained = torch.load(self.args.load_path, map_location=torch.device('cpu'))
dict_new = self.model.state_dict().copy()
trained_list = list(dict_trained.keys())
new_list = list(dict_new.keys())
j = 0
no_loda = {'dense', 'unary_affine', 'binary_affine', 'classifier'}
for i in range(len(trained_list)):
flag = False
if 'crf' in trained_list[i]:
continue
for nd in no_loda:
if nd in new_list[j] and 'bert' not in new_list[j]:
flag = True
if flag:
j += 8 # no_loda Of dense and bias Pass by
else:
dict_new[new_list[j]] = dict_trained[trained_list[i]]
if new_list[j] != trained_list[i]:
print("i:{},new_state_dict: {} trained state_dict: {} atypism ".format(i, new_list[j], trained_list[i]))
j += 1 #keys Not aligned
model.load_state_dict(dict_new)
Later, I learned that there is a kind of It's simpler Methods :
When you set up your own model , If you only want to use the parameters at the same structure of the pre training model , That is to say, when loading, set the parameter strict Set to False that will do . The default value of this parameter is True, The layer representing the pre training model is strictly equivalent to the network structure layer defined by itself ( Such as layer name and dimension ), Otherwise, we can't load , The implementation is as follows :
model.load_state_dict(torch.load(self.args.load_path, strict=False))
PS: Encountered a mistake , You may wish to modify the model parameters keys And loading model parameters keys Print it out , An antidote against the disease
3、 ... and . Freeze these layers of parameters
In a nutshell
for k in model.paramers:
k.requires_grad=False
There are many ways , The freezing method corresponding to the above method is used here
I suggest you take a look at
https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088
perhaps
https://discuss.pytorch.org/t/correct-way-to-freeze-layers/26714
perhaps
Corresponding , In training ,optimizer It can only be updated requires_grad = True Parameters of , therefore
optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, net.parameters(),lr) )
边栏推荐
- 11 IO frame
- C# 40. Byte[] to hexadecimal string
- How to make your big file upload stable and fast?
- 旧情书
- Win socket programming (Mengxin initial battle)
- There are applications related to web network request API in MATLAB (under update)
- 2021年OWASP-TOP10
- ZigBee explain in simple terms lesson 2 hardware related and IO operation
- [red team] what preparations should be made to join the red team?
- 定位设置水平,垂直居中(多种方法)
猜你喜欢

基于SDN的DDoS攻击缓解
![[arm] build boa based embedded web server on nuc977](/img/fb/7dc1898e35ed78b417770216b05286.png)
[arm] build boa based embedded web server on nuc977

Using Jenkins to perform testng+selenium+jsup automated tests and generate extendreport test reports

Wechat team sharing: technical decryption behind wechat's 100 million daily real-time audio and video chats

DOM文档

Leetcode114. Expand binary tree into linked list

As promised: Mars, the mobile terminal IM network layer cross platform component library used by wechat, has been officially open source

使用Jenkins执行TestNg+Selenium+Jsoup自动化测试和生成ExtentReport测试报告
![C# 39. Conversion between string type and byte[] type (actual measurement)](/img/33/046aef4e0c1d7c0c0d60c28e707546.png)
C# 39. Conversion between string type and byte[] type (actual measurement)

红队得分方法统计
随机推荐
转帖——不要迷失在技术的海洋中
Positioning setting horizontal and vertical center (multiple methods)
Mysql 源码阅读(二)登录连接调试
Could not get unknown property ‘*‘ for SigningConfig container of type org.gradle.api.internal
[upsampling method opencv interpolation]
How does P2P technology reduce the bandwidth of live video by 75%?
一段不离不弃的爱情
ZigBee learning in simple terms Lecture 1
How to ensure the efficiency and real-time of pushing large-scale group messages in mobile IM?
Henkel database custom operator '~~‘
uniCloud云开发获取小程序用户openid
About XXX management system (version C)
[PHP] PHP two-dimensional array is sorted by multiple fields
自定义WebSerivce作为代理解决SilverLight跨域调用WebService问题
Project suspension
CMakeLists. txt Template
[arm] build boa based embedded web server on nuc977
A love that never leaves
Mongodb image configuration method
CMakeLists.txt Template