From 91568de9ebe77fdb9d15499832b50a04ab105101 Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Thu, 12 Nov 2020 15:28:50 +0800 Subject: [PATCH] add gpu op onehot --- .../src/runtime/kernel/opencl/cl/one_hot.cl | 232 ++++++++ .../runtime/kernel/opencl/kernel/one_hot.cc | 102 ++++ .../runtime/kernel/opencl/kernel/one_hot.h | 52 ++ .../kernel/opencl/kernel/space_to_depth.cc | 2 +- .../src/runtime/kernel/opencl/opencl_kernel.h | 26 +- .../runtime/kernel/opencl/one_hot_tests.cc | 534 ++++++++++++++++++ 6 files changed, 940 insertions(+), 8 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/one_hot.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/one_hot.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/one_hot.cl new file mode 100644 index 0000000000..872cd054ea --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/one_hot.cl @@ -0,0 +1,232 @@ +#ifdef cl_khr_fp16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif + +#define C4NUM 4 +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void OneHotAxis0(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape, + int4 out_shape, int depth, float on_value, float off_value, int C) { + int X = get_global_id(0); // C4 + int Y = get_global_id(1); // W + int Z = get_global_id(2); // H * N + if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return; + int N = Z / out_shape.y; + int H = Z % out_shape.y; + int in_index = (H * out_shape.z + Y) * out_shape.w + X; + FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x)); + int *indices_int = (int *)&indices; + FLT4 result = (FLT4)(0.f); + if (4 * X < C) { + if (indices_int[0] == N) { + result.x = (FLT)(on_value); + } else { + result.x = (FLT)(off_value); + } + } + if (4 * X + 1 < C) { + if (indices_int[1] == N) { + result.y = (FLT)(on_value); + } else { + result.y = (FLT)(off_value); + } + } + if (4 * X + 2 < C) { + if (indices_int[2] == N) { + result.z = (FLT)(on_value); + } else { + result.z = (FLT)(off_value); + } + } + if (4 * X + 3 < C) { + if (indices_int[3] == N) { + result.w = (FLT)(on_value); + } else { + result.w = (FLT)(off_value); + } + } + WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result); +} + +__kernel void OneHotAxis1(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape, + int4 out_shape, int depth, float on_value, float off_value, int C) { + int X = get_global_id(0); // C4 + int Y = get_global_id(1); // W + int Z = get_global_id(2); // H * N + if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return; + int N = Z / out_shape.y; + int H = Z % out_shape.y; + int in_index = (N * out_shape.z + Y) * out_shape.w + X; + FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x)); + int *indices_int = (int *)&indices; + FLT4 result = (FLT4)(0.f); + if (4 * X < C) { + if (indices_int[0] == H) { + result.x = (FLT)(on_value); + } else { + result.x = (FLT)(off_value); + } + } + if (4 * X + 1 < C) { + if (indices_int[1] == H) { + result.y = (FLT)(on_value); + } else { + result.y = (FLT)(off_value); + } + } + if (4 * X + 2 < C) { + if (indices_int[2] == H) { + result.z = (FLT)(on_value); + } else { + result.z = (FLT)(off_value); + } + } + if (4 * X + 3 < C) { + if (indices_int[3] == H) { + result.w = (FLT)(on_value); + } else { + result.w = (FLT)(off_value); + } + } + WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result); +} + +__kernel void OneHotAxis2(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape, + int4 out_shape, int depth, float on_value, float off_value, int C) { + int X = get_global_id(0); // C4 + int Y = get_global_id(1); // W + int Z = get_global_id(2); // H * N + if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return; + int N = Z / out_shape.y; + int H = Z % out_shape.y; + int in_index = (N * out_shape.y + H) * out_shape.w + X; + FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x)); + int *indices_int = (int *)&indices; + FLT4 result = (FLT4)(0.f); + if (4 * X < C) { + if (indices_int[0] == Y) { + result.x = (FLT)(on_value); + } else { + result.x = (FLT)(off_value); + } + } + if (4 * X + 1 < C) { + if (indices_int[1] == Y) { + result.y = (FLT)(on_value); + } else { + result.y = (FLT)(off_value); + } + } + if (4 * X + 2 < C) { + if (indices_int[2] == Y) { + result.z = (FLT)(on_value); + } else { + result.z = (FLT)(off_value); + } + } + if (4 * X + 3 < C) { + if (indices_int[3] == Y) { + result.w = (FLT)(on_value); + } else { + result.w = (FLT)(off_value); + } + } + WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result); +} + +__kernel void OneHotAxis3(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape, + int4 out_shape, int depth, float on_value, float off_value, int C) { + int X = get_global_id(0); // C4 + int Y = get_global_id(1); // W + int Z = get_global_id(2); // H * N + if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return; + int N = Z / out_shape.y; + int H = Z % out_shape.y; + int ci4_size = UP_DIV(out_shape.z, C4NUM); + int in_index_c4 = (N * out_shape.y + H) * ci4_size + Y / 4; + int in_index_c4_remainder = Y % 4; + FLT4 indices = + READ_IMAGE(src_data, smp_zero, (int2)(in_index_c4 % in_image2d_shape.x, in_index_c4 / in_image2d_shape.x)); + int *indices_int = (int *)&indices; + int index_one = indices_int[in_index_c4_remainder]; + FLT4 result = (FLT4)(0.f); + if (4 * X < C) { + if (index_one == 4 * X) { + result.x = (FLT)(on_value); + } else { + result.x = (FLT)(off_value); + } + } + if (4 * X + 1 < C) { + if (index_one == 4 * X + 1) { + result.y = (FLT)(on_value); + } else { + result.y = (FLT)(off_value); + } + } + if (4 * X + 2 < C) { + if (index_one == 4 * X + 2) { + result.z = (FLT)(on_value); + } else { + result.z = (FLT)(off_value); + } + } + if (4 * X + 3 < C) { + if (index_one == 4 * X + 3) { + result.w = (FLT)(on_value); + } else { + result.w = (FLT)(off_value); + } + } + WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result); +} + +__kernel void OneHot2DAxis0(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape, + int4 out_shape, int depth, float on_value, float off_value, int C) { + int X = get_global_id(0); // C4 + int Y = get_global_id(1); // W + int Z = get_global_id(2); // N + if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return; + FLT4 result = (FLT4)(0.f); + int channel = 4 * X; + if (channel < C) { + FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel)); + int index = ((int *)&indices)[0]; + if (index == Z) { + result.x = (FLT)(on_value); + } else { + result.x = (FLT)(off_value); + } + } + channel++; + if (channel < C) { + FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel)); + int index = ((int *)&indices)[0]; + if (index == Z) { + result.y = (FLT)(on_value); + } else { + result.y = (FLT)(off_value); + } + } + channel++; + if (channel < C) { + FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel)); + int index = ((int *)&indices)[0]; + if (index == Z) { + result.z = (FLT)(on_value); + } else { + result.z = (FLT)(off_value); + } + } + channel++; + if (channel < C) { + FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel)); + int index = ((int *)&indices)[0]; + if (index == Z) { + result.w = (FLT)(on_value); + } else { + result.w = (FLT)(off_value); + } + } + WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.cc new file mode 100644 index 0000000000..236f3981c3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "include/errorcode.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/opencl/kernel/one_hot.h" +#include "src/runtime/kernel/opencl/cl/one_hot.cl.inc" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_OneHot; + +namespace mindspore::kernel { +int OneHotOpenCLKernel::CheckSpecs() { return RET_OK; } + +int OneHotOpenCLKernel::Prepare() { + std::string kernel_name = "OneHot"; + auto param = reinterpret_cast(op_parameter_); + in_shape_ = Image2DInfo(in_tensors_[0]); + out_shape_ = Image2DInfo(out_tensors_[0]); + axis_ = out_shape_.AlignAxis(param->axis_); + if (in_tensors_[0]->shape().size() == 1 && axis_ == 0) { + kernel_name += "2DAxis0"; + } else { + kernel_name += "Axis" + std::to_string(axis_); + } +#ifdef PROGRAM_WITH_IL + kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name); +#else + std::set build_options; + std::string source = one_hot_source; + std::string program_name = "OneHot"; + ocl_runtime_->LoadSource(program_name, source); + ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + InitWeights(); + SetConstArgs(); + SetGlobalLocal(); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return mindspore::lite::RET_OK; +} + +int OneHotOpenCLKernel::InitWeights() { + if (in_tensors_.size() <= 1) { + return RET_ERROR; + } + depth_ = static_cast(in_tensors_[1]->data_c())[0]; + if (in_tensors_.size() > 2) { + on_value_ = static_cast(in_tensors_[2]->data_c())[0]; + } + if (in_tensors_.size() > 3) { + off_value_ = static_cast(in_tensors_[3]->data_c())[0]; + } + return RET_OK; +} + +void OneHotOpenCLKernel::SetConstArgs() { + cl_int2 cl_in_image2d_shape = {static_cast(in_shape_.width), static_cast(in_shape_.height)}; + cl_int4 cl_out_shape = {static_cast(out_shape_.N), static_cast(out_shape_.H), + static_cast(out_shape_.W), static_cast(out_shape_.Slice)}; + int arg_idx = 2; + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, cl_in_image2d_shape); + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, cl_out_shape); + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, depth_); + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, on_value_); + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, off_value_); + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, static_cast(out_shape_.C)); +} +void OneHotOpenCLKernel::SetGlobalLocal() { + global_range_ = {out_shape_.Slice, out_shape_.W, out_shape_.H * out_shape_.N}; +} + +int OneHotOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->name() << " Running!"; + int arg_idx = 0; + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c()); + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c()); + ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr); + return mindspore::lite::RET_OK; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_OneHot, OpenCLKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_OneHot, OpenCLKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h new file mode 100644 index 0000000000..8f8c9c1f2a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ONE_HOT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ONE_HOT_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "nnacl/fp32/one_hot.h" + +namespace mindspore::kernel { +class OneHotOpenCLKernel : public OpenCLKernel { + public: + OneHotOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + ~OneHotOpenCLKernel() override = default; + + int Run() override; + int Prepare() override; + int InitWeights() override; + int CheckSpecs() override; + void SetConstArgs() override; + void SetGlobalLocal() override; + + private: + cl::Kernel kernel_; + int depth_{0}; + float on_value_{1.0f}; + float off_value_{0.0f}; + int axis_{0}; + Image2DInfo in_shape_ = Image2DInfo(nullptr); + Image2DInfo out_shape_ = Image2DInfo(nullptr); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ONE_HOT_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.cc index 91eacc98f3..77acb688f1 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.cc @@ -37,7 +37,7 @@ int SpaceToDepthOpenCLKernel::Prepare() { std::string kernel_name; in_shape_ = Image2DInfo(in_tensors_[0]); out_shape_ = Image2DInfo(out_tensors_[0]); - if (in_shape_.C % 4 != 0) { + if (in_shape_.C % C4NUM != 0) { kernel_name = "SpaceToDepth"; } else { kernel_name = "SpaceToDepthAlign"; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h index 8421cb5269..f2ae765362 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h @@ -73,23 +73,23 @@ struct Image2DInfo { return; } auto shape = tensor->shape(); - auto ndim = shape.size(); - if (ndim == 1) { + OriDim = shape.size(); + if (OriDim == 1) { N = shape[0]; - } else if (ndim == 2) { + } else if (OriDim == 2) { N = shape[0]; C = shape[1]; - } else if (ndim == 3) { + } else if (OriDim == 3) { N = shape[0]; W = shape[1]; C = shape[2]; - } else if (ndim == 4) { + } else if (OriDim == 4) { N = shape[0]; H = shape[1]; W = shape[2]; C = shape[3]; - } else if (ndim >= 5) { - MS_LOG(ERROR) << "GPU doesn't support Tensor with ndim>=" << ndim; + } else if (OriDim >= 5) { + MS_LOG(ERROR) << "GPU doesn't support Tensor with ndim>=" << OriDim; } Slice = UP_DIV(C, C4NUM); @@ -116,6 +116,17 @@ struct Image2DInfo { return row_pitch; } + int AlignAxis(int oriAxis) const { + if (OriDim == 0) { + return 0; + } + int no_neg_axis = (oriAxis + OriDim) % OriDim; + if (no_neg_axis == 0) { + return 0; + } + return no_neg_axis + 4 - OriDim; + } + size_t N{1}; size_t H{1}; size_t W{1}; @@ -129,6 +140,7 @@ struct Image2DInfo { size_t ElementsC4Num{}; size_t OriginSize{}; size_t Image2DSize{}; + size_t OriDim{}; }; class OpenCLKernel : public LiteKernel { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc new file mode 100644 index 0000000000..150cde5620 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc @@ -0,0 +1,534 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "src/common/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h" +#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" + +namespace mindspore { +class TestOneHotOpenCL : public mindspore::CommonTest { + public: + TestOneHotOpenCL() {} +}; + +void RunTestCaseOneHot(const std::vector &shape_in, const std::vector &shape_out, void *input_data, + void *output_data, int axis, int depth, float on_value, float off_value) { + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + auto param = static_cast(malloc(sizeof(OneHotParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "param_ptr create error."; + return; + } + param->axis_ = axis; + auto tensor_x_ptr = std::make_unique(kNumberTypeFloat32, shape_in, schema::Format_NHWC); + auto tensor_x = tensor_x_ptr.get(); + if (tensor_x == nullptr) { + MS_LOG(ERROR) << "tensor_x create error."; + return; + } + std::vector weight_shape = {}; + auto tensor_depth_ptr = std::make_unique(kNumberTypeInt32, weight_shape, schema::Format_NHWC); + auto tensor_depth = tensor_depth_ptr.get(); + if (tensor_depth == nullptr) { + MS_LOG(ERROR) << "tensor_depth create error."; + return; + } + tensor_depth->set_data(&depth); + auto tensor_on_value_ptr = std::make_unique(kNumberTypeFloat32, weight_shape, schema::Format_NHWC); + auto tensor_on_value = tensor_on_value_ptr.get(); + if (tensor_on_value == nullptr) { + MS_LOG(ERROR) << "tensor_on_value create error."; + return; + } + tensor_on_value->set_data(&on_value); + auto tensor_off_value_ptr = std::make_unique(kNumberTypeFloat32, weight_shape, schema::Format_NHWC); + auto tensor_off_value = tensor_off_value_ptr.get(); + if (tensor_off_value == nullptr) { + MS_LOG(ERROR) << "tensor_off_value create error."; + return; + } + tensor_off_value->set_data(&off_value); + auto tensor_out_ptr = std::make_unique(kNumberTypeFloat32, shape_out); + auto tensor_out = tensor_out_ptr.get(); + if (tensor_out == nullptr) { + MS_LOG(ERROR) << "tensor_out create error."; + return; + } + std::vector inputs{tensor_x, tensor_depth, tensor_on_value, tensor_off_value}; + std::vector outputs{tensor_out}; + auto arith_kernel = kernel::OpenCLKernelCreator( + inputs, outputs, reinterpret_cast(param), nullptr, kernel::KernelKey(), nullptr); + if (arith_kernel == nullptr) { + MS_LOG(ERROR) << "arith_kernel create error."; + return; + } + + inputs[0]->MallocData(allocator); + + std::vector kernels{arith_kernel}; + std::vector inputs_g{tensor_x}; + auto pGraph_ptr = std::make_unique(inputs_g, outputs, kernels, kernels, kernels); + auto pGraph = pGraph_ptr.get(); + if (pGraph == nullptr) { + MS_LOG(ERROR) << "pGraph create error."; + return; + } + pGraph->Init(); + memcpy(inputs[0]->MutableData(), input_data, inputs[0]->ElementsNum() * sizeof(int)); + pGraph->Run(); + + CompareOutput(outputs[0]->MutableData(), output_data, outputs[0]->ElementsNum(), static_cast(1e-5)); + for (auto t : inputs) { + t->set_data(nullptr); + } + for (auto t : outputs) { + t->set_data(nullptr); + } + + MS_LOG(INFO) << "Test OneHot passed"; +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis3Fp32) { + int depth = 4; + int axis = -1; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {1, 2, 2}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {3, 4, -1, 2}; + std::vector output_data = {-1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis3T2Fp32) { + int depth = 5; + int axis = -1; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {1, 2, 2}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {-1, 3, 4, 5}; + std::vector output_data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis3T3Fp32) { + int depth = 9; + int axis = -1; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {1, 2, 3}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {4, 9, 8, 9, 1, 8}; + std::vector output_data = {-1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis3T4Fp32) { + int depth = 6; + int axis = -1; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {1, 2, 5}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {2, 4, 0, 6, 1, 6, 2, 2, 4, 5}; + std::vector output_data = {-1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, + 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis2Fp32) { + int depth = 5; + int axis = 2; + float on_value = 2.f; + float off_value = 0.f; + std::vector shape_in = {1, 2, 2}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {2, 3, 0, 3}; + std::vector output_data = {0.0f, 0.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f, + 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis2T2Fp32) { + int depth = 5; + int axis = 2; + float on_value = 2.f; + float off_value = 0.f; + std::vector shape_in = {1, 6, 2}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {1, 1, 1, 0, 1, 1, 4, -1, 4, 4, -1, 1}; + std::vector output_data = {0.0f, 0.0f, 2.0f, 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f, + 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f, 2.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 2.0f, 2.0f, 0.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis2T3Fp32) { + int depth = 1; + int axis = 2; + float on_value = 2.f; + float off_value = 0.f; + std::vector shape_in = {1, 2, 2}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {-1, 1, -1, 0}; + std::vector output_data = {0.0f, 0.0f, 0.0f, 2.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis2T4Fp32) { + int depth = 5; + int axis = 2; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {1, 2, 5}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {4, 0, -1, 2, 5, 4, -1, 4, 4, 4}; + std::vector output_data = {-1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis1T1Fp32) { + int depth = 1; + int axis = 1; + float on_value = 2.f; + float off_value = -2.f; + std::vector shape_in = {1, 6, 6}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {0, -1, 1, 0, -1, -1, 0, 0, -1, 1, 0, -1, -1, 1, 1, -1, 1, 1, + -1, 1, 1, 1, -1, 0, 0, -1, 0, 0, 1, 1, 1, 1, 0, 0, 0, -1}; + std::vector output_data = {2.0f, -2.0f, -2.0f, 2.0f, -2.0f, -2.0f, 2.0f, 2.0f, -2.0f, -2.0f, 2.0f, -2.0f, + -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, + 2.0f, -2.0f, 2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, 2.0f, 2.0f, -2.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis1T2Fp32) { + int depth = 4; + int axis = 1; + float on_value = 2.f; + float off_value = -2.f; + std::vector shape_in = {1, 2, 2}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {-1, 1, 1, 2}; + std::vector output_data = {-2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, 2.0f, -2.0f, + -2.0f, -2.0f, -2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis1T3Fp32) { + int depth = 5; + int axis = 1; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {1, 2, 5}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {3, 5, 2, 0, 2, 2, -1, 0, 4, 3}; + std::vector output_data = {-1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis0Fp32) { + int depth = 5; + int axis = 0; + float on_value = 2.f; + float off_value = -2.f; + std::vector shape_in = {1, 2, 2}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {4, 0, 3, 3}; + std::vector output_data = {-2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, + -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, 2.0f, 2.0f, -2.0f, -2.0f, -2.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis0T2Fp32) { + int depth = 5; + int axis = 0; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {1, 2, 5}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {2, 4, 4, 3, 5, 0, 3, 3, -1, 2}; + std::vector output_data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, + -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, + -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot4DAxis0T3Fp32) { + int depth = 5; + int axis = 0; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {2, 2, 5}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {0, 3, 2, 0, 0, 3, 4, 1, 5, 1, 4, -1, 3, 3, 1, 1, 4, 2, 2, 4}; + std::vector output_data = { + 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, + -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, + 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot3DAxis0Fp32) { + int depth = 5; + int axis = 0; + float on_value = 2.f; + float off_value = -2.f; + std::vector shape_in = {2, 3}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {4, 4, 3, 2, -1, 5}; + std::vector output_data = {-2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, + -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f, + 2.0f, -2.0f, -2.0f, -2.0f, 2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot3DAxis0T2Fp32) { + int depth = 5; + int axis = 0; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {2, 5}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {4, 2, 2, 3, -1, 5, 2, 4, 5, -1}; + std::vector output_data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot3DAxis1Fp32) { + int depth = 5; + int axis = 1; + float on_value = 2.f; + float off_value = -2.f; + std::vector shape_in = {2, 3}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {0, 0, 0, 0, 4, -1}; + std::vector output_data = {2.0f, 2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, + -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f, + -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, -2.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot3DAxis1T2Fp32) { + int depth = 5; + int axis = 1; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {2, 5}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {1, -1, 3, 2, 5, 5, 4, 5, 0, -1}; + std::vector output_data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot3DAxis2Fp32) { + int depth = 4; + int axis = 2; + float on_value = 2.f; + float off_value = -2.f; + std::vector shape_in = {2, 2}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {0, 3, 4, 2}; + std::vector output_data = {2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, + -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, -2.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot3DAxis2T2Fp32) { + int depth = 5; + int axis = 2; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {2, 5}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {0, -1, 2, -1, 5, 4, 2, -1, 4, -1}; + std::vector output_data = {1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, + -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot2DAxis0Fp32) { + int depth = 3; + int axis = 0; + float on_value = 2.f; + float off_value = -2.f; + std::vector shape_in = {3}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {2, 1, 3}; + std::vector output_data = {-2.0f, -2.0f, -2.0f, -2.0f, 2.0f, -2.0f, 2.0f, -2.0f, -2.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot2DAxis0T2Fp32) { + int depth = 5; + int axis = 0; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {5}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {2, 2, 0, 0, 4}; + std::vector output_data = {-1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot2DAxis1Fp32) { + int depth = 3; + int axis = -1; + float on_value = 2.f; + float off_value = -2.f; + std::vector shape_in = {3}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {1, 2, 0}; + std::vector output_data = {-2.0f, 2.0f, -2.0f, -2.0f, -2.0f, 2.0f, 2.0f, -2.0f, -2.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot2DAxis1T2Fp32) { + int depth = 5; + int axis = -1; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {5}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {5, 4, 0, 4, -1}; + std::vector output_data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot1DAxis0Fp32) { + int depth = 3; + int axis = -1; + float on_value = 2.f; + float off_value = -2.f; + std::vector shape_in = {}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {1}; + std::vector output_data = {-2.0f, 2.0f, -2.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} + +TEST_F(TestOneHotOpenCL, OneHot1DAxis0T2Fp32) { + int depth = 5; + int axis = 0; + float on_value = 1.f; + float off_value = -1.f; + std::vector shape_in = {}; + std::vector shape_out = shape_in; + shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth); + std::vector input_data = {4}; + std::vector output_data = {-1.0f, -1.0f, -1.0f, -1.0f, 1.0f}; + + RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value); +} +} // namespace mindspore