| @@ -1,12 +1,12 @@ | |||||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | #pragma OPENCL EXTENSION cl_khr_fp16 : enable | ||||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | ||||
| __kernel void DepthwiseConv2d_IMG_NC4HW4(__write_only image2d_t dst_data, __read_only image2d_t src_data, | |||||
| __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride, | |||||
| int2 padding, int2 dilation, int4 src_size, int4 dst_size, float relu_clip_min, | |||||
| float relu_clip_max) { | |||||
| int X = get_global_id(0); | |||||
| int Y = get_global_id(1); | |||||
| int Z = get_global_id(2); | |||||
| __kernel void DepthwiseConv2d_IMG_NHWC4(__write_only image2d_t dst_data, __read_only image2d_t src_data, | |||||
| __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size, | |||||
| int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size, | |||||
| float relu_clip_min, float relu_clip_max) { | |||||
| int X = get_global_id(1); | |||||
| int Y = get_global_id(2); | |||||
| int Z = get_global_id(0); | |||||
| if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; | if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; | ||||
| FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); | FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); | ||||
| int x_offset = X * stride.x + padding.x; | int x_offset = X * stride.x + padding.x; | ||||
| @@ -19,8 +19,8 @@ __kernel void DepthwiseConv2d_IMG_NC4HW4(__write_only image2d_t dst_data, __read | |||||
| int x_c = x_offset + kx * dilation.x; | int x_c = x_offset + kx * dilation.x; | ||||
| bool outside_x = x_c < 0 || x_c >= src_size.x; | bool outside_x = x_c < 0 || x_c >= src_size.x; | ||||
| if (!outside_x && !outside_y) { | if (!outside_x && !outside_y) { | ||||
| FLT4 flt_p = filter[fx_c]; | |||||
| FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(x_c, (Z * src_size.y + y_c))); | |||||
| FLT4 flt_p = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z)); | |||||
| FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c)); | |||||
| r += TO_FLT4(src_p * flt_p); | r += TO_FLT4(src_p * flt_p); | ||||
| } | } | ||||
| fx_c++; | fx_c++; | ||||
| @@ -29,9 +29,39 @@ __kernel void DepthwiseConv2d_IMG_NC4HW4(__write_only image2d_t dst_data, __read | |||||
| FLT4 bias_p = bias[Z]; | FLT4 bias_p = bias[Z]; | ||||
| FLT4 res = TO_FLT4(r) + bias_p; | FLT4 res = TO_FLT4(r) + bias_p; | ||||
| res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | ||||
| WRITE_IMAGE(dst_data, (int2)(X, (Z * dst_size.y + Y)), res); | |||||
| WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res); | |||||
| } | } | ||||
| __kernel void DepthwiseConv2d_IMG_NHWC4_1x1(__write_only image2d_t dst_data, __read_only image2d_t src_data, | |||||
| __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size, | |||||
| int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size, | |||||
| float relu_clip_min, float relu_clip_max) { | |||||
| int X = get_global_id(1); | |||||
| int Y = get_global_id(2); | |||||
| int Z = get_global_id(0); | |||||
| if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; | |||||
| FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); | |||||
| int x_offset = X * stride.x + padding.x; | |||||
| int y_offset = Y * stride.y + padding.y; | |||||
| int fx_c = Z; | |||||
| { | |||||
| int y_c = y_offset; | |||||
| bool outside_y = y_c < 0 || y_c >= src_size.y; | |||||
| { | |||||
| int x_c = x_offset; | |||||
| bool outside_x = x_c < 0 || x_c >= src_size.x; | |||||
| if (!outside_x && !outside_y) { | |||||
| FLT4 flt_p = READ_IMAGE(filter, smp_zero, (int2)(0, Z)); | |||||
| FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c)); | |||||
| r += TO_FLT4(src_p * flt_p); | |||||
| } | |||||
| } | |||||
| } | |||||
| FLT4 bias_p = bias[Z]; | |||||
| FLT4 res = TO_FLT4(r) + bias_p; | |||||
| res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | |||||
| WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res); | |||||
| } | |||||
| __kernel void DepthwiseConv2d_IMG_NHWC4_b222(__write_only image2d_t dst_data, __read_only image2d_t src_data, | __kernel void DepthwiseConv2d_IMG_NHWC4_b222(__write_only image2d_t dst_data, __read_only image2d_t src_data, | ||||
| __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride, | __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride, | ||||
| int2 padding, int2 dilation, int4 src_size, int4 dst_size, | int2 padding, int2 dilation, int4 src_size, int4 dst_size, | ||||
| @@ -264,65 +294,3 @@ __kernel void DepthwiseConv2d_BUF_NC4HW4(__global FLT4 *dst_data, __global FLT4 | |||||
| res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | ||||
| dst_data[(((Z)*dst_size.y + (Y)) * dst_size.x + (X))] = res; | dst_data[(((Z)*dst_size.y + (Y)) * dst_size.x + (X))] = res; | ||||
| } | } | ||||
| __kernel void DepthwiseConv2d_BUF_NHWC4(__global FLT4 *dst_data, __global FLT4 *src_data, __global FLT4 *filter, | |||||
| __global FLT4 *bias, int2 kernel_size, int2 stride, int2 padding, int2 dilation, | |||||
| int4 src_size, int4 dst_size, float relu_clip_min, float relu_clip_max) { | |||||
| int X = get_global_id(0); | |||||
| int Y = get_global_id(1); | |||||
| int Z = get_global_id(2); | |||||
| if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; | |||||
| FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); | |||||
| int x_offset = X * stride.x + padding.x; | |||||
| int y_offset = Y * stride.y + padding.y; | |||||
| int fx_c = Z * kernel_size.x * kernel_size.y; | |||||
| for (int ky = 0; ky < kernel_size.y; ++ky) { | |||||
| int y_c = y_offset + ky * dilation.y; | |||||
| bool outside_y = y_c < 0 || y_c >= src_size.y; | |||||
| for (int kx = 0; kx < kernel_size.x; ++kx) { | |||||
| int x_c = x_offset + kx * dilation.x; | |||||
| bool outside_x = x_c < 0 || x_c >= src_size.x; | |||||
| if (!outside_x && !outside_y) { | |||||
| FLT4 flt_p = filter[fx_c]; | |||||
| FLT4 src_p = src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; | |||||
| r += TO_FLT4(src_p * flt_p); | |||||
| } | |||||
| fx_c++; | |||||
| } | |||||
| } | |||||
| FLT4 bias_p = bias[Z]; | |||||
| FLT4 res = TO_FLT4(r) + bias_p; | |||||
| res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | |||||
| dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res; | |||||
| } | |||||
| __kernel void DepthwiseConv2d_BUF_NHWC4_1x1(__global FLT4 *dst_data, __global FLT4 *src_data, __global FLT4 *filter, | |||||
| __global FLT4 *bias, int2 kernel_size, int2 stride, int2 padding, | |||||
| int2 dilation, int4 src_size, int4 dst_size, float relu_clip_min, | |||||
| float relu_clip_max) { | |||||
| int X = get_global_id(0); | |||||
| int Y = get_global_id(1); | |||||
| int Z = get_global_id(2); | |||||
| if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; | |||||
| FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); | |||||
| int x_offset = X * stride.x + padding.x; | |||||
| int y_offset = Y * stride.y + padding.y; | |||||
| int fx_c = Z; | |||||
| { | |||||
| int y_c = y_offset; | |||||
| bool outside_y = y_c < 0 || y_c >= src_size.y; | |||||
| { | |||||
| int x_c = x_offset; | |||||
| bool outside_x = x_c < 0 || x_c >= src_size.x; | |||||
| if (!outside_x && !outside_y) { | |||||
| FLT4 flt_p = filter[fx_c]; | |||||
| FLT4 src_p = src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; | |||||
| r += TO_FLT4(src_p * flt_p); | |||||
| } | |||||
| } | |||||
| } | |||||
| FLT4 bias_p = bias[Z]; | |||||
| FLT4 res = TO_FLT4(r) + bias_p; | |||||
| res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); | |||||
| dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res; | |||||
| } | |||||
| @@ -73,7 +73,11 @@ int DepthwiseConv2dOpenCLKernel::Prepare() { | |||||
| if (parameter->kernel_h_ == 1 && parameter->kernel_w_ == 1) { | if (parameter->kernel_h_ == 1 && parameter->kernel_w_ == 1) { | ||||
| kernel_name += "_1x1"; | kernel_name += "_1x1"; | ||||
| } | } | ||||
| kernel_name += "_b" + std::to_string(block_size_.H) + std::to_string(block_size_.W) + std::to_string(block_size_.C); | |||||
| if (filter_type_ == lite::opencl::MemType::BUF) { | |||||
| kernel_name += "_b" + std::to_string(block_size_.H) + std::to_string(block_size_.W) + std::to_string(block_size_.C); | |||||
| } else { | |||||
| block_size_.C = block_size_.H = block_size_.W = 1; | |||||
| } | |||||
| #ifdef PROGRAM_WITH_IL | #ifdef PROGRAM_WITH_IL | ||||
| kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name); | kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name); | ||||
| #else | #else | ||||
| @@ -107,32 +111,42 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { | |||||
| int CO4 = UP_DIV(out_info.C, C4NUM * block_size_.C); | int CO4 = UP_DIV(out_info.C, C4NUM * block_size_.C); | ||||
| int pack_weight_size = C4NUM * CO4 * parameter->kernel_h_ * parameter->kernel_w_; | int pack_weight_size = C4NUM * CO4 * parameter->kernel_h_ * parameter->kernel_w_; | ||||
| int plane = parameter->kernel_h_ * parameter->kernel_w_; | |||||
| int plane_in = parameter->kernel_h_ * parameter->kernel_w_; | |||||
| int plane_out = plane_in * C4NUM; | |||||
| std::vector<size_t> img_size; | |||||
| if (filter_type_ == MemType::IMG) { | |||||
| int alignment = ocl_runtime_->GetImagePitchAlignment(); | |||||
| plane_out = UP_ROUND(plane_out, alignment) * C4NUM; | |||||
| pack_weight_size = plane_out * CO4; | |||||
| auto shape = in_tensors_[1]->shape(); | |||||
| size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT; | |||||
| img_size = {(size_t)plane_out / C4NUM, (size_t)shape[0] * CO4, img_dtype}; | |||||
| } | |||||
| if (is_fp16) { | if (is_fp16) { | ||||
| packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(int16_t)); | |||||
| packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(int16_t), img_size); | |||||
| packed_weight_ = allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true); | packed_weight_ = allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true); | ||||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) { | if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) { | ||||
| std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; }; | std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; }; | ||||
| PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||||
| PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) { | } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) { | ||||
| std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); }; | std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); }; | ||||
| PackNCHWToNC4HW4<float, float16_t>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||||
| PackNCHWToNC4HW4<float, float16_t>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||||
| } else { // int8 or int16 | } else { // int8 or int16 | ||||
| std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; }; | std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; }; | ||||
| PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||||
| PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||||
| } | } | ||||
| } else { | } else { | ||||
| packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float)); | |||||
| packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float), img_size); | |||||
| packed_weight_ = allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true); | packed_weight_ = allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true); | ||||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) { | if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) { | ||||
| std::function<float(float)> to_dtype = [](float x) -> float { return x; }; | std::function<float(float)> to_dtype = [](float x) -> float { return x; }; | ||||
| PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||||
| PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) { | } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) { | ||||
| std::function<float(float16_t)> to_dtype = [](float16_t x) -> float { return static_cast<float>(x); }; | std::function<float(float16_t)> to_dtype = [](float16_t x) -> float { return static_cast<float>(x); }; | ||||
| PackNCHWToNC4HW4<float16_t, float>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||||
| PackNCHWToNC4HW4<float16_t, float>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||||
| } else { // int8 or int16 | } else { // int8 or int16 | ||||
| std::function<float(float)> to_dtype = [](float x) -> float { return x; }; | std::function<float(float)> to_dtype = [](float x) -> float { return x; }; | ||||
| PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype); | |||||
| PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype); | |||||
| } | } | ||||
| } | } | ||||
| allocator->UnmapBuffer(packed_weight_); | allocator->UnmapBuffer(packed_weight_); | ||||
| @@ -184,7 +198,7 @@ void DepthwiseConv2dOpenCLKernel::SetConstArgs() { | |||||
| cl_int4 dst_size = {(cl_int)out_info.W, (cl_int)out_info.H, (cl_int)CO4, (cl_int)out_info.N}; | cl_int4 dst_size = {(cl_int)out_info.W, (cl_int)out_info.H, (cl_int)CO4, (cl_int)out_info.N}; | ||||
| int arg_cnt = 2; | int arg_cnt = 2; | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, packed_weight_, lite::opencl::MemType::BUF); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, packed_weight_, filter_type_); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, bias_data_, lite::opencl::MemType::BUF); | ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, bias_data_, lite::opencl::MemType::BUF); | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, kernel_size); | ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, kernel_size); | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, stride); | ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, stride); | ||||
| @@ -21,13 +21,18 @@ | |||||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | #include "src/runtime/kernel/opencl/opencl_kernel.h" | ||||
| #include "nnacl/conv_parameter.h" | #include "nnacl/conv_parameter.h" | ||||
| using mindspore::lite::opencl::MemType; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { | class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { | ||||
| public: | public: | ||||
| DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs) | const std::vector<lite::Tensor *> &outputs) | ||||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||||
| : OpenCLKernel(parameter, inputs, outputs) { | |||||
| bool is_adreno = ocl_runtime_->GetGpuInfo().type == lite::opencl::GpuType::ADRENO; | |||||
| filter_type_ = is_adreno ? MemType::IMG : MemType::BUF; | |||||
| } | |||||
| ~DepthwiseConv2dOpenCLKernel() override = default; | ~DepthwiseConv2dOpenCLKernel() override = default; | ||||
| @@ -47,6 +52,7 @@ class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { | |||||
| int W{2}; | int W{2}; | ||||
| int C{1}; | int C{1}; | ||||
| } block_size_; | } block_size_; | ||||
| MemType filter_type_{MemType::BUF}; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -62,19 +62,20 @@ std::vector<int> GetNHWCShape(const std::vector<int> &tensor_shape); | |||||
| std::vector<size_t> GetImage2dShapeFromNHWC(const std::vector<int> &tensor_shape, schema::Format format); | std::vector<size_t> GetImage2dShapeFromNHWC(const std::vector<int> &tensor_shape, schema::Format format); | ||||
| template <class T1, class T2> | template <class T1, class T2> | ||||
| void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane, int channel, const std::function<T2(T1)> &to_dtype) { | |||||
| void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane_in, int plane_out, int channel, | |||||
| const std::function<T2(T1)> &to_dtype) { | |||||
| MS_ASSERT(src); | MS_ASSERT(src); | ||||
| MS_ASSERT(dst); | MS_ASSERT(dst); | ||||
| int c4 = UP_DIV(channel, C4NUM); | int c4 = UP_DIV(channel, C4NUM); | ||||
| for (int b = 0; b < batch; b++) { | for (int b = 0; b < batch; b++) { | ||||
| int src_offset = b * plane * channel; | |||||
| int dst_offset = b * plane * c4 * C4NUM; | |||||
| int src_offset = b * plane_in * channel; | |||||
| int dst_offset = b * plane_out * c4; | |||||
| for (int c = 0; c < channel; c++) { | for (int c = 0; c < channel; c++) { | ||||
| int c4_block_num = c / C4NUM; | int c4_block_num = c / C4NUM; | ||||
| int c4_block_rem = c % C4NUM; | int c4_block_rem = c % C4NUM; | ||||
| int src_c_offset = src_offset + c * plane; | |||||
| int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; | |||||
| for (int k = 0; k < plane; k++) { | |||||
| int src_c_offset = src_offset + c * plane_in; | |||||
| int dst_c_offset = dst_offset + c4_block_num * plane_out; | |||||
| for (int k = 0; k < plane_in; k++) { | |||||
| int src_kernel_offset = src_c_offset + k; | int src_kernel_offset = src_c_offset + k; | ||||
| int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; | int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; | ||||
| (static_cast<T2 *>(dst) + dst_kernel_offset)[0] = to_dtype((static_cast<T1 *>(src) + src_kernel_offset)[0]); | (static_cast<T2 *>(dst) + dst_kernel_offset)[0] = to_dtype((static_cast<T1 *>(src) + src_kernel_offset)[0]); | ||||
| @@ -187,13 +187,20 @@ class OpenCLRuntime { | |||||
| std::vector<size_t> max_work_item_sizes_; | std::vector<size_t> max_work_item_sizes_; | ||||
| void *handle_{nullptr}; | void *handle_{nullptr}; | ||||
| TuningMode tuning_mode_{TuningMode::DEFAULT}; | TuningMode tuning_mode_{TuningMode::DEFAULT}; | ||||
| #if MS_OPENCL_PROFILE | |||||
| bool profiling_{true}; | |||||
| #else | |||||
| bool profiling_{false}; | bool profiling_{false}; | ||||
| #endif | |||||
| // for cache | // for cache | ||||
| private: | private: | ||||
| void LoadCache(); | void LoadCache(); | ||||
| void StoreCache(); | void StoreCache(); | ||||
| #ifdef MS_OPENCL_BINARY_CACHE | |||||
| bool enable_cache_{true}; | |||||
| #else | |||||
| bool enable_cache_{false}; | bool enable_cache_{false}; | ||||
| #endif | |||||
| bool flush_cache_{false}; | bool flush_cache_{false}; | ||||
| std::string cache_path_{"/data/local/tmp/.opencl_cache"}; | std::string cache_path_{"/data/local/tmp/.opencl_cache"}; | ||||
| const std::string cache_version_{"V0.1"}; | const std::string cache_version_{"V0.1"}; | ||||
| @@ -81,7 +81,7 @@ TEST_F(TestOpenCL_DepthwiseConv2d, NoPad) { | |||||
| TestMain({{input_shape, input_data, VAR}, | TestMain({{input_shape, input_data, VAR}, | ||||
| {weight_shape, weight_data, CONST_TENSOR}, | {weight_shape, weight_data, CONST_TENSOR}, | ||||
| {bias_shape, bias_data, CONST_TENSOR}}, | {bias_shape, bias_data, CONST_TENSOR}}, | ||||
| {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5); | |||||
| {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); | |||||
| } | } | ||||
| } | } | ||||
| @@ -128,7 +128,7 @@ TEST_F(TestOpenCL_DepthwiseConv2d, Pad) { | |||||
| TestMain({{input_shape, input_data, VAR}, | TestMain({{input_shape, input_data, VAR}, | ||||
| {weight_shape, weight_data, CONST_TENSOR}, | {weight_shape, weight_data, CONST_TENSOR}, | ||||
| {bias_shape, bias_data, CONST_TENSOR}}, | {bias_shape, bias_data, CONST_TENSOR}}, | ||||
| {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5); | |||||
| {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); | |||||
| } | } | ||||
| } | } | ||||