Browse Source

!6310 pad reflect and symmetric

Merge pull request !6310 from zhaozhenlong/lite/issue/pad_reflect_symmetric
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
00802a87a5
9 changed files with 584 additions and 56 deletions
  1. +37
    -0
      mindspore/lite/nnacl/fp32/pad.c
  2. +2
    -0
      mindspore/lite/nnacl/fp32/pad.h
  3. +4
    -0
      mindspore/lite/nnacl/pad_parameter.h
  4. +21
    -2
      mindspore/lite/src/ops/pad.cc
  5. +30
    -9
      mindspore/lite/src/ops/strided_slice.cc
  6. +14
    -14
      mindspore/lite/src/populate_parameter.cc
  7. +185
    -28
      mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc
  8. +13
    -3
      mindspore/lite/src/runtime/kernel/arm/fp32/pad.h
  9. +278
    -0
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pad_fp32_test.cc

+ 37
- 0
mindspore/lite/nnacl/fp32/pad.c View File

@@ -33,3 +33,40 @@ void Pad(const float *input_data, float *output_data, const int *input_shape, co
}
}
}

int TransOut2InputDimIndex(int out_dim_index, int left_pad, int in_dim, int offset) {
if (out_dim_index < left_pad) {
// left pad
const int index_sum = left_pad + offset - 1;
return MSMAX(index_sum - out_dim_index, offset);
}
out_dim_index -= left_pad;
if (out_dim_index < in_dim) {
return out_dim_index;
}
// right pad
out_dim_index -= in_dim;
const int index_sum = in_dim - 1 - offset;
return MSMAX(index_sum - out_dim_index, 0);
}

int GetInputFlattenIndex(int out_flatten_index, const int *input_shape, const PadParameter *pad_param) {
int in_flatten_index = 0;
int i;
for (i = 0; i < DEFAULT_PAD_NDIMS; ++i) {
int left_pad = pad_param->paddings_[i * 2];
int out_dim_index = out_flatten_index / pad_param->out_strides[i];
out_flatten_index %= pad_param->out_strides[i];
int in_dim_index = TransOut2InputDimIndex(out_dim_index, left_pad, input_shape[i], pad_param->mirror_offset_);
in_flatten_index += in_dim_index * pad_param->in_strides[i];
}
return in_flatten_index;
}

void MirrorPad(const float *input_data, float *output_data, const int *input_shape, const PadParameter *pad_param,
int begin, int end) {
int i = 0;
for (i = begin; i < end; ++i) {
output_data[i] = input_data[GetInputFlattenIndex(i, input_shape, pad_param)];
}
}

+ 2
- 0
mindspore/lite/nnacl/fp32/pad.h View File

@@ -29,6 +29,8 @@ extern "C" {
#endif
void Pad(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
const int *paddings, const int tid, const int thread_num);
void MirrorPad(const float *input_data, float *output_data, const int *input_shape, const PadParameter *pad_param,
int begin, int end);
#ifdef __cplusplus
}
#endif


+ 4
- 0
mindspore/lite/nnacl/pad_parameter.h View File

@@ -26,8 +26,12 @@ typedef struct PadParameter {
OpParameter op_parameter_;
PadQuantArg pad_quant_arg_;
int paddings_[MAX_PAD_SIZE];
int padding_length;
int pad_mode_;
float constant_value_;
int mirror_offset_;
int in_strides[DEFAULT_PAD_NDIMS];
int out_strides[DEFAULT_PAD_NDIMS];
} PadParameter;

#endif // MINDSPORE_LITE_NNACL_PAD_PARAMETER_H_

+ 21
- 2
mindspore/lite/src/ops/pad.cc View File

@@ -64,8 +64,6 @@ int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs)
return RET_NULL_PTR;
}

auto paddings = GetPaddings();

auto input = inputs.front();
if (input == nullptr) {
return RET_NULL_PTR;
@@ -79,6 +77,27 @@ int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs)
if (!GetInferFlag()) {
return RET_OK;
}

std::vector<int> paddings;
if (GetPaddingMode() == static_cast<int>(schema::PaddingMode_CONSTANT)) {
paddings = GetPaddings();
} else {
// mirror pad
MS_ASSERT(inputs.size() == 2);
auto paddings_tensor = inputs.at(1);
int rank = static_cast<int>(inputs.front()->shape().size());
MS_ASSERT(paddings_tensor->ElementsNum() == 2 * rank);
int *paddings_data = reinterpret_cast<int *>(paddings_tensor->MutableData());
if (paddings_data == nullptr) {
return RET_INFER_ERR;
}
paddings.clear();
for (auto i = 0; i < rank; ++i) {
paddings.emplace_back(paddings_data[i * 2]);
paddings.emplace_back(paddings_data[i * 2 + 1]);
}
}

auto input_shape = input->shape();
std::vector<int> output_shape;
MS_ASSERT(input->shape().size() <= 4);


+ 30
- 9
mindspore/lite/src/ops/strided_slice.cc View File

@@ -156,8 +156,9 @@ int StridedSlice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu
}
#endif
namespace {
constexpr int kStridedSliceOutputNum = 1;
constexpr int kStridedSliceInputNum = 1;
constexpr size_t kStridedSliceOutputNum = 1;
constexpr size_t kStridedSliceInputNum = 1;
constexpr size_t kStridedSliceMultiInputNum = 4;
} // namespace

void StridedSlice::ApplyNewAxisMask() {
@@ -231,7 +232,7 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
return RET_PARAM_INVALID;
}
if (inputs.size() != kStridedSliceInputNum) {
if (inputs.size() != kStridedSliceInputNum && inputs.size() != kStridedSliceMultiInputNum) {
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
return RET_PARAM_INVALID;
}
@@ -244,13 +245,33 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
MS_ASSERT(input != nullptr);
auto input_shape = input->shape();
std::vector<int> output_shape;
ndim_ = static_cast<int>(GetBegin().size());

for (int i = 0; i < ndim_; i++) {
in_shape_.emplace_back(input_shape.at(i));
begins_.emplace_back((GetBegin())[i]);
ends_.emplace_back((GetEnd())[i]);
strides_.emplace_back((GetStride())[i]);
if (inputs.size() == kStridedSliceInputNum) {
ndim_ = static_cast<int>(GetBegin().size());

for (int i = 0; i < ndim_; i++) {
in_shape_.emplace_back(input_shape.at(i));
begins_.emplace_back((GetBegin())[i]);
ends_.emplace_back((GetEnd())[i]);
strides_.emplace_back((GetStride())[i]);
}
} else {
auto begin_tensor = inputs.at(1);
int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData());
auto end_tensor = inputs.at(2);
int *end_data = reinterpret_cast<int *>(end_tensor->MutableData());
auto stride_tensor = inputs.at(3);
int *stride_data = reinterpret_cast<int *>(stride_tensor->MutableData());
if (begin_data == nullptr || end_data == nullptr || stride_data == nullptr) {
return RET_INFER_ERR;
}
ndim_ = begin_tensor->ElementsNum();
for (int i=0; i< ndim_; ++i) {
in_shape_.emplace_back(input_shape.at(i));
begins_.emplace_back(begin_data[i]);
ends_.emplace_back(end_data[i]);
strides_.emplace_back(stride_data[i]);
}
}

// set all mask to original input shape


+ 14
- 14
mindspore/lite/src/populate_parameter.cc View File

@@ -601,24 +601,24 @@ OpParameter *PopulatePadParameter(const mindspore::lite::PrimitiveC *primitive)
pad_param->op_parameter_.type_ = primitive->Type();
auto pad_node = reinterpret_cast<mindspore::lite::Pad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
pad_param->pad_mode_ = pad_node->GetPaddingMode();
if (pad_param->pad_mode_ == schema::PaddingMode_CONSTANT) {
if (pad_param->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
pad_param->constant_value_ = pad_node->GetConstantValue();
} else {
MS_LOG(ERROR) << "Invalid padding mode: " << pad_param->pad_mode_;
free(pad_param);
return nullptr;
}
auto size = pad_node->GetPaddings().size();
if (size > MAX_PAD_SIZE) {
MS_LOG(ERROR) << "Invalid padding size: " << size;
free(pad_param);
return nullptr;
}

auto size = pad_node->GetPaddings().size();
if (size > MAX_PAD_SIZE) {
MS_LOG(ERROR) << "Invalid padding size: " << size;
free(pad_param);
return nullptr;
for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) {
pad_param->paddings_[i] = 0;
}
for (size_t i = 0; i < size; i++) {
pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i];
}
pad_param->padding_length = MAX_PAD_SIZE;
}

for (size_t i = 0; i < size; i++) {
pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i];
}
return reinterpret_cast<OpParameter *>(pad_param);
}



+ 185
- 28
mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc View File

@@ -14,9 +14,10 @@
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp32/pad.h"
#include <string>
#include "src/kernel_registry.h"
#include "schema/model_generated.h"
#include "src/runtime/kernel/arm/fp32/pad.h"
#include "include/errorcode.h"
#include "nnacl/errorcode.h"
#include "src/runtime/runtime_api.h"
@@ -30,17 +31,9 @@ using mindspore::schema::PrimitiveType_Pad;

namespace mindspore::kernel {
namespace {
constexpr int kInputNum = 1;
constexpr int kOutputNum = 1;
} // namespace

constexpr size_t kMirrorPadInputSize = 2;
}
int PadCPUKernel::Init() {
if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) {
MS_LOG(ERROR) << "Pad input size should be " << kInputNum << ", got " << in_tensors_.size()
<< ", output size should be" << kOutputNum << ", got " << out_tensors_.size();
return RET_ERROR;
}

if (!InferShapeDone()) {
return RET_OK;
}
@@ -49,21 +42,58 @@ int PadCPUKernel::Init() {

int PadCPUKernel::ReSize() {
auto input = in_tensors_.at(0);
auto output = out_tensors_.at(0);
if (input == nullptr || output == nullptr) {
MS_LOG(ERROR) << "Pad input or output nullptr";
return RET_NULL_PTR;
}

auto rank = input->shape().size();
if (rank > DEFAULT_PAD_NDIMS) {
MS_LOG(ERROR) << "Pad input rank should <= " << DEFAULT_PAD_NDIMS << ", got " << rank;
return RET_ERROR;
}
auto output = out_tensors_.at(0);
if (pad_param_->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
auto ret = ExtendShape(in_, DEFAULT_PAD_NDIMS, input->shape().data(), rank);
if (ret != RET_OK) {
return ret;
}
ret = ExtendShape(out_, DEFAULT_PAD_NDIMS, output->shape().data(), rank);
if (ret != RET_OK) {
return ret;
}
if (pad_param_->padding_length < MAX_PAD_SIZE) {
int ori_paddings[MAX_PAD_SIZE];
for (auto i = 0; i < pad_param_->padding_length; ++i) {
ori_paddings[i] = pad_param_->paddings_[i];
}
ret = ExtendPaddings(pad_param_->paddings_, MAX_PAD_SIZE, ori_paddings, pad_param_->padding_length);
if (ret != RET_OK) {
return ret;
}
pad_param_->padding_length = MAX_PAD_SIZE;
}
}
return RET_OK;
}

for (size_t i = 0; i < rank; i++) {
in_[DEFAULT_PAD_NDIMS - rank + i] = input->shape()[i];
out_[DEFAULT_PAD_NDIMS - rank + i] = output->shape()[i];
int PadCPUKernel::ExtendShape(int *shape, int length, const int *ori_shape, int rank) {
if (shape == nullptr || ori_shape == nullptr) {
return RET_NULL_PTR;
}
for (auto i = 0; i < length - rank; ++i) {
shape[i] = 1;
}
for (auto i = length - rank; i < length; ++i) {
shape[i] = ori_shape[i - (length - rank)];
}
return RET_OK;
}

int PadCPUKernel::ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length) {
if (paddings == nullptr || ori_paddings == nullptr) {
return RET_NULL_PTR;
}
for (auto i = 0; i < length - ori_length; ++i) {
paddings[i] = 0;
}
for (auto i = length - ori_length; i < length; ++i) {
paddings[i] = ori_paddings[i - (length - ori_length)];
}
return RET_OK;
}
@@ -90,23 +120,150 @@ int PadCPUKernel::RunImpl(int task_id) {
return RET_OK;
}

int MirrorPadImpl(void *cdata, int task_id) {
auto padKernel = reinterpret_cast<PadCPUKernel *>(cdata);
int error_code = padKernel->RunMirrorPadImpl(task_id);
if (error_code != NNACL_OK) {
MS_LOG(ERROR) << "Pad Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}

int PadCPUKernel::RunMirrorPadImpl(int task_id) {
auto input = in_tensors_.at(0);
auto output = out_tensors_.at(0);
auto input_data = reinterpret_cast<float *>(input->MutableData());
auto output_data = reinterpret_cast<float *>(output->MutableData());

int unit = UP_DIV(output->ElementsNum(), context_->thread_num_);
int begin = unit * task_id;
int end = MSMIN(begin + unit, output->ElementsNum());
MirrorPad(input_data, output_data, in_, pad_param_, begin, end);
return RET_OK;
}

int PadCPUKernel::CheckPaddings(int *paddings, int length, int *input_shape, int mode) {
if (paddings == nullptr || input_shape == nullptr) {
return RET_NULL_PTR;
}
std::string prefix;
int offset;
if (mode == static_cast<int>(schema::PaddingMode_SYMMETRIC)) {
prefix = "For Pad SYMMETRIC ";
offset = 0;
} else {
prefix = "For Pad REFLECT ";
offset = 1;
}
for (auto i = 0; i < length; ++i) {
int max_valid = input_shape[i] - offset;
if (paddings[i * 2] > max_valid) {
MS_LOG(ERROR) << prefix << "paddings " << paddings[i * 2] << "should be less than " << max_valid + 1;
return RET_ERROR;
}
if (paddings[i * 2 + 1] > max_valid) {
MS_LOG(ERROR) << prefix << "paddings " << paddings[i * 2 + 1] << "should be less than " << max_valid + 1;
return RET_ERROR;
}
}
return RET_OK;
}

int PadCPUKernel::CopyPaddingFromInput() {
if (in_tensors_.size() != kMirrorPadInputSize) {
MS_LOG(ERROR) << "Pad Reflect or Symmetric mode need 2 inputs, got " << in_tensors_.size();
return RET_ERROR;
}
auto padding_tensor = in_tensors_.at(1);
auto paddings = reinterpret_cast<int *>(padding_tensor->MutableData());
if (paddings == nullptr) {
MS_LOG(ERROR) << "Pad second input data nullptr";
return RET_ERROR;
}
auto input_shape = in_tensors_.at(0)->shape();
int rank = static_cast<int>(input_shape.size());
if (padding_tensor->ElementsNum() != rank * 2) {
MS_LOG(ERROR) << "Pad second input elements num" << padding_tensor->ElementsNum() << ", should be " << rank * 2;
return RET_ERROR;
}

auto ret = ExtendShape(in_, DEFAULT_PAD_NDIMS, input_shape.data(), rank);
if (ret != RET_OK) {
return ret;
}
ret = ExtendPaddings(pad_param_->paddings_, MAX_PAD_SIZE, paddings, padding_tensor->ElementsNum());
if (ret != RET_OK) {
return ret;
}
pad_param_->padding_length = MAX_PAD_SIZE;
return RET_OK;
}

void PadCPUKernel::CalculateStrides() {
auto input_shape = in_tensors_.at(0)->shape();
pad_param_->in_strides[DEFAULT_PAD_NDIMS - 1] = 1;
for (auto i = DEFAULT_PAD_NDIMS - 2; i >= 0; --i) {
pad_param_->in_strides[i] = input_shape[i + 1] * pad_param_->in_strides[i + 1];
}
for (auto i = 0; i < DEFAULT_PAD_NDIMS; ++i) {
out_[i] = in_[i] + pad_param_->paddings_[i * 2] + pad_param_->paddings_[i * 2 + 1];
}
pad_param_->out_strides[DEFAULT_PAD_NDIMS - 1] = 1;
for (auto i = DEFAULT_PAD_NDIMS - 2; i >= 0; --i) {
pad_param_->out_strides[i] = out_[i + 1] * pad_param_->out_strides[i + 1];
}
}

int PadCPUKernel::HandleMirrorPad() {
auto ret = CopyPaddingFromInput();
if (ret != RET_OK) {
return ret;
}
ret = CheckPaddings(pad_param_->paddings_, DEFAULT_PAD_NDIMS, in_, pad_param_->pad_mode_);
if (ret != RET_OK) {
return ret;
}
CalculateStrides();
pad_param_->mirror_offset_ = pad_param_->pad_mode_ == static_cast<int>(schema::PaddingMode_REFLECT) ? 1 : 0;
return RET_OK;
}

int PadCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto output = out_tensors_.at(0);
int output_size = output->ElementsNum();

auto output_data = reinterpret_cast<float *>(output->MutableData());
memset(output_data, 0, output_size * sizeof(float));
int error_code;
if (pad_param_->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
auto output = out_tensors_.at(0);
int output_size = output->ElementsNum();
auto output_data = reinterpret_cast<float *>(output->MutableData());
if (pad_param_->constant_value_ - 0.0f < 1e-5) {
memset(output_data, 0, output_size * sizeof(float));
} else {
for (auto i = 0; i < output_size; ++i) {
output_data[i] = pad_param_->constant_value_;
}
}
error_code = ParallelLaunch(THREAD_POOL_DEFAULT, PadImpl, this, context_->thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Pad run error, error_code[" << error_code << "]";
return RET_ERROR;
}
} else {
// mirror pad case
HandleMirrorPad();

int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, PadImpl, this, context_->thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Pad run error, error_code[" << error_code << "]";
return RET_ERROR;
error_code = ParallelLaunch(THREAD_POOL_DEFAULT, MirrorPadImpl, this, context_->thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Pad Reflect or Symmetric mode run error, error_code[" << error_code << "]";
return RET_ERROR;
}
}

return RET_OK;
}
} // namespace mindspore::kernel

+ 13
- 3
mindspore/lite/src/runtime/kernel/arm/fp32/pad.h View File

@@ -38,14 +38,24 @@ class PadCPUKernel : public LiteKernel {
int ReSize() override;
int Run() override;
virtual int RunImpl(int task_id);
int RunMirrorPadImpl(int task_id);

private:
int HandleMirrorPad();
int CheckPaddings(int *paddings, int length, int *input_shape, int mode);
int CopyPaddingFromInput();
void CalculateStrides();
int ExtendShape(int *shape, int length, const int *ori_shape, int rank);
int ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length);

protected:
const PadParameter *pad_param_;
int in_[4] = {1, 1, 1, 1};
int out_[4] = {1, 1, 1, 1};
PadParameter *pad_param_;
int in_[4];
int out_[4];
};

int PadImpl(void *cdata, int task_id);
int MirrorPadImpl(void *cdata, int task_id);
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PAD_H_

+ 278
- 0
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pad_fp32_test.cc View File

@@ -0,0 +1,278 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "mindspore/lite/src/lite_kernel.h"
#include "mindspore/lite/src/tensor.h"
#include "common/common_test.h"
#include "nnacl/pad_parameter.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/schema/ops_generated.h"

using mindspore::schema::Format_NHWC;
using mindspore::schema::PaddingMode;
using mindspore::schema::PaddingMode_CONSTANT;
using mindspore::schema::PaddingMode_REFLECT;
using mindspore::schema::PaddingMode_SYMMETRIC;
namespace mindspore {

class TestPadFp32 : public mindspore::CommonTest {
public:
TestPadFp32() = default;
void Prepare(const std::vector<int> &input_shape, const std::vector<int> &output_shape, float *input_data,
float *output_data, PaddingMode mode, int *paddings, int padding_length, float constant_value,
const int thread_num);

void TearDown() override;

public:
float err_tol = 1e-5;
lite::Tensor in_tensor_;
lite::Tensor paddings_tensor_;
lite::Tensor out_tensor_;
PadParameter param_;
std::vector<lite::Tensor *> inputs_{&in_tensor_};
std::vector<lite::Tensor *> outputs_{&out_tensor_};
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Pad};
lite::Context ctx_ = lite::Context();
kernel::KernelCreator creator_ = nullptr;
kernel::LiteKernel *kernel_ = nullptr;
};

void TestPadFp32::TearDown() {
paddings_tensor_.SetData(nullptr);
in_tensor_.SetData(nullptr);
out_tensor_.SetData(nullptr);
}

void TestPadFp32::Prepare(const std::vector<int> &input_shape, const std::vector<int> &output_shape, float *input_data,
float *output_data, PaddingMode mode, int *paddings, int padding_length, float constant_value,
const int thread_num) {
in_tensor_.set_data_type(kNumberTypeFloat32);
in_tensor_.SetFormat(Format_NHWC);
in_tensor_.set_shape(input_shape);
out_tensor_.set_data_type(kNumberTypeFloat32);
out_tensor_.set_shape(output_shape);
in_tensor_.SetData(input_data);
out_tensor_.SetData(output_data);

param_.pad_mode_ = static_cast<int>(mode);
if (mode == PaddingMode_CONSTANT) {
param_.constant_value_ = constant_value;
param_.padding_length = padding_length;
for (auto i = 0; i < padding_length; ++i) {
param_.paddings_[i] = paddings[i];
}
} else {
paddings_tensor_.set_data_type(kNumberTypeInt32);
paddings_tensor_.set_shape({4, 2});
paddings_tensor_.SetData(paddings);
inputs_.emplace_back(&paddings_tensor_);
}

desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Pad};
ctx_ = lite::Context();
ctx_.thread_num_ = thread_num;
creator_ = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator_, nullptr);
kernel_ = creator_(inputs_, outputs_, reinterpret_cast<OpParameter *>(&param_), &ctx_, desc, nullptr);
ASSERT_NE(kernel_, nullptr);
}

TEST_F(TestPadFp32, TestPad1) {
std::vector<int> input_shape{1, 4, 4, 3};
std::vector<int> output_shape{1, 12, 12, 3};
float in_data[48] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0};
float out_data[432] = {0};
int paddings[8] = {0, 0, 4, 4, 4, 4, 0, 0};
PaddingMode mode = PaddingMode_SYMMETRIC;
int thread_num = 2;
Prepare(input_shape, output_shape, in_data, out_data, mode, paddings, 8, 0.0f, thread_num);

auto ret = kernel_->Run();
EXPECT_EQ(0, ret);

std::vector<float> expect{
45.0, 46.0, 47.0, 42.0, 43.0, 44.0, 39.0, 40.0, 41.0, 36.0, 37.0, 38.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0,
43.0, 44.0, 45.0, 46.0, 47.0, 45.0, 46.0, 47.0, 42.0, 43.0, 44.0, 39.0, 40.0, 41.0, 36.0, 37.0, 38.0, 33.0, 34.0,
35.0, 30.0, 31.0, 32.0, 27.0, 28.0, 29.0, 24.0, 25.0, 26.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0,
33.0, 34.0, 35.0, 33.0, 34.0, 35.0, 30.0, 31.0, 32.0, 27.0, 28.0, 29.0, 24.0, 25.0, 26.0, 21.0, 22.0, 23.0, 18.0,
19.0, 20.0, 15.0, 16.0, 17.0, 12.0, 13.0, 14.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0,
23.0, 21.0, 22.0, 23.0, 18.0, 19.0, 20.0, 15.0, 16.0, 17.0, 12.0, 13.0, 14.0, 9.0, 10.0, 11.0, 6.0, 7.0, 8.0,
3.0, 4.0, 5.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 9.0,
10.0, 11.0, 6.0, 7.0, 8.0, 3.0, 4.0, 5.0, 0.0, 1.0, 2.0, 9.0, 10.0, 11.0, 6.0, 7.0, 8.0, 3.0, 4.0,
5.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 9.0, 10.0, 11.0,
6.0, 7.0, 8.0, 3.0, 4.0, 5.0, 0.0, 1.0, 2.0, 21.0, 22.0, 23.0, 18.0, 19.0, 20.0, 15.0, 16.0, 17.0, 12.0,
13.0, 14.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 21.0, 22.0, 23.0, 18.0, 19.0,
20.0, 15.0, 16.0, 17.0, 12.0, 13.0, 14.0, 33.0, 34.0, 35.0, 30.0, 31.0, 32.0, 27.0, 28.0, 29.0, 24.0, 25.0, 26.0,
24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 33.0, 34.0, 35.0, 30.0, 31.0, 32.0, 27.0,
28.0, 29.0, 24.0, 25.0, 26.0, 45.0, 46.0, 47.0, 42.0, 43.0, 44.0, 39.0, 40.0, 41.0, 36.0, 37.0, 38.0, 36.0, 37.0,
38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 45.0, 46.0, 47.0, 42.0, 43.0, 44.0, 39.0, 40.0, 41.0,
36.0, 37.0, 38.0, 45.0, 46.0, 47.0, 42.0, 43.0, 44.0, 39.0, 40.0, 41.0, 36.0, 37.0, 38.0, 36.0, 37.0, 38.0, 39.0,
40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 45.0, 46.0, 47.0, 42.0, 43.0, 44.0, 39.0, 40.0, 41.0, 36.0, 37.0,
38.0, 33.0, 34.0, 35.0, 30.0, 31.0, 32.0, 27.0, 28.0, 29.0, 24.0, 25.0, 26.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 33.0, 34.0, 35.0, 30.0, 31.0, 32.0, 27.0, 28.0, 29.0, 24.0, 25.0, 26.0, 21.0,
22.0, 23.0, 18.0, 19.0, 20.0, 15.0, 16.0, 17.0, 12.0, 13.0, 14.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0,
20.0, 21.0, 22.0, 23.0, 21.0, 22.0, 23.0, 18.0, 19.0, 20.0, 15.0, 16.0, 17.0, 12.0, 13.0, 14.0, 9.0, 10.0, 11.0,
6.0, 7.0, 8.0, 3.0, 4.0, 5.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
10.0, 11.0, 9.0, 10.0, 11.0, 6.0, 7.0, 8.0, 3.0, 4.0, 5.0, 0.0, 1.0, 2.0};
CompareOutputData(out_data, expect.data(), 432, err_tol);
}

TEST_F(TestPadFp32, TestPad2) {
std::vector<int> input_shape{1, 4, 4, 3};
std::vector<int> output_shape{1, 10, 10, 3};
float in_data[48] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0};
float out_data[300] = {0};
int paddings[8] = {0, 0, 3, 3, 3, 3, 0, 0};
PaddingMode mode = PaddingMode_REFLECT;
int thread_num = 2;
Prepare(input_shape, output_shape, in_data, out_data, mode, paddings, 8, 0.0f, thread_num);

auto ret = kernel_->Run();
EXPECT_EQ(0, ret);

std::vector<float> expect{
45.0, 46.0, 47.0, 42.0, 43.0, 44.0, 39.0, 40.0, 41.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0,
46.0, 47.0, 42.0, 43.0, 44.0, 39.0, 40.0, 41.0, 36.0, 37.0, 38.0, 33.0, 34.0, 35.0, 30.0, 31.0, 32.0, 27.0, 28.0,
29.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 30.0, 31.0, 32.0, 27.0, 28.0, 29.0,
24.0, 25.0, 26.0, 21.0, 22.0, 23.0, 18.0, 19.0, 20.0, 15.0, 16.0, 17.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, 22.0, 23.0, 18.0, 19.0, 20.0, 15.0, 16.0, 17.0, 12.0, 13.0, 14.0, 9.0, 10.0, 11.0, 6.0, 7.0,
8.0, 3.0, 4.0, 5.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 6.0, 7.0, 8.0,
3.0, 4.0, 5.0, 0.0, 1.0, 2.0, 21.0, 22.0, 23.0, 18.0, 19.0, 20.0, 15.0, 16.0, 17.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 18.0, 19.0, 20.0, 15.0, 16.0, 17.0, 12.0, 13.0, 14.0, 33.0, 34.0,
35.0, 30.0, 31.0, 32.0, 27.0, 28.0, 29.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0,
30.0, 31.0, 32.0, 27.0, 28.0, 29.0, 24.0, 25.0, 26.0, 45.0, 46.0, 47.0, 42.0, 43.0, 44.0, 39.0, 40.0, 41.0, 36.0,
37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 42.0, 43.0, 44.0, 39.0, 40.0, 41.0, 36.0, 37.0,
38.0, 33.0, 34.0, 35.0, 30.0, 31.0, 32.0, 27.0, 28.0, 29.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0,
33.0, 34.0, 35.0, 30.0, 31.0, 32.0, 27.0, 28.0, 29.0, 24.0, 25.0, 26.0, 21.0, 22.0, 23.0, 18.0, 19.0, 20.0, 15.0,
16.0, 17.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 18.0, 19.0, 20.0, 15.0, 16.0,
17.0, 12.0, 13.0, 14.0, 9.0, 10.0, 11.0, 6.0, 7.0, 8.0, 3.0, 4.0, 5.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 6.0, 7.0, 8.0, 3.0, 4.0, 5.0, 0.0, 1.0, 2.0};
CompareOutputData(out_data, expect.data(), 300, err_tol);
}

TEST_F(TestPadFp32, TestPad3) {
std::vector<int> input_shape{1, 4, 4, 3};
std::vector<int> output_shape{1, 10, 10, 3};
float in_data[48] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0};
float out_data[300] = {0};
int paddings[8] = {0, 0, 3, 3, 3, 3, 0, 0};
PaddingMode mode = PaddingMode_CONSTANT;
float pad_value = 0.0f;
int thread_num = 2;
Prepare(input_shape, output_shape, in_data, out_data, mode, paddings, 8, pad_value, thread_num);

auto ret = kernel_->Run();
EXPECT_EQ(0, ret);

std::vector<float> expect{
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 36.0,
37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
CompareOutputData(out_data, expect.data(), 300, err_tol);
}

TEST_F(TestPadFp32, TestPad4) {
std::vector<int> input_shape{1, 4, 4, 3};
std::vector<int> output_shape{1, 10, 10, 3};
float in_data[48] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0};
float out_data[300] = {0};
int paddings[8] = {0, 0, 3, 3, 3, 3, 0, 0};
PaddingMode mode = PaddingMode_CONSTANT;
float pad_value = 1.0f;
int thread_num = 2;
Prepare(input_shape, output_shape, in_data, out_data, mode, paddings, 8, pad_value, thread_num);

auto ret = kernel_->Run();
EXPECT_EQ(0, ret);

std::vector<float> expect{
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 36.0,
37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
CompareOutputData(out_data, expect.data(), 300, err_tol);
}

TEST_F(TestPadFp32, TestPad5) {
std::vector<int> input_shape{4, 4, 3};
std::vector<int> output_shape{10, 10, 3};
float in_data[48] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0};
float out_data[300] = {0};
int paddings[8] = {3, 3, 3, 3, 0, 0, 0, 0};
PaddingMode mode = PaddingMode_CONSTANT;
float pad_value = 1.0f;
int thread_num = 2;
Prepare(input_shape, output_shape, in_data, out_data, mode, paddings, 6, pad_value, thread_num);

auto ret = kernel_->Run();
EXPECT_EQ(0, ret);

std::vector<float> expect{
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 36.0,
37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
CompareOutputData(out_data, expect.data(), 300, err_tol);
}
} // namespace mindspore

Loading…
Cancel
Save