当前位置:网站首页>Halide::Generator生成器使用说明
Halide::Generator生成器使用说明
2022-07-24 01:57:00 【慷仔】
Halide::Generator生成器使用说明
我们之前已经能够使用Halide去调用OpenCL后端了,但是我们的代码总是在建立计算图和调度后使用realize实例化进行运行。
这很明显对于我们AI模型优化核函数、推理运行时这些操作造成不小的影响。而且也无法剥离出Halide这个庞大的依赖库,导致
最终的推理时依赖过多引起的接口过于复杂。
1、前言
本次的主要工作就是尝试了Halide::Generator生成器去进行核函数生成以及如何使用不同后端的试验。
关键信息如下所示:
- 多计算图
- 调度gpu约束
- 输入输出定义信息
- 自动调度操作
- 生成器调用示例
2、核心代码总览
#include <stdio.h>
#include "Halide.h"
#include "HalideBuffer.h"
#include "clock.h"
using namespace Halide;
using namespace Halide::Tools;
//尝试进行核函数生成操作。
class image_f4:public Halide::Generator<image_f4>{
public:
// 定义输入buffer,图片数据,三维数据
Input<Buffer<uint8_t>> input{"input", 3};
// 定义输出buffer,图片数据,三维数据,两者shape完全一致。
Output<Buffer<uint8_t>> output{"output", 3};
void generate()
{
Var x, y, c, i, ii, xo, yo, xi, yi;
Func lut;
Func curved;
Func padded;
lut(i) = cast<uint8_t>(clamp(pow(i / 255.0f, 1.2f) * 255.0f, 0, 255));
// Augment the input with a boundary condition.
padded(x, y, c) = input(clamp(x, 0, input.width() - 1),
clamp(y, 0, input.height() - 1), c);
// Cast it to 16-bit to do the math.
Func padded16;
padded16(x, y, c) = cast<uint16_t>(padded(x, y, c));
// Next we sharpen it with a five-tap filter.
Func sharpen;
sharpen(x, y, c) = (padded16(x, y, c) * 2 -
(padded16(x - 1, y, c) +
padded16(x, y - 1, c) +
padded16(x + 1, y, c) +
padded16(x, y + 1, c)) /
4);
curved(x, y, c) = lut(sharpen(x, y, c));
lut.compute_root();
Var block, thread;
lut.split(i, block, thread, 16);
lut.gpu_blocks(block)
.gpu_threads(thread);
output(x, y, c) = curved(x, y, c);
}
void schedule()
{
/* THE SCHEDULE */
// input.set_estimates({
{0, 1024}, {0, 1024}, {1, 3}});
// output.set_estimates({
{0, 1024}, {0, 1024}, {1, 3}});
}
};
HALIDE_REGISTER_GENERATOR(image_f4, image_f4)
int main(int argc, char **argv) {
return Halide::Internal::generate_filter_main(argc, argv, std::cerr);
}
3、具体说明和注意事项
3.1、生成器使用流程
- 1、建立基于父类Halide::Generator的核生成器
class image_f4:public Halide::Generator<image_f4>{}
- 2、声明输出输入buffer信息
// 定义输入buffer,图片数据,三维数据
Input<Buffer<uint8_t>> input{"input", 3};
// 定义输出buffer,图片数据,三维数据,两者shape完全一致。
Output<Buffer<uint8_t>> output{"output", 3};
- 3、声明核函数的halide计算图实现
void generate()
{
Var x, y, c, i, ii, xo, yo, xi, yi;
Func lut;
Func curved;
Func padded;
lut(i) = cast<uint8_t>(clamp(pow(i / 255.0f, 1.2f) * 255.0f, 0, 255));
// Augment the input with a boundary condition.
padded(x, y, c) = input(clamp(x, 0, input.width() - 1),
clamp(y, 0, input.height() - 1), c);
// Cast it to 16-bit to do the math.
Func padded16;
padded16(x, y, c) = cast<uint16_t>(padded(x, y, c));
// Next we sharpen it with a five-tap filter.
Func sharpen;
sharpen(x, y, c) = (padded16(x, y, c) * 2 -
(padded16(x - 1, y, c) +
padded16(x, y - 1, c) +
padded16(x + 1, y, c) +
padded16(x, y + 1, c)) /
4);
curved(x, y, c) = lut(sharpen(x, y, c));
/*-----------自定义调度-----------*/
lut.compute_root();
Var block, thread;
lut.split(i, block, thread, 16);
lut.gpu_blocks(block)
.gpu_threads(thread);
/*-----------自定义调度-----------*/
output(x, y, c) = curved(x, y, c);
}
- 4、自动调度设置(可选)
void schedule()
{
/* THE SCHEDULE */
// input.set_estimates({
{0, 1024}, {0, 1024}, {1, 3}});
// output.set_estimates({
{0, 1024}, {0, 1024}, {1, 3}});
}
- 5、注册生成代码操作
HALIDE_REGISTER_GENERATOR(image_f4, image_f4)
//后续通过argv传参去生成image_f4核实现
int main(int argc, char **argv) {
return Halide::Internal::generate_filter_main(argc, argv, std::cerr);
}
- 6、命令行操作
if [ ! -d "./halide_generate_file" ]; then
mkdir halide_generate_file
else
rm -rf halide_generate_file/*
fi
# 假设总览代码被编译成了test可执行程序
# target=x86-64-linux-opencl -r GPU这些参数的设置决定了目标平台必然会使用opencl实现
./test -g image_f4 -e c_header,c_source -o halide_generate_file target=x86-64-linux-opencl -r GPU
# 那么将在halide_generate_file文件夹下生成相关代码
3.2、生成代码内容总览
- opencl核函数代码如下所示:
/* OpenCL C x86-64-linux-opencl*/
#pragma OPENCL FP_CONTRACT ON
inline float float_from_bits(unsigned int x) {
return as_float(x);}
inline float nan_f32() {
return NAN; }
inline float neg_inf_f32() {
return -INFINITY; }
inline float inf_f32() {
return INFINITY; }
inline bool is_nan_f32(float x) {
return isnan(x); }
inline bool is_inf_f32(float x) {
return isinf(x); }
inline bool is_finite_f32(float x) {
return isfinite(x); }
#define sqrt_f32 sqrt
#define sin_f32 sin
#define cos_f32 cos
#define exp_f32 exp
#define log_f32 log
#define abs_f32 fabs
#define floor_f32 floor
#define ceil_f32 ceil
#define round_f32 round
#define trunc_f32 trunc
#define pow_f32 pow
#define asin_f32 asin
#define acos_f32 acos
#define tan_f32 tan
#define atan_f32 atan
#define atan2_f32 atan2
#define sinh_f32 sinh
#define asinh_f32 asinh
#define cosh_f32 cosh
#define acosh_f32 acosh
#define tanh_f32 tanh
#define atanh_f32 atanh
#define fast_inverse_f32 native_recip
#define fast_inverse_sqrt_f32 native_rsqrt
#define halide_unused(x)
__kernel void _at_least_one_kernel(int x) {
}
// Address spaces for _kernel_f0_s0_v3_v9___block_id_x
#define __address_space__f0 __global
__kernel void _kernel_f0_s0_v3_v9___block_id_x(
__address_space__f0 uchar *restrict _f0,
__local int16* __shared)
{
int _f0_s0_v3_v9___block_id_x = get_group_id(0);
int ___thread_id_x = get_local_id(0);
int _0 = _f0_s0_v3_v9___block_id_x * 16;
int _1 = _0 + ___thread_id_x;
float _2 = (float)(_1);
float _3 = float_from_bits(998277249 /* 0.00392157 */);
float _4 = _2 * _3;
float _5 = float_from_bits(1067030938 /* 1.2 */);
float _6 = pow_f32(_4, _5);
float _7 = float_from_bits(1065353216 /* 1 */);
float _8 = min(_6, _7);
float _9 = float_from_bits(0 /* 0 */);
float _10 = max(_8, _9);
float _11 = float_from_bits(1132396544 /* 255 */);
float _12 = _10 * _11;
uchar _13 = (uchar)(_12);
_f0[_1] = _13;
} // kernel _kernel_f0_s0_v3_v9___block_id_x
#undef __address_space__f0
4、后续计划与安排
到了这一步,我们已经可以使用halide继续核函数的生成,但是还需要进行如何使用核函数的过程。
边栏推荐
- jenkins多任務並發構建
- Install go environment under Kali
- C byte array and class mutual conversion
- On the possibility and limitation of defi in the metauniverse
- 深入了解-微信开发者工具
- Local empowerment learning
- Vantui, axiso, FAQs and usage:
- [untitled]
- How to synchronize MySQL database when easycvr platform is upgraded to the latest version v2.5.0?
- win11系统之win11亮点
猜你喜欢

Arm architecture and programming 7 -- exceptions and interrupts (based on Baiwen arm architecture and programming tutorial video)

win11之缺点

Hospital generic cabling

Hcip network type, PPP session, data link layer protocol

Structure the second operation of the actual combat battalion module

How CAD draws arrows with arcs

STM32概念和安装【第一天】

Spark partition operators partitionby, coalesce, repartition

Digicert code signing certificate

145-keep-alive的初步使用
随机推荐
C - structure
暑假第三周
Structure the second operation of the actual combat battalion module
xxl-job使用注意事项
浅谈元宇宙中DeFi的可能性和局限性
LiteSpeed Web服务器中安装SSL证书
Precautions for using XXL job
Express operates mysql. What is wrong with the SQL?
Magazine feature: the metauniverse will reshape our lives, and we need to make sure it gets better
Draw pictures with canvas
OSPF (sixth day notes)
MySQL Basics (operators, sorting and paging, multi table queries, functions)
20220723 记录一次SAP Oracle 监听服务莫名停掉的问题
"Guanghetong AI intelligent module sca825-w" with full AI performance accelerates the era of e-commerce live broadcast 2.0
Design of hospital wireless network system
How QT counts the frequency of letters in a string
[code case] website confession wall & to do list (including complete source code)
win11之缺点
Ora-12899 error caused by nchar character
STM32概念和安装【第一天】