| @@ -0,0 +1,81 @@ | |||
| #define INT2 int2 | |||
| #define INT4 int4 | |||
| #define FLT4 float4 | |||
| __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; | |||
| __kernel void slice(__read_only image2d_t input, __write_only image2d_t output, INT4 input_shape, INT4 out_shape, | |||
| INT4 begin, INT2 sharedNoUpdiv) { | |||
| int X = get_global_id(1); // H | |||
| int Y = get_global_id(2); // W | |||
| if (X >= out_shape.y || Y >= out_shape.z) { | |||
| return; | |||
| } | |||
| FLT4 result; | |||
| if (sharedNoUpdiv.x % 4 == 0) { | |||
| for (int i = 0; i < out_shape.w; i++) { | |||
| result = read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (i + begin.w), (X + begin.y))); | |||
| write_imagef(output, (INT2)((Y)*out_shape.w + i, (X)), result); | |||
| } | |||
| } else { | |||
| int begin_postion = sharedNoUpdiv.y % 4; | |||
| FLT4 first = read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + begin.w, (X + begin.y))); | |||
| if (begin_postion == 1) { | |||
| for (int i = 1; i <= out_shape.w; i++) { | |||
| FLT4 second = | |||
| read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y))); | |||
| result.x = first.y; | |||
| result.y = first.z; | |||
| result.z = first.w; | |||
| result.w = second.x; | |||
| write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result); | |||
| first.y = second.y; | |||
| first.z = second.z; | |||
| first.w = second.w; | |||
| } | |||
| } else if (begin_postion == 2) { | |||
| for (int i = 1; i <= out_shape.w; i++) { | |||
| FLT4 second = | |||
| read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y))); | |||
| result.x = first.z; | |||
| result.y = first.w; | |||
| result.z = second.x; | |||
| result.w = second.y; | |||
| write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result); | |||
| first.z = second.z; | |||
| first.w = second.w; | |||
| } | |||
| } else { | |||
| for (int i = 1; i <= out_shape.w; i++) { | |||
| FLT4 second = | |||
| read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y))); | |||
| result.x = first.w; | |||
| result.y = second.x; | |||
| result.z = second.y; | |||
| result.w = second.z; | |||
| write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result); | |||
| first.w = second.w; | |||
| } | |||
| } | |||
| } | |||
| // judge the line of size | |||
| int size = sharedNoUpdiv.y % 4; | |||
| FLT4 result_fill0; | |||
| if (size == 1) { | |||
| result_fill0.x = result.x; | |||
| result_fill0.y = 0; | |||
| result_fill0.z = 0; | |||
| result_fill0.w = 0; | |||
| write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0); | |||
| } else if (size == 2) { | |||
| result_fill0.x = result.x; | |||
| result_fill0.y = result.y; | |||
| result_fill0.z = 0; | |||
| result_fill0.w = 0; | |||
| write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0); | |||
| } else if (size == 3) { | |||
| result_fill0.x = result.x; | |||
| result_fill0.y = result.y; | |||
| result_fill0.z = result.z; | |||
| result_fill0.w = 0; | |||
| write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0); | |||
| } | |||
| } | |||
| @@ -137,7 +137,7 @@ kernel::LiteKernel *OpenCLBatchnormKernelCreator(const std::vector<lite::tensor: | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (0 != ret) { | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: Convolution"; | |||
| delete kernel; | |||
| return nullptr; | |||
| @@ -214,7 +214,7 @@ kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<lite::tensor::Te | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (0 != ret) { | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: Convolution"; | |||
| delete kernel; | |||
| return nullptr; | |||
| @@ -273,7 +273,7 @@ int ConvolutionOpenCLKernel::Run() { | |||
| } | |||
| if (use_winograd_) { | |||
| ocl_runtime->RunKernel(kernel_4x4to36, {size_t(TILES_XY), 6, size_t(CI_SLICES)}, {16, 6, 4}, nullptr); | |||
| ocl_runtime->RunKernel(kernel_4x4to36, {size_t(TILES_XY), 6, size_t(CI_SLICES)}, {8, 6, 4}, nullptr); | |||
| ocl_runtime->RunKernel(kernel_conv, {size_t(TILES_XY / 2), 36, size_t(CO_SLICES / 2)}, {8, 6, 2}, nullptr); | |||
| ocl_runtime->RunKernel(kernel_36to4x4, {size_t(TILES_XY), 4, size_t(CO_SLICES)}, {32, 4, 2}, nullptr); | |||
| } else { | |||
| @@ -674,7 +674,7 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector<lite::tenso | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (0 != ret) { | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: Convolution"; | |||
| delete kernel; | |||
| return nullptr; | |||
| @@ -0,0 +1,144 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cstring> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <set> | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/opencl/kernel/slice.h" | |||
| #include "src/runtime/kernel/opencl/cl/slice.cl.inc" | |||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::schema::PrimitiveType_Slice; | |||
| namespace mindspore::kernel { | |||
| int SliceOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM); | |||
| size_t im_dst_x, im_dst_y; | |||
| if (in_tensors_[0]->GetFormat() == schema::Format_NHWC4) { | |||
| im_dst_x = out_tensors_[0]->Width() * CO4; | |||
| im_dst_y = out_tensors_[0]->Height(); | |||
| } else { | |||
| im_dst_y = out_tensors_[0]->Height() * CO4; | |||
| im_dst_x = out_tensors_[0]->Width(); | |||
| } | |||
| #ifdef ENABLE_FP16 | |||
| size_t img_dtype = CL_HALF_FLOAT; | |||
| #else | |||
| size_t img_dtype = CL_FLOAT; | |||
| #endif | |||
| img_size->clear(); | |||
| std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype}; | |||
| *img_size = vec; | |||
| return RET_OK; | |||
| } | |||
| int SliceOpenCLKernel::Init() { | |||
| std::set<std::string> build_options; | |||
| std::string source = slice_source; | |||
| std::string program_name = "slice"; | |||
| std::string kernel_name = "slice"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->LoadSource(program_name, source); | |||
| ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||
| ori_format_ = out_tensors_[0]->GetFormat(); | |||
| out_tensors_[0]->SetFormat(schema::Format_NHWC4); | |||
| return RET_OK; | |||
| } | |||
| int SliceOpenCLKernel::ReSize() { return RET_OK; } | |||
| int SliceGetBiggestDividerWithPriority(int number, int max_divider) { | |||
| if (number % 8 == 0 && 8 <= max_divider) { | |||
| return number / 8; | |||
| } else if (number % 4 == 0 && 4 <= max_divider) { | |||
| return number / 4; | |||
| } else if (number % 2 == 0 && 2 <= max_divider) { | |||
| return number / 2; | |||
| } | |||
| for (int i = max_divider; i != 0; i--) { | |||
| if (number % i == 0) { | |||
| return i; | |||
| } | |||
| } | |||
| return 1; | |||
| } | |||
| void SlcieGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) { | |||
| const int max_divider = 8; | |||
| const int max_x = 4, max_y = 8; | |||
| int x = std::min(SliceGetBiggestDividerWithPriority(global[0], max_divider), max_x); | |||
| int yz = max_size / x; | |||
| int y = std::min(std::min(SliceGetBiggestDividerWithPriority(global[1], max_divider), yz), max_y); | |||
| int z = std::min(yz / y, static_cast<int>(UP_DIV(global[2], 2))); | |||
| local->clear(); | |||
| local->push_back(x); | |||
| local->push_back(y); | |||
| local->push_back(z); | |||
| } | |||
| int SliceOpenCLKernel::Run() { | |||
| MS_LOG(DEBUG) << this->name() << " Running!"; | |||
| auto param = reinterpret_cast<SliceParameter *>(this->op_parameter_); | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| auto input_shape = in_tensors_[0]->shape(); | |||
| cl_int4 input_shape_ = {input_shape[0], input_shape[1], input_shape[2], UP_DIV(input_shape[3], C4NUM)}; | |||
| cl_int4 size_ = {param->size_[0], param->size_[1], param->size_[2], UP_DIV(param->size_[3], C4NUM)}; | |||
| cl_int4 begin_ = {param->begin_[0], param->begin_[1], param->begin_[2], param->begin_[3] / 4}; | |||
| cl_int2 sharedNoUpdiv = {param->begin_[3], param->size_[3]}; | |||
| uint32_t OH = param->size_[1]; | |||
| uint32_t OW = param->size_[2]; | |||
| const std::vector<size_t> &max_global = ocl_runtime->GetWorkItemSize(); | |||
| std::vector<size_t> local = {1, 1, 1}; // init local | |||
| std::vector<size_t> global = {1, OH, OW}; | |||
| SlcieGetWorkGroup(global, &local, max_global[0]); | |||
| int arg_cn = 0; | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->Data()); // input tensor | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->Data()); // out tensor | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, size_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, begin_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, sharedNoUpdiv); | |||
| ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||
| return RET_OK; | |||
| } // namespace mindspore::kernel | |||
| kernel::LiteKernel *OpenCLSliceKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| auto *kernel = new (std::nothrow) SliceOpenCLKernel(opParameter, inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new SliceOpenCLKernel failed"; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: Convolution"; | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Slice, OpenCLSliceKernelCreator); | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_SLICE_H_ | |||
| #define MINDSPORE_LITE_SRC_BACKEND_OPENCL_SLICE_H_ | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/slice.h" | |||
| namespace mindspore::kernel { | |||
| class SliceOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| explicit SliceOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| ~SliceOpenCLKernel() override{}; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size) override; | |||
| private: | |||
| cl::Kernel kernel_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif | |||
| @@ -46,8 +46,10 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor * | |||
| } | |||
| if (mem_type == OpenCLMemType::IMG) { | |||
| jv->set_in_tensors({}); | |||
| jv->SetInKernel({}); | |||
| } else { | |||
| jv->set_out_tensors({}); | |||
| jv->SetOutKernel({}); | |||
| } | |||
| } | |||
| } | |||
| @@ -120,13 +122,21 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor * | |||
| if (mem_type == OpenCLMemType::IMG) { | |||
| for (auto &iv : in_kernels[i]) { | |||
| in_opencl_op->AddOutKernel(iv); | |||
| reinterpret_cast<OpenCLKernel *>(iv)->SetInKernel({in_convert_op}); | |||
| reinterpret_cast<OpenCLKernel *>(iv)->set_in_tensors({new_tensor}); | |||
| auto kernels = iv->in_kernels(); | |||
| kernels.emplace_back(in_convert_op); | |||
| iv->SetInKernel(kernels); | |||
| auto tensors = iv->in_tensors(); | |||
| tensors.emplace_back(new_tensor); | |||
| iv->set_in_tensors(tensors); | |||
| } | |||
| } else { | |||
| for (auto &iv : in_kernels[i]) { | |||
| reinterpret_cast<OpenCLKernel *>(iv)->SetOutKernel({in_convert_op}); | |||
| reinterpret_cast<OpenCLKernel *>(iv)->set_out_tensors({new_tensor}); | |||
| auto kernels = iv->out_kernels(); | |||
| kernels.emplace_back(in_convert_op); | |||
| iv->SetOutKernel(kernels); | |||
| auto tensors = iv->out_tensors(); | |||
| tensors.emplace_back(new_tensor); | |||
| iv->set_out_tensors(tensors); | |||
| in_convert_op->AddInKernel(iv); | |||
| } | |||
| } | |||
| @@ -145,6 +145,7 @@ if (SUPPORT_GPU) | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/batchnorm.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/slice.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/activation.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc | |||
| @@ -318,6 +319,7 @@ if (SUPPORT_GPU) | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/batchnorm_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/slice_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/arithmetic_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc | |||
| @@ -43,8 +43,8 @@ TEST_F(TestBatchnormOpenCL, Batchnorminput_dim4) { | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| MS_LOG(INFO) << "Read tensors from .bin"; | |||
| std::vector<int> input_shape = {1, 256, 256, 48}; | |||
| std::vector<int> output_shape = {1, 256, 256, 48}; | |||
| std::vector<int> input_shape = {1, 256, 256, 16}; | |||
| std::vector<int> output_shape = {1, 256, 256, 16}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| auto tensor_type = schema::NodeType_ValueNode; | |||
| @@ -0,0 +1,149 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include "utils/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/common/file_utils.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/slice.h" | |||
| namespace mindspore { | |||
| class TestSliceOpenCL : public mindspore::CommonTest { | |||
| public: | |||
| TestSliceOpenCL() {} | |||
| }; | |||
| template <typename T> | |||
| void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bound) { | |||
| for (size_t i = 0; i < size; i++) { | |||
| T abs = fabs(output_data[i] - correct_data[i]); | |||
| ASSERT_LE(abs, err_bound); | |||
| } | |||
| } | |||
| TEST_F(TestSliceOpenCL, Sliceinput_dim4) { | |||
| MS_LOG(INFO) << "begin test"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| MS_LOG(INFO) << "Read tensors from .bin"; | |||
| std::vector<int> input_shape = {1, 256, 256, 48}; | |||
| std::vector<int> output_shape = {1, 255, 255, 15}; | |||
| std::vector<int> begin = {0, 1, 1, 7}; | |||
| std::vector<int> size = {1, 255, 255, 15}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| auto tensor_type = schema::NodeType_ValueNode; | |||
| // get the input from .bin | |||
| size_t input_size, output_size; | |||
| std::string input_path = "./test_data/in_data.bin"; | |||
| std::string output_path = "./test_data/out_data.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); | |||
| MS_LOG(INFO) << "construct tensors"; | |||
| lite::tensor::Tensor *tensor_data = | |||
| new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type); | |||
| if (tensor_data == nullptr) { | |||
| MS_LOG(INFO) << "init tensor failed"; | |||
| return; | |||
| } | |||
| auto *output_tensor = | |||
| new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type); | |||
| if (output_tensor == nullptr) { | |||
| delete tensor_data; | |||
| MS_LOG(INFO) << "init tensor failed"; | |||
| return; | |||
| } | |||
| std::vector<lite::tensor::Tensor *> inputs = {tensor_data}; | |||
| std::vector<lite::tensor::Tensor *> outputs = {output_tensor}; | |||
| MS_LOG(INFO) << "setting SliceParameter"; | |||
| auto param = new (std::nothrow) SliceParameter(); | |||
| if (param == nullptr) { | |||
| for (auto tensor : inputs) { | |||
| delete tensor; | |||
| } | |||
| for (auto tensor : outputs) { | |||
| delete tensor; | |||
| } | |||
| MS_LOG(INFO) << "new SliceParameter failed"; | |||
| return; | |||
| } | |||
| for (int i = 0; i < 4; i++) { | |||
| param->begin_[i] = begin[i]; | |||
| param->size_[i] = size[i]; | |||
| } | |||
| auto *slice_kernel = | |||
| new (std::nothrow) kernel::SliceOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (slice_kernel == nullptr) { | |||
| for (auto tensor : inputs) { | |||
| delete tensor; | |||
| } | |||
| for (auto tensor : outputs) { | |||
| delete tensor; | |||
| } | |||
| delete param; | |||
| MS_LOG(INFO) << "new kernel::slice_kernel failed"; | |||
| return; | |||
| } | |||
| slice_kernel->Init(); | |||
| // to do allocate memory for inputs and outputs | |||
| for (auto &input_tensor : inputs) { | |||
| input_tensor->MallocData(allocator); | |||
| } | |||
| MS_LOG(INFO) << "initialize sub_graph"; | |||
| std::vector<kernel::LiteKernel *> kernels{slice_kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| for (auto tensor : inputs) { | |||
| delete tensor; | |||
| } | |||
| for (auto tensor : outputs) { | |||
| delete tensor; | |||
| } | |||
| delete param; | |||
| delete slice_kernel; | |||
| MS_LOG(INFO) << "new kernel::SubGraphOpenCLKernel failed"; | |||
| return; | |||
| } | |||
| sub_graph->Init(); | |||
| MS_LOG(INFO) << "init tensors"; | |||
| memcpy(inputs[0]->Data(), input_data, input_size); | |||
| std::cout << "==================output data================" << std::endl; | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data()); | |||
| CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); | |||
| for (auto tensor : inputs) { | |||
| delete tensor; | |||
| } | |||
| for (auto tensor : outputs) { | |||
| delete tensor; | |||
| } | |||
| delete slice_kernel; | |||
| delete sub_graph; | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| } // namespace mindspore | |||