From: @ddwsky Reviewed-by: @HilbertDavid,@zhanghaibo5 Signed-off-by: @zhanghaibo5tags/v1.2.0-rc1
| @@ -1,12 +1,12 @@ | |||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void DepthwiseConv2d_IMG_NC4HW4(__write_only image2d_t dst_data, __read_only image2d_t src_data, | |||
| __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride, | |||
| int2 padding, int2 dilation, int4 src_size, int4 dst_size, float relu_clip_min, | |||
| float relu_clip_max) { | |||
| int X = get_global_id(0); | |||
| int Y = get_global_id(1); | |||
| int Z = get_global_id(2); | |||
| __kernel void DepthwiseConv2d_IMG_NHWC4(__write_only image2d_t dst_data, __read_only image2d_t src_data, | |||
| __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size, | |||
| int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size, | |||
| float relu_clip_min, float relu_clip_max) { | |||
| int X = get_global_id(1); | |||
| int Y = get_global_id(2); | |||
| int Z = get_global_id(0); | |||
| if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; | |||
| FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); | |||
| int x_offset = X * stride.x + padding.x; | |||
| @@ -19,8 +19,8 @@ __kernel void DepthwiseConv2d_IMG_NC4HW4(__write_only image2d_t dst_data, __read | |||
| int x_c = x_offset + kx * dilation.x; | |||
| bool outside_x = x_c < 0 || x_c >= src_size.x; | |||
| if (!outside_x && !outside_y) { | |||
| FLT4 flt_p = filter[fx_c]; | |||
| FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(x_c, (Z * src_size.y + y_c))); | |||
| FLT4 flt_p = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z)); | |||
| FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c)); | |||
| r += TO_FLT4(src_p * flt_p); | |||
| } | |||
| fx_c++; | |||
| @@ -29,9 +29,39 @@ __kernel void DepthwiseConv2d_IMG_NC4HW4(__write_only image2d_t dst_data, __read | |||
| FLT4 bias_p = bias[Z]; | |||
| FLT4 res = TO_FLT4(r) + bias_p; | |||
| res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | |||
| WRITE_IMAGE(dst_data, (int2)(X, (Z * dst_size.y + Y)), res); | |||
| WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res); | |||
| } | |||
| __kernel void DepthwiseConv2d_IMG_NHWC4_1x1(__write_only image2d_t dst_data, __read_only image2d_t src_data, | |||
| __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size, | |||
| int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size, | |||
| float relu_clip_min, float relu_clip_max) { | |||
| int X = get_global_id(1); | |||
| int Y = get_global_id(2); | |||
| int Z = get_global_id(0); | |||
| if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; | |||
| FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); | |||
| int x_offset = X * stride.x + padding.x; | |||
| int y_offset = Y * stride.y + padding.y; | |||
| int fx_c = Z; | |||
| { | |||
| int y_c = y_offset; | |||
| bool outside_y = y_c < 0 || y_c >= src_size.y; | |||
| { | |||
| int x_c = x_offset; | |||
| bool outside_x = x_c < 0 || x_c >= src_size.x; | |||
| if (!outside_x && !outside_y) { | |||
| FLT4 flt_p = READ_IMAGE(filter, smp_zero, (int2)(0, Z)); | |||
| FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c)); | |||
| r += TO_FLT4(src_p * flt_p); | |||
| } | |||
| } | |||
| } | |||
| FLT4 bias_p = bias[Z]; | |||
| FLT4 res = TO_FLT4(r) + bias_p; | |||
| res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | |||
| WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res); | |||
| } | |||
| __kernel void DepthwiseConv2d_IMG_NHWC4_b222(__write_only image2d_t dst_data, __read_only image2d_t src_data, | |||
| __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride, | |||
| int2 padding, int2 dilation, int4 src_size, int4 dst_size, | |||
| @@ -264,65 +294,3 @@ __kernel void DepthwiseConv2d_BUF_NC4HW4(__global FLT4 *dst_data, __global FLT4 | |||
| res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | |||
| dst_data[(((Z)*dst_size.y + (Y)) * dst_size.x + (X))] = res; | |||
| } | |||
| __kernel void DepthwiseConv2d_BUF_NHWC4(__global FLT4 *dst_data, __global FLT4 *src_data, __global FLT4 *filter, | |||
| __global FLT4 *bias, int2 kernel_size, int2 stride, int2 padding, int2 dilation, | |||
| int4 src_size, int4 dst_size, float relu_clip_min, float relu_clip_max) { | |||
| int X = get_global_id(0); | |||
| int Y = get_global_id(1); | |||
| int Z = get_global_id(2); | |||
| if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; | |||
| FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); | |||
| int x_offset = X * stride.x + padding.x; | |||
| int y_offset = Y * stride.y + padding.y; | |||
| int fx_c = Z * kernel_size.x * kernel_size.y; | |||
| for (int ky = 0; ky < kernel_size.y; ++ky) { | |||
| int y_c = y_offset + ky * dilation.y; | |||
| bool outside_y = y_c < 0 || y_c >= src_size.y; | |||
| for (int kx = 0; kx < kernel_size.x; ++kx) { | |||
| int x_c = x_offset + kx * dilation.x; | |||
| bool outside_x = x_c < 0 || x_c >= src_size.x; | |||
| if (!outside_x && !outside_y) { | |||
| FLT4 flt_p = filter[fx_c]; | |||
| FLT4 src_p = src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; | |||
| r += TO_FLT4(src_p * flt_p); | |||
| } | |||
| fx_c++; | |||
| } | |||
| } | |||
| FLT4 bias_p = bias[Z]; | |||
| FLT4 res = TO_FLT4(r) + bias_p; | |||
| res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | |||
| dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res; | |||
| } | |||
| __kernel void DepthwiseConv2d_BUF_NHWC4_1x1(__global FLT4 *dst_data, __global FLT4 *src_data, __global FLT4 *filter, | |||
| __global FLT4 *bias, int2 kernel_size, int2 stride, int2 padding, | |||
| int2 dilation, int4 src_size, int4 dst_size, float relu_clip_min, | |||
| float relu_clip_max) { | |||
| int X = get_global_id(0); | |||
| int Y = get_global_id(1); | |||
| int Z = get_global_id(2); | |||
| if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; | |||
| FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); | |||
| int x_offset = X * stride.x + padding.x; | |||
| int y_offset = Y * stride.y + padding.y; | |||
| int fx_c = Z; | |||
| { | |||
| int y_c = y_offset; | |||
| bool outside_y = y_c < 0 || y_c >= src_size.y; | |||
| { | |||
| int x_c = x_offset; | |||
| bool outside_x = x_c < 0 || x_c >= src_size.x; | |||
| if (!outside_x && !outside_y) { | |||
| FLT4 flt_p = filter[fx_c]; | |||
| FLT4 src_p = src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; | |||
| r += TO_FLT4(src_p * flt_p); | |||
| } | |||
| } | |||
| } | |||
| FLT4 bias_p = bias[Z]; | |||
| FLT4 res = TO_FLT4(r) + bias_p; | |||
| res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | |||
| dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res; | |||
| } | |||
| @@ -73,7 +73,11 @@ int DepthwiseConv2dOpenCLKernel::Prepare() { | |||
| if (parameter->kernel_h_ == 1 && parameter->kernel_w_ == 1) { | |||
| kernel_name += "_1x1"; | |||
| } | |||
| kernel_name += "_b" + std::to_string(block_size_.H) + std::to_string(block_size_.W) + std::to_string(block_size_.C); | |||
| if (filter_type_ == lite::opencl::MemType::BUF) { | |||
| kernel_name += "_b" + std::to_string(block_size_.H) + std::to_string(block_size_.W) + std::to_string(block_size_.C); | |||
| } else { | |||
| block_size_.C = block_size_.H = block_size_.W = 1; | |||
| } | |||
| #ifdef PROGRAM_WITH_IL | |||
| kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name); | |||
| #else | |||
| @@ -107,32 +111,42 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { | |||
| int CO4 = UP_DIV(out_info.C, C4NUM * block_size_.C); | |||
| int pack_weight_size = C4NUM * CO4 * parameter->kernel_h_ * parameter->kernel_w_; | |||
| int plane = parameter->kernel_h_ * parameter->kernel_w_; | |||
| int plane_in = parameter->kernel_h_ * parameter->kernel_w_; | |||
| int plane_out = plane_in * C4NUM; | |||
| std::vector<size_t> img_size; | |||
| if (filter_type_ == MemType::IMG) { | |||
| int alignment = ocl_runtime_->GetImagePitchAlignment(); | |||
| plane_out = UP_ROUND(plane_out, alignment) * C4NUM; | |||
| pack_weight_size = plane_out * CO4; | |||
| auto shape = in_tensors_[1]->shape(); | |||
| size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT; | |||
| img_size = {(size_t)plane_out / C4NUM, (size_t)shape[0] * CO4, img_dtype}; | |||
| } | |||
| if (is_fp16) { | |||
| packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(int16_t)); | |||
| packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(int16_t), img_size); | |||
| packed_weight_ = allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true); | |||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) { | |||
| std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; }; | |||
| PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||
| PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) { | |||
| std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); }; | |||
| PackNCHWToNC4HW4<float, float16_t>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||
| PackNCHWToNC4HW4<float, float16_t>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||
| } else { // int8 or int16 | |||
| std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; }; | |||
| PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||
| PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||
| } | |||
| } else { | |||
| packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float)); | |||
| packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float), img_size); | |||
| packed_weight_ = allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true); | |||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) { | |||
| std::function<float(float)> to_dtype = [](float x) -> float { return x; }; | |||
| PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||
| PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) { | |||
| std::function<float(float16_t)> to_dtype = [](float16_t x) -> float { return static_cast<float>(x); }; | |||
| PackNCHWToNC4HW4<float16_t, float>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||
| PackNCHWToNC4HW4<float16_t, float>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||
| } else { // int8 or int16 | |||
| std::function<float(float)> to_dtype = [](float x) -> float { return x; }; | |||
| PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||
| PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||
| } | |||
| } | |||
| allocator->UnmapBuffer(packed_weight_); | |||
| @@ -184,7 +198,7 @@ void DepthwiseConv2dOpenCLKernel::SetConstArgs() { | |||
| cl_int4 dst_size = {(cl_int)out_info.W, (cl_int)out_info.H, (cl_int)CO4, (cl_int)out_info.N}; | |||
| int arg_cnt = 2; | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, packed_weight_, lite::opencl::MemType::BUF); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, packed_weight_, filter_type_); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, bias_data_, lite::opencl::MemType::BUF); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, kernel_size); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, stride); | |||
| @@ -21,13 +21,18 @@ | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| using mindspore::lite::opencl::MemType; | |||
| namespace mindspore::kernel { | |||
| class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| : OpenCLKernel(parameter, inputs, outputs) { | |||
| bool is_adreno = ocl_runtime_->GetGpuInfo().type == lite::opencl::GpuType::ADRENO; | |||
| filter_type_ = is_adreno ? MemType::IMG : MemType::BUF; | |||
| } | |||
| ~DepthwiseConv2dOpenCLKernel() override = default; | |||
| @@ -47,6 +52,7 @@ class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { | |||
| int W{2}; | |||
| int C{1}; | |||
| } block_size_; | |||
| MemType filter_type_{MemType::BUF}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -62,19 +62,20 @@ std::vector<int> GetNHWCShape(const std::vector<int> &tensor_shape); | |||
| std::vector<size_t> GetImage2dShapeFromNHWC(const std::vector<int> &tensor_shape, schema::Format format); | |||
| template <class T1, class T2> | |||
| void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane, int channel, const std::function<T2(T1)> &to_dtype) { | |||
| void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane_in, int plane_out, int channel, | |||
| const std::function<T2(T1)> &to_dtype) { | |||
| MS_ASSERT(src); | |||
| MS_ASSERT(dst); | |||
| int c4 = UP_DIV(channel, C4NUM); | |||
| for (int b = 0; b < batch; b++) { | |||
| int src_offset = b * plane * channel; | |||
| int dst_offset = b * plane * c4 * C4NUM; | |||
| int src_offset = b * plane_in * channel; | |||
| int dst_offset = b * plane_out * c4; | |||
| for (int c = 0; c < channel; c++) { | |||
| int c4_block_num = c / C4NUM; | |||
| int c4_block_rem = c % C4NUM; | |||
| int src_c_offset = src_offset + c * plane; | |||
| int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; | |||
| for (int k = 0; k < plane; k++) { | |||
| int src_c_offset = src_offset + c * plane_in; | |||
| int dst_c_offset = dst_offset + c4_block_num * plane_out; | |||
| for (int k = 0; k < plane_in; k++) { | |||
| int src_kernel_offset = src_c_offset + k; | |||
| int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; | |||
| (static_cast<T2 *>(dst) + dst_kernel_offset)[0] = to_dtype((static_cast<T1 *>(src) + src_kernel_offset)[0]); | |||
| @@ -187,13 +187,20 @@ class OpenCLRuntime { | |||
| std::vector<size_t> max_work_item_sizes_; | |||
| void *handle_{nullptr}; | |||
| TuningMode tuning_mode_{TuningMode::DEFAULT}; | |||
| #if MS_OPENCL_PROFILE | |||
| bool profiling_{true}; | |||
| #else | |||
| bool profiling_{false}; | |||
| #endif | |||
| // for cache | |||
| private: | |||
| void LoadCache(); | |||
| void StoreCache(); | |||
| #ifdef MS_OPENCL_BINARY_CACHE | |||
| bool enable_cache_{true}; | |||
| #else | |||
| bool enable_cache_{false}; | |||
| #endif | |||
| bool flush_cache_{false}; | |||
| std::string cache_path_{"/data/local/tmp/.opencl_cache"}; | |||
| const std::string cache_version_{"V0.1"}; | |||
| @@ -81,7 +81,7 @@ TEST_F(TestOpenCL_DepthwiseConv2d, NoPad) { | |||
| TestMain({{input_shape, input_data, VAR}, | |||
| {weight_shape, weight_data, CONST_TENSOR}, | |||
| {bias_shape, bias_data, CONST_TENSOR}}, | |||
| {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5); | |||
| {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); | |||
| } | |||
| } | |||
| @@ -128,7 +128,7 @@ TEST_F(TestOpenCL_DepthwiseConv2d, Pad) { | |||
| TestMain({{input_shape, input_data, VAR}, | |||
| {weight_shape, weight_data, CONST_TENSOR}, | |||
| {bias_shape, bias_data, CONST_TENSOR}}, | |||
| {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5); | |||
| {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); | |||
| } | |||
| } | |||