当前位置:网站首页>pytorch模型转libtorch和onnx格式的通用代码
pytorch模型转libtorch和onnx格式的通用代码
2022-08-02 14:09:00 【虹夭】
依赖
- torch
- onnx
- onnx simplifer
需要自己设置的重要参数
- model_path 模型权重路径
- model 网络实例
- inp 样例输入,就是一个shape合法的tensor,batchsize(第一维)设置为1就行
下面以torchvision自带的resnet101模型为例。权重是使用官方的预训练模型,调用resnet101(pretrained=True)时会自动下载到%USERPROFILE%/.cache/torch/hub下面
import onnx
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
from torchvision.models.resnet import resnet101
from utils.func import file_size, colorstr
model_path = './weights/resnet101.pth' # 模型权重路径
model = resnet101() # 模型对象
height, width = 640, 640
inp = torch.zeros([1, 3, height, width]) # 样例输入,用于trace
# common
half = True # fp16量化
# onnx profile
onnx_export = True # 是否输出onnx格式
opset_version = 13 # 算子集版本
dynamic = False # 是否动态输入batchsize,需要设置下面两个选项
input_names = ['inputs']
dynamic_axes = {
'inputs': {
0: 'batch', 1: 'kp28'}, # 动态batchsize设置
'output': {
0: 'batch', 1: 'classes'}}
simplify = True # 是否简化
# libtorch profile
libtorch_export = True # 是否输出libtorch格式
optimize = False # 针对移动端优化,不是移动端别用
strict = False # 严格模式,设置False就行
if __name__ == '__main__':
model.load_state_dict(torch.load(model_path))
model.cpu().eval()
if half:
inp, model = inp.half(), model.half()
if onnx_export:
prefix = colorstr('ONNX:')
f = model_path.replace('.pth', '.onnx') # filename
torch.onnx.export(model, inp, f, verbose=False, opset_version=opset_version, input_names=input_names,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
dynamic_axes=dynamic_axes if dynamic else None)
# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph)) # print
# Simplify
if simplify:
try:
import onnxsim
print(f'simplifying with onnx-simplifier {
onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=dynamic,
input_shapes={
'images': list(inp.shape)} if dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{
prefix} simplifier failure: {
e}')
print(f'{
prefix} export success, saved as {
f} ({
file_size(f):.1f} MB)')
if libtorch_export:
prefix = colorstr('TorchScript:')
try:
print(f'\n{
prefix} starting export with torch {
torch.__version__}...')
f = model_path.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, inp, strict=strict)
(optimize_for_mobile(ts) if optimize else ts).save(f)
print(f'{
prefix} export success, saved as {
f} ({
file_size(f):.1f} MB)')
except Exception as e:
print(f'{
prefix} export failure: {
e}')
边栏推荐
猜你喜欢
随机推荐
FP7195芯片PWM转模拟调光至0.1%低亮度时恒流一致性的控制原理
Win10上帝模式干嘛的?Win10怎么开启上帝模式?
Bash shell位置参数
Win11 computer off for a period of time without operating network how to solve
FP7195降压恒流PWM转模拟调光零压差大功率驱动方案原理图
Seq2Seq模型PyTorch版本
arm push/pop/b/bl汇编指令
DP1332E内置c8051的mcu内核NFC刷卡芯片国产兼容NXP
PyTorch(13)---优化器_随机梯度下降法
Win10安装了固态硬盘还是有明显卡顿怎么办?
Actual combat Meituan Nuxt +Vue family bucket, server-side rendering, mailbox verification, passport authentication service, map API reference, mongodb, redis and other technical points
为vscode配置clangd
Win7遇到错误无法正常开机进桌面怎么解决?
CS4398音频解码替代芯片DP4398完全兼容DAC解码
Letter combination of LeetCode2 phone number
Please make sure you have the correct access rights and the repository exists.问题解决
Win11系统找不到dll文件怎么修复
PyTorch(14)---使用现有的模型及其修改
STM32F1和F4的区别
DP4301无线收发SUB-1G芯片兼容CC1101智能家居



![[论文阅读] ACT: An Attentive Convolutional Transformer for Efficient Text Classification](/img/59/88db682b6ff82d3612fd582cd499b2.png)





