当前位置:网站首页>pytorch的模型保存加载和继续训练
pytorch的模型保存加载和继续训练
2022-06-22 19:27:00 【Weiyaner】
随着现在模型越来越大,一次性训练完模型在低算力平台也越来越难以实现,因此很有必要在训练过程中保存模型,以便下次之前训练的基础上进行继续训练,节约时间。代码如下:
导包
import torch
from torch import nn
import numpy as np
定义模型
定义一个三层的MLP分类模型
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(64, 32)
self.linear1 = nn.Linear(32, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear(x)
x = self.relu(x)
x = self.linear1(x)
return x
## 随机生成2组带标签的数据
rand1 = torch.rand((100, 64)).to(torch.float)
label1 = np.random.randint(0, 10, size=100)
label1 = torch.from_numpy(label1).to(torch.long)
rand2 = torch.rand((100, 64)).to(torch.float)
label2 = np.random.randint(0, 10, size=100)
label2 = torch.from_numpy(label2).to(torch.long)
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss = nn.CrossEntropyLoss()
## 训练10个epoch
epoch = 10
for i in range(epoch):
output = model(rand1)
my_loss = loss(output, label1)
optimizer.zero_grad()
my_loss.backward()
optimizer.step()
print("epoch:{} loss:{}".format(i, my_loss))
结果如下:记下这些loss值,观察下次继续训练的初始loss
epoch:0 loss:2.3494179248809814
epoch:1 loss:2.287858009338379
epoch:2 loss:2.2486231327056885
epoch:3 loss:2.2189149856567383
epoch:4 loss:2.193182945251465
epoch:5 loss:2.167125940322876
epoch:6 loss:2.140075206756592
epoch:7 loss:2.1100614070892334
epoch:8 loss:2.0764594078063965
epoch:9 loss:2.0402779579162598
模型保存
采用torch.save函数保存模型,一般分为两种模式,分别是简单的保存所有参数,第二种是保存各部分参数,到一个字典结构里面。
# 保存模型的整体参数
save_path = r'model_para/'
torch.save(model, save_path+'model_full.pth')
保存模型参数,优化器参数和epoch情况。
def save_model(save_path, epoch, optimizer, model):
torch.save({
'epoch': epoch+1,
'optimizer_dict': optimizer.state_dict(),
'model_dict': model.state_dict()},
save_path)
print("model save success")
save_model(save_path+'model_dict.pth',epoch, optimizer, model)
加载模型
对于保存的pth参数文件,使用torch.load进行加载,代码如下:
def load_model(save_name, optimizer, model):
model_data = torch.load(save_name)
model.load_state_dict(model_data['model_dict'])
optimizer.load_state_dict(model_data['optimizer_dict'])
print("model load success")
观察当前训练模型的权重参数
print(model.state_dict()['linear.weight'])
tensor([[-0.0215, 0.0299, -0.0255, ..., -0.0997, -0.0899, 0.0499],
[-0.0113, -0.0974, 0.1020, ..., 0.0874, -0.0744, 0.0801],
[ 0.0471, 0.1373, 0.0069, ..., -0.0573, -0.0199, -0.0654],
...,
[ 0.0693, 0.1900, 0.0013, ..., -0.0348, 0.1541, 0.1372],
[ 0.1672, -0.0086, 0.0189, ..., 0.0926, 0.1545, 0.0934],
[-0.0773, 0.0645, -0.1544, ..., -0.1130, 0.0213, -0.0613]])
命名一个新模型,加载之前保存的参数文件,并打印出层参数
new_model = MyModel()
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.01)
load_model(save_path+'model_dict.pth', new_optimizer, new_model)
print(new_model.state_dict()['linear.weight'])
可以看出新模型和当前模型的参数一致,说明参数加载成功。
model load success
tensor([[-0.0215, 0.0299, -0.0255, ..., -0.0997, -0.0899, 0.0499],
[-0.0113, -0.0974, 0.1020, ..., 0.0874, -0.0744, 0.0801],
[ 0.0471, 0.1373, 0.0069, ..., -0.0573, -0.0199, -0.0654],
...,
[ 0.0693, 0.1900, 0.0013, ..., -0.0348, 0.1541, 0.1372],
[ 0.1672, -0.0086, 0.0189, ..., 0.0926, 0.1545, 0.0934],
[-0.0773, 0.0645, -0.1544, ..., -0.1130, 0.0213, -0.0613]])
继续训练
在新模型加载原来模型参数的基础上,继续训练,观察loss值,是在之前训练的最终loss,继续下降,说明模型继续训练成功。
epoch = 10
for i in range(epoch):
output = new_model(rand1)
my_loss = loss(output, label1)
new_optimizer.zero_grad()
my_loss.backward()
new_optimizer.step()
print("epoch:{} loss:{}".format(i, my_loss))
epoch:0 loss:2.0036799907684326
epoch:1 loss:1.965193271636963
epoch:2 loss:1.924098253250122
epoch:3 loss:1.881495714187622
epoch:4 loss:1.835693359375
epoch:5 loss:1.7865667343139648
epoch:6 loss:1.7352293729782104
epoch:7 loss:1.6832704544067383
epoch:8 loss:1.6308385133743286
epoch:9 loss:1.5763107538223267
数据分布不一致带来的问题
同样,在这里我发现一个问题,因为之前随机产生了2组数据,之前模型训练使用的rand1,这里只有继续训练rand1,之前模型的参数才有效,如果使用rand2,模型相当于从0训练(如下loss),这是因为,两组数据都是随机生成的,数据分布几乎不一样,所以上一组数据训练的模型在第二组数据几乎无效。
epoch:0 loss:2.523787498474121
epoch:1 loss:2.469816207885742
epoch:2 loss:2.4141526222229004
epoch:3 loss:2.379054069519043
epoch:4 loss:2.3563807010650635
epoch:5 loss:2.319946765899658
epoch:6 loss:2.271805763244629
epoch:7 loss:2.2274367809295654
epoch:8 loss:2.186885118484497
epoch:9 loss:2.144239902496338
但是在真实情况中,由于batch数据都是假设同一分布,所以不用考虑这个问题,
那么以上,就完成了pytorch的模型保存,加载和继续训练的三种重要过程,希望能够帮到您!!!
祝您训练愉快。
边栏推荐
- 80-分页查询,不止写法
- Security policy and NAT (easy IP) of firewall Foundation
- A Dynamic Near-Optimal Algorithm for Online Linear Programming
- Lora technology -- Lora signal changes from data to Lora spread spectrum signal, and then from RF signal to data through demodulation
- 启牛送的券商账户是安全的吗?启牛提供的券商账户是真的?
- 树莓派环境设置
- Precautions for Apollo use
- MySQL中如何计算同比和环比
- Huawei cloud releases Latin American Internet strategy
- 【Proteus仿真】8x8Led点阵数字循环显示
猜你喜欢

Using qtest for data set test performance test GUI test

Simple integration of client go gin 11 delete

EasyClick 固定状态日志窗口

Introduction of Neural Network (BP) in Intelligent Computing

R语言penguins数据集可视化

Easyclick fixed status log window

Comment le sac à dos complet considère - t - il la disposition?

用RNN & CNN进行情感分析 - PyTorch

Three months of self-taught automatic test, salary from 4.5K to 15K, who knows what I have experienced?

智能計算之神經網絡(BP)介紹
随机推荐
扩展Ribbon支持基于元数据的版本管理
真正的缓存之王Caffine Cache
[graduation season] step by step? Thinking about four years of University by an automation er
智能计算之神经网络(Hopfield网络-DHNN,CHNN )介绍
84-我对网传<52 条 SQL 语句性能优化策略>的一些看法
Easydss problem and solution summary
Scheduling with Testing
Huawei cloud releases Latin American Internet strategy
【深入理解TcaplusDB技术】单据受理之建表审批
一张图解码 OpenCloudOS 社区开放日
MySQL基础——约束
Using qtest for data set test performance test GUI test
86-给参加<SQL写法与改写培训>的学员补充一个二手案例
Scheduling with Testing
91-oracle普通表改分区表的几种方法
uniapp小程序商城开发thinkphp6积分商城、团购、秒杀 封装APP
One picture decoding opencloudos community open day
Oracle system/用户被锁定的解决方法
Multi transactions in redis
Nestjs integrates config module and Nacos to realize configuration unification