| @@ -0,0 +1,44 @@ | |||||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||||
| __kernel void batch_to_space_nd_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 src_size, | |||||
| int4 dst_size, int2 block_size, int4 paddings) { | |||||
| int X = get_global_id(0); // c | |||||
| int Y = get_global_id(1); // w | |||||
| int Z = get_global_id(2); // h*n | |||||
| if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) { | |||||
| return; | |||||
| } | |||||
| for (int i = 0; i < block_size.x; ++i) { | |||||
| for (int j = 0; j < block_size.y; ++j) { | |||||
| int Y_dst = (Y * block_size.y + j); | |||||
| int Z_dst = Z * block_size.x + i; | |||||
| int Y_org = (Y_dst + paddings.z) / block_size.y; | |||||
| int Z_org = (Z_dst + paddings.x) / block_size.x; | |||||
| int Z_com = (i * block_size.y + j) * src_size.z + Z_org; | |||||
| FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); | |||||
| res_data = READ_IMAGE(src_data, smp_zero, (int2)(Y_org * dst_size.x + X, Z_com)); | |||||
| WRITE_IMAGE(dst_data, (int2)((Y * block_size.y + j) * dst_size.x + X, Z * block_size.x + i), res_data); | |||||
| } | |||||
| } | |||||
| } | |||||
| __kernel void batch_to_space_nd_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 src_size, | |||||
| int4 dst_size, int2 block_size, int4 paddings) { | |||||
| int X = get_global_id(0); // c | |||||
| int Y = get_global_id(1); // w | |||||
| int Z = get_global_id(2); // h*n | |||||
| if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) { | |||||
| return; | |||||
| } | |||||
| for (int i = 0; i < block_size.x; ++i) { | |||||
| for (int j = 0; j < block_size.y; ++j) { | |||||
| int Y_dst = (Y * block_size.y + j); | |||||
| int Z_dst = Z * block_size.x + i; | |||||
| int Y_org = (Y_dst + paddings.z) / block_size.y; | |||||
| int Z_org = (Z_dst + paddings.x) / block_size.x; | |||||
| int Z_com = (i * block_size.y + j) * src_size.z + Z_org; | |||||
| FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); | |||||
| res_data = READ_IMAGE(src_data, smp_zero, (int2)(Y_org * dst_size.x + X, Z_com)); | |||||
| WRITE_IMAGE(dst_data, (int2)((Y * block_size.y + j) * dst_size.x + X, Z * block_size.x + i), res_data); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,149 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cstring> | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <set> | |||||
| #include <utility> | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/kernel/opencl/kernel/batch_to_space_nd.h" | |||||
| #include "src/runtime/kernel/opencl/cl/batch_to_space_nd.cl.inc" | |||||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_BatchToSpace; | |||||
| using mindspore::schema::PrimitiveType_BatchToSpaceND; | |||||
| namespace mindspore::kernel { | |||||
| int BatchToSpaceNDOpenCLKernel::Init() { | |||||
| std::string kernel_name = "batch_to_space_nd"; | |||||
| auto in_format = op_format_; | |||||
| if (in_tensors_[0]->shape().size() != 4 && out_tensors_[0]->shape().size() != 4) { | |||||
| MS_LOG(ERROR) << "input/output shape size must be 4, actual: " << in_tensors_[0]->shape().size() << ", " | |||||
| << out_tensors_[0]->shape().size(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) { | |||||
| MS_LOG(ERROR) << "input format(" << in_format << ") " | |||||
| << "format not support!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_); | |||||
| if (param->block_shape_[0] < 1 || param->block_shape_[1] < 1) { | |||||
| MS_LOG(ERROR) << "block_sizes_ must > 1, actual " << param->block_shape_[0] << ", " << param->block_shape_[1]; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (in_tensors_[0]->shape()[kNHWC_H] * param->block_shape_[0] <= (param->crops_[0] + param->crops_[1]) || | |||||
| in_tensors_[0]->shape()[kNHWC_W] * param->block_shape_[1] <= (param->crops_[2] + param->crops_[3])) { | |||||
| MS_LOG(ERROR) << "crop shape error!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| in_ori_format_ = in_tensors_[0]->GetFormat(); | |||||
| out_ori_format_ = out_tensors_[0]->GetFormat(); | |||||
| in_tensors_[0]->SetFormat(op_format_); | |||||
| out_tensors_[0]->SetFormat(op_format_); | |||||
| #ifdef PROGRAM_WITH_IL | |||||
| kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name); | |||||
| #else | |||||
| if (in_format == schema::Format_NC4HW4) { | |||||
| kernel_name += "_NC4HW4"; | |||||
| } else { | |||||
| kernel_name += "_NHWC4"; | |||||
| } | |||||
| std::set<std::string> build_options; | |||||
| std::string source = batch_to_space_nd_source; | |||||
| std::string program_name = "batch_to_space_nd"; | |||||
| ocl_runtime_->LoadSource(program_name, source); | |||||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||||
| #endif | |||||
| return RET_OK; | |||||
| } | |||||
| int BatchToSpaceNDOpenCLKernel::InitBuffer() { return RET_OK; } | |||||
| int BatchToSpaceNDOpenCLKernel::ReSize() { return RET_OK; } | |||||
| int BatchToSpaceNDOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||||
| size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM); | |||||
| size_t im_dst_x, im_dst_y; | |||||
| if (in_tensors_[0]->GetFormat() == schema::Format::Format_NHWC4) { | |||||
| im_dst_x = out_tensors_[0]->Width() * CO4; | |||||
| im_dst_y = out_tensors_[0]->Height() * out_tensors_[0]->Batch(); | |||||
| } else { | |||||
| im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * CO4; | |||||
| im_dst_x = out_tensors_[0]->Width(); | |||||
| } | |||||
| size_t img_dtype = CL_FLOAT; | |||||
| auto enable_fp16_ = ocl_runtime_->GetFp16Enable(); | |||||
| if (enable_fp16_) { | |||||
| img_dtype = CL_HALF_FLOAT; | |||||
| } | |||||
| img_size->clear(); | |||||
| std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype}; | |||||
| *img_size = std::move(vec); | |||||
| return RET_OK; | |||||
| } | |||||
| int BatchToSpaceNDOpenCLKernel::Run() { | |||||
| MS_LOG(DEBUG) << this->name() << " Running! "; | |||||
| auto param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_); | |||||
| size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM); | |||||
| size_t CI4 = UP_DIV(in_tensors_[0]->Channel(), C4NUM); | |||||
| cl_int4 src_size = { | |||||
| (cl_int)CI4, in_tensors_[0]->Width(), | |||||
| in_tensors_[0]->Height() * in_tensors_[0]->Batch() / param->block_shape_[0] / param->block_shape_[1], 1}; | |||||
| cl_int4 dst_size = {(cl_int)CO4, out_tensors_[0]->Width() / param->block_shape_[1], | |||||
| out_tensors_[0]->Height() / param->block_shape_[0] * out_tensors_[0]->Batch(), 1}; | |||||
| cl_int2 block_size = {param->block_shape_[0], param->block_shape_[1]}; | |||||
| cl_int4 paddings = {param->crops_[0], param->crops_[1], param->crops_[2], param->crops_[3]}; | |||||
| std::vector<size_t> local = {1, 1, 1}; | |||||
| std::vector<size_t> global = {(size_t)dst_size.s[0], (size_t)dst_size.s[1], (size_t)dst_size.s[2]}; | |||||
| int arg_cn = 0; | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c(), lite::opencl::MemType::IMG); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::IMG); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, src_size); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, dst_size); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, block_size); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, paddings); | |||||
| ocl_runtime_->RunKernel(kernel_, global, local, nullptr); | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *OpenCLBatchToSpaceNDKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::InnerContext *ctx, | |||||
| const kernel::KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto *kernel = new (std::nothrow) BatchToSpaceNDOpenCLKernel(opParameter, inputs, outputs); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Kernel " << opParameter->name_ << " new failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Kernel " << opParameter->name_ << " init failed."; | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BatchToSpaceND, OpenCLBatchToSpaceNDKernelCreator); | |||||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_BatchToSpaceND, OpenCLBatchToSpaceNDKernelCreator); | |||||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, OpenCLBatchToSpaceNDKernelCreator); | |||||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_BatchToSpace, OpenCLBatchToSpaceNDKernelCreator); | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BATCH_TO_SPACE_ND_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BATCH_TO_SPACE_ND_H_ | |||||
| #include <vector> | |||||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||||
| #include "nnacl/batch_to_space.h" | |||||
| namespace mindspore::kernel { | |||||
| class BatchToSpaceNDOpenCLKernel : public OpenCLKernel { | |||||
| public: | |||||
| explicit BatchToSpaceNDOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs) | |||||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||||
| ~BatchToSpaceNDOpenCLKernel() override{}; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size) override; | |||||
| int InitBuffer(); | |||||
| private: | |||||
| cl::Kernel kernel_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif | |||||
| @@ -24,6 +24,7 @@ | |||||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | using mindspore::kernel::KERNEL_ARCH::kGPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| using mindspore::schema::PrimitiveType_SpaceToBatch; | |||||
| using mindspore::schema::PrimitiveType_SpaceToBatchND; | using mindspore::schema::PrimitiveType_SpaceToBatchND; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| @@ -103,8 +104,6 @@ int SpaceToBatchNDOpenCLKernel::Run() { | |||||
| MS_LOG(DEBUG) << this->name() << " Running! "; | MS_LOG(DEBUG) << this->name() << " Running! "; | ||||
| auto param = reinterpret_cast<SpaceToBatchParameter *>(this->op_parameter_); | auto param = reinterpret_cast<SpaceToBatchParameter *>(this->op_parameter_); | ||||
| auto input_shape = in_tensors_[0]->shape(); | |||||
| auto output_shape = out_tensors_[0]->shape(); | |||||
| size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM); | size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM); | ||||
| size_t CI4 = UP_DIV(in_tensors_[0]->Channel(), C4NUM); | size_t CI4 = UP_DIV(in_tensors_[0]->Channel(), C4NUM); | ||||
| cl_int4 src_size = {(cl_int)CI4, in_tensors_[0]->Width(), in_tensors_[0]->Height(), in_tensors_[0]->Batch()}; | cl_int4 src_size = {(cl_int)CI4, in_tensors_[0]->Width(), in_tensors_[0]->Height(), in_tensors_[0]->Batch()}; | ||||
| @@ -146,5 +145,7 @@ kernel::LiteKernel *OpenCLSpaceToBatchNDKernelCreator(const std::vector<lite::Te | |||||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SpaceToBatchND, OpenCLSpaceToBatchNDKernelCreator); | REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SpaceToBatchND, OpenCLSpaceToBatchNDKernelCreator); | ||||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_SpaceToBatchND, OpenCLSpaceToBatchNDKernelCreator); | REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_SpaceToBatchND, OpenCLSpaceToBatchNDKernelCreator); | ||||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SpaceToBatch, OpenCLSpaceToBatchNDKernelCreator); | |||||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_SpaceToBatch, OpenCLSpaceToBatchNDKernelCreator); | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -174,8 +174,8 @@ class OpenCLRuntimeWrapper { | |||||
| public: | public: | ||||
| OpenCLRuntimeWrapper() { ocl_runtime_ = OpenCLRuntime::GetInstance(); } | OpenCLRuntimeWrapper() { ocl_runtime_ = OpenCLRuntime::GetInstance(); } | ||||
| ~OpenCLRuntimeWrapper() { OpenCLRuntime::DeleteInstance(); } | ~OpenCLRuntimeWrapper() { OpenCLRuntime::DeleteInstance(); } | ||||
| explicit OpenCLRuntimeWrapper(const OpenCLRuntime &) = delete; | |||||
| OpenCLRuntimeWrapper &operator=(const OpenCLRuntime &) = delete; | |||||
| OpenCLRuntimeWrapper(const OpenCLRuntimeWrapper &) = delete; | |||||
| OpenCLRuntimeWrapper &operator=(const OpenCLRuntimeWrapper &) = delete; | |||||
| OpenCLRuntime *GetInstance() { return ocl_runtime_; } | OpenCLRuntime *GetInstance() { return ocl_runtime_; } | ||||
| private: | private: | ||||
| @@ -0,0 +1,174 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "src/runtime/kernel/opencl/utils.h" | |||||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/batch_to_space_nd.h" | |||||
| namespace mindspore { | |||||
| class TestBatchToSpaceNDOpenCL : public mindspore::CommonTest { | |||||
| public: | |||||
| TestBatchToSpaceNDOpenCL() {} | |||||
| }; | |||||
| template <typename T> | |||||
| void test_main_batch_to_space_nd(void *input_data, void *correct_data, const std::vector<int> &input_shape, | |||||
| BatchToSpaceParameter *param, TypeId data_type, schema::Format format) { | |||||
| MS_LOG(INFO) << " begin test "; | |||||
| auto ocl_runtime_wrap = lite::opencl::OpenCLRuntimeWrapper(); | |||||
| auto ocl_runtime = ocl_runtime_wrap.GetInstance(); | |||||
| ocl_runtime->Init(); | |||||
| auto allocator = ocl_runtime->GetAllocator(); | |||||
| std::vector<int> output_shape = input_shape; | |||||
| output_shape[0] = input_shape[0] / param->block_shape_[0] / param->block_shape_[1]; | |||||
| output_shape[1] = input_shape[1] * param->block_shape_[0] - param->crops_[0] - param->crops_[1]; | |||||
| output_shape[2] = input_shape[2] * param->block_shape_[1] - param->crops_[2] - param->crops_[3]; | |||||
| auto tensor_a = lite::Tensor(TypeId(data_type), input_shape, format); | |||||
| auto tensor_c = lite::Tensor(TypeId(data_type), output_shape, format); | |||||
| std::vector<lite::Tensor *> inputs{&tensor_a}; | |||||
| std::vector<lite::Tensor *> outputs{&tensor_c}; | |||||
| size_t input_size = tensor_a.Size(); | |||||
| auto *pkernel = | |||||
| new (std::nothrow) kernel::BatchToSpaceNDOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| if (pkernel == nullptr) { | |||||
| MS_LOG(INFO) << "new BatchToSpaceNDOpenCLKernel failed "; | |||||
| return; | |||||
| } | |||||
| pkernel->Init(); | |||||
| // to do allocate memory for inputs and outputs | |||||
| for (auto &input_tensor : inputs) { | |||||
| input_tensor->MallocData(allocator); | |||||
| } | |||||
| MS_LOG(INFO) << " initialize sub_graph "; | |||||
| std::vector<kernel::LiteKernel *> kernels{pkernel}; | |||||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| if (sub_graph == nullptr) { | |||||
| delete pkernel; | |||||
| MS_LOG(INFO) << " new SubGraphOpenCLKernel failed "; | |||||
| return; | |||||
| } | |||||
| sub_graph->Init(); | |||||
| MS_LOG(INFO) << " init tensors "; | |||||
| T *input_ptr = reinterpret_cast<T *>(inputs[0]->MutableData()); | |||||
| memcpy(input_ptr, input_data, input_size); | |||||
| std::cout << "==================input data================" << std::endl; | |||||
| for (auto i = 0; i < inputs[0]->ElementsNum(); ++i) { | |||||
| std::cout << input_ptr[i] << ", "; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| sub_graph->Run(); | |||||
| auto *output_data = reinterpret_cast<T *>(outputs[0]->MutableData()); | |||||
| std::cout << "==================output data================" << std::endl; | |||||
| for (auto i = 0; i < outputs[0]->ElementsNum(); ++i) { | |||||
| std::cout << output_data[i] << ", "; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::cout << "==================correct data================" << std::endl; | |||||
| for (auto i = 0; i < outputs[0]->ElementsNum(); ++i) { | |||||
| std::cout << static_cast<T *>(correct_data)[i] << ", "; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| CommonTest::CompareOutputData<T>(output_data, static_cast<T *>(correct_data), outputs[0]->ElementsNum(), 0.0001); | |||||
| delete sub_graph; | |||||
| } | |||||
| TEST_F(TestBatchToSpaceNDOpenCL, NHWC4H2W2Pad2222) { | |||||
| std::vector<int> input_shape{4, 5, 5, 4}; | |||||
| BatchToSpaceParameter *param = std::make_unique<BatchToSpaceParameter>().release(); | |||||
| if (param == nullptr) { | |||||
| return; | |||||
| } | |||||
| param->block_shape_[0] = 2; | |||||
| param->block_shape_[1] = 2; | |||||
| param->crops_[0] = 2; | |||||
| param->crops_[1] = 2; | |||||
| param->crops_[2] = 2; | |||||
| param->crops_[3] = 2; | |||||
| float input_data[] = { | |||||
| 172, 47, 117, 192, 67, 251, 195, 103, 9, 211, 21, 242, 36, 87, 70, 216, 88, 140, 58, 193, 230, 39, 87, | |||||
| 174, 88, 81, 165, 25, 77, 72, 9, 148, 115, 208, 243, 197, 254, 79, 175, 192, 82, 99, 216, 177, 243, 29, | |||||
| 147, 147, 142, 167, 32, 193, 9, 185, 127, 32, 31, 202, 244, 151, 163, 254, 203, 114, 183, 28, 34, 128, 128, | |||||
| 164, 53, 133, 38, 232, 244, 17, 79, 132, 105, 42, 186, 31, 120, 1, 65, 231, 169, 57, 35, 102, 119, 11, | |||||
| 174, 82, 91, 128, 142, 99, 53, 140, 121, 170, 84, 203, 68, 6, 196, 47, 127, 244, 131, 204, 100, 180, 232, | |||||
| 78, 143, 148, 227, 186, 23, 207, 141, 117, 85, 48, 49, 69, 169, 163, 192, 95, 197, 94, 0, 113, 178, 36, | |||||
| 162, 48, 93, 131, 98, 42, 205, 112, 231, 149, 201, 127, 0, 138, 114, 43, 186, 127, 23, 187, 130, 121, 98, | |||||
| 62, 163, 222, 123, 195, 82, 174, 227, 148, 209, 50, 155, 14, 41, 58, 193, 36, 10, 86, 43, 104, 11, 2, | |||||
| 51, 80, 32, 182, 128, 38, 19, 174, 42, 115, 184, 188, 232, 77, 30, 24, 125, 2, 3, 94, 226, 107, 13, | |||||
| 112, 40, 72, 19, 95, 72, 154, 194, 248, 180, 67, 236, 61, 14, 96, 4, 195, 237, 139, 252, 86, 205, 121, | |||||
| 109, 75, 184, 16, 152, 157, 149, 110, 25, 208, 188, 121, 118, 117, 189, 83, 161, 104, 160, 228, 251, 251, 121, | |||||
| 70, 213, 31, 13, 71, 184, 152, 79, 41, 18, 40, 182, 207, 11, 166, 111, 93, 249, 129, 223, 118, 44, 216, | |||||
| 125, 24, 67, 210, 239, 3, 234, 204, 230, 35, 214, 254, 189, 197, 215, 43, 32, 11, 104, 212, 138, 182, 235, | |||||
| 165, 125, 156, 111, 232, 2, 27, 211, 217, 151, 53, 51, 174, 148, 181, 29, 67, 35, 39, 137, 73, 41, 151, | |||||
| 131, 46, 218, 178, 108, 3, 31, 9, 138, 27, 173, 199, 167, 61, 85, 97, 44, 34, 162, 88, 33, 133, 232, | |||||
| 36, 0, 203, 34, 197, 126, 181, 254, 80, 190, 136, 189, 129, 209, 112, 35, 120, 91, 168, 116, 36, 176, 25, | |||||
| 67, 103, 252, 35, 114, 30, 29, 241, 33, 146, 17, 221, 84, 253, 2, 69, 101, 140, 44, 117, 253, 66, 111, | |||||
| 91, 85, 167, 39, 203, 150, 158, 145, 198, | |||||
| }; | |||||
| float correct_data[] = {88, 81, 165, 25, 85, 48, 49, 69, 77, 72, 9, 148, 169, 163, 192, 95, 115, 208, | |||||
| 243, 197, 197, 94, 0, 113, 237, 139, 252, 86, 218, 178, 108, 3, 205, 121, 109, 75, | |||||
| 31, 9, 138, 27, 184, 16, 152, 157, 173, 199, 167, 61, 243, 29, 147, 147, 205, 112, | |||||
| 231, 149, 142, 167, 32, 193, 201, 127, 0, 138, 9, 185, 127, 32, 114, 43, 186, 127, | |||||
| 189, 83, 161, 104, 232, 36, 0, 203, 160, 228, 251, 251, 34, 197, 126, 181, 121, 70, | |||||
| 213, 31, 254, 80, 190, 136, 183, 28, 34, 128, 123, 195, 82, 174, 128, 164, 53, 133, | |||||
| 227, 148, 209, 50, 38, 232, 244, 17, 155, 14, 41, 58, 182, 207, 11, 166, 116, 36, | |||||
| 176, 25, 111, 93, 249, 129, 67, 103, 252, 35, 223, 118, 44, 216, 114, 30, 29, 241}; | |||||
| TypeId data_type = kNumberTypeFloat32; | |||||
| schema::Format format = schema::Format_NHWC; | |||||
| test_main_batch_to_space_nd<float>(input_data, correct_data, input_shape, param, data_type, format); | |||||
| } | |||||
| TEST_F(TestBatchToSpaceNDOpenCL, NC4HW4H2W2Pad2222) { | |||||
| std::vector<int> input_shape{4, 5, 5, 4}; | |||||
| BatchToSpaceParameter *param = std::make_unique<BatchToSpaceParameter>().release(); | |||||
| if (param == nullptr) { | |||||
| return; | |||||
| } | |||||
| param->block_shape_[0] = 2; | |||||
| param->block_shape_[1] = 2; | |||||
| param->crops_[0] = 2; | |||||
| param->crops_[1] = 2; | |||||
| param->crops_[2] = 2; | |||||
| param->crops_[3] = 2; | |||||
| float input_data[] = {172, 47, 117, 192, 67, 251, 195, 103, 9, 211, 21, 242, 36, 87, 70, 216, 88, 140, | |||||
| 58, 193, 230, 39, 87, 174, 88, 81, 165, 25, 77, 72, 9, 148, 115, 208, 243, 197, | |||||
| 254, 79, 175, 192, 82, 99, 216, 177, 243, 29, 147, 147, 142, 167, 32, 193, 9, 185, | |||||
| 127, 32, 31, 202, 244, 151, 163, 254, 203, 114, 183, 28, 34, 128, 128, 164, 53, 133, | |||||
| 38, 232, 244, 17, 79, 132, 105, 42, 186, 31, 120, 1, 65, 231, 169, 57, 35, 102, | |||||
| 119, 11, 174, 82, 91, 128, 142, 99, 53, 140, 121, 170, 84, 203, 68, 6, 196, 47, | |||||
| 127, 244, 131, 204, 100, 180, 232, 78, 143, 148, 227, 186, 23, 207, 141, 117, 85, 48, | |||||
| 49, 69, 169, 163, 192, 95, 197, 94, 0, 113, 178, 36, 162, 48, 93, 131, 98, 42}; | |||||
| float correct_data[] = {88, 81, 165, 25, 85, 48, 49, 69, 77, 72, 9, 148, 169, 163, 192, 95, 115, 208, | |||||
| 243, 197, 197, 94, 0, 113, 237, 139, 252, 86, 218, 178, 108, 3, 205, 121, 109, 75, | |||||
| 31, 9, 138, 27, 184, 16, 152, 157, 173, 199, 167, 61, 243, 29, 147, 147, 205, 112, | |||||
| 231, 149, 142, 167, 32, 193, 201, 127, 0, 138, 9, 185, 127, 32, 114, 43, 186, 127, | |||||
| 189, 83, 161, 104, 232, 36, 0, 203, 160, 228, 251, 251, 34, 197, 126, 181, 121, 70, | |||||
| 213, 31, 254, 80, 190, 136, 183, 28, 34, 128, 123, 195, 82, 174, 128, 164, 53, 133, | |||||
| 227, 148, 209, 50, 38, 232, 244, 17, 155, 14, 41, 58, 182, 207, 11, 166, 116, 36, | |||||
| 176, 25, 111, 93, 249, 129, 67, 103, 252, 35, 223, 118, 44, 216, 114, 30, 29, 241}; | |||||
| TypeId data_type = kNumberTypeFloat32; | |||||
| schema::Format format = schema::Format_NCHW; | |||||
| test_main_batch_to_space_nd<float>(input_data, correct_data, input_shape, param, data_type, format); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -396,7 +396,7 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_3input_dim4_axis1) { | |||||
| TEST_F(TestConcatOpenCLfp16, ConcatFp16_6input_dim4_axis1) { | TEST_F(TestConcatOpenCLfp16, ConcatFp16_6input_dim4_axis1) { | ||||
| MS_LOG(INFO) << " begin test "; | MS_LOG(INFO) << " begin test "; | ||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); | |||||
| ocl_runtime->SetFp16Enable(true); | ocl_runtime->SetFp16Enable(true); | ||||
| ocl_runtime->Init(); | ocl_runtime->Init(); | ||||
| auto allocator = ocl_runtime->GetAllocator(); | auto allocator = ocl_runtime->GetAllocator(); | ||||
| @@ -523,7 +523,6 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_6input_dim4_axis1) { | |||||
| sub_graph->Run(); | sub_graph->Run(); | ||||
| auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->MutableData()); | auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->MutableData()); | ||||
| CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); | CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); | ||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| for (auto tensor : inputs) { | for (auto tensor : inputs) { | ||||
| tensor->SetData(nullptr); | tensor->SetData(nullptr); | ||||
| delete tensor; | delete tensor; | ||||
| @@ -138,7 +138,7 @@ TEST_F(TestSpaceToBatchNDOpenCL, NHWC4H2W2Pad2222) { | |||||
| schema::Format format = schema::Format_NHWC; | schema::Format format = schema::Format_NHWC; | ||||
| test_main_space_to_batch_nd<float>(input_data, correct_data, input_shape, param, data_type, format); | test_main_space_to_batch_nd<float>(input_data, correct_data, input_shape, param, data_type, format); | ||||
| } | } | ||||
| TEST_F(TestSpaceToBatchNDOpenCL, Nc4HW4H2W2Pad2222) { | |||||
| TEST_F(TestSpaceToBatchNDOpenCL, NC4HW4H2W2Pad2222) { | |||||
| std::vector<int> input_shape{1, 6, 6, 4}; | std::vector<int> input_shape{1, 6, 6, 4}; | ||||
| SpaceToBatchParameter *param = std::make_unique<SpaceToBatchParameter>().release(); | SpaceToBatchParameter *param = std::make_unique<SpaceToBatchParameter>().release(); | ||||
| if (param == nullptr) { | if (param == nullptr) { | ||||