当前位置:网站首页>pytorch:模型的保存与导出
pytorch:模型的保存与导出
2022-06-23 15:09:00 【代码小白的成长】
方法一: 保存模型和模型参数
torch.save( network, savePath )
def save_network( save_dir, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(save_dir, save_filename)
torch.save(network, save_path)
network = torch.load( loadPath )
def load( load_dir, network_label, epoch_label='latest'):
load_filename = '%s_net_%s.pth' % (epoch_label, network_label)
load_path = os.path.join(load_dir, load_filename)
return torch.load(load_path )
特点:模型的导入和导出很容易,但是如果数据量比较大,会消耗大量时间。占用的内存也比较高。
方法二: 只保存模型参数 (推荐)
torch.save( network.state_dict(), savePath )
def save_network(save_dir, network, network_label, epoch_label):
save_filename = '%s_net_%s_params.pkl' % (epoch_label, network_label)
save_path = os.path.join(save_dir, save_filename)
torch.save(network.state_dict(), save_path)
由于模型保存的是参数,所以在测试阶段,要先定义网络,再把导出模型参数赋值给定义的网络:
# Load model
#定义网络
G = Generator(opt.input_nc, opt.output_nc)
G.cuda()
# 把保存的参数导出,并赋值给网络
model_dir = os.path.join(opt.checkpoints_dir, opt.name)
load_filename = '%s_net_%s_params.pkl' % (epoch_label, network_label)
load_path = os.path.join(save_dir, load_filename)
G.load_state_dict(torch.load(load_path))
边栏推荐
- Gartner's latest report: development of low code application development platform in China
- Important knowledge of golang: detailed explanation of context
- JS里的数组
- 进销存软件排行榜前十名!
- stylegan1: a style-based henerator architecture for gemerative adversarial networks
- 股票开账户如何优惠开户?在线开户安全么?
- TCP协议三次握手和四次挥手抓包分析
- Variable declaration of go language
- 30. concatenate substrings of all words
- MySQL advanced statement I
猜你喜欢

重卡界销售和服务的“扛把子”,临沂广顺深耕产品全生命周期服务

他山之石 | 微信搜一搜中的智能问答技术

Important knowledge of golang: atomic atomic operation

进销存软件排行榜前十名!

Three simple tips for accelerating yarn install
Solution to the problem that MySQL cannot be started in xampp

JS garbage collection

golang 重要知识:context 详解

嵌入式软件架构设计-程序分层
mysql主从只同步部分库或表的思路与方法
随机推荐
Sorting out and summarizing the handling schemes for the three major exceptions of redis cache
30. 串联所有单词的子串
[MAE]Masked Autoencoders掩膜自编码器
C. Product 1 Modulo N-Codeforces Round #716 (Div. 2)
Unshift() and shift() of JS
Converging ecology, enabling safe operation, Huawei cloud security, cloud brain intelligent service security
C. Set or Decrease-Educational Codeforces Round 120 (Rated for Div. 2)
现在我要买股票,怎么开户?手机开户安全么?
F5 application strategy status report in 2022: edge deployment and load security become the focus of attention in the Asia Pacific Region
Arrays in JS
Shandong: food "hidden money", consumption "sweeping monk"
C. Phoenix and Towers-Codeforces Global Round 14
进销存软件排行榜前十名!
看,这就是调制解调原理分析!附仿真文件
JS garbage collection
Convert JSON file of labelme to coco dataset format
【Pyside2】 pyside2的窗口在maya置顶(笔记)
Big factory Architect: how to draw a grand business map?
List query sorting parameter processing
Moher College - manual SQL injection vulnerability test (MySQL database)