| @@ -346,8 +346,12 @@ int LiteSession::Init(Context *context) { | |||||
| if (context_->device_type_ == DT_GPU) { | if (context_->device_type_ == DT_GPU) { | ||||
| auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | ||||
| opencl_runtime->SetFp16Enable(context_->float16_priority); | opencl_runtime->SetFp16Enable(context_->float16_priority); | ||||
| opencl_runtime->Init(); | |||||
| MS_LOG(INFO) << "Init OpenCL runtime."; | |||||
| if (opencl_runtime->Init() != RET_OK) { | |||||
| context_->device_type_ = DT_CPU; | |||||
| MS_LOG(WARNING) << "Init OpenCL runtime failed, change to CPU mode."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "Init OpenCL runtime success."; | |||||
| } | |||||
| } | } | ||||
| #endif | #endif | ||||
| executor = new Executor(); | executor = new Executor(); | ||||
| @@ -97,6 +97,7 @@ void ArithmeticSelfOpenCLKernel::GetKernelName(std::string *kernel_name, Arithme | |||||
| break; | break; | ||||
| case PrimitiveType_Round: | case PrimitiveType_Round: | ||||
| kernel_name[0] += "_ElementRound"; | kernel_name[0] += "_ElementRound"; | ||||
| break; | |||||
| case PrimitiveType_Neg: | case PrimitiveType_Neg: | ||||
| kernel_name[0] += "_ElementNeg"; | kernel_name[0] += "_ElementNeg"; | ||||
| break; | break; | ||||
| @@ -68,7 +68,7 @@ int ConvolutionOpenCLKernel::Init() { | |||||
| TILES_X_ = UP_DIV(OW_, 4); | TILES_X_ = UP_DIV(OW_, 4); | ||||
| TILES_Y_ = UP_DIV(OH_, 4); | TILES_Y_ = UP_DIV(OH_, 4); | ||||
| TILES_XY_ = TILES_X_ * TILES_Y_; | TILES_XY_ = TILES_X_ * TILES_Y_; | ||||
| use_winograd_ = UseWinograd4x4To6x6() && use_fp16_; | |||||
| use_winograd_ = UseWinograd4x4To6x6(); | |||||
| // build kernel | // build kernel | ||||
| if (use_winograd_) { | if (use_winograd_) { | ||||
| @@ -247,7 +247,7 @@ int ConvolutionOpenCLKernel::InitBuffer() { | |||||
| int ConvolutionOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | int ConvolutionOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | ||||
| size_t im_dst_x, im_dst_y; | size_t im_dst_x, im_dst_y; | ||||
| if (in_tensors_[0]->GetFormat() == Format_NHWC4) { | if (in_tensors_[0]->GetFormat() == Format_NHWC4) { | ||||
| if (out_tensors_[0]->Width() * CO_SLICES_ < 65536) { | |||||
| if (out_tensors_[0]->Width() * CO_SLICES_ <= MAX_IMAGE2D_SIZE) { | |||||
| { | { | ||||
| im_dst_x = out_tensors_[0]->Width() * CO_SLICES_; | im_dst_x = out_tensors_[0]->Width() * CO_SLICES_; | ||||
| im_dst_y = out_tensors_[0]->Height(); | im_dst_y = out_tensors_[0]->Height(); | ||||
| @@ -314,7 +314,8 @@ int ConvolutionOpenCLKernel::Run() { | |||||
| if (use_winograd_) { | if (use_winograd_) { | ||||
| ocl_runtime_->RunKernel(kernel_4x4to36_, {size_t(TILES_XY_), 6, size_t(CI_SLICES_)}, {8, 6, 4}, nullptr); | ocl_runtime_->RunKernel(kernel_4x4to36_, {size_t(TILES_XY_), 6, size_t(CI_SLICES_)}, {8, 6, 4}, nullptr); | ||||
| ocl_runtime_->RunKernel(kernel_conv_, {size_t(TILES_XY_ / 2), 36, size_t(CO_SLICES_ / 2)}, {8, 6, 2}, nullptr); | |||||
| ocl_runtime_->RunKernel(kernel_conv_, {size_t(UP_DIV(TILES_XY_, 2)), 36, size_t(UP_DIV(CO_SLICES_, 2))}, {8, 6, 2}, | |||||
| nullptr); | |||||
| ocl_runtime_->RunKernel(kernel_36to4x4_, {size_t(TILES_XY_), 4, size_t(CO_SLICES_)}, {32, 4, 2}, nullptr); | ocl_runtime_->RunKernel(kernel_36to4x4_, {size_t(TILES_XY_), 4, size_t(CO_SLICES_)}, {32, 4, 2}, nullptr); | ||||
| } else { | } else { | ||||
| std::vector<size_t> global, local; | std::vector<size_t> global, local; | ||||
| @@ -414,7 +415,7 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNHWC4() { | |||||
| code += " out0_c4_bias = clamp(out0_c4_bias, (FLT4)(0.0f), (FLT4)(6.0f));\n"; | code += " out0_c4_bias = clamp(out0_c4_bias, (FLT4)(0.0f), (FLT4)(6.0f));\n"; | ||||
| } | } | ||||
| if (OW_ * CO_SLICES_ < 65536) { | |||||
| if (OW_ * CO_SLICES_ <= MAX_IMAGE2D_SIZE) { | |||||
| code += " WRITE_IMAGE(output, (int2)(ow * CO_SLICES + co_slice, oh), out0_c4_bias);// NHWC4: H WC\n}"; | code += " WRITE_IMAGE(output, (int2)(ow * CO_SLICES + co_slice, oh), out0_c4_bias);// NHWC4: H WC\n}"; | ||||
| } else { | } else { | ||||
| code += " WRITE_IMAGE(output, (int2)(oh * CO_SLICES + co_slice, ow), out0_c4_bias);// NHWC4: H WC\n}"; | code += " WRITE_IMAGE(output, (int2)(oh * CO_SLICES + co_slice, ow), out0_c4_bias);// NHWC4: H WC\n}"; | ||||
| @@ -616,23 +617,27 @@ std::string ConvolutionOpenCLKernel::CodeGenWinograd4x4To36() { | |||||
| " FLT4 BtD_row[6] = {0};\n" | " FLT4 BtD_row[6] = {0};\n" | ||||
| " for (int y = 0; y < 6; y++)\n" | " for (int y = 0; y < 6; y++)\n" | ||||
| " {\n" | " {\n" | ||||
| " int y_idx = tile_y * 4 - PAD + y;\n"; | |||||
| " int ih = tile_y * 4 - PAD + y;\n"; | |||||
| if (op_format_ == Format_NHWC4) { | if (op_format_ == Format_NHWC4) { | ||||
| code += | |||||
| " for (int x = 0; x < 6; x++)\n" | |||||
| " {\n" | |||||
| " int x_idx = (tile_x * 4 - PAD + x) * SLICES + slice;\n"; | |||||
| code += " int y_idx = ih;\n"; | |||||
| } else if (op_format_ == Format_NC4HW4) { | } else if (op_format_ == Format_NC4HW4) { | ||||
| code += | code += | ||||
| " if(y_idx < 0 || y_idx >= IH)\n" | |||||
| " {\n" | |||||
| " continue;\n" | |||||
| " }\n" | |||||
| " y_idx += slice * IH;\n" | |||||
| " for (int x = 0; x < 6; x++)\n" | |||||
| " {\n" | |||||
| " int x_idx = tile_x * 4 - PAD + x;\n"; | |||||
| " if(ih < 0 || ih >= IH) {continue;}\n" | |||||
| " int y_idx = slice * IH + ih;\n"; | |||||
| } | |||||
| code += | |||||
| " for (int x = 0; x < 6; x++)\n" | |||||
| " {\n" | |||||
| " int iw = tile_x * 4 - PAD + x;\n"; | |||||
| if (op_format_ == Format_NHWC4) { | |||||
| code += | |||||
| " if(iw < 0 || iw >= IW) {continue;}\n" | |||||
| " int x_idx = iw * SLICES + slice;\n"; | |||||
| } else if (op_format_ == Format_NC4HW4) { | |||||
| code += " int x_idx = iw;\n"; | |||||
| } | } | ||||
| code += | code += | ||||
| @@ -792,9 +797,9 @@ std::string ConvolutionOpenCLKernel::CodeGenWinograd36To4x4() { | |||||
| auto param = reinterpret_cast<ConvParameter *>(op_parameter_); | auto param = reinterpret_cast<ConvParameter *>(op_parameter_); | ||||
| if (param->act_type_ == ActType_Relu) { | if (param->act_type_ == ActType_Relu) { | ||||
| code += " acc = max(acc, (FLT4)(0.0f));\n"; | |||||
| code += " acc = max(acc, (FLT4)(0.0f));\n\n"; | |||||
| } else if (param->act_type_ == ActType_Relu6) { | } else if (param->act_type_ == ActType_Relu6) { | ||||
| code += " acc = clamp(acc, (FLT4)(0.0f), (FLT4)(6.0f));\n"; | |||||
| code += " acc = clamp(acc, (FLT4)(0.0f), (FLT4)(6.0f));\n\n"; | |||||
| } | } | ||||
| code += | code += | ||||
| @@ -838,7 +843,7 @@ int ConvolutionOpenCLKernel::SetGlobalLocalConv(std::vector<size_t> *global, std | |||||
| } | } | ||||
| if (op_format_ == Format_NHWC4) { | if (op_format_ == Format_NHWC4) { | ||||
| if (OW_ * CO_SLICES_ > 65536) { | |||||
| if (OW_ * CO_SLICES_ > MAX_IMAGE2D_SIZE) { | |||||
| local_w = 4; | local_w = 4; | ||||
| } | } | ||||
| } | } | ||||
| @@ -81,8 +81,10 @@ class ConvolutionOpenCLKernel : public OpenCLKernel { | |||||
| bool UseWinograd4x4To6x6() { | bool UseWinograd4x4To6x6() { | ||||
| auto param = reinterpret_cast<ConvParameter *>(op_parameter_); | auto param = reinterpret_cast<ConvParameter *>(op_parameter_); | ||||
| const bool attr_valid = param->kernel_h_ == 3 && param->kernel_w_ == 3 && param->dilation_h_ == 1 && | |||||
| param->dilation_w_ == 1 && param->stride_h_ == 1 && param->stride_w_ == 1; | |||||
| const bool attr_valid = param->kernel_h_ == 3 && param->kernel_w_ == 3 && param->stride_h_ == 1 && | |||||
| param->stride_w_ == 1 && param->pad_u_ == 1 && param->pad_d_ == 1 && param->pad_l_ == 1 && | |||||
| param->pad_r_ == 1 && param->dilation_h_ == 1 && param->dilation_w_ == 1 && IH_ == OH_ && | |||||
| IW_ == OW_; | |||||
| const bool channel_good = CI_SLICES_ >= 12 && CO_SLICES_ >= 12; | const bool channel_good = CI_SLICES_ >= 12 && CO_SLICES_ >= 12; | ||||
| const bool hw_good = TILES_X_ * TILES_Y_ >= 16; | const bool hw_good = TILES_X_ * TILES_Y_ >= 16; | ||||
| return attr_valid && channel_good && hw_good; | return attr_valid && channel_good && hw_good; | ||||
| @@ -16,7 +16,8 @@ j* you may not use this file except in compliance with the License. | |||||
| #ifndef MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_ | #ifndef MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_ | ||||
| #define MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_ | #define MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_ | ||||
| // Get from Device? | |||||
| #define MAX_IMAGE2D_SIZE 65535 | |||||
| #include <vector> | #include <vector> | ||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -127,7 +128,6 @@ class OpenCLRuntime { | |||||
| int UnmapBuffer(const cl::Memory &buffer, void *host_ptr, cl::CommandQueue *command_queue = nullptr) const; | int UnmapBuffer(const cl::Memory &buffer, void *host_ptr, cl::CommandQueue *command_queue = nullptr) const; | ||||
| int UnmapBuffer(void *host_ptr, cl::CommandQueue *command_queue = nullptr) const; | int UnmapBuffer(void *host_ptr, cl::CommandQueue *command_queue = nullptr) const; | ||||
| bool SyncCommandQueue(cl::CommandQueue *command_queue = nullptr); | bool SyncCommandQueue(cl::CommandQueue *command_queue = nullptr); | ||||
| bool IsInitOK() {return init_done_;} | |||||
| /** | /** | ||||
| * Get kernel max worker group size. | * Get kernel max worker group size. | ||||
| @@ -6,3 +6,4 @@ mtk_AADB_HADB_MBV2_model_fp32.tflite | |||||
| hiai_cn_recognize_modify_padv2.tflite | hiai_cn_recognize_modify_padv2.tflite | ||||
| hiai_cv_focusShootOCRModel_08.tflite | hiai_cv_focusShootOCRModel_08.tflite | ||||
| hiai_model_normalize_object_scene_ps_20200519.tflite | hiai_model_normalize_object_scene_ps_20200519.tflite | ||||
| inception_v3.tflite | |||||
| @@ -44,10 +44,10 @@ void RunTestTranspose(const std::vector<int> &shape, void *input_data, void *out | |||||
| return; | return; | ||||
| } | } | ||||
| param->num_axes_ = 4; | param->num_axes_ = 4; | ||||
| param->perm_[0] = 0; | |||||
| param->perm_[1] = 3; | |||||
| param->perm_[2] = 1; | |||||
| param->perm_[3] = 2; | |||||
| param->perm_[0] = shape[3]; | |||||
| param->perm_[1] = shape[4]; | |||||
| param->perm_[2] = shape[5]; | |||||
| param->perm_[3] = shape[6]; | |||||
| auto allocator = ocl_runtime->GetAllocator(); | auto allocator = ocl_runtime->GetAllocator(); | ||||
| int h = shape[0]; | int h = shape[0]; | ||||
| int w = shape[1]; | int w = shape[1]; | ||||
| @@ -60,9 +60,10 @@ void RunTestTranspose(const std::vector<int> &shape, void *input_data, void *out | |||||
| MS_LOG(ERROR) << "tensor_x create error."; | MS_LOG(ERROR) << "tensor_x create error."; | ||||
| return; | return; | ||||
| } | } | ||||
| std::vector<int> out_shape = {1, c, h, w}; | |||||
| std::vector<int> out_shape = {input_shape[param->perm_[0]], input_shape[param->perm_[1]], | |||||
| input_shape[param->perm_[2]], input_shape[param->perm_[3]]}; | |||||
| auto tensor_out_ptr = std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), | auto tensor_out_ptr = std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), | ||||
| out_shape, schema::Format_NCHW); | |||||
| out_shape, schema::Format_NHWC); | |||||
| auto tensor_out = tensor_out_ptr.get(); | auto tensor_out = tensor_out_ptr.get(); | ||||
| if (tensor_out == nullptr) { | if (tensor_out == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor_out create error."; | MS_LOG(ERROR) << "tensor_out create error."; | ||||
| @@ -105,25 +106,63 @@ void RunTestTranspose(const std::vector<int> &shape, void *input_data, void *out | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | lite::opencl::OpenCLRuntime::DeleteInstance(); | ||||
| } | } | ||||
| TEST_F(TestTransposeOpenCL, TransposeFp32) { | |||||
| TEST_F(TestTransposeOpenCL, TransposeNHWC2NCHWFp32) { | |||||
| int h = 2; | int h = 2; | ||||
| int w = 2; | int w = 2; | ||||
| int c = 3; | int c = 3; | ||||
| std::vector<int> shape = {h, w, c}; | |||||
| int perm0 = 0; | |||||
| int perm1 = 3; | |||||
| int perm2 = 1; | |||||
| int perm3 = 2; | |||||
| std::vector<int> shape = {h, w, c, perm0, perm1, perm2, perm3}; | |||||
| std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; | std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; | ||||
| std::vector<float> output_data = {0.0f, 3.0f, 6.0f, 9.0f, 1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f, 8.0f, 11.0f}; | std::vector<float> output_data = {0.0f, 3.0f, 6.0f, 9.0f, 1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f, 8.0f, 11.0f}; | ||||
| RunTestTranspose(shape, input_data.data(), output_data.data(), false); | RunTestTranspose(shape, input_data.data(), output_data.data(), false); | ||||
| } | } | ||||
| TEST_F(TestTransposeOpenCL, TransposeFp16) { | |||||
| TEST_F(TestTransposeOpenCL, TransposeNHWC2NCHWFp16) { | |||||
| int h = 2; | int h = 2; | ||||
| int w = 2; | int w = 2; | ||||
| int c = 3; | int c = 3; | ||||
| std::vector<int> shape = {h, w, c}; | |||||
| int perm0 = 0; | |||||
| int perm1 = 3; | |||||
| int perm2 = 1; | |||||
| int perm3 = 2; | |||||
| std::vector<int> shape = {h, w, c, perm0, perm1, perm2, perm3}; | |||||
| std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; | std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; | ||||
| std::vector<float16_t> output_data = {0.0f, 3.0f, 6.0f, 9.0f, 1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f, 8.0f, 11.0f}; | std::vector<float16_t> output_data = {0.0f, 3.0f, 6.0f, 9.0f, 1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f, 8.0f, 11.0f}; | ||||
| RunTestTranspose(shape, input_data.data(), output_data.data(), true); | RunTestTranspose(shape, input_data.data(), output_data.data(), true); | ||||
| } | } | ||||
| TEST_F(TestTransposeOpenCL, TransposeNCHW2NHWCFp32) { | |||||
| int h = 2; | |||||
| int w = 2; | |||||
| int c = 3; | |||||
| int perm0 = 0; | |||||
| int perm1 = 2; | |||||
| int perm2 = 3; | |||||
| int perm3 = 1; | |||||
| std::vector<int> shape = {h, w, c, perm0, perm1, perm2, perm3}; | |||||
| std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; | |||||
| std::vector<float> output_data = {0.0f, 6.0f, 1.0f, 7.0f, 2.0f, 8.0f, 3.0f, 9.0f, 4.0f, 10.0f, 5.0f, 11.0f}; | |||||
| RunTestTranspose(shape, input_data.data(), output_data.data(), false); | |||||
| } | |||||
| TEST_F(TestTransposeOpenCL, TransposeNCHW2NHWCFp16) { | |||||
| int h = 2; | |||||
| int w = 2; | |||||
| int c = 3; | |||||
| int perm0 = 0; | |||||
| int perm1 = 2; | |||||
| int perm2 = 3; | |||||
| int perm3 = 1; | |||||
| std::vector<int> shape = {h, w, c, perm0, perm1, perm2, perm3}; | |||||
| std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; | |||||
| std::vector<float16_t> output_data = {0.0f, 6.0f, 1.0f, 7.0f, 2.0f, 8.0f, 3.0f, 9.0f, 4.0f, 10.0f, 5.0f, 11.0f}; | |||||
| RunTestTranspose(shape, input_data.data(), output_data.data(), true); | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||