Browse Source

!8546 [MS][LITE][GPU]add lite gpu op onehot

From: @chenzupeng
Reviewed-by: @ddwsky,@zhang_xue_tong
Signed-off-by: @ddwsky
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
97c0b7eef6
6 changed files with 940 additions and 8 deletions
  1. +232
    -0
      mindspore/lite/src/runtime/kernel/opencl/cl/one_hot.cl
  2. +102
    -0
      mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.cc
  3. +52
    -0
      mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h
  4. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.cc
  5. +19
    -7
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h
  6. +534
    -0
      mindspore/lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc

+ 232
- 0
mindspore/lite/src/runtime/kernel/opencl/cl/one_hot.cl View File

@@ -0,0 +1,232 @@
#ifdef cl_khr_fp16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif

#define C4NUM 4
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void OneHotAxis0(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,
int4 out_shape, int depth, float on_value, float off_value, int C) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H * N
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
int N = Z / out_shape.y;
int H = Z % out_shape.y;
int in_index = (H * out_shape.z + Y) * out_shape.w + X;
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x));
int *indices_int = (int *)&indices;
FLT4 result = (FLT4)(0.f);
if (4 * X < C) {
if (indices_int[0] == N) {
result.x = (FLT)(on_value);
} else {
result.x = (FLT)(off_value);
}
}
if (4 * X + 1 < C) {
if (indices_int[1] == N) {
result.y = (FLT)(on_value);
} else {
result.y = (FLT)(off_value);
}
}
if (4 * X + 2 < C) {
if (indices_int[2] == N) {
result.z = (FLT)(on_value);
} else {
result.z = (FLT)(off_value);
}
}
if (4 * X + 3 < C) {
if (indices_int[3] == N) {
result.w = (FLT)(on_value);
} else {
result.w = (FLT)(off_value);
}
}
WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);
}

__kernel void OneHotAxis1(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,
int4 out_shape, int depth, float on_value, float off_value, int C) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H * N
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
int N = Z / out_shape.y;
int H = Z % out_shape.y;
int in_index = (N * out_shape.z + Y) * out_shape.w + X;
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x));
int *indices_int = (int *)&indices;
FLT4 result = (FLT4)(0.f);
if (4 * X < C) {
if (indices_int[0] == H) {
result.x = (FLT)(on_value);
} else {
result.x = (FLT)(off_value);
}
}
if (4 * X + 1 < C) {
if (indices_int[1] == H) {
result.y = (FLT)(on_value);
} else {
result.y = (FLT)(off_value);
}
}
if (4 * X + 2 < C) {
if (indices_int[2] == H) {
result.z = (FLT)(on_value);
} else {
result.z = (FLT)(off_value);
}
}
if (4 * X + 3 < C) {
if (indices_int[3] == H) {
result.w = (FLT)(on_value);
} else {
result.w = (FLT)(off_value);
}
}
WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);
}

__kernel void OneHotAxis2(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,
int4 out_shape, int depth, float on_value, float off_value, int C) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H * N
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
int N = Z / out_shape.y;
int H = Z % out_shape.y;
int in_index = (N * out_shape.y + H) * out_shape.w + X;
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x));
int *indices_int = (int *)&indices;
FLT4 result = (FLT4)(0.f);
if (4 * X < C) {
if (indices_int[0] == Y) {
result.x = (FLT)(on_value);
} else {
result.x = (FLT)(off_value);
}
}
if (4 * X + 1 < C) {
if (indices_int[1] == Y) {
result.y = (FLT)(on_value);
} else {
result.y = (FLT)(off_value);
}
}
if (4 * X + 2 < C) {
if (indices_int[2] == Y) {
result.z = (FLT)(on_value);
} else {
result.z = (FLT)(off_value);
}
}
if (4 * X + 3 < C) {
if (indices_int[3] == Y) {
result.w = (FLT)(on_value);
} else {
result.w = (FLT)(off_value);
}
}
WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);
}

__kernel void OneHotAxis3(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,
int4 out_shape, int depth, float on_value, float off_value, int C) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H * N
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
int N = Z / out_shape.y;
int H = Z % out_shape.y;
int ci4_size = UP_DIV(out_shape.z, C4NUM);
int in_index_c4 = (N * out_shape.y + H) * ci4_size + Y / 4;
int in_index_c4_remainder = Y % 4;
FLT4 indices =
READ_IMAGE(src_data, smp_zero, (int2)(in_index_c4 % in_image2d_shape.x, in_index_c4 / in_image2d_shape.x));
int *indices_int = (int *)&indices;
int index_one = indices_int[in_index_c4_remainder];
FLT4 result = (FLT4)(0.f);
if (4 * X < C) {
if (index_one == 4 * X) {
result.x = (FLT)(on_value);
} else {
result.x = (FLT)(off_value);
}
}
if (4 * X + 1 < C) {
if (index_one == 4 * X + 1) {
result.y = (FLT)(on_value);
} else {
result.y = (FLT)(off_value);
}
}
if (4 * X + 2 < C) {
if (index_one == 4 * X + 2) {
result.z = (FLT)(on_value);
} else {
result.z = (FLT)(off_value);
}
}
if (4 * X + 3 < C) {
if (index_one == 4 * X + 3) {
result.w = (FLT)(on_value);
} else {
result.w = (FLT)(off_value);
}
}
WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);
}

__kernel void OneHot2DAxis0(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,
int4 out_shape, int depth, float on_value, float off_value, int C) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // N
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
FLT4 result = (FLT4)(0.f);
int channel = 4 * X;
if (channel < C) {
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel));
int index = ((int *)&indices)[0];
if (index == Z) {
result.x = (FLT)(on_value);
} else {
result.x = (FLT)(off_value);
}
}
channel++;
if (channel < C) {
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel));
int index = ((int *)&indices)[0];
if (index == Z) {
result.y = (FLT)(on_value);
} else {
result.y = (FLT)(off_value);
}
}
channel++;
if (channel < C) {
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel));
int index = ((int *)&indices)[0];
if (index == Z) {
result.z = (FLT)(on_value);
} else {
result.z = (FLT)(off_value);
}
}
channel++;
if (channel < C) {
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel));
int index = ((int *)&indices)[0];
if (index == Z) {
result.w = (FLT)(on_value);
} else {
result.w = (FLT)(off_value);
}
}
WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);
}

+ 102
- 0
mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.cc View File

@@ -0,0 +1,102 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <set>
#include <string>
#include <map>
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/kernel/one_hot.h"
#include "src/runtime/kernel/opencl/cl/one_hot.cl.inc"

using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_OneHot;

namespace mindspore::kernel {
int OneHotOpenCLKernel::CheckSpecs() { return RET_OK; }

int OneHotOpenCLKernel::Prepare() {
std::string kernel_name = "OneHot";
auto param = reinterpret_cast<OneHotParameter *>(op_parameter_);
in_shape_ = Image2DInfo(in_tensors_[0]);
out_shape_ = Image2DInfo(out_tensors_[0]);
axis_ = out_shape_.AlignAxis(param->axis_);
if (in_tensors_[0]->shape().size() == 1 && axis_ == 0) {
kernel_name += "2DAxis0";
} else {
kernel_name += "Axis" + std::to_string(axis_);
}
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
#else
std::set<std::string> build_options;
std::string source = one_hot_source;
std::string program_name = "OneHot";
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
InitWeights();
SetConstArgs();
SetGlobalLocal();
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return mindspore::lite::RET_OK;
}

int OneHotOpenCLKernel::InitWeights() {
if (in_tensors_.size() <= 1) {
return RET_ERROR;
}
depth_ = static_cast<int32_t *>(in_tensors_[1]->data_c())[0];
if (in_tensors_.size() > 2) {
on_value_ = static_cast<float *>(in_tensors_[2]->data_c())[0];
}
if (in_tensors_.size() > 3) {
off_value_ = static_cast<float *>(in_tensors_[3]->data_c())[0];
}
return RET_OK;
}

void OneHotOpenCLKernel::SetConstArgs() {
cl_int2 cl_in_image2d_shape = {static_cast<cl_int>(in_shape_.width), static_cast<cl_int>(in_shape_.height)};
cl_int4 cl_out_shape = {static_cast<cl_int>(out_shape_.N), static_cast<cl_int>(out_shape_.H),
static_cast<cl_int>(out_shape_.W), static_cast<cl_int>(out_shape_.Slice)};
int arg_idx = 2;
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, cl_in_image2d_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, cl_out_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, depth_);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, on_value_);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, off_value_);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, static_cast<int>(out_shape_.C));
}
void OneHotOpenCLKernel::SetGlobalLocal() {
global_range_ = {out_shape_.Slice, out_shape_.W, out_shape_.H * out_shape_.N};
}

int OneHotOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running!";
int arg_idx = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c());
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr);
return mindspore::lite::RET_OK;
}

REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_OneHot, OpenCLKernelCreator<OneHotOpenCLKernel>)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_OneHot, OpenCLKernelCreator<OneHotOpenCLKernel>)
} // namespace mindspore::kernel

+ 52
- 0
mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h View File

@@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ONE_HOT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ONE_HOT_H_

#include <vector>
#include <string>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "nnacl/fp32/one_hot.h"

namespace mindspore::kernel {
class OneHotOpenCLKernel : public OpenCLKernel {
public:
OneHotOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~OneHotOpenCLKernel() override = default;

int Run() override;
int Prepare() override;
int InitWeights() override;
int CheckSpecs() override;
void SetConstArgs() override;
void SetGlobalLocal() override;

private:
cl::Kernel kernel_;
int depth_{0};
float on_value_{1.0f};
float off_value_{0.0f};
int axis_{0};
Image2DInfo in_shape_ = Image2DInfo(nullptr);
Image2DInfo out_shape_ = Image2DInfo(nullptr);
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ONE_HOT_H_

+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.cc View File

@@ -37,7 +37,7 @@ int SpaceToDepthOpenCLKernel::Prepare() {
std::string kernel_name;
in_shape_ = Image2DInfo(in_tensors_[0]);
out_shape_ = Image2DInfo(out_tensors_[0]);
if (in_shape_.C % 4 != 0) {
if (in_shape_.C % C4NUM != 0) {
kernel_name = "SpaceToDepth";
} else {
kernel_name = "SpaceToDepthAlign";


+ 19
- 7
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h View File

@@ -73,23 +73,23 @@ struct Image2DInfo {
return;
}
auto shape = tensor->shape();
auto ndim = shape.size();
if (ndim == 1) {
OriDim = shape.size();
if (OriDim == 1) {
N = shape[0];
} else if (ndim == 2) {
} else if (OriDim == 2) {
N = shape[0];
C = shape[1];
} else if (ndim == 3) {
} else if (OriDim == 3) {
N = shape[0];
W = shape[1];
C = shape[2];
} else if (ndim == 4) {
} else if (OriDim == 4) {
N = shape[0];
H = shape[1];
W = shape[2];
C = shape[3];
} else if (ndim >= 5) {
MS_LOG(ERROR) << "GPU doesn't support Tensor with ndim>=" << ndim;
} else if (OriDim >= 5) {
MS_LOG(ERROR) << "GPU doesn't support Tensor with ndim>=" << OriDim;
}
Slice = UP_DIV(C, C4NUM);

@@ -116,6 +116,17 @@ struct Image2DInfo {
return row_pitch;
}

int AlignAxis(int oriAxis) const {
if (OriDim == 0) {
return 0;
}
int no_neg_axis = (oriAxis + OriDim) % OriDim;
if (no_neg_axis == 0) {
return 0;
}
return no_neg_axis + 4 - OriDim;
}

size_t N{1};
size_t H{1};
size_t W{1};
@@ -129,6 +140,7 @@ struct Image2DInfo {
size_t ElementsC4Num{};
size_t OriginSize{};
size_t Image2DSize{};
size_t OriDim{};
};

class OpenCLKernel : public LiteKernel {


+ 534
- 0
mindspore/lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc View File

@@ -0,0 +1,534 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h"
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"

namespace mindspore {
class TestOneHotOpenCL : public mindspore::CommonTest {
public:
TestOneHotOpenCL() {}
};

void RunTestCaseOneHot(const std::vector<int> &shape_in, const std::vector<int> &shape_out, void *input_data,
void *output_data, int axis, int depth, float on_value, float off_value) {
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
auto param = static_cast<OneHotParameter *>(malloc(sizeof(OneHotParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "param_ptr create error.";
return;
}
param->axis_ = axis;
auto tensor_x_ptr = std::make_unique<lite::Tensor>(kNumberTypeFloat32, shape_in, schema::Format_NHWC);
auto tensor_x = tensor_x_ptr.get();
if (tensor_x == nullptr) {
MS_LOG(ERROR) << "tensor_x create error.";
return;
}
std::vector<int> weight_shape = {};
auto tensor_depth_ptr = std::make_unique<lite::Tensor>(kNumberTypeInt32, weight_shape, schema::Format_NHWC);
auto tensor_depth = tensor_depth_ptr.get();
if (tensor_depth == nullptr) {
MS_LOG(ERROR) << "tensor_depth create error.";
return;
}
tensor_depth->set_data(&depth);
auto tensor_on_value_ptr = std::make_unique<lite::Tensor>(kNumberTypeFloat32, weight_shape, schema::Format_NHWC);
auto tensor_on_value = tensor_on_value_ptr.get();
if (tensor_on_value == nullptr) {
MS_LOG(ERROR) << "tensor_on_value create error.";
return;
}
tensor_on_value->set_data(&on_value);
auto tensor_off_value_ptr = std::make_unique<lite::Tensor>(kNumberTypeFloat32, weight_shape, schema::Format_NHWC);
auto tensor_off_value = tensor_off_value_ptr.get();
if (tensor_off_value == nullptr) {
MS_LOG(ERROR) << "tensor_off_value create error.";
return;
}
tensor_off_value->set_data(&off_value);
auto tensor_out_ptr = std::make_unique<lite::Tensor>(kNumberTypeFloat32, shape_out);
auto tensor_out = tensor_out_ptr.get();
if (tensor_out == nullptr) {
MS_LOG(ERROR) << "tensor_out create error.";
return;
}
std::vector<lite::Tensor *> inputs{tensor_x, tensor_depth, tensor_on_value, tensor_off_value};
std::vector<lite::Tensor *> outputs{tensor_out};
auto arith_kernel = kernel::OpenCLKernelCreator<kernel::OneHotOpenCLKernel>(
inputs, outputs, reinterpret_cast<OpParameter *>(param), nullptr, kernel::KernelKey(), nullptr);
if (arith_kernel == nullptr) {
MS_LOG(ERROR) << "arith_kernel create error.";
return;
}

inputs[0]->MallocData(allocator);

std::vector<kernel::LiteKernel *> kernels{arith_kernel};
std::vector<lite::Tensor *> inputs_g{tensor_x};
auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs_g, outputs, kernels, kernels, kernels);
auto pGraph = pGraph_ptr.get();
if (pGraph == nullptr) {
MS_LOG(ERROR) << "pGraph create error.";
return;
}
pGraph->Init();
memcpy(inputs[0]->MutableData(), input_data, inputs[0]->ElementsNum() * sizeof(int));
pGraph->Run();

CompareOutput(outputs[0]->MutableData(), output_data, outputs[0]->ElementsNum(), static_cast<float>(1e-5));
for (auto t : inputs) {
t->set_data(nullptr);
}
for (auto t : outputs) {
t->set_data(nullptr);
}

MS_LOG(INFO) << "Test OneHot passed";
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis3Fp32) {
int depth = 4;
int axis = -1;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {1, 2, 2};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {3, 4, -1, 2};
std::vector<float> output_data = {-1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis3T2Fp32) {
int depth = 5;
int axis = -1;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {1, 2, 2};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {-1, 3, 4, 5};
std::vector<float> output_data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis3T3Fp32) {
int depth = 9;
int axis = -1;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {1, 2, 3};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {4, 9, 8, 9, 1, 8};
std::vector<float> output_data = {-1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis3T4Fp32) {
int depth = 6;
int axis = -1;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {1, 2, 5};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {2, 4, 0, 6, 1, 6, 2, 2, 4, 5};
std::vector<float> output_data = {-1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f,
1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis2Fp32) {
int depth = 5;
int axis = 2;
float on_value = 2.f;
float off_value = 0.f;
std::vector<int> shape_in = {1, 2, 2};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {2, 3, 0, 3};
std::vector<float> output_data = {0.0f, 0.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f,
2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis2T2Fp32) {
int depth = 5;
int axis = 2;
float on_value = 2.f;
float off_value = 0.f;
std::vector<int> shape_in = {1, 6, 2};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {1, 1, 1, 0, 1, 1, 4, -1, 4, 4, -1, 1};
std::vector<float> output_data = {0.0f, 0.0f, 2.0f, 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f,
2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f, 2.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2.0f, 2.0f, 0.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis2T3Fp32) {
int depth = 1;
int axis = 2;
float on_value = 2.f;
float off_value = 0.f;
std::vector<int> shape_in = {1, 2, 2};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {-1, 1, -1, 0};
std::vector<float> output_data = {0.0f, 0.0f, 0.0f, 2.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis2T4Fp32) {
int depth = 5;
int axis = 2;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {1, 2, 5};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {4, 0, -1, 2, 5, 4, -1, 4, 4, 4};
std::vector<float> output_data = {-1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis1T1Fp32) {
int depth = 1;
int axis = 1;
float on_value = 2.f;
float off_value = -2.f;
std::vector<int> shape_in = {1, 6, 6};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {0, -1, 1, 0, -1, -1, 0, 0, -1, 1, 0, -1, -1, 1, 1, -1, 1, 1,
-1, 1, 1, 1, -1, 0, 0, -1, 0, 0, 1, 1, 1, 1, 0, 0, 0, -1};
std::vector<float> output_data = {2.0f, -2.0f, -2.0f, 2.0f, -2.0f, -2.0f, 2.0f, 2.0f, -2.0f, -2.0f, 2.0f, -2.0f,
-2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f,
2.0f, -2.0f, 2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, 2.0f, 2.0f, -2.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis1T2Fp32) {
int depth = 4;
int axis = 1;
float on_value = 2.f;
float off_value = -2.f;
std::vector<int> shape_in = {1, 2, 2};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {-1, 1, 1, 2};
std::vector<float> output_data = {-2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, 2.0f, -2.0f,
-2.0f, -2.0f, -2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis1T3Fp32) {
int depth = 5;
int axis = 1;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {1, 2, 5};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {3, 5, 2, 0, 2, 2, -1, 0, 4, 3};
std::vector<float> output_data = {-1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis0Fp32) {
int depth = 5;
int axis = 0;
float on_value = 2.f;
float off_value = -2.f;
std::vector<int> shape_in = {1, 2, 2};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {4, 0, 3, 3};
std::vector<float> output_data = {-2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f,
-2.0f, -2.0f, -2.0f, -2.0f, 2.0f, 2.0f, 2.0f, -2.0f, -2.0f, -2.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis0T2Fp32) {
int depth = 5;
int axis = 0;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {1, 2, 5};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {2, 4, 4, 3, 5, 0, 3, 3, -1, 2};
std::vector<float> output_data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f,
-1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f,
-1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot4DAxis0T3Fp32) {
int depth = 5;
int axis = 0;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {2, 2, 5};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {0, 3, 2, 0, 0, 3, 4, 1, 5, 1, 4, -1, 3, 3, 1, 1, 4, 2, 2, 4};
std::vector<float> output_data = {
1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f,
-1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f,
1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot3DAxis0Fp32) {
int depth = 5;
int axis = 0;
float on_value = 2.f;
float off_value = -2.f;
std::vector<int> shape_in = {2, 3};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {4, 4, 3, 2, -1, 5};
std::vector<float> output_data = {-2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f,
-2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f,
2.0f, -2.0f, -2.0f, -2.0f, 2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot3DAxis0T2Fp32) {
int depth = 5;
int axis = 0;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {2, 5};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {4, 2, 2, 3, -1, 5, 2, 4, 5, -1};
std::vector<float> output_data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot3DAxis1Fp32) {
int depth = 5;
int axis = 1;
float on_value = 2.f;
float off_value = -2.f;
std::vector<int> shape_in = {2, 3};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {0, 0, 0, 0, 4, -1};
std::vector<float> output_data = {2.0f, 2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f,
-2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, -2.0f, -2.0f, -2.0f, -2.0f,
-2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, -2.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot3DAxis1T2Fp32) {
int depth = 5;
int axis = 1;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {2, 5};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {1, -1, 3, 2, 5, 5, 4, 5, 0, -1};
std::vector<float> output_data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot3DAxis2Fp32) {
int depth = 4;
int axis = 2;
float on_value = 2.f;
float off_value = -2.f;
std::vector<int> shape_in = {2, 2};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {0, 3, 4, 2};
std::vector<float> output_data = {2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f,
-2.0f, -2.0f, -2.0f, -2.0f, -2.0f, -2.0f, 2.0f, -2.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot3DAxis2T2Fp32) {
int depth = 5;
int axis = 2;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {2, 5};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {0, -1, 2, -1, 5, 4, 2, -1, 4, -1};
std::vector<float> output_data = {1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f,
-1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot2DAxis0Fp32) {
int depth = 3;
int axis = 0;
float on_value = 2.f;
float off_value = -2.f;
std::vector<int> shape_in = {3};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {2, 1, 3};
std::vector<float> output_data = {-2.0f, -2.0f, -2.0f, -2.0f, 2.0f, -2.0f, 2.0f, -2.0f, -2.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot2DAxis0T2Fp32) {
int depth = 5;
int axis = 0;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {5};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {2, 2, 0, 0, 4};
std::vector<float> output_data = {-1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot2DAxis1Fp32) {
int depth = 3;
int axis = -1;
float on_value = 2.f;
float off_value = -2.f;
std::vector<int> shape_in = {3};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {1, 2, 0};
std::vector<float> output_data = {-2.0f, 2.0f, -2.0f, -2.0f, -2.0f, 2.0f, 2.0f, -2.0f, -2.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot2DAxis1T2Fp32) {
int depth = 5;
int axis = -1;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {5};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {5, 4, 0, 4, -1};
std::vector<float> output_data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,
-1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot1DAxis0Fp32) {
int depth = 3;
int axis = -1;
float on_value = 2.f;
float off_value = -2.f;
std::vector<int> shape_in = {};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {1};
std::vector<float> output_data = {-2.0f, 2.0f, -2.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}

TEST_F(TestOneHotOpenCL, OneHot1DAxis0T2Fp32) {
int depth = 5;
int axis = 0;
float on_value = 1.f;
float off_value = -1.f;
std::vector<int> shape_in = {};
std::vector<int> shape_out = shape_in;
shape_out.insert(shape_out.begin() + (axis + shape_in.size() + 1) % (shape_in.size() + 1), depth);
std::vector<int> input_data = {4};
std::vector<float> output_data = {-1.0f, -1.0f, -1.0f, -1.0f, 1.0f};

RunTestCaseOneHot(shape_in, shape_out, input_data.data(), output_data.data(), axis, depth, on_value, off_value);
}
} // namespace mindspore

Loading…
Cancel
Save