| @@ -510,8 +510,12 @@ gene_ocl_program() { | |||
| build_opencl() { | |||
| cd ${BASEPATH} | |||
| git submodule update --init third_party/OpenCL-Headers | |||
| git submodule update --init third_party/OpenCL-CLHPP | |||
| if [[ ! -d "third_party/OpenCL-Headers" ]]; then | |||
| git submodule update --init third_party/OpenCL-Headers | |||
| fi | |||
| if [[ ! -d "third_party/OpenCL-CLHPP" ]]; then | |||
| git submodule update --init third_party/OpenCL-CLHPP | |||
| fi | |||
| if [[ "${OPENCL_OFFLINE_COMPILE}" == "on" ]]; then | |||
| gene_ocl_program | |||
| else | |||
| @@ -524,6 +528,7 @@ build_lite() | |||
| echo "start build mindspore lite project" | |||
| if [[ "${ENABLE_GPU}" == "on" ]]; then | |||
| echo "start build opencl" | |||
| build_opencl | |||
| fi | |||
| if [[ "${LITE_PLATFORM}" == "x86_64" ]]; then | |||
| @@ -49,7 +49,6 @@ endif () | |||
| if (SUPPORT_GPU) | |||
| add_definitions(-DUSE_OPENCL_WRAPPER) | |||
| add_definitions(-DMS_OPENCL_PROFILE=false) | |||
| add_definitions(-DCL_HPP_TARGET_OPENCL_VERSION=200) | |||
| add_compile_definitions(SUPPORT_GPU) | |||
| if(OFFLINE_COMPILE) | |||
| add_compile_definitions(PROGRAM_WITH_IL) | |||
| @@ -19,7 +19,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/opencl/kernel/concat.h" | |||
| #include "src/backend/opencl/cl/fp32/concat.cl.inc" | |||
| #include "src/runtime/kernel/opencl/cl/fp32/concat.cl.inc" | |||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -115,12 +115,16 @@ int GetBiggestDividerWithPriority(int number, int max_divider) { | |||
| return 1; | |||
| } | |||
| void ConcatGetWorkGroup(const std::vector<size_t> &global, const std::vector<size_t> &local, int max_size) { | |||
| void ConcatGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) { | |||
| int x = std::min(GetBiggestDividerWithPriority(global[0], 8), 4); | |||
| int yz = max_size / x; | |||
| int y = std::min(std::min(GetBiggestDividerWithPriority(global[1], 8), yz), 8); | |||
| int z = std::min(yz / y, DivideRoundUp(global[2], 2)); | |||
| local = {static_cast<unsigned int>(x), static_cast<unsigned int>(y), static_cast<unsigned int>(z)}; | |||
| local->clear(); | |||
| local->push_back(x); | |||
| local->push_back(y); | |||
| local->push_back(z); | |||
| } | |||
| int ConcatOpenCLKernel::Run() { | |||
| auto param = reinterpret_cast<ConcatParameter *>(this->opParameter); | |||
| @@ -144,7 +148,7 @@ int ConcatOpenCLKernel::Run() { | |||
| uint32_t OW = output_shape[2]; | |||
| uint32_t OC = output_shape[3]; | |||
| global = {OH, OW, OC}; // HWC | |||
| ConcatGetWorkGroup(global, local, 384); | |||
| ConcatGetWorkGroup(global, &local, 384); | |||
| std::cout << "local size=:" << std::endl; | |||
| for (int i = 0; i < local.size(); i++) { | |||
| std::cout << local[i] << " "; | |||
| @@ -174,7 +178,7 @@ int ConcatOpenCLKernel::Run() { | |||
| uint32_t OW = output_shape[2]; | |||
| uint32_t OC = output_shape[3]; | |||
| global = {OH, OW, OC}; // HWC | |||
| ConcatGetWorkGroup(global, local, 384); | |||
| ConcatGetWorkGroup(global, &local, 384); | |||
| std::cout << "local size=:" << std::endl; | |||
| for (int i = 0; i < local.size(); i++) { | |||
| std::cout << local[i] << " "; | |||
| @@ -196,8 +200,9 @@ int ConcatOpenCLKernel::Run() { | |||
| return 0; | |||
| } | |||
| kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<tensor::Tensor *> &inputs, | |||
| const std::vector<tensor::Tensor *> &outputs, OpParameter *opParameter, | |||
| kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, | |||
| const lite::Context *ctx, const kernel::KernelKey &desc) { | |||
| auto *kernel = new ConcatOpenCLKernel(opParameter, inputs, outputs); | |||
| auto ret = kernel->Init(); | |||
| @@ -20,24 +20,21 @@ | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "src/backend/arm/opclib/conv_parameter.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/backend/arm/opclib/concat.h" | |||
| #include "src/runtime/kernel/arm/base/concat_base.h" | |||
| namespace mindspore::kernel { | |||
| class ConcatOpenCLKernel : public LiteKernel { | |||
| public: | |||
| explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector<tensor::Tensor *> &inputs, | |||
| const std::vector<tensor::Tensor *> &outputs) | |||
| explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : LiteKernel(parameter, inputs, outputs) {} | |||
| ~ConcatOpenCLKernel() override{}; | |||
| int Init() override; | |||
| // int InferShape() { return {}; }; | |||
| int InferShape() {} | |||
| int ReSize() override; | |||
| int Run_axis0(); | |||
| @@ -39,7 +39,6 @@ class Conv2dTransposeOpenCLKernel : public LiteKernel { | |||
| ~Conv2dTransposeOpenCLKernel() override {}; | |||
| int Init() override; | |||
| int InferShape() {} | |||
| int ReSize() override; | |||
| int Run() override; | |||
| void PadWeight(); | |||
| @@ -41,7 +41,6 @@ class MatMulOpenCLKernel : public LiteKernel { | |||
| ~MatMulOpenCLKernel() override{}; | |||
| int Init() override; | |||
| int InferShape() {} | |||
| int ReSize() override; | |||
| int Run() override; | |||
| void PadWeight(); | |||
| @@ -265,10 +265,10 @@ endif() | |||
| if (SUPPORT_GPU) | |||
| set(TEST_SRC | |||
| ${TEST_SRC} | |||
| ${TEST_DIR}/ut/stc/runtime/kernel/opencl/matmul_tests.cc | |||
| ${TEST_DIR}/ut/stc/runtime/kernel/opencl/depthwise_conv2d_tests.cc | |||
| ${TEST_DIR}/ut/stc/runtime/kernel/opencl/concat_tests.cc | |||
| ${TEST_DIR}/ut/stc/runtime/kernel/opencl/softmax_cl_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/matmul_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_cl_tests.cc | |||
| ) | |||
| endif() | |||
| @@ -18,12 +18,10 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/backend/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/backend/opencl/kernel/concat.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h" | |||
| using mindspore::kernel; | |||
| using mindspore::lite; | |||
| using mindspore; | |||
| int DivideRoundUp(int n, int div) { | |||
| int q = n / div; | |||
| return n % div == 0 ? q : q + 1; | |||
| @@ -96,7 +94,7 @@ void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *i | |||
| } | |||
| namespace mindspore { | |||
| class TestConcatOpenCL : public UT::Common { | |||
| class TestConcatOpenCL : public mindspore::Common { | |||
| public: | |||
| TestConcatOpenCL(){} | |||
| }; | |||
| @@ -113,30 +111,31 @@ TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { | |||
| output_shape[3] = DivideRoundUp(output_shape[3], 4) * 4; | |||
| auto data_type = kNumberTypeFloat32; | |||
| auto tensor_type = schema::NodeType_ValueNode; | |||
| std::vector<tensor::Tensor *> inputs; | |||
| std::vector<lite::tensor::Tensor *> inputs; | |||
| for (auto &shape : input_shapes) { | |||
| inputs.push_back(new tensor::Tensor(data_type, shape, schema::Format_NHWC, tensor_type)); | |||
| inputs.push_back(new lite::tensor::Tensor(data_type, shape, schema::Format_NHWC, tensor_type)); | |||
| } | |||
| auto *output_tensor = new tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); | |||
| std::vector<tensor::Tensor *> outputs{output_tensor}; | |||
| auto *output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); | |||
| std::vector<lite::tensor::Tensor *> outputs{output_tensor}; | |||
| std::cout << "input_shapes size=: " << input_shapes.size() << std::endl; | |||
| MS_LOG(INFO) << "initialize tensors"; | |||
| auto param = new ConcatParameter(); | |||
| param->axis_ = 3; | |||
| auto *concat_kernel = new ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| auto *concat_kernel = new kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| concat_kernel->Init(); | |||
| MS_LOG(INFO) << "initialize sub_graph"; | |||
| std::vector<LiteKernel *> kernels{concat_kernel}; | |||
| auto *sub_graph = new SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| std::vector<kernel::LiteKernel *> kernels{concat_kernel}; | |||
| auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| sub_graph->Init(); | |||
| MS_LOG(INFO) << "initialize input data"; | |||
| srand(time(NULL)); | |||
| for (auto &input_tensor : inputs) { | |||
| auto input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||
| static unsigned int seed = 123; | |||
| for (int i = 0; i < input_tensor->ElementsNum(); ++i) { | |||
| input_data[i] = static_cast<float>(rand_r() % 10 + 1); | |||
| input_data[i] = static_cast<float>(rand_r(&seed) % 10 + 1); | |||
| } | |||
| printf("\n"); | |||
| } | |||
| @@ -23,9 +23,6 @@ | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h" | |||
| using mindspore::kernel; | |||
| using mindspore::lite; | |||
| using mindspore; | |||
| #define SAFE_DELETE_ARRAY(a) \ | |||
| if (a != nullptr) { \ | |||
| @@ -39,12 +36,12 @@ using mindspore; | |||
| } | |||
| namespace mindspore { | |||
| class TestConvolutionDwOpenCL : public UT::Common { | |||
| class TestConvolutionDwOpenCL : public mindspore::Common { | |||
| public: | |||
| TestConvolutionDwOpenCL(){} | |||
| }; | |||
| void DepthWiseTestMain(const ConvParameter *conv_param, float_t *input_data, float_t *weight_data, float_t *gnd_data, | |||
| void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t *weight_data, float_t *gnd_data, | |||
| schema::Format format, bool is_compare = true) { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| @@ -92,13 +89,13 @@ void DepthWiseTestMain(const ConvParameter *conv_param, float_t *input_data, flo | |||
| inputs[1]->SetData(packed_weight); | |||
| inputs[2]->SetData(bias_data); | |||
| OpParameter * parameter = conv_param; | |||
| auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||
| OpParameter * parameter = reinterpret_cast<OpParameter *>(conv_param); | |||
| auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||
| pKernel->Init(); | |||
| std::vector<LiteKernel *> kernels{pKernel}; | |||
| std::vector<kernel::LiteKernel *> kernels{pKernel}; | |||
| std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | |||
| auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||
| pGraph->Init(); | |||
| // freamework to do!!! | |||
| @@ -141,7 +138,7 @@ void DepthWiseTestMain(const ConvParameter *conv_param, float_t *input_data, flo | |||
| } | |||
| std::cout << std::endl; | |||
| // compare | |||
| UT::Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||
| Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||
| SAFE_DELETE_ARRAY(packed_correct_data) | |||
| } | |||
| @@ -202,7 +199,7 @@ TEST_F(TestConvolutionDwOpenCL, NoPadNC4HW4Fp32) { | |||
| 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; | |||
| DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); | |||
| opencl::OpenCLRuntime::DeleteInstance(); | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| TEST_F(TestConvolutionDwOpenCL, PadNC4HW4Fp32) { | |||
| @@ -275,7 +272,7 @@ TEST_F(TestConvolutionDwOpenCL, PadNC4HW4Fp32) { | |||
| 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; | |||
| DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); | |||
| opencl::OpenCLRuntime::DeleteInstance(); | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp32) { | |||
| @@ -321,7 +318,7 @@ TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp32) { | |||
| 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; | |||
| DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); | |||
| opencl::OpenCLRuntime::DeleteInstance(); | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| TEST_F(TestConvolutionDwOpenCL, PadNHWC4Fp32) { | |||
| @@ -394,7 +391,7 @@ TEST_F(TestConvolutionDwOpenCL, PadNHWC4Fp32) { | |||
| 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; | |||
| DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); | |||
| opencl::OpenCLRuntime::DeleteInstance(); | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| @@ -474,13 +471,13 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { | |||
| inputs[1]->SetData(packed_weight); | |||
| inputs[2]->SetData(bias_data); | |||
| OpParameter * parameter = conv_param; | |||
| auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||
| OpParameter * parameter = reinterpret_cast<OpParameter *>(conv_param); | |||
| auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||
| pKernel->Init(); | |||
| std::vector<LiteKernel *> kernels{pKernel}; | |||
| std::vector<kernel::LiteKernel *> kernels{pKernel}; | |||
| std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | |||
| auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||
| pGraph->Init(); | |||
| // freamework to do!!! | |||
| @@ -517,7 +514,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { | |||
| } | |||
| std::cout << std::endl; | |||
| // compare | |||
| CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||
| Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||
| for (auto tensor : inputs) { | |||
| tensor->SetData(nullptr); | |||
| @@ -530,7 +527,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { | |||
| SAFE_DELETE_PTR(pKernel) | |||
| SAFE_DELETE_PTR(pGraph) | |||
| MS_LOG(INFO) << "TestConvolutionDwNoPadFp32 passed"; | |||
| opencl::OpenCLRuntime::DeleteInstance(); | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { | |||
| @@ -637,13 +634,13 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { | |||
| inputs[1]->SetData(packed_weight); | |||
| inputs[2]->SetData(bias_data); | |||
| OpParameter * parameter = conv_param; | |||
| auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||
| OpParameter * parameter = reinterpret_cast<OpParameter *>(conv_param); | |||
| auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); | |||
| pKernel->Init(); | |||
| std::vector<LiteKernel *> kernels{pKernel}; | |||
| std::vector<kernel::LiteKernel *> kernels{pKernel}; | |||
| std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | |||
| auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||
| pGraph->Init(); | |||
| // freamework to do!!! | |||
| @@ -688,7 +685,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { | |||
| } | |||
| std::cout << std::endl; | |||
| // compare | |||
| CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||
| Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||
| SAFE_DELETE_ARRAY(packed_input); | |||
| SAFE_DELETE_ARRAY(packed_correct_data) | |||
| @@ -703,7 +700,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { | |||
| SAFE_DELETE_PTR(pKernel) | |||
| SAFE_DELETE_PTR(pGraph) | |||
| MS_LOG(INFO) << "TestConvolutionDwPadFp32 passed"; | |||
| opencl::OpenCLRuntime::DeleteInstance(); | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { | |||
| @@ -803,7 +800,7 @@ TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { | |||
| } | |||
| SAFE_DELETE_ARRAY(input_data); | |||
| SAFE_DELETE_ARRAY(weight_data); | |||
| opencl::OpenCLRuntime::DeleteInstance(); | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -22,10 +22,6 @@ | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h" | |||
| // using namespace mindspore::kernel; | |||
| // using namespace mindspore::lite; | |||
| // using namespace mindspore; | |||
| namespace mindspore { | |||
| class TestMatMulOpenCL : public mindspore::Common { | |||
| public: | |||
| @@ -53,11 +49,11 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { | |||
| lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, co}); | |||
| std::vector<lite::tensor::Tensor *> inputs{tensor_x, tensor_w}; | |||
| std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | |||
| auto *arith_kernel = new MatMulOpenCLKernel(nullptr, inputs, outputs, false); | |||
| auto *arith_kernel = new kernel::MatMulOpenCLKernel(nullptr, inputs, outputs, false); | |||
| arith_kernel->Init(); | |||
| std::vector<LiteKernel *> kernels{arith_kernel}; | |||
| auto *pGraph = new SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| pGraph->Init(); | |||
| memcpy(inputs[0]->Data(), input_data, sizeof(float) * ci); | |||
| @@ -22,10 +22,6 @@ | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h" | |||
| // using namespace mindspore::kernel; | |||
| // using namespace mindspore::lite; | |||
| // using namespace mindspore; | |||
| namespace mindspore { | |||
| class TestSoftmaxOpenCL : public mindspore::Common {}; | |||
| @@ -53,12 +49,12 @@ TEST_F(TestSoftmaxOpenCL, SoftmaxFp32) { | |||
| std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | |||
| MS_LOG(INFO) << "create OpenCL Kernel"; | |||
| auto *Softmax_kernel = new SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| auto *Softmax_kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| Softmax_kernel->Init(); | |||
| std::vector<LiteKernel *> kernels{Softmax_kernel}; | |||
| MS_LOG(INFO) << "create SubGraphOpenCLKernel"; | |||
| auto *pGraph = new SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| pGraph->Init(); | |||
| MS_LOG(INFO) << "initialize data"; | |||