From edc4ac2c257d73b2ecc09900b1be81cbf00e283d Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Wed, 11 Nov 2020 09:35:40 +0800 Subject: [PATCH] add spacetodepth --- .../kernel/opencl/cl/space_to_depth.cl | 56 ++++ .../kernel/opencl/kernel/space_to_depth.cc | 87 ++++++ .../kernel/opencl/kernel/space_to_depth.h | 47 +++ .../kernel/opencl/space_to_depth_tests.cc | 268 ++++++++++++++++++ 4 files changed, 458 insertions(+) create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/space_to_depth.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.h create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/space_to_depth.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/space_to_depth.cl new file mode 100644 index 0000000000..a968cec81e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/space_to_depth.cl @@ -0,0 +1,56 @@ +#ifdef cl_khr_fp16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif + +#define C4NUM 4 +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void SpaceToDepth(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_shape, + int4 out_shape, int block_size, int ci_size) { + int X = get_global_id(0); // C4 + int Y = get_global_id(1); // W + int Z = get_global_id(2); // H * N + if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return; + int N = Z / out_shape.y; + int H = Z % out_shape.y; + int co_base = X * C4NUM; + FLT result[C4NUM] = {0.f}; + for (int i = 0; i < C4NUM; i++) { + int co = co_base + i; + int ci = co % ci_size; + int hw_block = co / ci_size; + int hi = H * block_size + hw_block / block_size; + int wi = Y * block_size + hw_block % block_size; + int ci4 = ci / C4NUM; + int ci4_ramainder = ci % C4NUM; + FLT4 tmp = READ_IMAGE(src_data, smp_zero, (int2)(wi * in_shape.w + ci4, N * in_shape.y + hi)); + if (ci4_ramainder == 0) { + result[i] = tmp.x; + } else if (ci4_ramainder == 1) { + result[i] = tmp.y; + } else if (ci4_ramainder == 2) { + result[i] = tmp.z; + } else { + result[i] = tmp.w; + } + } + FLT4 result_flt4 = {result[0], result[1], result[2], result[3]}; + WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result_flt4); +} + +__kernel void SpaceToDepthAlign(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_shape, + int4 out_shape, int block_size, int ci_size) { + int X = get_global_id(0); // C4 + int Y = get_global_id(1); // W + int Z = get_global_id(2); // H * N + if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return; + + int N = Z / out_shape.y; + int H = Z % out_shape.y; + int ni = N; + int ci = X % in_shape.w; + int hw_block = X / in_shape.w; + int hi = H * block_size + hw_block / block_size; + int wi = Y * block_size + hw_block % block_size; + WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), + READ_IMAGE(src_data, smp_zero, (int2)(wi * in_shape.w + ci, ni * in_shape.y + hi))); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.cc new file mode 100644 index 0000000000..91eacc98f3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "include/errorcode.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/opencl/kernel/space_to_depth.h" +#include "src/runtime/kernel/opencl/cl/space_to_depth.cl.inc" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::lite::RET_PARAM_INVALID; +using mindspore::schema::PrimitiveType_SpaceToDepth; + +namespace mindspore::kernel { +int SpaceToDepthOpenCLKernel::CheckSpecs() { return RET_OK; } + +int SpaceToDepthOpenCLKernel::Prepare() { + std::string kernel_name; + in_shape_ = Image2DInfo(in_tensors_[0]); + out_shape_ = Image2DInfo(out_tensors_[0]); + if (in_shape_.C % 4 != 0) { + kernel_name = "SpaceToDepth"; + } else { + kernel_name = "SpaceToDepthAlign"; + } +#ifdef PROGRAM_WITH_IL + kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name); +#else + std::set build_options; + std::string source = space_to_depth_source; + std::string program_name = "SpaceToDepth"; + ocl_runtime_->LoadSource(program_name, source); + ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + SetConstArgs(); + SetGlobalLocal(); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return mindspore::lite::RET_OK; +} +void SpaceToDepthOpenCLKernel::SetConstArgs() { + cl_int4 cl_in_shape = {static_cast(in_shape_.N), static_cast(in_shape_.H), + static_cast(in_shape_.W), static_cast(in_shape_.Slice)}; + cl_int4 cl_out_shape = {static_cast(out_shape_.N), static_cast(out_shape_.H), + static_cast(out_shape_.W), static_cast(out_shape_.Slice)}; + auto param = reinterpret_cast(op_parameter_); + int arg_idx = 2; + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, cl_in_shape); + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, cl_out_shape); + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, param->block_size_); + int ci_size = in_shape_.C; + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, ci_size); +} +void SpaceToDepthOpenCLKernel::SetGlobalLocal() { + global_range_ = {out_shape_.Slice, out_shape_.W, out_shape_.H * out_shape_.N}; +} + +int SpaceToDepthOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->name() << " Running!"; + int arg_idx = 0; + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c()); + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c()); + ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr); + return mindspore::lite::RET_OK; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SpaceToDepth, OpenCLKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_SpaceToDepth, OpenCLKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.h new file mode 100644 index 0000000000..af9ff202e1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SAPCE_TO_DEPTH_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SAPCE_TO_DEPTH_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "nnacl/fp32/space_to_depth.h" + +namespace mindspore::kernel { +class SpaceToDepthOpenCLKernel : public OpenCLKernel { + public: + SpaceToDepthOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + ~SpaceToDepthOpenCLKernel() override = default; + + int Run() override; + int Prepare() override; + int CheckSpecs() override; + void SetConstArgs() override; + void SetGlobalLocal() override; + + private: + cl::Kernel kernel_; + Image2DInfo in_shape_ = Image2DInfo(nullptr); + Image2DInfo out_shape_ = Image2DInfo(nullptr); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SAPCE_TO_DEPTH_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc new file mode 100644 index 0000000000..83ba480679 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc @@ -0,0 +1,268 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "src/common/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.h" +#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" + +namespace mindspore { +class TestSpaceToDepthOpenCL : public mindspore::CommonTest { + public: + TestSpaceToDepthOpenCL() {} +}; + +void RunTestCaseSpaceToDepth(const std::vector &shape_in, const std::vector &shape_out, void *input_data, + void *output_data, bool enable_fp16, int block_size) { + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); + ocl_runtime->Init(); + size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); + ocl_runtime->SetFp16Enable(enable_fp16); + auto allocator = ocl_runtime->GetAllocator(); + auto param = static_cast(malloc(sizeof(SpaceToDepthParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "param_ptr create error."; + return; + } + param->block_size_ = block_size; + auto tensor_x_ptr = std::make_unique(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), + shape_in, schema::Format_NHWC); + auto tensor_x = tensor_x_ptr.get(); + if (tensor_x == nullptr) { + MS_LOG(ERROR) << "tensor_x create error."; + return; + } + auto tensor_out_ptr = + std::make_unique(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), shape_out); + auto tensor_out = tensor_out_ptr.get(); + if (tensor_out == nullptr) { + MS_LOG(ERROR) << "tensor_out create error."; + return; + } + std::vector inputs{tensor_x}; + std::vector outputs{tensor_out}; + auto arith_kernel = kernel::OpenCLKernelCreator( + inputs, outputs, reinterpret_cast(param), nullptr, kernel::KernelKey(), nullptr); + if (arith_kernel == nullptr) { + MS_LOG(ERROR) << "arith_kernel create error."; + return; + } + + inputs[0]->MallocData(allocator); + + std::vector kernels{arith_kernel}; + auto pGraph_ptr = std::make_unique(inputs, outputs, kernels, kernels, kernels); + auto pGraph = pGraph_ptr.get(); + if (pGraph == nullptr) { + MS_LOG(ERROR) << "pGraph create error."; + return; + } + pGraph->Init(); + memcpy(inputs[0]->MutableData(), input_data, inputs[0]->ElementsNum() * dtype_size); + pGraph->Run(); + + if (enable_fp16) { + CompareOutput(outputs[0]->MutableData(), output_data, outputs[0]->ElementsNum(), static_cast(1e-3), + 2e-2); + } else { + CompareOutput(outputs[0]->MutableData(), output_data, outputs[0]->ElementsNum(), static_cast(1e-5)); + } + for (auto t : inputs) { + t->set_data(nullptr); + } + for (auto t : outputs) { + t->set_data(nullptr); + } + + MS_LOG(INFO) << "Test SpaceToDepth passed"; +} + +TEST_F(TestSpaceToDepthOpenCL, AlignTest1Fp32) { + std::vector shape_in = {1, 2, 2, 4}; + std::vector shape_out = {1, 1, 1, 16}; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f}; + std::vector output_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f}; + + RunTestCaseSpaceToDepth(shape_in, shape_out, input_data.data(), output_data.data(), false, 2); +} + +TEST_F(TestSpaceToDepthOpenCL, AlignTest1Fp16) { + std::vector shape_in = {1, 2, 2, 4}; + std::vector shape_out = {1, 1, 1, 16}; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f}; + std::vector output_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f}; + + RunTestCaseSpaceToDepth(shape_in, shape_out, input_data.data(), output_data.data(), true, 2); +} + +TEST_F(TestSpaceToDepthOpenCL, AlignTest2Fp32) { + std::vector shape_in = {1, 4, 4, 4}; + std::vector shape_out = {1, 2, 2, 16}; + std::vector input_data = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, + 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, + 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, 61.0f, 62.0f, 63.0f}; + std::vector output_data = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, + 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, + 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, 61.0f, 62.0f, 63.0f}; + + RunTestCaseSpaceToDepth(shape_in, shape_out, input_data.data(), output_data.data(), false, 2); +} + +TEST_F(TestSpaceToDepthOpenCL, AlignTest2Fp16) { + std::vector shape_in = {1, 4, 4, 4}; + std::vector shape_out = {1, 2, 2, 16}; + std::vector input_data = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, + 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, + 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, 61.0f, 62.0f, 63.0f}; + std::vector output_data = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, + 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, + 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, 61.0f, 62.0f, 63.0f}; + + RunTestCaseSpaceToDepth(shape_in, shape_out, input_data.data(), output_data.data(), true, 2); +} + +TEST_F(TestSpaceToDepthOpenCL, AlignTest3Fp32) { + std::vector shape_in = {1, 6, 6, 4}; + std::vector shape_out = {1, 2, 2, 36}; + std::vector input_data = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, + 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, + 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, + 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, + 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f, + 70.0f, 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f, 80.0f, 81.0f, 82.0f, 83.0f, + 84.0f, 85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f, 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f, 97.0f, + 98.0f, 99.0f, 100.0f, 101.0f, 102.0f, 103.0f, 104.0f, 105.0f, 106.0f, 107.0f, 108.0f, 109.0f, 110.0f, 111.0f, + 112.0f, 113.0f, 114.0f, 115.0f, 116.0f, 117.0f, 118.0f, 119.0f, 120.0f, 121.0f, 122.0f, 123.0f, 124.0f, 125.0f, + 126.0f, 127.0f, 128.0f, 129.0f, 130.0f, 131.0f, 132.0f, 133.0f, 134.0f, 135.0f, 136.0f, 137.0f, 138.0f, 139.0f, + 140.0f, 141.0f, 142.0f, 143.0f}; + std::vector output_data = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 24.0f, 25.0f, + 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 48.0f, 49.0f, 50.0f, 51.0f, + 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, + 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, + 44.0f, 45.0f, 46.0f, 47.0f, 60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f, + 70.0f, 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f, 80.0f, 81.0f, 82.0f, 83.0f, + 96.0f, 97.0f, 98.0f, 99.0f, 100.0f, 101.0f, 102.0f, 103.0f, 104.0f, 105.0f, 106.0f, 107.0f, 120.0f, 121.0f, + 122.0f, 123.0f, 124.0f, 125.0f, 126.0f, 127.0f, 128.0f, 129.0f, 130.0f, 131.0f, 84.0f, 85.0f, 86.0f, 87.0f, + 88.0f, 89.0f, 90.0f, 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 108.0f, 109.0f, 110.0f, 111.0f, 112.0f, 113.0f, + 114.0f, 115.0f, 116.0f, 117.0f, 118.0f, 119.0f, 132.0f, 133.0f, 134.0f, 135.0f, 136.0f, 137.0f, 138.0f, 139.0f, + 140.0f, 141.0f, 142.0f, 143.0f}; + + RunTestCaseSpaceToDepth(shape_in, shape_out, input_data.data(), output_data.data(), false, 3); +} + +TEST_F(TestSpaceToDepthOpenCL, NotAlignTest1Fp32) { + std::vector shape_in = {1, 2, 2, 1}; + std::vector shape_out = {1, 1, 1, 4}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f}; + std::vector output_data = {0.0f, 1.0f, 2.0f, 3.0f}; + + RunTestCaseSpaceToDepth(shape_in, shape_out, input_data.data(), output_data.data(), false, 2); +} + +TEST_F(TestSpaceToDepthOpenCL, NotAlignTest1Fp16) { + std::vector shape_in = {1, 2, 2, 1}; + std::vector shape_out = {1, 1, 1, 4}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f}; + std::vector output_data = {0.0f, 1.0f, 2.0f, 3.0f}; + + RunTestCaseSpaceToDepth(shape_in, shape_out, input_data.data(), output_data.data(), true, 2); +} + +TEST_F(TestSpaceToDepthOpenCL, NotAlignTest2Fp32) { + std::vector shape_in = {1, 2, 2, 3}; + std::vector shape_out = {1, 1, 1, 12}; + std::vector input_data = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, + }; + std::vector output_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; + + RunTestCaseSpaceToDepth(shape_in, shape_out, input_data.data(), output_data.data(), false, 2); +} + +TEST_F(TestSpaceToDepthOpenCL, NotAlignTest3Fp32) { + std::vector shape_in = {1, 4, 4, 3}; + std::vector shape_out = {1, 2, 2, 12}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, + 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, + 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f}; + std::vector output_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, + 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f}; + + RunTestCaseSpaceToDepth(shape_in, shape_out, input_data.data(), output_data.data(), false, 2); +} + +TEST_F(TestSpaceToDepthOpenCL, NotAlignTest4Fp32) { + std::vector shape_in = {1, 6, 6, 6}; + std::vector shape_out = {1, 2, 2, 54}; + std::vector input_data = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, + 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, + 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, + 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, + 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f, + 70.0f, 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f, 80.0f, 81.0f, 82.0f, 83.0f, + 84.0f, 85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f, 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f, 97.0f, + 98.0f, 99.0f, 100.0f, 101.0f, 102.0f, 103.0f, 104.0f, 105.0f, 106.0f, 107.0f, 108.0f, 109.0f, 110.0f, 111.0f, + 112.0f, 113.0f, 114.0f, 115.0f, 116.0f, 117.0f, 118.0f, 119.0f, 120.0f, 121.0f, 122.0f, 123.0f, 124.0f, 125.0f, + 126.0f, 127.0f, 128.0f, 129.0f, 130.0f, 131.0f, 132.0f, 133.0f, 134.0f, 135.0f, 136.0f, 137.0f, 138.0f, 139.0f, + 140.0f, 141.0f, 142.0f, 143.0f, 144.0f, 145.0f, 146.0f, 147.0f, 148.0f, 149.0f, 150.0f, 151.0f, 152.0f, 153.0f, + 154.0f, 155.0f, 156.0f, 157.0f, 158.0f, 159.0f, 160.0f, 161.0f, 162.0f, 163.0f, 164.0f, 165.0f, 166.0f, 167.0f, + 168.0f, 169.0f, 170.0f, 171.0f, 172.0f, 173.0f, 174.0f, 175.0f, 176.0f, 177.0f, 178.0f, 179.0f, 180.0f, 181.0f, + 182.0f, 183.0f, 184.0f, 185.0f, 186.0f, 187.0f, 188.0f, 189.0f, 190.0f, 191.0f, 192.0f, 193.0f, 194.0f, 195.0f, + 196.0f, 197.0f, 198.0f, 199.0f, 200.0f, 201.0f, 202.0f, 203.0f, 204.0f, 205.0f, 206.0f, 207.0f, 208.0f, 209.0f, + 210.0f, 211.0f, 212.0f, 213.0f, 214.0f, 215.0f}; + std::vector output_data = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, + 14.0f, 15.0f, 16.0f, 17.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, + 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, + 78.0f, 79.0f, 80.0f, 81.0f, 82.0f, 83.0f, 84.0f, 85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 18.0f, 19.0f, + 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, + 34.0f, 35.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, + 66.0f, 67.0f, 68.0f, 69.0f, 70.0f, 71.0f, 90.0f, 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f, 97.0f, + 98.0f, 99.0f, 100.0f, 101.0f, 102.0f, 103.0f, 104.0f, 105.0f, 106.0f, 107.0f, 108.0f, 109.0f, 110.0f, 111.0f, + 112.0f, 113.0f, 114.0f, 115.0f, 116.0f, 117.0f, 118.0f, 119.0f, 120.0f, 121.0f, 122.0f, 123.0f, 124.0f, 125.0f, + 144.0f, 145.0f, 146.0f, 147.0f, 148.0f, 149.0f, 150.0f, 151.0f, 152.0f, 153.0f, 154.0f, 155.0f, 156.0f, 157.0f, + 158.0f, 159.0f, 160.0f, 161.0f, 180.0f, 181.0f, 182.0f, 183.0f, 184.0f, 185.0f, 186.0f, 187.0f, 188.0f, 189.0f, + 190.0f, 191.0f, 192.0f, 193.0f, 194.0f, 195.0f, 196.0f, 197.0f, 126.0f, 127.0f, 128.0f, 129.0f, 130.0f, 131.0f, + 132.0f, 133.0f, 134.0f, 135.0f, 136.0f, 137.0f, 138.0f, 139.0f, 140.0f, 141.0f, 142.0f, 143.0f, 162.0f, 163.0f, + 164.0f, 165.0f, 166.0f, 167.0f, 168.0f, 169.0f, 170.0f, 171.0f, 172.0f, 173.0f, 174.0f, 175.0f, 176.0f, 177.0f, + 178.0f, 179.0f, 198.0f, 199.0f, 200.0f, 201.0f, 202.0f, 203.0f, 204.0f, 205.0f, 206.0f, 207.0f, 208.0f, 209.0f, + 210.0f, 211.0f, 212.0f, 213.0f, 214.0f, 215.0f}; + + RunTestCaseSpaceToDepth(shape_in, shape_out, input_data.data(), output_data.data(), false, 3); +} +} // namespace mindspore