当前位置:网站首页>Pytorch中自己所定义(修改)的模型加载所需部分预训练模型参数并冻结
Pytorch中自己所定义(修改)的模型加载所需部分预训练模型参数并冻结
2022-06-26 05:26:00 【被rua弄的小狸花】
本文部分参考https://zhuanlan.zhihu.com/p/34147880
一.此法比较万能,就根据自己模型的参数来加载预训练模型参数,同名就赋值。如果自己在原模型上加了些层则不会加载
dict_trained=torch.load(self.args.load_path, map_location=torch.device('cpu'))
dict_new=model.state_dict()
# 1. filter out unnecessary keys
dict_trained = {
k: v for k, v in dict_trained.items() if k in dict_new}
# 2. overwrite entries in the existing state dict
model_dict.update(dict_trained)
model.load_state_dict(dict_new)
二. 这个则就复杂不少,按自己所需进行更改,比如我的,就是本模型增加了四层’dense’, ‘unary_affine’, ‘binary_affine’, ‘classifier’,通过j+=8,跳过他们的weight和bias,这个可以参考权重衰减。同时将原模型参数中’crf’部分不加载。
dict_trained = torch.load(self.args.load_path, map_location=torch.device('cpu'))
dict_new = self.model.state_dict().copy()
trained_list = list(dict_trained.keys())
new_list = list(dict_new.keys())
j = 0
no_loda = {'dense', 'unary_affine', 'binary_affine', 'classifier'}
for i in range(len(trained_list)):
flag = False
if 'crf' in trained_list[i]:
continue
for nd in no_loda:
if nd in new_list[j] and 'bert' not in new_list[j]:
flag = True
if flag:
j += 8 # no_loda的dense和bias掠过
else:
dict_new[new_list[j]] = dict_trained[trained_list[i]]
if new_list[j] != trained_list[i]:
print("i:{},new_state_dict: {} trained state_dict: {}不一致".format(i, new_list[j], trained_list[i]))
j += 1 #keys不对齐
model.load_state_dict(dict_new)
后面了解到有一种更简单的方法:
就是当你设置好你自己的模型后,如果仅想使用预训练模型相同结构处的参数,即在加载的时候将参数strict设置为False即可。该参数值默认为True,表示预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度),否则无法加载,实现如下:
model.load_state_dict(torch.load(self.args.load_path, strict=False))
PS: 遇到错了,不妨把自己所修改模型参数的keys和加载模型参数的keys打印下来看看,对症下药
三.冻结这几层参数
简单来说就是
for k in model.paramers:
k.requires_grad=False
方法很多,这里用和上面方法对应的冻结方法
建议看一下
https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088
或者
https://discuss.pytorch.org/t/correct-way-to-freeze-layers/26714
或者
对应的,在训练时候,optimizer里面只能更新requires_grad = True的参数,于是
optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, net.parameters(),lr) )
边栏推荐
- [red team] what preparations should be made to join the red team?
- Command line interface of alluxio
- cartographer_local_trajectory_builder_2d
- ZigBee explain in simple terms lesson 2 hardware related and IO operation
- Douban top250
- Briefly describe the pitfalls of mobile IM development: architecture design, communication protocol and client
- How to make your big file upload stable and fast?
- Introduction to GUI programming to game practice (I)
- [unity3d] collider assembly
- Vie procédurale
猜你喜欢
Decipher the AI black technology behind sports: figure skating action recognition, multi-mode video classification and wonderful clip editing
The localstorage browser stores locally to limit the number of forms submitted when tourists do not log in.
cartographer_ local_ trajectory_ builder_ 2d
The difference between get and post in small interview questions
Baidu API map is not displayed in the middle, but in the upper left corner. What's the matter? Resolved!
cartographer_ fast_ correlative_ scan_ matcher_ 2D branch and bound rough matching
zencart新建的URL怎么重写伪静态
红队得分方法统计
cartographer_fast_correlative_scan_matcher_2d分支定界粗匹配
Using Jenkins to perform testng+selenium+jsup automated tests and generate extendreport test reports
随机推荐
skimage.morphology.medial_axis
The best Chinese open source class of vision transformer, ten hours of on-site coding to play with the popular model of Vit!
关于支付接口回调地址参数字段是“notify_url”,签名过后的特殊字符url编码以后再解码后出现错误(¬ , ¢, ¤, £)
[arm] add desktop application for buildreoot of rk3568 development board
Gd32f3x0 official PWM drive has a small positive bandwidth (inaccurate timing)
出色的学习能力,才是你唯一可持续的竞争优势
Henkel database custom operator '~~‘
12 multithreading
cartographer_ fast_ correlative_ scan_ matcher_ 2D branch and bound rough matching
【红队】要想加入红队,需要做好哪些准备?
Protocol selection of mobile IM system: UDP or TCP?
Procedural life
Baidu API map is not displayed in the middle, but in the upper left corner. What's the matter? Resolved!
使用Jedis监听Redis Stream 实现消息队列功能
Uni app ceiling fixed style
Anaconda creates tensorflow environment
Leetcode114. Expand binary tree into linked list
Sentimentin tensorflow_ analysis_ cell
Douban top250
Tensorflow visualization tensorboard "no graph definition files were found." error