From 9b4cce82081dac16f0a77408558e30ac6bbdff19 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Thu, 10 Sep 2020 17:11:46 +0800 Subject: [PATCH] l2norm multi thread and trailing axis --- mindspore/lite/nnacl/l2_norm.c | 76 +++++--- mindspore/lite/nnacl/l2_norm.h | 6 +- mindspore/lite/nnacl/l2_norm_parameter.h | 5 +- mindspore/lite/schema/ops.fbs | 1 + mindspore/lite/src/ops/l2_norm.cc | 5 + mindspore/lite/src/ops/l2_norm.h | 2 + mindspore/lite/src/populate_parameter.cc | 11 +- .../src/runtime/kernel/arm/fp32/l2_norm.cc | 169 +++++++++++++++--- .../src/runtime/kernel/arm/fp32/l2_norm.h | 19 +- .../kernel/arm/fp32/l2norm_fp32_test.cc | 160 +++++++++++++++++ 10 files changed, 395 insertions(+), 59 deletions(-) create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/l2norm_fp32_test.cc diff --git a/mindspore/lite/nnacl/l2_norm.c b/mindspore/lite/nnacl/l2_norm.c index a45658e64a..22fc6fde0d 100644 --- a/mindspore/lite/nnacl/l2_norm.c +++ b/mindspore/lite/nnacl/l2_norm.c @@ -16,36 +16,60 @@ #include "nnacl/l2_norm.h" #include +#include "nnacl/errorcode.h" -int L2NormFp32(const float *input_ptr, float *output_ptr, - L2NormParameter *param) { - int *axis = param->axis_; - size_t axis_num = param->axis_num_; - float epsilon = param->epsilon_; - int shape_num = param->shape_num_; +int CalcThreadSquareSum(const float *input_ptr, float *sum, int begin, int end) { + *sum = 0.0f; + int i; + for (i = begin; i < end; ++i) { + *sum += input_ptr[i] * input_ptr[i]; + } + return NNACL_OK; +} - // default case, axis is set default - if (shape_num == axis_num) { - bool default_case_flag = true; - for (int i = 0; i < axis_num; i++) { - if (axis[i] != i) { - default_case_flag = false; - } +int ThreadDivSqrtSum(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const float sqrt_sum, + const int begin, const int end) { + bool is_relu = param->act_type_ == ActType_Relu; + bool is_relu6 = param->act_type_ == ActType_Relu6; + int i; + for (i = begin; i < end; i++) { + float tmp = input_ptr[i] / sqrt_sum; + if (is_relu) { + output_ptr[i] = MSMAX(0, tmp); + } else if (is_relu6) { + output_ptr[i] = MSMIN(6, MSMAX(0, tmp)); + } else { + output_ptr[i] = tmp; } - if (default_case_flag) { - int data_num = param->data_num_; - float sum = 0; - for (int i = 0; i < data_num; i++) { - sum = sum + input_ptr[i] * input_ptr[i]; - } - float res = sqrt(sum > epsilon ? sum : epsilon); - for (int i = 0; i < data_num; i++) { - output_ptr[i] = input_ptr[i] / res; + } + return NNACL_OK; +} + +int ThreadTrailingAxis(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const int begin, + const int end) { + bool is_relu = param->act_type_ == ActType_Relu; + bool is_relu6 = param->act_type_ == ActType_Relu6; + + const int c = param->shape_[param->shape_num_ - 1]; + int i = 0; + for (i = begin; i < end; ++i) { + float square_sum = 0.0f; + int j = 0; + for (j = 0; j < c; ++j) { + const float val = input_ptr[i * c + j]; + square_sum += val * val; + } + float sqrt_sum = sqrt(square_sum > param->epsilon_ ? square_sum : param->epsilon_); + for (j = 0; j < c; ++j) { + float tmp = input_ptr[i * c + j] / sqrt_sum; + if (is_relu) { + output_ptr[i * c + j] = MSMAX(0, tmp); + } else if (is_relu6) { + output_ptr[i * c + j] = MSMIN(6, MSMAX(0, tmp)); + } else { + output_ptr[i * c + j] = tmp; } - return 0; } - } else { - return -1; } - return 0; + return NNACL_OK; } diff --git a/mindspore/lite/nnacl/l2_norm.h b/mindspore/lite/nnacl/l2_norm.h index 2895cc418f..5932af687f 100644 --- a/mindspore/lite/nnacl/l2_norm.h +++ b/mindspore/lite/nnacl/l2_norm.h @@ -21,7 +21,11 @@ #ifdef __cplusplus extern "C" { #endif -int L2NormFp32(const float *input_ptr, float *output_ptr, L2NormParameter *param); +int CalcThreadSquareSum(const float *input_ptr, float *sum, int begin, int end); +int ThreadDivSqrtSum(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const float sqrt_sum, + const int begin, const int end); +int ThreadTrailingAxis(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const int begin, + const int end); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/l2_norm_parameter.h b/mindspore/lite/nnacl/l2_norm_parameter.h index 1ffc0cf78d..8ce43c0740 100644 --- a/mindspore/lite/nnacl/l2_norm_parameter.h +++ b/mindspore/lite/nnacl/l2_norm_parameter.h @@ -24,9 +24,10 @@ typedef struct L2NormParameter { int *axis_; size_t axis_num_; float epsilon_; - float data_num_; + int data_num_; int *shape_; - int shape_num_; + size_t shape_num_; + ActType act_type_; } L2NormParameter; #endif // MINDSPORE_LITE_NNACL_L2NORM_PARAMETER_H_ diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index a2e25b327d..82a5e8d4eb 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -773,6 +773,7 @@ table FloorMod { table L2Norm { axis: [int]; epsilon: float; + activationType: ActivationType; } table LogicalAnd { diff --git a/mindspore/lite/src/ops/l2_norm.cc b/mindspore/lite/src/ops/l2_norm.cc index 4c8431ae68..4fc1486daa 100644 --- a/mindspore/lite/src/ops/l2_norm.cc +++ b/mindspore/lite/src/ops/l2_norm.cc @@ -21,9 +21,13 @@ namespace lite { #ifdef PRIMITIVE_WRITEABLE std::vector L2Norm::GetAxis() const { return this->primitive_->value.AsL2Norm()->axis; } float L2Norm::GetEpsilon() const { return this->primitive_->value.AsL2Norm()->epsilon; } +int L2Norm::GetActivationType() const { return this->primitive_->value.AsL2Norm()->activationType; } void L2Norm::SetAxis(const std::vector &axis) { this->primitive_->value.AsL2Norm()->axis = axis; } void L2Norm::SetEpsilon(float epsilon) { this->primitive_->value.AsL2Norm()->epsilon = epsilon; } +void L2Norm::SetActivationType(int activationType) { + this->primitive_->value.AsL2Norm()->activationType = (schema::ActivationType)activationType; +} #else int L2Norm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { @@ -51,6 +55,7 @@ std::vector L2Norm::GetAxis() const { return std::vector(fb_vector->begin(), fb_vector->end()); } float L2Norm::GetEpsilon() const { return this->primitive_->value_as_L2Norm()->epsilon(); } +int L2Norm::GetActivationType() const { return this->primitive_->value_as_L2Norm()->activationType(); } #endif } // namespace lite diff --git a/mindspore/lite/src/ops/l2_norm.h b/mindspore/lite/src/ops/l2_norm.h index e44579d574..19556b00a8 100644 --- a/mindspore/lite/src/ops/l2_norm.h +++ b/mindspore/lite/src/ops/l2_norm.h @@ -34,6 +34,7 @@ class L2Norm : public PrimitiveC { explicit L2Norm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetAxis(const std::vector &axis); void SetEpsilon(float epsilon); + void SetActivationType(int activationType); #else L2Norm() = default; @@ -41,6 +42,7 @@ class L2Norm : public PrimitiveC { #endif std::vector GetAxis() const; float GetEpsilon() const; + int GetActivationType() const; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 57efab75b3..488c6f54d2 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -1535,11 +1535,18 @@ OpParameter *PopulateL2NormParameter(const mindspore::lite::PrimitiveC *primitiv for (size_t i = 0; i < axis_vec.size(); i++) { l2_norm_parameter->axis_[i] = axis_vec[i]; } - if (param->GetEpsilon() < 1e-12) { - l2_norm_parameter->epsilon_ = 1e-12; + if (param->GetEpsilon() < 1e-6) { + l2_norm_parameter->epsilon_ = 1e-6; } else { l2_norm_parameter->epsilon_ = param->GetEpsilon(); } + if (param->GetActivationType() == static_cast(schema::ActivationType_RELU)) { + l2_norm_parameter->act_type_ = ActType_Relu; + } else if (param->GetActivationType() == static_cast(schema::ActivationType_RELU6)) { + l2_norm_parameter->act_type_ = ActType_Relu6; + } else { + l2_norm_parameter->act_type_ = ActType_No; + } return reinterpret_cast(l2_norm_parameter); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/l2_norm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/l2_norm.cc index 381f4cff47..2154be4b5f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/l2_norm.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/l2_norm.cc @@ -15,9 +15,11 @@ */ #include +#include #include "src/runtime/kernel/arm/fp32/l2_norm.h" #include "include/errorcode.h" #include "nnacl/l2_norm.h" +#include "src/runtime/runtime_api.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -26,14 +28,154 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_L2Norm; namespace mindspore::kernel { +namespace { +const int kMaxThreadNum = 8; +} int L2NormCPUKernel::Init() { - l2_norm_param_->data_num_ = in_tensors_.at(kInputIndex)->ElementsNum(); + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int L2NormCPUKernel::MallocTmpBuffer() { auto shape = in_tensors_.at(kInputIndex)->shape(); l2_norm_param_->shape_ = reinterpret_cast(malloc(shape.size() * sizeof(int))); + if (l2_norm_param_->shape_ == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + return RET_ERROR; + } + + tmp_sum_ = reinterpret_cast(malloc(kMaxThreadNum * sizeof(float))); + if (tmp_sum_ == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + return RET_ERROR; + } + return RET_OK; +} + +void L2NormCPUKernel::FreeTmpBuffer() { + if (l2_norm_param_->shape_ != nullptr) { + free(l2_norm_param_->shape_); + l2_norm_param_->shape_ = nullptr; + } + if (tmp_sum_ != nullptr) { + free(tmp_sum_); + tmp_sum_ = nullptr; + } +} + +int L2NormCPUKernel::ReSize() { + FreeTmpBuffer(); + auto ret = MallocTmpBuffer(); + if (ret != RET_OK) { + FreeTmpBuffer(); + return ret; + } + + l2_norm_param_->data_num_ = in_tensors_.at(kInputIndex)->ElementsNum(); + auto shape = in_tensors_.at(kInputIndex)->shape(); l2_norm_param_->shape_num_ = shape.size(); for (size_t i = 0; i < shape.size(); i++) { l2_norm_param_->shape_[i] = shape[i]; } + for (size_t i = 0; i < l2_norm_param_->axis_num_; ++i) { + if (l2_norm_param_->axis_[i] < 0) { + l2_norm_param_->axis_[i] += static_cast(shape.size()); + } + } + return RET_OK; +} + +int L2NormCPUKernel::CalcSquareSum(int task_id) { + int unit = UP_DIV(l2_norm_param_->data_num_, context_->thread_num_); + int begin = task_id * unit; + int end = MSMIN(begin + unit, l2_norm_param_->data_num_); + return CalcThreadSquareSum(input_ptr_, tmp_sum_ + task_id, begin, end); +} + +int L2NormCPUKernel::DivSqrtSum(int task_id) { + int unit = UP_DIV(l2_norm_param_->data_num_, context_->thread_num_); + int begin = task_id * unit; + int end = MSMIN(begin + unit, l2_norm_param_->data_num_); + return ThreadDivSqrtSum(input_ptr_, output_ptr_, l2_norm_param_, sqrt_sum_, begin, end); +} + +int L2NormCPUKernel::CalcL2NormTrailingAxis(int task_id) { + auto input = in_tensors_.at(0); + int outer_size = input->ElementsNum() / input->shape().back(); + int unit = UP_DIV(outer_size, context_->thread_num_); + int begin = task_id * unit; + int end = MSMIN(begin + unit, outer_size); + return ThreadTrailingAxis(input_ptr_, output_ptr_, l2_norm_param_, begin, end); +} + +int SquareSumRun(void *cdata, int task_id) { + auto kernel = reinterpret_cast(cdata); + auto ret = kernel->CalcSquareSum(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "L2Norm SquareSumRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int L2NormRun(void *cdata, int task_id) { + auto kernel = reinterpret_cast(cdata); + auto ret = kernel->DivSqrtSum(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "L2Norm L2NormRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int L2NormTrailingAxisRun(void *cdata, int task_id) { + auto kernel = reinterpret_cast(cdata); + auto ret = kernel->CalcL2NormTrailingAxis(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "L2Norm TrailingAxisRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int L2NormCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail! ret: " << ret; + return ret; + } + auto input_shape = in_tensors().at(kInputIndex)->shape(); + input_ptr_ = reinterpret_cast(in_tensors_.at(kInputIndex)->MutableData()); + output_ptr_ = reinterpret_cast(out_tensors_.at(kOutputIndex)->MutableData()); + if (l2_norm_param_->axis_num_ == 0 || l2_norm_param_->axis_num_ == input_shape.size()) { + // all axis + ret = ParallelLaunch(THREAD_POOL_DEFAULT, SquareSumRun, this, context_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "L2Norm error: error_code[" << ret << "]"; + return RET_ERROR; + } + float sum = 0.0f; + for (int i = 0; i < context_->thread_num_; ++i) { + sum += tmp_sum_[i]; + } + sqrt_sum_ = sqrt(sum > l2_norm_param_->epsilon_ ? sum : l2_norm_param_->epsilon_); + ret = ParallelLaunch(THREAD_POOL_DEFAULT, L2NormRun, this, context_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "L2Norm error: error_code[" << ret << "]"; + return RET_ERROR; + } + } else if (l2_norm_param_->axis_num_ == 1 && l2_norm_param_->axis_[0] == static_cast(input_shape.size()) - 1) { + ret = ParallelLaunch(THREAD_POOL_DEFAULT, L2NormTrailingAxisRun, this, context_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "L2Norm error: error_code[" << ret << "]"; + return RET_ERROR; + } + } else { + MS_LOG(ERROR) << "L2Norm only support reduce on all axis and trailing axis with trailing axis"; + return RET_ERROR; + } return RET_OK; } @@ -61,30 +203,5 @@ kernel::LiteKernel *CpuL2NormFp32KernelCreator(const std::vector return kernel; } -int L2NormCPUKernel::Run() { - auto ret = Prepare(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Prepare fail!ret: " << ret; - return ret; - } - auto input_ptr = reinterpret_cast(in_tensors_.at(kInputIndex)->MutableData()); - auto output_ptr = reinterpret_cast(out_tensors_.at(kOutputIndex)->MutableData()); - ret = L2NormFp32(input_ptr, output_ptr, l2_norm_param_); - if (ret != 0) { - MS_LOG_ERROR << "unsupported axis setting, more work will be done"; - return ret; - } - return RET_OK; -} - -L2NormCPUKernel::~L2NormCPUKernel() { - if (l2_norm_param_->shape_ != nullptr) { - free(l2_norm_param_->shape_); - } - if (l2_norm_param_->axis_ != nullptr) { - free(l2_norm_param_->axis_); - } -} - REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_L2Norm, CpuL2NormFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/l2_norm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/l2_norm.h index 7c08beccb7..4ccd37c2e9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/l2_norm.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/l2_norm.h @@ -35,14 +35,29 @@ class L2NormCPUKernel : public LiteKernel { : LiteKernel(parameter, inputs, outputs, ctx, primitive) { l2_norm_param_ = reinterpret_cast(op_parameter_); } - ~L2NormCPUKernel(); + ~L2NormCPUKernel() { + FreeTmpBuffer(); + if (l2_norm_param_->axis_ != nullptr) { + free(l2_norm_param_->axis_); + } + } + + int CalcSquareSum(int task_id); + int DivSqrtSum(int task_id); + int CalcL2NormTrailingAxis(int task_id); int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override; private: + int MallocTmpBuffer(); + void FreeTmpBuffer(); L2NormParameter *l2_norm_param_; + float sqrt_sum_; + float *input_ptr_; + float *output_ptr_; + float *tmp_sum_ = nullptr; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/l2norm_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/l2norm_fp32_test.cc new file mode 100644 index 0000000000..180c0a28fc --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/l2norm_fp32_test.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32/l2_norm.h" +#include "src/kernel_registry.h" +#include "src/lite_kernel.h" +using mindspore::schema::Format_NHWC; + +namespace mindspore { +class TestL2NormFp32 : public mindspore::CommonTest { + public: + TestL2NormFp32() = default; + void Init(const std::vector &input_shape, const std::vector &output_shape, float *input_data, + float *output_data, const int axis_num, ActType activation_type, const int thread_num); + void TearDown() override; + + public: + float err_tol_ = 1e-5; + lite::Tensor in_tensor_; + lite::Tensor out_tensor_; + std::vector inputs_{&in_tensor_}; + std::vector outputs_{&out_tensor_}; + L2NormParameter param_; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Resize}; + lite::Context ctx_ = lite::Context(); + kernel::KernelCreator creator_ = nullptr; + kernel::LiteKernel *kernel_ = nullptr; +}; + +void TestL2NormFp32::TearDown() { + in_tensor_.SetData(nullptr); + out_tensor_.SetData(nullptr); +} + +void TestL2NormFp32::Init(const std::vector &input_shape, const std::vector &output_shape, float *input_data, + float *output_data, const int axis_num, ActType activation_type, const int thread_num) { + in_tensor_.set_data_type(kNumberTypeFloat32); + in_tensor_.SetFormat(Format_NHWC); + in_tensor_.set_shape(input_shape); + out_tensor_.set_data_type(kNumberTypeFloat32); + out_tensor_.set_shape(output_shape); + in_tensor_.SetData(input_data); + out_tensor_.SetData(output_data); + + param_.axis_num_ = axis_num; + if (axis_num == 1) { + param_.axis_ = reinterpret_cast(malloc(sizeof(int))); + param_.axis_[0] = -1; + } + param_.epsilon_ = 1e-6; + param_.act_type_ = activation_type; + + desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_L2Norm}; + ctx_ = lite::Context(); + ctx_.thread_num_ = thread_num; + creator_ = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator_, nullptr); + kernel_ = creator_(inputs_, outputs_, reinterpret_cast(¶m_), &ctx_, desc, nullptr); + ASSERT_NE(kernel_, nullptr); +} + +// 2thread all axis no_activation +TEST_F(TestL2NormFp32, Test1) { + float input_data[18] = {-9.0, -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.0, + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + float output_data[18] = {0}; + std::vector input_shape = {1, 3, 2, 3}; + std::vector output_shape = {1, 3, 2, 3}; + std::vector expect = {-0.40699407, -0.3617725, -0.31655094, -0.27132937, -0.22610782, -0.18088625, + -0.13566469, -0.09044313, -0.045221563, 0.0, 0.045221563, 0.09044313, + 0.13566469, 0.18088625, 0.22610782, 0.27132937, 0.31655094, 0.3617725}; + auto output_size = 18; + int axis_num = 0; + ActType act_type = ActType_No; + int thread_num = 2; + Init(input_shape, output_shape, input_data, output_data, axis_num, act_type, thread_num); + auto ret = kernel_->Run(); + EXPECT_EQ(0, ret); + + CompareOutputData(output_data, expect.data(), output_size, err_tol_); +} + +// 2thread all axis relu +TEST_F(TestL2NormFp32, Test2) { + float input_data[18] = {-9.0, -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.0, + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + float output_data[18] = {0}; + std::vector input_shape = {1, 3, 2, 3}; + std::vector output_shape = {1, 3, 2, 3}; + std::vector expect = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.045221563, 0.09044313, + 0.13566469, 0.18088625, 0.22610782, 0.27132937, 0.31655094, 0.3617725}; + auto output_size = 18; + int axis_num = 0; + ActType act_type = ActType_Relu; + int thread_num = 2; + Init(input_shape, output_shape, input_data, output_data, axis_num, act_type, thread_num); + auto ret = kernel_->Run(); + EXPECT_EQ(0, ret); + + CompareOutputData(output_data, expect.data(), output_size, err_tol_); +} + +// 4 thread trailing axis no activation +TEST_F(TestL2NormFp32, Test3) { + float input_data[18] = {-9.0, -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.0, + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + float output_data[18] = {0}; + std::vector input_shape = {1, 3, 2, 3}; + std::vector output_shape = {1, 3, 2, 3}; + std::vector expect = {-0.6461623, -0.57436645, -0.5025706, -0.6837635, -0.5698029, -0.45584232, + -0.8017837, -0.5345225, -0.26726124, 0.0, 0.4472136, 0.8944272, + 0.42426407, 0.56568545, 0.7071068, 0.49153918, 0.57346237, 0.65538555}; + auto output_size = 18; + int axis_num = 1; + ActType act_type = ActType_No; + int thread_num = 4; + Init(input_shape, output_shape, input_data, output_data, axis_num, act_type, thread_num); + auto ret = kernel_->Run(); + EXPECT_EQ(0, ret); + + CompareOutputData(output_data, expect.data(), output_size, err_tol_); +} + +// 1 thread trailing axis no activation +TEST_F(TestL2NormFp32, Test4) { + float input_data[18] = {-9.0, -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.0, + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + float output_data[18] = {0}; + std::vector input_shape = {1, 3, 2, 3}; + std::vector output_shape = {1, 3, 2, 3}; + std::vector expect = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.4472136, 0.8944272, + 0.42426407, 0.56568545, 0.7071068, 0.49153918, 0.57346237, 0.65538555}; + auto output_size = 18; + int axis_num = 1; + ActType act_type = ActType_Relu6; + int thread_num = 1; + Init(input_shape, output_shape, input_data, output_data, axis_num, act_type, thread_num); + auto ret = kernel_->Run(); + EXPECT_EQ(0, ret); + + CompareOutputData(output_data, expect.data(), output_size, err_tol_); +} + +} // namespace mindspore