当前位置:网站首页>ONNX Runtime介绍
ONNX Runtime介绍
2022-07-25 09:33:00 【fengbingchun】
ONNX Runtime:由微软推出,用于优化和加速机器学习推理和训练,适用于ONNX模型,是一个跨平台推理和训练机器学习加速器(ONNX Runtime is a cross-platform inference and training machine-learning accelerator),源码地址:https://github.com/microsoft/onnxruntime,最新发布版本为v1.11.1,License为MIT:
1.ONNX Runtime Inferencing:高性能推理引擎
(1).可在不同的操作系统上运行,包括Windows、Linux、Mac、Android、iOS等;
(2).可利用硬件增加性能,包括CUDA、TensorRT、DirectML、OpenVINO等;
(3).支持PyTorch、TensorFlow等深度学习框架的模型,需先调用相应接口转换为ONNX模型;
(4).在Python中训练,确可部署到C++/Java等应用程序中。
2.ONNX Runtime Training:于2021年4月发布,可加快PyTorch对模型训练,可通过CUDA加速,目前多用于Linux平台。
通过conda命令安装执行:
conda install -c conda-forge onnxruntime以下为测试代码:通过ResNet-50对图像进行分类
import numpy as np
import onnxruntime
import onnx
from onnx import numpy_helper
import urllib.request
import os
import tarfile
import json
import cv2
# reference: https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb
def download_onnx_model():
labels_file_name = "imagenet-simple-labels.json"
model_tar_name = "resnet50v2.tar.gz"
model_directory_name = "resnet50v2"
if os.path.exists(model_tar_name) and os.path.exists(labels_file_name):
print("files exist, don't need to download")
else:
print("files don't exist, need to download ...")
onnx_model_url = "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.tar.gz"
imagenet_labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
# retrieve our model from the ONNX model zoo
urllib.request.urlretrieve(onnx_model_url, filename=model_tar_name)
urllib.request.urlretrieve(imagenet_labels_url, filename=labels_file_name)
print("download completed, start decompress ...")
file = tarfile.open(model_tar_name)
file.extractall("./")
file.close()
return model_directory_name, labels_file_name
def load_labels(path):
with open(path) as f:
data = json.load(f)
return np.asarray(data)
def images_preprocess(images_path, images_name):
input_data = []
for name in images_name:
img = cv2.imread(images_path + name)
img = cv2.resize(img, (224, 224))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
data = np.array(img).transpose(2, 0, 1)
#print(f"name: {name}, opencv image shape(h,w,c): {img.shape}, transpose shape(c,h,w): {data.shape}")
# convert the input data into the float32 input
data = data.astype('float32')
# normalize
mean_vec = np.array([0.485, 0.456, 0.406])
stddev_vec = np.array([0.229, 0.224, 0.225])
norm_data = np.zeros(data.shape).astype('float32')
for i in range(data.shape[0]):
norm_data[i,:,:] = (data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
# add batch channel
norm_data = norm_data.reshape(1, 3, 224, 224).astype('float32')
input_data.append(norm_data)
return input_data
def softmax(x):
x = x.reshape(-1)
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
def postprocess(result):
return softmax(np.array(result)).tolist()
def inference(onnx_model, labels, input_data, images_name, images_label):
session = onnxruntime.InferenceSession(onnx_model, None)
# get the name of the first input of the model
input_name = session.get_inputs()[0].name
count = 0
for data in input_data:
print(f"{count+1}. image name: {images_name[count]}, actual value: {images_label[count]}")
count += 1
raw_result = session.run([], {input_name: data})
res = postprocess(raw_result)
idx = np.argmax(res)
print(f" result: idx: {idx}, label: {labels[idx]}, percentage: {round(res[idx]*100, 4)}%")
sort_idx = np.flip(np.squeeze(np.argsort(res)))
print(" top 5 labels are:", labels[sort_idx[:5]])
def main():
model_directory_name, labels_file_name = download_onnx_model()
labels = load_labels(labels_file_name)
print("the number of categories is:", len(labels)) # 1000
images_path = "../../data/image/"
images_name = ["5.jpg", "6.jpg", "7.jpg", "8.jpg", "9.jpg", "10.jpg"]
images_label = ["goldfish", "hen", "ostrich", "crocodile", "goose", "sheep"]
if len(images_name) != len(images_label):
print("Error: images count and labes'length don't match")
return
input_data = images_preprocess(images_path, images_name)
onnx_model = model_directory_name + "/resnet50v2.onnx"
inference(onnx_model, labels, input_data, images_name, images_label)
print("test finish")
if __name__ == "__main__":
main()测试图像如下所示:

执行结果如下所示:
边栏推荐
- 3.信你能理解的!shell脚本之循环语句与函数,数组,冒泡排序
- 三、unittest测试用例五种运行方式
- Storage, computing and distributed computing (collection and sorting is suitable for Xiaobai)
- Ansible部署指南
- Reproduce asvspoof 2021 baseline rawnet2
- 5.NFS共享服务和ssh远程控制服务
- Angr (IX) -- angr_ ctf
- Fastdfs离线部署(图文)
- Multithreading -- callable interface, lambda
- 5. NFS shared services and SSH Remote Control Services
猜你喜欢

Angr(五)——angr_ctf

Idea overall font size modification

MySQL offline deployment

Vs Code connects to the remote jupyter server

After switching the shell command line terminal (bash/zsh), CONDA cannot be used: command not found

js 双向链表 02

Angr(七)——angr_ctf

2021 京东笔试总结

使用Three.js实现炫酷的赛博朋克风格3D数字地球大屏

3. Believe you can understand! Circular statements and functions of shell scripts, arrays, bubble sorting
随机推荐
Open虚拟专线网络负载均衡
Simple addition calculator
Number theory -- Research on divisor
MySQL solves the problem of not supporting Chinese
7. Shell practical gadget cut, etc
Detailed explanation of chrome developer tools
2021 京东笔试总结
Array static initialization, traversal, maximum value
11. Iptables firewall
微信小程序WxPrase中包含文件无法点击解决
Yiwen society, three necessary packet capturing tools for hackers
Pytorch tensor list is converted to tensor list of tensor to tensor using torch.stack()
Bug elements
一、unittest框架和pytest框架的区别
For cycle: daffodil case
5. NFS shared services and SSH Remote Control Services
Selenium waits for the occurrence of elements and the conditions under which the waiting operation can be performed
Set up lnmp+discuz Forum
Angr (III) - angr_ ctf
shortest-unsorted-continuous-subarray