| @@ -510,8 +510,12 @@ gene_ocl_program() { | |||||
| build_opencl() { | build_opencl() { | ||||
| cd ${BASEPATH} | cd ${BASEPATH} | ||||
| git submodule update --init third_party/OpenCL-Headers | |||||
| git submodule update --init third_party/OpenCL-CLHPP | |||||
| if [[ ! -d "third_party/OpenCL-Headers" ]]; then | |||||
| git submodule update --init third_party/OpenCL-Headers | |||||
| fi | |||||
| if [[ ! -d "third_party/OpenCL-CLHPP" ]]; then | |||||
| git submodule update --init third_party/OpenCL-CLHPP | |||||
| fi | |||||
| if [[ "${OPENCL_OFFLINE_COMPILE}" == "on" ]]; then | if [[ "${OPENCL_OFFLINE_COMPILE}" == "on" ]]; then | ||||
| gene_ocl_program | gene_ocl_program | ||||
| else | else | ||||
| @@ -524,6 +528,7 @@ build_lite() | |||||
| echo "start build mindspore lite project" | echo "start build mindspore lite project" | ||||
| if [[ "${ENABLE_GPU}" == "on" ]]; then | if [[ "${ENABLE_GPU}" == "on" ]]; then | ||||
| echo "start build opencl" | |||||
| build_opencl | build_opencl | ||||
| fi | fi | ||||
| if [[ "${LITE_PLATFORM}" == "x86_64" ]]; then | if [[ "${LITE_PLATFORM}" == "x86_64" ]]; then | ||||
| @@ -49,7 +49,6 @@ endif () | |||||
| if (SUPPORT_GPU) | if (SUPPORT_GPU) | ||||
| add_definitions(-DUSE_OPENCL_WRAPPER) | add_definitions(-DUSE_OPENCL_WRAPPER) | ||||
| add_definitions(-DMS_OPENCL_PROFILE=false) | add_definitions(-DMS_OPENCL_PROFILE=false) | ||||
| add_definitions(-DCL_HPP_TARGET_OPENCL_VERSION=200) | |||||
| add_compile_definitions(SUPPORT_GPU) | add_compile_definitions(SUPPORT_GPU) | ||||
| if(OFFLINE_COMPILE) | if(OFFLINE_COMPILE) | ||||
| add_compile_definitions(PROGRAM_WITH_IL) | add_compile_definitions(PROGRAM_WITH_IL) | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/opencl/opencl_runtime.h" | #include "src/runtime/opencl/opencl_runtime.h" | ||||
| #include "src/runtime/kernel/opencl/kernel/concat.h" | #include "src/runtime/kernel/opencl/kernel/concat.h" | ||||
| #include "src/backend/opencl/cl/fp32/concat.cl.inc" | |||||
| #include "src/runtime/kernel/opencl/cl/fp32/concat.cl.inc" | |||||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | using mindspore::kernel::KERNEL_ARCH::kGPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -115,12 +115,16 @@ int GetBiggestDividerWithPriority(int number, int max_divider) { | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| void ConcatGetWorkGroup(const std::vector<size_t> &global, const std::vector<size_t> &local, int max_size) { | |||||
| void ConcatGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) { | |||||
| int x = std::min(GetBiggestDividerWithPriority(global[0], 8), 4); | int x = std::min(GetBiggestDividerWithPriority(global[0], 8), 4); | ||||
| int yz = max_size / x; | int yz = max_size / x; | ||||
| int y = std::min(std::min(GetBiggestDividerWithPriority(global[1], 8), yz), 8); | int y = std::min(std::min(GetBiggestDividerWithPriority(global[1], 8), yz), 8); | ||||
| int z = std::min(yz / y, DivideRoundUp(global[2], 2)); | int z = std::min(yz / y, DivideRoundUp(global[2], 2)); | ||||
| local = {static_cast<unsigned int>(x), static_cast<unsigned int>(y), static_cast<unsigned int>(z)}; | |||||
| local->clear(); | |||||
| local->push_back(x); | |||||
| local->push_back(y); | |||||
| local->push_back(z); | |||||
| } | } | ||||
| int ConcatOpenCLKernel::Run() { | int ConcatOpenCLKernel::Run() { | ||||
| auto param = reinterpret_cast<ConcatParameter *>(this->opParameter); | auto param = reinterpret_cast<ConcatParameter *>(this->opParameter); | ||||
| @@ -144,7 +148,7 @@ int ConcatOpenCLKernel::Run() { | |||||
| uint32_t OW = output_shape[2]; | uint32_t OW = output_shape[2]; | ||||
| uint32_t OC = output_shape[3]; | uint32_t OC = output_shape[3]; | ||||
| global = {OH, OW, OC}; // HWC | global = {OH, OW, OC}; // HWC | ||||
| ConcatGetWorkGroup(global, local, 384); | |||||
| ConcatGetWorkGroup(global, &local, 384); | |||||
| std::cout << "local size=:" << std::endl; | std::cout << "local size=:" << std::endl; | ||||
| for (int i = 0; i < local.size(); i++) { | for (int i = 0; i < local.size(); i++) { | ||||
| std::cout << local[i] << " "; | std::cout << local[i] << " "; | ||||
| @@ -174,7 +178,7 @@ int ConcatOpenCLKernel::Run() { | |||||
| uint32_t OW = output_shape[2]; | uint32_t OW = output_shape[2]; | ||||
| uint32_t OC = output_shape[3]; | uint32_t OC = output_shape[3]; | ||||
| global = {OH, OW, OC}; // HWC | global = {OH, OW, OC}; // HWC | ||||
| ConcatGetWorkGroup(global, local, 384); | |||||
| ConcatGetWorkGroup(global, &local, 384); | |||||
| std::cout << "local size=:" << std::endl; | std::cout << "local size=:" << std::endl; | ||||
| for (int i = 0; i < local.size(); i++) { | for (int i = 0; i < local.size(); i++) { | ||||
| std::cout << local[i] << " "; | std::cout << local[i] << " "; | ||||
| @@ -196,8 +200,9 @@ int ConcatOpenCLKernel::Run() { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<tensor::Tensor *> &inputs, | |||||
| const std::vector<tensor::Tensor *> &outputs, OpParameter *opParameter, | |||||
| kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, | |||||
| const lite::Context *ctx, const kernel::KernelKey &desc) { | const lite::Context *ctx, const kernel::KernelKey &desc) { | ||||
| auto *kernel = new ConcatOpenCLKernel(opParameter, inputs, outputs); | auto *kernel = new ConcatOpenCLKernel(opParameter, inputs, outputs); | ||||
| auto ret = kernel->Init(); | auto ret = kernel->Init(); | ||||
| @@ -20,24 +20,21 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "src/backend/arm/opclib/conv_parameter.h" | |||||
| #include "src/runtime/opencl/opencl_runtime.h" | #include "src/runtime/opencl/opencl_runtime.h" | ||||
| #include "src/backend/arm/opclib/concat.h" | |||||
| #include "src/runtime/kernel/arm/base/concat_base.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class ConcatOpenCLKernel : public LiteKernel { | class ConcatOpenCLKernel : public LiteKernel { | ||||
| public: | public: | ||||
| explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector<tensor::Tensor *> &inputs, | |||||
| const std::vector<tensor::Tensor *> &outputs) | |||||
| explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(parameter, inputs, outputs) {} | : LiteKernel(parameter, inputs, outputs) {} | ||||
| ~ConcatOpenCLKernel() override{}; | ~ConcatOpenCLKernel() override{}; | ||||
| int Init() override; | int Init() override; | ||||
| // int InferShape() { return {}; }; | |||||
| int InferShape() {} | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run_axis0(); | int Run_axis0(); | ||||
| @@ -39,7 +39,6 @@ class Conv2dTransposeOpenCLKernel : public LiteKernel { | |||||
| ~Conv2dTransposeOpenCLKernel() override {}; | ~Conv2dTransposeOpenCLKernel() override {}; | ||||
| int Init() override; | int Init() override; | ||||
| int InferShape() {} | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| void PadWeight(); | void PadWeight(); | ||||
| @@ -41,7 +41,6 @@ class MatMulOpenCLKernel : public LiteKernel { | |||||
| ~MatMulOpenCLKernel() override{}; | ~MatMulOpenCLKernel() override{}; | ||||
| int Init() override; | int Init() override; | ||||
| int InferShape() {} | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| void PadWeight(); | void PadWeight(); | ||||
| @@ -265,10 +265,10 @@ endif() | |||||
| if (SUPPORT_GPU) | if (SUPPORT_GPU) | ||||
| set(TEST_SRC | set(TEST_SRC | ||||
| ${TEST_SRC} | ${TEST_SRC} | ||||
| ${TEST_DIR}/ut/stc/runtime/kernel/opencl/matmul_tests.cc | |||||
| ${TEST_DIR}/ut/stc/runtime/kernel/opencl/depthwise_conv2d_tests.cc | |||||
| ${TEST_DIR}/ut/stc/runtime/kernel/opencl/concat_tests.cc | |||||
| ${TEST_DIR}/ut/stc/runtime/kernel/opencl/softmax_cl_tests.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/matmul_tests.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_cl_tests.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| @@ -18,12 +18,10 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | ||||
| #include "mindspore/lite/src/backend/opencl/subgraph_opencl_kernel.h" | |||||
| #include "mindspore/lite/src/backend/opencl/kernel/concat.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h" | |||||
| using mindspore::kernel; | |||||
| using mindspore::lite; | |||||
| using mindspore; | |||||
| int DivideRoundUp(int n, int div) { | int DivideRoundUp(int n, int div) { | ||||
| int q = n / div; | int q = n / div; | ||||
| return n % div == 0 ? q : q + 1; | return n % div == 0 ? q : q + 1; | ||||
| @@ -96,7 +94,7 @@ void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *i | |||||
| } | } | ||||
| namespace mindspore { | namespace mindspore { | ||||
| class TestConcatOpenCL : public UT::Common { | |||||
| class TestConcatOpenCL : public mindspore::Common { | |||||
| public: | public: | ||||
| TestConcatOpenCL(){} | TestConcatOpenCL(){} | ||||
| }; | }; | ||||
| @@ -113,30 +111,31 @@ TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { | |||||
| output_shape[3] = DivideRoundUp(output_shape[3], 4) * 4; | output_shape[3] = DivideRoundUp(output_shape[3], 4) * 4; | ||||
| auto data_type = kNumberTypeFloat32; | auto data_type = kNumberTypeFloat32; | ||||
| auto tensor_type = schema::NodeType_ValueNode; | auto tensor_type = schema::NodeType_ValueNode; | ||||
| std::vector<tensor::Tensor *> inputs; | |||||
| std::vector<lite::tensor::Tensor *> inputs; | |||||
| for (auto &shape : input_shapes) { | for (auto &shape : input_shapes) { | ||||
| inputs.push_back(new tensor::Tensor(data_type, shape, schema::Format_NHWC, tensor_type)); | |||||
| inputs.push_back(new lite::tensor::Tensor(data_type, shape, schema::Format_NHWC, tensor_type)); | |||||
| } | } | ||||
| auto *output_tensor = new tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); | |||||
| std::vector<tensor::Tensor *> outputs{output_tensor}; | |||||
| auto *output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); | |||||
| std::vector<lite::tensor::Tensor *> outputs{output_tensor}; | |||||
| std::cout << "input_shapes size=: " << input_shapes.size() << std::endl; | std::cout << "input_shapes size=: " << input_shapes.size() << std::endl; | ||||
| MS_LOG(INFO) << "initialize tensors"; | MS_LOG(INFO) << "initialize tensors"; | ||||
| auto param = new ConcatParameter(); | auto param = new ConcatParameter(); | ||||
| param->axis_ = 3; | param->axis_ = 3; | ||||
| auto *concat_kernel = new ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| auto *concat_kernel = new kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| concat_kernel->Init(); | concat_kernel->Init(); | ||||
| MS_LOG(INFO) << "initialize sub_graph"; | MS_LOG(INFO) << "initialize sub_graph"; | ||||
| std::vector<LiteKernel *> kernels{concat_kernel}; | |||||
| auto *sub_graph = new SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| std::vector<kernel::LiteKernel *> kernels{concat_kernel}; | |||||
| auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| sub_graph->Init(); | sub_graph->Init(); | ||||
| MS_LOG(INFO) << "initialize input data"; | MS_LOG(INFO) << "initialize input data"; | ||||
| srand(time(NULL)); | srand(time(NULL)); | ||||
| for (auto &input_tensor : inputs) { | for (auto &input_tensor : inputs) { | ||||
| auto input_data = reinterpret_cast<float *>(input_tensor->Data()); | auto input_data = reinterpret_cast<float *>(input_tensor->Data()); | ||||
| static unsigned int seed = 123; | |||||
| for (int i = 0; i < input_tensor->ElementsNum(); ++i) { | for (int i = 0; i < input_tensor->ElementsNum(); ++i) { | ||||
| input_data[i] = static_cast<float>(rand_r() % 10 + 1); | |||||
| input_data[i] = static_cast<float>(rand_r(&seed) % 10 + 1); | |||||
| } | } | ||||
| printf("\n"); | printf("\n"); | ||||
| } | } | ||||
| @@ -23,9 +23,6 @@ | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h" | #include "mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h" | ||||
| using mindspore::kernel; | |||||
| using mindspore::lite; | |||||
| using mindspore; | |||||
| #define SAFE_DELETE_ARRAY(a) \ | #define SAFE_DELETE_ARRAY(a) \ | ||||
| if (a != nullptr) { \ | if (a != nullptr) { \ | ||||
| @@ -39,12 +36,12 @@ using mindspore; | |||||
| } | } | ||||
| namespace mindspore { | namespace mindspore { | ||||
| class TestConvolutionDwOpenCL : public UT::Common { | |||||
| class TestConvolutionDwOpenCL : public mindspore::Common { | |||||
| public: | public: | ||||
| TestConvolutionDwOpenCL(){} | TestConvolutionDwOpenCL(){} | ||||
| }; | }; | ||||
| void DepthWiseTestMain(const ConvParameter *conv_param, float_t *input_data, float_t *weight_data, float_t *gnd_data, | |||||
| void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t *weight_data, float_t *gnd_data, | |||||
| schema::Format format, bool is_compare = true) { | schema::Format format, bool is_compare = true) { | ||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | ||||
| ocl_runtime->Init(); | ocl_runtime->Init(); | ||||
| @@ -92,13 +89,13 @@ void DepthWiseTestMain(const ConvParameter *conv_param, float_t *input_data, flo | |||||
| inputs[1]->SetData(packed_weight); | inputs[1]->SetData(packed_weight); | ||||
| inputs[2]->SetData(bias_data); | inputs[2]->SetData(bias_data); | ||||
| OpParameter * parameter = conv_param; | |||||
| auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||||
| OpParameter * parameter = reinterpret_cast<OpParameter *>(conv_param); | |||||
| auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||||
| pKernel->Init(); | pKernel->Init(); | ||||
| std::vector<LiteKernel *> kernels{pKernel}; | |||||
| std::vector<kernel::LiteKernel *> kernels{pKernel}; | |||||
| std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | ||||
| auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||||
| pGraph->Init(); | pGraph->Init(); | ||||
| // freamework to do!!! | // freamework to do!!! | ||||
| @@ -141,7 +138,7 @@ void DepthWiseTestMain(const ConvParameter *conv_param, float_t *input_data, flo | |||||
| } | } | ||||
| std::cout << std::endl; | std::cout << std::endl; | ||||
| // compare | // compare | ||||
| UT::Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||||
| Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||||
| SAFE_DELETE_ARRAY(packed_correct_data) | SAFE_DELETE_ARRAY(packed_correct_data) | ||||
| } | } | ||||
| @@ -202,7 +199,7 @@ TEST_F(TestConvolutionDwOpenCL, NoPadNC4HW4Fp32) { | |||||
| 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; | 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; | ||||
| DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); | DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); | ||||
| opencl::OpenCLRuntime::DeleteInstance(); | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| TEST_F(TestConvolutionDwOpenCL, PadNC4HW4Fp32) { | TEST_F(TestConvolutionDwOpenCL, PadNC4HW4Fp32) { | ||||
| @@ -275,7 +272,7 @@ TEST_F(TestConvolutionDwOpenCL, PadNC4HW4Fp32) { | |||||
| 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; | 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; | ||||
| DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); | DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); | ||||
| opencl::OpenCLRuntime::DeleteInstance(); | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp32) { | TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp32) { | ||||
| @@ -321,7 +318,7 @@ TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp32) { | |||||
| 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; | 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; | ||||
| DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); | DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); | ||||
| opencl::OpenCLRuntime::DeleteInstance(); | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| TEST_F(TestConvolutionDwOpenCL, PadNHWC4Fp32) { | TEST_F(TestConvolutionDwOpenCL, PadNHWC4Fp32) { | ||||
| @@ -394,7 +391,7 @@ TEST_F(TestConvolutionDwOpenCL, PadNHWC4Fp32) { | |||||
| 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; | 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; | ||||
| DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); | DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); | ||||
| opencl::OpenCLRuntime::DeleteInstance(); | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| @@ -474,13 +471,13 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { | |||||
| inputs[1]->SetData(packed_weight); | inputs[1]->SetData(packed_weight); | ||||
| inputs[2]->SetData(bias_data); | inputs[2]->SetData(bias_data); | ||||
| OpParameter * parameter = conv_param; | |||||
| auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||||
| OpParameter * parameter = reinterpret_cast<OpParameter *>(conv_param); | |||||
| auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||||
| pKernel->Init(); | pKernel->Init(); | ||||
| std::vector<LiteKernel *> kernels{pKernel}; | |||||
| std::vector<kernel::LiteKernel *> kernels{pKernel}; | |||||
| std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | ||||
| auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||||
| pGraph->Init(); | pGraph->Init(); | ||||
| // freamework to do!!! | // freamework to do!!! | ||||
| @@ -517,7 +514,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { | |||||
| } | } | ||||
| std::cout << std::endl; | std::cout << std::endl; | ||||
| // compare | // compare | ||||
| CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||||
| Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||||
| for (auto tensor : inputs) { | for (auto tensor : inputs) { | ||||
| tensor->SetData(nullptr); | tensor->SetData(nullptr); | ||||
| @@ -530,7 +527,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { | |||||
| SAFE_DELETE_PTR(pKernel) | SAFE_DELETE_PTR(pKernel) | ||||
| SAFE_DELETE_PTR(pGraph) | SAFE_DELETE_PTR(pGraph) | ||||
| MS_LOG(INFO) << "TestConvolutionDwNoPadFp32 passed"; | MS_LOG(INFO) << "TestConvolutionDwNoPadFp32 passed"; | ||||
| opencl::OpenCLRuntime::DeleteInstance(); | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { | TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { | ||||
| @@ -637,13 +634,13 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { | |||||
| inputs[1]->SetData(packed_weight); | inputs[1]->SetData(packed_weight); | ||||
| inputs[2]->SetData(bias_data); | inputs[2]->SetData(bias_data); | ||||
| OpParameter * parameter = conv_param; | |||||
| auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||||
| OpParameter * parameter = reinterpret_cast<OpParameter *>(conv_param); | |||||
| auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||||
| pKernel->Init(); | pKernel->Init(); | ||||
| std::vector<LiteKernel *> kernels{pKernel}; | |||||
| std::vector<kernel::LiteKernel *> kernels{pKernel}; | |||||
| std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | ||||
| auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||||
| pGraph->Init(); | pGraph->Init(); | ||||
| // freamework to do!!! | // freamework to do!!! | ||||
| @@ -688,7 +685,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { | |||||
| } | } | ||||
| std::cout << std::endl; | std::cout << std::endl; | ||||
| // compare | // compare | ||||
| CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||||
| Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||||
| SAFE_DELETE_ARRAY(packed_input); | SAFE_DELETE_ARRAY(packed_input); | ||||
| SAFE_DELETE_ARRAY(packed_correct_data) | SAFE_DELETE_ARRAY(packed_correct_data) | ||||
| @@ -703,7 +700,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { | |||||
| SAFE_DELETE_PTR(pKernel) | SAFE_DELETE_PTR(pKernel) | ||||
| SAFE_DELETE_PTR(pGraph) | SAFE_DELETE_PTR(pGraph) | ||||
| MS_LOG(INFO) << "TestConvolutionDwPadFp32 passed"; | MS_LOG(INFO) << "TestConvolutionDwPadFp32 passed"; | ||||
| opencl::OpenCLRuntime::DeleteInstance(); | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { | TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { | ||||
| @@ -803,7 +800,7 @@ TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { | |||||
| } | } | ||||
| SAFE_DELETE_ARRAY(input_data); | SAFE_DELETE_ARRAY(input_data); | ||||
| SAFE_DELETE_ARRAY(weight_data); | SAFE_DELETE_ARRAY(weight_data); | ||||
| opencl::OpenCLRuntime::DeleteInstance(); | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,10 +22,6 @@ | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h" | #include "mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h" | ||||
| // using namespace mindspore::kernel; | |||||
| // using namespace mindspore::lite; | |||||
| // using namespace mindspore; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class TestMatMulOpenCL : public mindspore::Common { | class TestMatMulOpenCL : public mindspore::Common { | ||||
| public: | public: | ||||
| @@ -53,11 +49,11 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { | |||||
| lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, co}); | lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, co}); | ||||
| std::vector<lite::tensor::Tensor *> inputs{tensor_x, tensor_w}; | std::vector<lite::tensor::Tensor *> inputs{tensor_x, tensor_w}; | ||||
| std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | ||||
| auto *arith_kernel = new MatMulOpenCLKernel(nullptr, inputs, outputs, false); | |||||
| auto *arith_kernel = new kernel::MatMulOpenCLKernel(nullptr, inputs, outputs, false); | |||||
| arith_kernel->Init(); | arith_kernel->Init(); | ||||
| std::vector<LiteKernel *> kernels{arith_kernel}; | std::vector<LiteKernel *> kernels{arith_kernel}; | ||||
| auto *pGraph = new SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| pGraph->Init(); | pGraph->Init(); | ||||
| memcpy(inputs[0]->Data(), input_data, sizeof(float) * ci); | memcpy(inputs[0]->Data(), input_data, sizeof(float) * ci); | ||||
| @@ -22,10 +22,6 @@ | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h" | #include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h" | ||||
| // using namespace mindspore::kernel; | |||||
| // using namespace mindspore::lite; | |||||
| // using namespace mindspore; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class TestSoftmaxOpenCL : public mindspore::Common {}; | class TestSoftmaxOpenCL : public mindspore::Common {}; | ||||
| @@ -53,12 +49,12 @@ TEST_F(TestSoftmaxOpenCL, SoftmaxFp32) { | |||||
| std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | ||||
| MS_LOG(INFO) << "create OpenCL Kernel"; | MS_LOG(INFO) << "create OpenCL Kernel"; | ||||
| auto *Softmax_kernel = new SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| auto *Softmax_kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| Softmax_kernel->Init(); | Softmax_kernel->Init(); | ||||
| std::vector<LiteKernel *> kernels{Softmax_kernel}; | std::vector<LiteKernel *> kernels{Softmax_kernel}; | ||||
| MS_LOG(INFO) << "create SubGraphOpenCLKernel"; | MS_LOG(INFO) << "create SubGraphOpenCLKernel"; | ||||
| auto *pGraph = new SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| pGraph->Init(); | pGraph->Init(); | ||||
| MS_LOG(INFO) << "initialize data"; | MS_LOG(INFO) << "initialize data"; | ||||