diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl index 4bc003799a..aef7bf8310 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl @@ -1,35 +1,30 @@ -#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable -#define ACCUM_FLT4 float4 +#ifdef ENABLE_FP16 +#define FLT half +#define FLT4 half4 +#define TO_FLT4 convert_half4 +#else #define FLT float -#define FLT2 float2 -#define FLT3 float3 #define FLT4 float4 #define TO_FLT4 convert_float4 -#define TO_ACCUM_TYPE convert_float4 -#define TO_ACCUM_FLT convert_float -#define READ_IMAGE read_imagef -#define WRITE_IMAGE write_imagef -__constant sampler_t smp_edge = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST; -__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; -__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; -__kernel void DepthwiseConv2d_NC4HW4( -__global float4* src_data, - __global FLT4* filters, -__global FLT4* biases, +#endif +__constant sampler_t sampler_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void DepthwiseConv2d_IMG_NC4HW4( +__read_only image2d_t src_data, +__global FLT4* filter, +__global FLT4* bias, float relu_clip1, -__global float4* dst_data, - int2 kernel_size, - int2 stride, - int2 padding, - int2 dilation, - int4 src_size, - int4 dst_size -) { +__write_only image2d_t dst_data, + int2 kernel_size, + int2 stride, + int2 padding, + int2 dilation, + int4 src_size, + int4 dst_size) { int X = get_global_id(0); int Y = get_global_id(1); int Z = get_global_id(2); if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; - ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); int x_offseted = X * stride.x + padding.x; int y_offseted = Y * stride.y + padding.y; int fx_c = Z * kernel_size.x * kernel_size.y; @@ -40,37 +35,160 @@ __global float4* dst_data, int x_c = x_offseted + kx * dilation.x; bool outside_x = x_c < 0 || x_c >= src_size.x; if (!outside_x && !outside_y) { - FLT4 f = filters[fx_c]; - FLT4 src_final =src_data[(((Z) * src_size.y + (y_c)) * src_size.x + (x_c))]; - r += TO_ACCUM_TYPE(src_final * f); + FLT4 f = filter[fx_c]; + //FLT4 src_final =src_data[(((Z) * src_size.y + (y_c)) * src_size.x + (x_c))]; + FLT4 src_final =read_imagef(src_data, sampler_zero, (int2)(x_c, (Z * src_size.y + y_c))); + r += TO_FLT4(src_final * f); }; fx_c++; } } - FLT4 bias_val = biases[Z]; + FLT4 bias_val = bias[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + //dst_data[(((Z) * dst_size.y + (Y)) * dst_size.x + (X))] = res0; + write_imagef(dst_data, (int2)(X, (Z * dst_size.y + Y)), res0); +} + +__kernel void DepthwiseConv2d_IMG_NHWC4( +__read_only image2d_t src_data, +__global FLT4* filter, +__global FLT4* bias, + float relu_clip1, +__write_only image2d_t dst_data, + int2 kernel_size, + int2 stride, + int2 padding, + int2 dilation, + int4 src_size, + int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filter[fx_c]; + //FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; + FLT4 src_final =read_imagef(src_data, sampler_zero, (int2)(Z+x_c*src_size.z, y_c)); + r += TO_FLT4(src_final * f); + }; + fx_c++; + } + } + FLT4 bias_val = bias[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + //dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0; + write_imagef(dst_data, (int2)(X*dst_size.z+Z, Y), res0); +} + +__kernel void DepthwiseConv2d_IMG_NHWC4_1x1( +__read_only image2d_t src_data, +__global FLT4* filter, +__global FLT4* bias, + float relu_clip1, +__write_only image2d_t dst_data, + int2 kernel_size, + int2 stride, + int2 padding, + int2 dilation, + int4 src_size, + int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z; + { + int y_c = y_offseted; + bool outside_y = y_c < 0 || y_c >= src_size.y; + { + int x_c = x_offseted; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filter[fx_c]; + //FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; + FLT4 src_final = read_imagef(src_data, sampler_zero, (int2)(Z, (y_c * src_size.x + x_c) * src_size.z)); + r += TO_FLT4(src_final * f); + }; + } + } + FLT4 bias_val = bias[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + //dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0; + write_imagef(dst_data, (int2)(Z, (Y * dst_size.x + X) * dst_size.z), res0); +} +__kernel void DepthwiseConv2d_BUF_NC4HW4( +__global FLT4* src_data, +__global FLT4* filter, +__global FLT4* bias, + float relu_clip1, +__global FLT4* dst_data, + int2 kernel_size, + int2 stride, + int2 padding, + int2 dilation, + int4 src_size, + int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filter[fx_c]; + FLT4 src_final =src_data[(((Z) * src_size.y + (y_c)) * src_size.x + (x_c))]; + r += TO_FLT4(src_final * f); + }; + fx_c++; + } + } + FLT4 bias_val = bias[Z]; FLT4 res0 = TO_FLT4(r) + bias_val; res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); dst_data[(((Z) * dst_size.y + (Y)) * dst_size.x + (X))] = res0; } -__kernel void DepthwiseConv2d_NHWC4( -__global float4* src_data, - __global FLT4* filters, -__global FLT4* biases, +__kernel void DepthwiseConv2d_BUF_NHWC4( +__global FLT4* src_data, +__global FLT4* filter, +__global FLT4* bias, float relu_clip1, -__global float4* dst_data, - int2 kernel_size, - int2 stride, - int2 padding, - int2 dilation, - int4 src_size, - int4 dst_size -) { +__global FLT4* dst_data, + int2 kernel_size, + int2 stride, + int2 padding, + int2 dilation, + int4 src_size, + int4 dst_size) { int X = get_global_id(0); int Y = get_global_id(1); int Z = get_global_id(2); if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; - ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); int x_offseted = X * stride.x + padding.x; int y_offseted = Y * stride.y + padding.y; int fx_c = Z * kernel_size.x * kernel_size.y; @@ -81,14 +199,53 @@ __global float4* dst_data, int x_c = x_offseted + kx * dilation.x; bool outside_x = x_c < 0 || x_c >= src_size.x; if (!outside_x && !outside_y) { - FLT4 f = filters[fx_c]; - FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; - r += TO_ACCUM_TYPE(src_final * f); + FLT4 f = filter[fx_c]; + FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; + r += TO_FLT4(src_final * f); }; fx_c++; } } - FLT4 bias_val = biases[Z]; + FLT4 bias_val = bias[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0; +} + +__kernel void DepthwiseConv2d_BUF_NHWC4_1x1( +__global FLT4* src_data, +__global FLT4* filter, +__global FLT4* bias, + float relu_clip1, +__global FLT4* dst_data, + int2 kernel_size, + int2 stride, + int2 padding, + int2 dilation, + int4 src_size, + int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z; + { + int y_c = y_offseted; + bool outside_y = y_c < 0 || y_c >= src_size.y; + { + int x_c = x_offseted; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filter[fx_c]; + FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; + r += TO_FLT4(src_final * f); + }; + } + } + FLT4 bias_val = bias[Z]; FLT4 res0 = TO_FLT4(r) + bias_val; res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc index 44f169e2ef..7c44f9c824 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc @@ -21,9 +21,12 @@ #include "src/runtime/opencl/opencl_runtime.h" #include "src/runtime/kernel/arm/fp32/convolution_depthwise.h" #include "src/runtime/kernel/arm/opclib/pack.h" + #ifndef PROGRAM_WITH_IL + #include "src/runtime/kernel/opencl/cl/fp16/depthwise_conv2d.cl.inc" #include "src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl.inc" + #endif using mindspore::kernel::KERNEL_ARCH::kGPU; @@ -35,20 +38,31 @@ namespace mindspore::kernel { int DepthwiseConv2dOpenCLKernel::Init() { auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); - std::string kernel_name = "DepthwiseConv2d_NHWC4"; + std::string kernel_name = "DepthwiseConv2d"; auto in_format = inputs_[0]->GetFormat(); outputs_[0]->SetFormat(in_format); if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) { MS_LOG(ERROR) << "input format(" << in_format << ") " << "format not support!"; } + if (mem_type_ == MEM_TYPE::BUF) { + kernel_name += "_BUF"; + } else { + kernel_name += "_IMG"; + } if (in_format == schema::Format_NC4HW4) { - kernel_name = "DepthwiseConv2d_NC4HW4"; + kernel_name += "_NC4HW4"; + } else if (in_format == schema::Format_NHWC4) { + kernel_name += "_NHWC4"; + } + auto parameter = reinterpret_cast(opParameter); + if (parameter->kernel_h_ == 1) { + kernel_name += "_1x1"; } #ifdef PROGRAM_WITH_IL ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); #else std::string program_name = "DepthwiseConv2d"; - std::set build_options; + std::set build_options; #ifdef ENABLE_FP16 std::string source = depthwise_conv2d_source_fp16; #else @@ -61,8 +75,9 @@ int DepthwiseConv2dOpenCLKernel::Init() { MS_LOG(DEBUG) << kernel_name << " Init Done!"; return 0; } + int DepthwiseConv2dOpenCLKernel::InitBuffer() { - auto parameter = reinterpret_cast(opParameter); + auto parameter = reinterpret_cast(opParameter); auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); auto allocator = ocl_runtime->GetAllocator(); @@ -89,54 +104,101 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() { size_t up_co_size = C4NUM * CO4 * sizeof(FLOAT_t); memset_s(bias_data_, up_co_size, 0, up_co_size); auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); - memcpy_s(bias_data_, outputs_[0]->Channel() * sizeof(FLOAT_t), ori_bias, outputs_[0]->Channel() * sizeof(FLOAT_t)); + memcpy_s(bias_data_, outputs_[0]->Channel() * sizeof(FLOAT_t), ori_bias, + outputs_[0]->Channel() * sizeof(FLOAT_t)); allocator->UnmapBuffer(bias_data_); } else { MS_ASSERT(inputs_.size() == kInputSize1); } return 0; } + int DepthwiseConv2dOpenCLKernel::ReSize() { return 0; } int DepthwiseConv2dOpenCLKernel::Run() { MS_LOG(DEBUG) << this->Name() << " Running!"; - auto parameter = reinterpret_cast(opParameter); + auto parameter = reinterpret_cast(opParameter); auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); size_t CI4 = UP_DIV(inputs_[0]->Channel(), C4NUM); - std::vector global = {(size_t)outputs_[0]->Width(), (size_t)outputs_[0]->Height(), CO4}; - std::vector local = {1, 1, 1}; + std::vector global = {(size_t) outputs_[0]->Width(), (size_t) outputs_[0]->Height(), CO4}; + std::vector local = {1, 1, CO4}; float relu_clip1 = 6.0; cl_int2 kernel_size = {parameter->kernel_h_, parameter->kernel_w_}; cl_int2 stride = {parameter->stride_h_, parameter->stride_w_}; cl_int2 padding = {-parameter->pad_h_, -parameter->pad_w_}; cl_int2 dilation = {parameter->dilation_h_, parameter->dilation_w_}; - cl_int4 src_size = {inputs_[0]->Width(), inputs_[0]->Height(), (cl_int)CI4, inputs_[0]->Batch()}; - cl_int4 dst_size = {(cl_int)outputs_[0]->Width(), (cl_int)outputs_[0]->Height(), (cl_int)CO4, - (cl_int)outputs_[0]->Batch()}; - ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); + cl_int4 src_size = {inputs_[0]->Width(), inputs_[0]->Height(), (cl_int) CI4, inputs_[0]->Batch()}; + cl_int4 dst_size = {(cl_int) outputs_[0]->Width(), (cl_int) outputs_[0]->Height(), (cl_int) CO4, + (cl_int) outputs_[0]->Batch()}; + ocl_runtime->SetKernelArg(kernel_, 1, packed_weight_); ocl_runtime->SetKernelArg(kernel_, 2, bias_data_); ocl_runtime->SetKernelArg(kernel_, 3, relu_clip1); - ocl_runtime->SetKernelArg(kernel_, 4, outputs_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, 5, kernel_size); ocl_runtime->SetKernelArg(kernel_, 6, stride); ocl_runtime->SetKernelArg(kernel_, 7, padding); ocl_runtime->SetKernelArg(kernel_, 8, dilation); ocl_runtime->SetKernelArg(kernel_, 9, src_size); ocl_runtime->SetKernelArg(kernel_, 10, dst_size); + if (mem_type_ == MEM_TYPE::BUF) { + ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 4, outputs_[0]->Data()); + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + } else { + cl::ImageFormat image_format; + { + image_format.image_channel_order = CL_RGBA; + image_format.image_channel_data_type = CL_FLOAT; + } + cl_int in_error_code; + size_t im_src_x, im_src_y; + size_t im_dst_x, im_dst_y; + if (inputs_[0]->GetFormat() == schema::Format_NHWC4) { + im_src_x = inputs_[0]->Width() * CI4; + im_src_y = inputs_[0]->Height(); + im_dst_x = outputs_[0]->Width() * CO4; + im_dst_y = outputs_[0]->Height(); + } else { + im_src_y = inputs_[0]->Height() * CI4; + im_src_x = inputs_[0]->Width(); + im_dst_y = outputs_[0]->Height() * CO4; + im_dst_x = outputs_[0]->Width(); + } + cl::Image2D in_mem(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, + im_src_x, im_src_y, 0, inputs_[0]->Data(), &in_error_code); + cl_int out_error_code; + cl::Image2D out_mem(*ocl_runtime->Context(), CL_MEM_WRITE_ONLY, image_format, + im_dst_x, im_dst_y, 0, nullptr, &out_error_code); + if (in_error_code != CL_SUCCESS) { + MS_LOG(DEBUG) << "in Image2D Failed, error=" << in_error_code; + return 1; + } + if (out_error_code != CL_SUCCESS) { + MS_LOG(DEBUG) << "out Image2D Failed, error= " << out_error_code; + return 1; + } + auto origin = cl::array < cl::size_type, + 3U > {0, 0, 0}; + auto region = cl::array < cl::size_type, + 3U > {im_dst_x, im_dst_y, 1}; + ocl_runtime->SetKernelArg(kernel_, 0, in_mem); + ocl_runtime->SetKernelArg(kernel_, 4, out_mem); - ocl_runtime->RunKernel(kernel_, global, local, nullptr); + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(out_mem, CL_TRUE, origin, region, 0, 0, + outputs_[0]->Data()); + } return 0; } kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { auto *kernel = new DepthwiseConv2dOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); auto ret = kernel->Init(); if (0 != ret) { @@ -147,6 +209,7 @@ kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector - #include "src/lite_kernel.h" #include "src/runtime/kernel/arm/opclib/conv_parameter.h" #include "src/runtime/opencl/opencl_runtime.h" - namespace mindspore::kernel { class DepthwiseConv2dOpenCLKernel : public LiteKernel { @@ -31,18 +29,25 @@ class DepthwiseConv2dOpenCLKernel : public LiteKernel { explicit DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs) : LiteKernel(parameter, inputs, outputs), - packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {} + packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {} + ~DepthwiseConv2dOpenCLKernel() override {}; int Init() override; + int ReSize() override; + int Run() override; + int InitBuffer(); private: FLOAT_t *packed_weight_; FLOAT_t *bias_data_; cl::Kernel kernel_; + enum class MEM_TYPE { + BUF, IMG + } mem_type_{MEM_TYPE::BUF}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index e01e4079ad..4db4444ec3 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -264,6 +264,8 @@ if (SUPPORT_GPU) set(TEST_SRC ${TEST_SRC} ${TEST_DIR}/ut/stc/runtime/kernel/opencl/matmul_tests.cc + ${TEST_DIR}/ut/stc/runtime/kernel/opencl/depthwise_conv2d_tests.cc + ${TEST_DIR}/ut/stc/runtime/kernel/opencl/concat_tests.cc ${TEST_DIR}/ut/stc/runtime/kernel/opencl/softmax_cl_tests.cc ) endif() diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc new file mode 100755 index 0000000000..cc54fd3c07 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc @@ -0,0 +1,809 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h" + +using mindspore::kernel; +using mindspore::lite; +using mindspore; + +#define SAFE_DELETE_ARRAY(a) \ + if (a != nullptr) { \ + delete[] a; \ + a = nullptr; \ + } +#define SAFE_DELETE_PTR(a) \ + if (a != nullptr) { \ + delete a; \ + a = nullptr; \ + } + +namespace mindspore { +class TestConvolutionDwOpenCL : public UT::Common { + public: + TestConvolutionDwOpenCL(){} +}; + +void DepthWiseTestMain(const ConvParameter *conv_param, float_t *input_data, float_t *weight_data, float_t *gnd_data, + schema::Format format, bool is_compare = true) { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + // pack input + int IC4 = UP_DIV(conv_param->input_channel_, C4NUM); + int pack_input_size = C4NUM * IC4 * conv_param->input_h_ * conv_param->input_w_; + float *packed_input = new float[pack_input_size]; + memset(packed_input, 0, pack_input_size * sizeof(float)); + int plane = conv_param->input_w_ * conv_param->input_h_; + if (format == schema::Format_NHWC4) { + PackNHWCToNHWC4Fp32(input_data, packed_input, 1, plane, conv_param->input_channel_); + } else { + PackNHWCToNC4HW4Fp32(input_data, packed_input, 1, plane, conv_param->input_channel_); + } + + // pack weight + int OC4 = UP_DIV(conv_param->output_channel_, C4NUM); + int pack_weight_size = conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_; + float *packed_weight = weight_data; + + // float bias_data[] = {0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.0, 0.0, 0.0}; + float bias_data[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + size_t packed_output_size = conv_param->output_batch_ * C4NUM * UP_DIV(conv_param->output_channel_, C4NUM) * + conv_param->output_h_ * conv_param->output_w_; + + std::vector shape_in = {conv_param->input_batch_, conv_param->input_h_, conv_param->input_w_, + conv_param->input_channel_}; // Note!!!actual is NHWC4 + std::vector shape_filter = {1, conv_param->kernel_h_, conv_param->kernel_w_, conv_param->output_channel_}; + std::vector shape_bias = {conv_param->output_channel_}; + std::vector shape_out = {conv_param->output_batch_, conv_param->output_h_, conv_param->output_w_, + conv_param->output_channel_}; + lite::tensor::Tensor *tensor_a = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_in, format); // Note!!!actual is NHWC4 + lite::tensor::Tensor *tensor_b = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_filter, schema::Format_NHWC); + lite::tensor::Tensor *tensor_c = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_bias, schema::Format_NHWC); + lite::tensor::Tensor *tensor_d = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_out, format); + std::vector inputs{tensor_a, tensor_b, tensor_c}; + std::vector outputs{tensor_d}; + + // freamework to do!!! + inputs[1]->SetData(packed_weight); + inputs[2]->SetData(bias_data); + + OpParameter * parameter = conv_param; + auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); + pKernel->Init(); + + std::vector kernels{pKernel}; + std::vector inputs_{tensor_a}; + auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); + pGraph->Init(); + + // freamework to do!!! + memcpy(inputs[0]->Data(), packed_input, sizeof(float) * pack_input_size); + + pGraph->Run(); + if (is_compare) { + float* packed_output = reinterpret_cast(outputs[0]->Data()); + float *packed_correct_data = new float[packed_output_size]; + memset(packed_correct_data, 0, packed_output_size * sizeof(float)); + if (format == schema::Format_NC4HW4) { + PackNHWCToNC4HW4Fp32(gnd_data, packed_correct_data, conv_param->output_batch_, + conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_); + } else { + PackNHWCToNHWC4Fp32(gnd_data, packed_correct_data, conv_param->output_batch_, + conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_); + } + + printf("==================input_data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_input_size; i++) { + std::cout << packed_input[i] << ", "; + } + std::cout << std::endl; + printf("==================weight data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_weight_size; i++) { + std::cout << packed_weight[i] << ", "; + } + std::cout << std::endl; + printf("==================output data=================\n"); + std::cout << std::endl; + for (int i = 0; i < packed_output_size; i++) { + std::cout << packed_output[i] << ", "; + } + std::cout << std::endl; + printf("==================expected output data=================\n"); + for (int i = 0; i < packed_output_size; i++) { + std::cout << packed_correct_data[i] << ", "; + } + std::cout << std::endl; + // compare + UT::Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); + SAFE_DELETE_ARRAY(packed_correct_data) + } + + SAFE_DELETE_ARRAY(packed_input); + for (auto tensor : inputs) { + tensor->SetData(nullptr); + SAFE_DELETE_PTR(tensor) + } + for (auto tensor : outputs) { + tensor->SetData(nullptr); + SAFE_DELETE_PTR(tensor) + } + SAFE_DELETE_PTR(pKernel) + SAFE_DELETE_PTR(pGraph) + return; +} + +TEST_F(TestConvolutionDwOpenCL, NoPadNC4HW4Fp32) { + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 4; + conv_param->input_w_ = 4; + conv_param->input_channel_ = 4; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 2; + conv_param->output_w_ = 2; + conv_param->output_channel_ = 4; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 0; + conv_param->pad_w_ = 0; + } + + // nhwc + float input_data[] = {0.5488135, 0.0202184, 0.45615032, 0.31542835, 0.71518934, 0.83261985, 0.56843394, 0.36371076, + 0.60276335, 0.77815676, 0.0187898, 0.57019675, 0.5448832, 0.87001216, 0.6176355, 0.43860152, + 0.4236548, 0.9786183, 0.6120957, 0.9883738, 0.6458941, 0.7991586, 0.616934, 0.10204481, + 0.4375872, 0.46147937, 0.94374806, 0.20887676, 0.891773, 0.7805292, 0.6818203, 0.16130951, + 0.96366274, 0.11827443, 0.3595079, 0.6531083, 0.3834415, 0.639921, 0.43703195, 0.2532916, + 0.79172504, 0.14335328, 0.6976312, 0.46631077, 0.5288949, 0.9446689, 0.06022547, 0.2444256, + 0.56804454, 0.5218483, 0.6667667, 0.15896958, 0.92559665, 0.41466194, 0.67063785, 0.11037514, + 0.07103606, 0.2645556, 0.21038257, 0.6563296, 0.0871293, 0.7742337, 0.12892629, 0.13818295}; + + // co h w ci + float weight_data[] = {0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, 0.09609841, 0.97645944, 0.4686512, + 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, 0.12019656, 0.2961402, 0.11872772, + 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, 0.2653895, 0.5232481, 0.09394051, + 0.5759465, 0.9292962, 0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.18319136, + 0.5865129, 0.02010755, 0.82894003, 0.00469548}; + + // pack correct data, nhwc + float gnd_data[] = {3.3848767, 1.4446403, 1.8428744, 1.3194335, 2.5873442, 2.1384869, 2.04022, 1.1872686, + 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; + + DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); + opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestConvolutionDwOpenCL, PadNC4HW4Fp32) { + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 3; + conv_param->input_w_ = 3; + conv_param->input_channel_ = 5; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 3; + conv_param->output_w_ = 3; + conv_param->output_channel_ = 5; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; + } + + // nhwc + float input_data[] = {0.5488135, 0.3834415, 0.77815676, 0.9446689, 0.6120957, 0.71518934, 0.79172504, 0.87001216, + 0.5218483, 0.616934, 0.60276335, 0.5288949, 0.9786183, 0.41466194, 0.94374806, 0.5448832, + 0.56804454, 0.7991586, 0.2645556, 0.6818203, 0.4236548, 0.92559665, 0.46147937, 0.7742337, + 0.3595079, 0.6458941, 0.07103606, 0.7805292, 0.45615032, 0.43703195, 0.4375872, 0.0871293, + 0.11827443, 0.56843394, 0.6976312, 0.891773, 0.0202184, 0.639921, 0.0187898, 0.06022547, + 0.96366274, 0.83261985, 0.14335328, 0.6176355, 0.6667667}; + // float input_data[]={ + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 }; + // co h w ci + float weight_data[] = {0.67063785, 0.21038257, 0.12892629, 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, + 0.10204481, 0.20887676, 0.16130951, 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, + 0.11037514, 0.6563296, 0.13818295, 0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, + 0.09609841, 0.97645944, 0.4686512, 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, + 0.12019656, 0.2961402, 0.11872772, 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, + 0.2653895, 0.5232481, 0.09394051, 0.5759465, 0.9292962}; + // float weight_data[]={ + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 }; + // pack correct data, nhwc + float gnd_data[] = {1.189188, 1.0425153, 1.8012011, 0.6074867, 1.2120346, 1.5005531, 0.8346756, 2.4365785, + 0.54975945, 1.6815965, 1.2690231, 0.60214907, 1.6158017, 0.42115876, 0.8854959, 1.1709145, + 1.0929465, 1.3534508, 1.1985044, 1.2932993, 2.4621446, 1.7086457, 2.6977584, 2.1960166, + 2.3769147, 2.3185873, 0.6133741, 0.9687358, 0.9987654, 1.0254729, 0.8368954, 0.74171704, + 0.8749627, 0.8953936, 0.5093431, 1.5496738, 0.54936385, 0.7683113, 1.165742, 1.3682933, + 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; + + DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); + opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp32) { + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 4; + conv_param->input_w_ = 4; + conv_param->input_channel_ = 4; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 2; + conv_param->output_w_ = 2; + conv_param->output_channel_ = 4; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 0; + conv_param->pad_w_ = 0; + } + + // nhwc + float input_data[] = {0.5488135, 0.0202184, 0.45615032, 0.31542835, 0.71518934, 0.83261985, 0.56843394, 0.36371076, + 0.60276335, 0.77815676, 0.0187898, 0.57019675, 0.5448832, 0.87001216, 0.6176355, 0.43860152, + 0.4236548, 0.9786183, 0.6120957, 0.9883738, 0.6458941, 0.7991586, 0.616934, 0.10204481, + 0.4375872, 0.46147937, 0.94374806, 0.20887676, 0.891773, 0.7805292, 0.6818203, 0.16130951, + 0.96366274, 0.11827443, 0.3595079, 0.6531083, 0.3834415, 0.639921, 0.43703195, 0.2532916, + 0.79172504, 0.14335328, 0.6976312, 0.46631077, 0.5288949, 0.9446689, 0.06022547, 0.2444256, + 0.56804454, 0.5218483, 0.6667667, 0.15896958, 0.92559665, 0.41466194, 0.67063785, 0.11037514, + 0.07103606, 0.2645556, 0.21038257, 0.6563296, 0.0871293, 0.7742337, 0.12892629, 0.13818295}; + + // co h w ci + float weight_data[] = {0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, 0.09609841, 0.97645944, 0.4686512, + 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, 0.12019656, 0.2961402, 0.11872772, + 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, 0.2653895, 0.5232481, 0.09394051, + 0.5759465, 0.9292962, 0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.18319136, + 0.5865129, 0.02010755, 0.82894003, 0.00469548}; + + // pack correct data, nhwc + float gnd_data[] = {3.3848767, 1.4446403, 1.8428744, 1.3194335, 2.5873442, 2.1384869, 2.04022, 1.1872686, + 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; + + DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); + opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestConvolutionDwOpenCL, PadNHWC4Fp32) { + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 3; + conv_param->input_w_ = 3; + conv_param->input_channel_ = 5; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 3; + conv_param->output_w_ = 3; + conv_param->output_channel_ = 5; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; + } + + // nhwc + float input_data[] = {0.5488135, 0.3834415, 0.77815676, 0.9446689, 0.6120957, 0.71518934, 0.79172504, 0.87001216, + 0.5218483, 0.616934, 0.60276335, 0.5288949, 0.9786183, 0.41466194, 0.94374806, 0.5448832, + 0.56804454, 0.7991586, 0.2645556, 0.6818203, 0.4236548, 0.92559665, 0.46147937, 0.7742337, + 0.3595079, 0.6458941, 0.07103606, 0.7805292, 0.45615032, 0.43703195, 0.4375872, 0.0871293, + 0.11827443, 0.56843394, 0.6976312, 0.891773, 0.0202184, 0.639921, 0.0187898, 0.06022547, + 0.96366274, 0.83261985, 0.14335328, 0.6176355, 0.6667667}; + // float input_data[]={ + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 }; + // co h w ci + float weight_data[] = {0.67063785, 0.21038257, 0.12892629, 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, + 0.10204481, 0.20887676, 0.16130951, 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, + 0.11037514, 0.6563296, 0.13818295, 0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, + 0.09609841, 0.97645944, 0.4686512, 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, + 0.12019656, 0.2961402, 0.11872772, 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, + 0.2653895, 0.5232481, 0.09394051, 0.5759465, 0.9292962}; + // float weight_data[]={ + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 }; + // pack correct data, nhwc + float gnd_data[] = {1.189188, 1.0425153, 1.8012011, 0.6074867, 1.2120346, 1.5005531, 0.8346756, 2.4365785, + 0.54975945, 1.6815965, 1.2690231, 0.60214907, 1.6158017, 0.42115876, 0.8854959, 1.1709145, + 1.0929465, 1.3534508, 1.1985044, 1.2932993, 2.4621446, 1.7086457, 2.6977584, 2.1960166, + 2.3769147, 2.3185873, 0.6133741, 0.9687358, 0.9987654, 1.0254729, 0.8368954, 0.74171704, + 0.8749627, 0.8953936, 0.5093431, 1.5496738, 0.54936385, 0.7683113, 1.165742, 1.3682933, + 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; + + DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); + opencl::OpenCLRuntime::DeleteInstance(); +} + + +TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 4; + conv_param->input_w_ = 4; + conv_param->input_channel_ = 4; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 2; + conv_param->output_w_ = 2; + conv_param->output_channel_ = 4; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 0; + conv_param->pad_w_ = 0; + } + + // nhwc + float input_data[] = {0.5488135, 0.0202184, 0.45615032, 0.31542835, 0.71518934, 0.83261985, 0.56843394, 0.36371076, + 0.60276335, 0.77815676, 0.0187898, 0.57019675, 0.5448832, 0.87001216, 0.6176355, 0.43860152, + 0.4236548, 0.9786183, 0.6120957, 0.9883738, 0.6458941, 0.7991586, 0.616934, 0.10204481, + 0.4375872, 0.46147937, 0.94374806, 0.20887676, 0.891773, 0.7805292, 0.6818203, 0.16130951, + 0.96366274, 0.11827443, 0.3595079, 0.6531083, 0.3834415, 0.639921, 0.43703195, 0.2532916, + 0.79172504, 0.14335328, 0.6976312, 0.46631077, 0.5288949, 0.9446689, 0.06022547, 0.2444256, + 0.56804454, 0.5218483, 0.6667667, 0.15896958, 0.92559665, 0.41466194, 0.67063785, 0.11037514, + 0.07103606, 0.2645556, 0.21038257, 0.6563296, 0.0871293, 0.7742337, 0.12892629, 0.13818295}; + + // pack input + int IC4 = UP_DIV(conv_param->input_channel_, C4NUM); + int pack_input_size = C4NUM * IC4 * conv_param->input_h_ * conv_param->input_w_; + float *packed_input = input_data; + + // co h w ci + float weight_data[] = {0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, 0.09609841, 0.97645944, 0.4686512, + 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, 0.12019656, 0.2961402, 0.11872772, + 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, 0.2653895, 0.5232481, 0.09394051, + 0.5759465, 0.9292962, 0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.18319136, + 0.5865129, 0.02010755, 0.82894003, 0.00469548}; + + // pack weight + int OC4 = UP_DIV(conv_param->output_channel_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv_param->kernel_h_ * conv_param->kernel_w_; + float *packed_weight = weight_data; + + // float bias_data[] = {0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.0, 0.0, 0.0}; + float bias_data[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + size_t packed_output_size = conv_param->output_batch_ * C4NUM * UP_DIV(conv_param->output_channel_, C4NUM) * + conv_param->output_h_ * conv_param->output_w_; + + std::vector shape_in = {conv_param->input_batch_, conv_param->input_h_, conv_param->input_w_, + IC4 * C4NUM}; // Note!!!actual is NHWC4 + std::vector shape_filter = {1, conv_param->kernel_h_, conv_param->kernel_w_, conv_param->output_channel_}; + std::vector shape_bias = {conv_param->output_channel_}; + std::vector shape_out = {conv_param->output_batch_, conv_param->output_h_, conv_param->output_w_, + conv_param->output_channel_}; + lite::tensor::Tensor *tensor_a = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_in, schema::Format_NC4HW4); // Note!!!actual is NHWC4 + lite::tensor::Tensor *tensor_b = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_filter, schema::Format_NHWC); + lite::tensor::Tensor *tensor_c = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_bias, schema::Format_NHWC); + lite::tensor::Tensor *tensor_d = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_out, schema::Format_NC4HW4); + std::vector inputs{tensor_a, tensor_b, tensor_c}; + std::vector outputs{tensor_d}; + + // freamework to do!!! + inputs[1]->SetData(packed_weight); + inputs[2]->SetData(bias_data); + + OpParameter * parameter = conv_param; + auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); + pKernel->Init(); + + std::vector kernels{pKernel}; + std::vector inputs_{tensor_a}; + auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); + pGraph->Init(); + + // freamework to do!!! + memcpy(inputs[0]->Data(), packed_input, sizeof(float) * pack_input_size); + + pGraph->Run(); + float *packed_output = reinterpret_cast(outputs[0]->Data()); + + // pack correct data, nhwc + float packed_correct_data[] = {3.3848767, 1.4446403, 1.8428744, 1.3194335, 2.5873442, 2.1384869, 2.04022, 1.1872686, + 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; + + printf("==================input_data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_input_size; i++) { + std::cout << packed_input[i] << ", "; + } + std::cout << std::endl; + printf("==================packed_weight data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_weight_size; i++) { + std::cout << packed_weight[i] << ", "; + } + std::cout << std::endl; + printf("==================output data=================\n"); + std::cout << std::endl; + for (int i = 0; i < packed_output_size; i++) { + std::cout << packed_output[i] << ", "; + } + std::cout << std::endl; + printf("==================expected output data=================\n"); + for (int i = 0; i < packed_output_size; i++) { + std::cout << packed_correct_data[i] << ", "; + } + std::cout << std::endl; + // compare + CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); + + for (auto tensor : inputs) { + tensor->SetData(nullptr); + SAFE_DELETE_PTR(tensor) + } + for (auto tensor : outputs) { + tensor->SetData(nullptr); + SAFE_DELETE_PTR(tensor) + } + SAFE_DELETE_PTR(pKernel) + SAFE_DELETE_PTR(pGraph) + MS_LOG(INFO) << "TestConvolutionDwNoPadFp32 passed"; + opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 3; + conv_param->input_w_ = 3; + conv_param->input_channel_ = 5; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 3; + conv_param->output_w_ = 3; + conv_param->output_channel_ = 5; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; + } + + // nhwc + float input_data[] = {0.5488135, 0.3834415, 0.77815676, 0.9446689, 0.6120957, 0.71518934, 0.79172504, 0.87001216, + 0.5218483, 0.616934, 0.60276335, 0.5288949, 0.9786183, 0.41466194, 0.94374806, 0.5448832, + 0.56804454, 0.7991586, 0.2645556, 0.6818203, 0.4236548, 0.92559665, 0.46147937, 0.7742337, + 0.3595079, 0.6458941, 0.07103606, 0.7805292, 0.45615032, 0.43703195, 0.4375872, 0.0871293, + 0.11827443, 0.56843394, 0.6976312, 0.891773, 0.0202184, 0.639921, 0.0187898, 0.06022547, + 0.96366274, 0.83261985, 0.14335328, 0.6176355, 0.6667667}; + // float input_data[]={ + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 }; + + // pack input + int IC4 = UP_DIV(conv_param->input_channel_, C4NUM); + int pack_input_size = C4NUM * IC4 * conv_param->input_h_ * conv_param->input_w_; + float *packed_input = new float[pack_input_size]; + memset(packed_input, 0, pack_input_size * sizeof(float)); + int plane = conv_param->input_w_ * conv_param->input_h_; + PackNHWCToNC4HW4Fp32(input_data, packed_input, 1, plane, conv_param->input_channel_); + + // co h w ci + float weight_data[] = {0.67063785, 0.21038257, 0.12892629, 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, + 0.10204481, 0.20887676, 0.16130951, 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, + 0.11037514, 0.6563296, 0.13818295, 0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, + 0.09609841, 0.97645944, 0.4686512, 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, + 0.12019656, 0.2961402, 0.11872772, 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, + 0.2653895, 0.5232481, 0.09394051, 0.5759465, 0.9292962}; + // float weight_data[]={ + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 }; + + // pack weight + int OC4 = UP_DIV(conv_param->output_channel_, C4NUM); + int pack_weight_size = conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_; + float *packed_weight = weight_data; + + // float bias_data[] = {0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.0, 0.0, 0.0}; + float bias_data[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + size_t packed_output_size = conv_param->output_batch_ * C4NUM * UP_DIV(conv_param->output_channel_, C4NUM) * + conv_param->output_h_ * conv_param->output_w_; + + std::vector shape_in = {conv_param->input_batch_, conv_param->input_h_, conv_param->input_w_, + IC4 * C4NUM}; // Note!!!actual is NHWC4 + std::vector shape_filter = {1, conv_param->kernel_h_, conv_param->kernel_w_, conv_param->output_channel_}; + std::vector shape_bias = {conv_param->output_channel_}; + std::vector shape_out = {conv_param->output_batch_, conv_param->output_h_, conv_param->output_w_, + conv_param->output_channel_}; + lite::tensor::Tensor *tensor_a = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_in, schema::Format_NC4HW4); // Note!!!actual is NHWC4 + lite::tensor::Tensor *tensor_b = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_filter, schema::Format_NHWC); + lite::tensor::Tensor *tensor_c = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_bias, schema::Format_NHWC); + lite::tensor::Tensor *tensor_d = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_out, schema::Format_NC4HW4); + std::vector inputs{tensor_a, tensor_b, tensor_c}; + std::vector outputs{tensor_d}; + + // freamework to do!!! + inputs[1]->SetData(packed_weight); + inputs[2]->SetData(bias_data); + + OpParameter * parameter = conv_param; + auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); + pKernel->Init(); + + std::vector kernels{pKernel}; + std::vector inputs_{tensor_a}; + auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); + pGraph->Init(); + + // freamework to do!!! + memcpy(inputs[0]->Data(), packed_input, sizeof(float) * pack_input_size); + + pGraph->Run(); + float *packed_output = reinterpret_cast(outputs[0]->Data()); + + // pack correct data, nhwc + float correct_data[] = {1.189188, 1.0425153, 1.8012011, 0.6074867, 1.2120346, 1.5005531, 0.8346756, 2.4365785, + 0.54975945, 1.6815965, 1.2690231, 0.60214907, 1.6158017, 0.42115876, 0.8854959, 1.1709145, + 1.0929465, 1.3534508, 1.1985044, 1.2932993, 2.4621446, 1.7086457, 2.6977584, 2.1960166, + 2.3769147, 2.3185873, 0.6133741, 0.9687358, 0.9987654, 1.0254729, 0.8368954, 0.74171704, + 0.8749627, 0.8953936, 0.5093431, 1.5496738, 0.54936385, 0.7683113, 1.165742, 1.3682933, + 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; + float *packed_correct_data = new float[packed_output_size]; + memset(packed_correct_data, 0, packed_output_size * sizeof(float)); + PackNHWCToNC4HW4Fp32(correct_data, packed_correct_data, conv_param->output_batch_, + conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_); + + printf("==================input_data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_input_size; i++) { + std::cout << packed_input[i] << ", "; + } + std::cout << std::endl; + printf("==================weight data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_weight_size; i++) { + std::cout << packed_weight[i] << ", "; + } + std::cout << std::endl; + printf("==================output data=================\n"); + std::cout << std::endl; + for (int i = 0; i < packed_output_size; i++) { + std::cout << packed_output[i] << ", "; + } + std::cout << std::endl; + printf("==================expected output data=================\n"); + for (int i = 0; i < packed_output_size; i++) { + std::cout << packed_correct_data[i] << ", "; + } + std::cout << std::endl; + // compare + CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); + + SAFE_DELETE_ARRAY(packed_input); + SAFE_DELETE_ARRAY(packed_correct_data) + for (auto tensor : inputs) { + tensor->SetData(nullptr); + SAFE_DELETE_PTR(tensor) + } + for (auto tensor : outputs) { + tensor->SetData(nullptr); + SAFE_DELETE_PTR(tensor) + } + SAFE_DELETE_PTR(pKernel) + SAFE_DELETE_PTR(pGraph) + MS_LOG(INFO) << "TestConvolutionDwPadFp32 passed"; + opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { + std::vector> src_shape{ + {1, 32, 112, 112}, + {1, 96, 112, 112}, + {1, 144, 56, 56}, + {1, 144, 56, 56}, + {1, 192, 28, 28}, + {1, 192, 28, 28}, + {1, 384, 14, 14}, + {1, 576, 14, 14}, + {1, 576, 14, 14}, + {1, 960, 7, 7}, + }; + std::vector> dst_shape{ + {1, 32, 112, 112}, + {1, 96, 56, 56}, + {1, 144, 56, 56}, + {1, 144, 28, 28}, + {1, 192, 28, 28}, + {1, 192, 14, 14}, + {1, 384, 14, 14}, + {1, 576, 14, 14}, + {1, 576, 7, 7}, + {1, 960, 7, 7}, + }; + std::vector> filter_shape{ + {32, 1, 1, 1}, + {96, 3, 3, 1}, + {144, 1, 1, 1}, + {144, 3, 3, 1}, + {192, 1, 1, 1}, + {192, 3, 3, 1}, + {384, 1, 1, 1}, + {576, 1, 1, 1}, + {576, 3, 3, 1}, + {960, 1, 1, 1}, + }; + + // nhwc + float_t *input_data = new float_t[96*112*112]{ + 0.5488135 , 0.3834415 , 0.77815676, 0.9446689 , 0.6120957 , + 0.71518934, 0.79172504, 0.87001216, 0.5218483 , 0.616934 , + 0.60276335, 0.5288949 , 0.9786183 , 0.41466194, 0.94374806, + 0.5448832 , 0.56804454, 0.7991586 , 0.2645556 , 0.6818203 , + 0.4236548 , 0.92559665, 0.46147937, 0.7742337 , 0.3595079 , + 0.6458941 , 0.07103606, 0.7805292 , 0.45615032, 0.43703195, + 0.4375872 , 0.0871293 , 0.11827443, 0.56843394, 0.6976312 , + 0.891773 , 0.0202184 , 0.639921 , 0.0187898 , 0.06022547, + 0.96366274, 0.83261985, 0.14335328, 0.6176355 , 0.6667667 }; + // co h w ci + float_t *weight_data = new float_t[576*3*3]{ + 0.67063785, 0.21038257, 0.12892629, + 0.31542835, 0.36371076, 0.57019675, + 0.43860152, 0.9883738 , 0.10204481, + 0.20887676, 0.16130951, 0.6531083 , + 0.2532916 , 0.46631077, 0.2444256 , + 0.15896958, 0.11037514, 0.6563296 , + 0.13818295, 0.19658236, 0.36872518, + 0.82099324, 0.09710128, 0.8379449 , + 0.09609841, 0.97645944, 0.4686512 , + 0.9767611 , 0.6048455 , 0.7392636 , + 0.03918779, 0.28280696, 0.12019656, + 0.2961402 , 0.11872772, 0.31798318, + 0.41426298, 0.06414749, 0.6924721 , + 0.56660146, 0.2653895 , 0.5232481 , + 0.09394051, 0.5759465 , 0.9292962 }; + for (size_t i = 0; i < src_shape.size(); ++i) { + const int MAX_RUN_TIMES = 10; + for (int j = 0; j < MAX_RUN_TIMES; ++j) { + printf("========profiling depthwise, in shape(%d,%d,%d,%d), out shape(%d,%d,%d,%d), iter%d========\n", + src_shape[i][0], src_shape[i][1], src_shape[i][2], src_shape[i][3], + dst_shape[i][0], dst_shape[i][1], dst_shape[i][2], dst_shape[i][3], j); + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = src_shape[i][2]; + conv_param->input_w_ = src_shape[i][3]; + conv_param->input_channel_ = src_shape[i][1]; + conv_param->output_batch_ = 1; + conv_param->output_h_ = dst_shape[i][2]; + conv_param->output_w_ = dst_shape[i][3]; + conv_param->output_channel_ = dst_shape[i][1]; + conv_param->kernel_h_ = filter_shape[i][1]; + conv_param->kernel_w_ = filter_shape[i][2]; + conv_param->stride_h_ = conv_param->output_h_/conv_param->input_h_; + conv_param->stride_w_ = conv_param->output_w_/conv_param->input_w_; + conv_param->pad_h_ = (conv_param->kernel_h_-1)/2; + conv_param->pad_w_ = (conv_param->kernel_w_-1)/2; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + } + DepthWiseTestMain(conv_param, input_data, weight_data, nullptr, schema::Format_NC4HW4, false); + // DepthWiseTestMain(conv_param, input_data, weight_data, nullptr, schema::Format_NHWC4, false); + } + } + SAFE_DELETE_ARRAY(input_data); + SAFE_DELETE_ARRAY(weight_data); + opencl::OpenCLRuntime::DeleteInstance(); +} + +} // namespace mindspore