当前位置:网站首页>pytorch:模型的保存与导出
pytorch:模型的保存与导出
2022-06-23 15:09:00 【代码小白的成长】
方法一: 保存模型和模型参数
torch.save( network, savePath )
def save_network( save_dir, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(save_dir, save_filename)
torch.save(network, save_path)
network = torch.load( loadPath )
def load( load_dir, network_label, epoch_label='latest'):
load_filename = '%s_net_%s.pth' % (epoch_label, network_label)
load_path = os.path.join(load_dir, load_filename)
return torch.load(load_path )
特点:模型的导入和导出很容易,但是如果数据量比较大,会消耗大量时间。占用的内存也比较高。
方法二: 只保存模型参数 (推荐)
torch.save( network.state_dict(), savePath )
def save_network(save_dir, network, network_label, epoch_label):
save_filename = '%s_net_%s_params.pkl' % (epoch_label, network_label)
save_path = os.path.join(save_dir, save_filename)
torch.save(network.state_dict(), save_path)
由于模型保存的是参数,所以在测试阶段,要先定义网络,再把导出模型参数赋值给定义的网络:
# Load model
#定义网络
G = Generator(opt.input_nc, opt.output_nc)
G.cuda()
# 把保存的参数导出,并赋值给网络
model_dir = os.path.join(opt.checkpoints_dir, opt.name)
load_filename = '%s_net_%s_params.pkl' % (epoch_label, network_label)
load_path = os.path.join(save_dir, load_filename)
G.load_state_dict(torch.load(load_path))
边栏推荐
- Converging ecology, enabling safe operation, Huawei cloud security, cloud brain intelligent service security
- How can genetic testing help patients fight disease?
- Raspberry PI installing the wiring pi
- Introduction to the push function in JS
- [pyside2] pyside2 window is on the top of Maya (note)
- C. Phoenix and Towers-Codeforces Global Round 14
- C. Add One--Divide by Zero 2021 and Codeforces Round #714 (Div. 2)
- VGG下载(.net文件和imagenet-vgg-verydeep-19)
- 电荷泵原理讲义,电压是怎么“泵”上去的?
- OpenResty 基础
猜你喜欢

MySQL日志管理怎么配置

Important knowledge of golang: rwmutex read / write lock analysis

Shandong: food "hidden money", consumption "sweeping monk"

Arrays in JS

golang 重要知识:atomic 原子操作
The idea and method of MySQL master-slave only synchronizing some libraries or tables

Important knowledge of golang: waitgroup parsing

变压器只能转换交流电,那直流电怎么转换呢?

Three simple tips for accelerating yarn install

stylegan3:alias-free generative adversarial networks
随机推荐
stylegan2:analyzing and improving the image quality of stylegan
139. Séparation des mots
TCP协议三次握手和四次挥手抓包分析
139. 单词拆分
Arrays in JS
Shandong: food "hidden money", consumption "sweeping monk"
5 minutes to quickly launch web applications and APIs (vercel)
How can genetic testing help patients fight disease?
JS里的数组
Big factory Architect: how to draw a grand business map?
MySQL advanced statement I
golang 重要知识:context 详解
Sectigo(Comodo)证书的由来
Raspberry PI installing the wiring pi
JS traversal array (using the foreach () method)
C. Add One--Divide by Zero 2021 and Codeforces Round #714 (Div. 2)
The work and development steps that must be done in the early stage of the development of the source code of the live broadcasting room
stylegan3:alias-free generative adversarial networks
基金开户是有什么风险?开户安全吗
mysql 系列:总体架构概述