当前位置:网站首页>pytorch 保存和加载模型
pytorch 保存和加载模型
2022-07-23 01:28:00 【Mick..】
模型的保存和加载
1 只保存和加载模型参数
torch.save(model.state_dict(), PATH) ###将模型的参数保存到这个地址下,后缀名为pt
model = model(*args, **kwargs) ###定义模型
model.load_state_dict(torch.load(PATH)) ##导入模型参数
2 保存和加载整个模型
torch.save(model,path)
model=torch.load(path)这种方式可以直接保存整个模型,在应用的时候不用再重新定义模型。
定义网络结构
这里定义了最简单的网络结构。两层的全连接层
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1=nn.Linear(1,3) ###线性层
self.layer2=nn.Linear(3,1)
def forward(self,x):
x=self.layer1(x)
x=torch.relu(x) ###relu激活函数
x=self.layer2(x)
return x
训练神经网络
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
epoches=2000
# 学习率定义为0.01
learning_rate=0.01
# 创建一个模型
model=Net()
optimizer = optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=1e-5)
criterion = nn.MSELoss() #定义损失函数
# 使用优化器来更新网络权重,lr为学习率,
for i in range(epoch): #设定训练epoch次
model.train() #将模型的状态设置为train
for j in Sample: #对每一个样本进行遍历
optimizer.zero_grad() #将梯度清理,为这次的梯度计算做准备
output = model(j)
loss = criterion(output, target)
loss.backward() ###这里记录的平均loss
optimizer.step() #更新网络权重
if (epoch+1) % 10==0: ##每十次打印一下当前的状态
print("Epoch {} / {},loss {:.4f}".format(epoch+1,num_epoches,loss.item()))
Pytorch模型保存
##torch.save() 可以保存字典类型的数据
save_checkpoint({'loss': i, 'state_dict': model.state_dict()},dir)
def save_checkpoint(state, dic): #state是模型的权重和状态 dic是模型保存的目录
if not os.path.exists(dir):
os.makedirs(directory)
fileName = directory + 'last.pth'
torch.save(state,fileName)#使用torch.save函数直接对训练好的模型进行保存
边栏推荐
猜你喜欢

Huawei applications have called the checkappupdate interface. Why is there no prompt for version update in the application

求解最大公约数和最小公倍数

Developers must see | devweekly issue 1: what is time complexity?

Advantages of BGP machine room

一个月学透阿里整理的分布式架构笔记

【MySQL从入门到精通】【高级篇】(七)设计一个索引&InnoDB中的索引方案

-Bash: wget: command not found

Learn the distributed architecture notes sorted out by Alibaba in one month
How many points can you get on the latest UnionPay written test for test engineers?

Cbcgpcolordialog control used by BCG
随机推荐
Transformer summary
【管理篇 / 升级】* 02. 查看升级路径 * FortiGate 防火墙
How to learn MySQL efficiently and systematically?
Advantages of server hosting, server leasing and virtual machine
[C language] file operation
1059 Prime Factors
【C语言】预处理详解
727. 最小窗口子序列 滑动窗口
力扣(LeetCode)203. 移除链表元素(2022.07.22)
PyTorch可视化
2302. Count the number of subarrays with a score less than k - sliding array - double hundred code
1646. 获取生成数组中的最大值递归法
35岁程序员,早到的中年危机
727. Minimum window subsequence sliding window
Compose与RecyclerView结合效果会是怎样的?
[cann training camp] learning notes - Comparison between diffusion and Gan, dalle2 and Party
opensmile简介和安装过程中遇到的问题记录
【无标题】
模板学堂丨JumpServer安全运维审计大屏
2000. reverse word prefix