| @@ -0,0 +1,194 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/int8/div_int8.h" | |||||
| #include <limits> | |||||
| #include <algorithm> | |||||
| #include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Div; | |||||
| namespace mindspore::kernel { | |||||
| int DivInt8CPUKernel::Init() { | |||||
| lite::tensor::Tensor *input0 = in_tensors_.at(0); | |||||
| lite::tensor::Tensor *input1 = in_tensors_.at(1); | |||||
| lite::tensor::Tensor *output = out_tensors_.at(0); | |||||
| MS_ASSERT(input0); | |||||
| MS_ASSERT(input1); | |||||
| MS_ASSERT(output); | |||||
| broadcast_ = input0->ElementsNum() != input1->ElementsNum(); | |||||
| param_.in0_args_.scale_ = input0->GetQuantParams().front().scale; | |||||
| param_.in0_args_.zp_ = -input0->GetQuantParams().front().zeroPoint; | |||||
| param_.in1_args_.scale_ = input1->GetQuantParams().front().scale; | |||||
| param_.in1_args_.zp_ = -input1->GetQuantParams().front().zeroPoint; | |||||
| param_.out_args_.scale_ = output->GetQuantParams().front().scale; | |||||
| param_.out_args_.zp_ = output->GetQuantParams().front().zeroPoint; | |||||
| const double real_multiplier = param_.in0_args_.scale_ / (param_.in1_args_.scale_ * param_.out_args_.scale_); | |||||
| QuantizeMultiplier(real_multiplier, ¶m_.output_multiplier_, ¶m_.output_shift_); | |||||
| param_.output_activation_min_ = std::numeric_limits<int8_t>::min(); | |||||
| param_.output_activation_max_ = std::numeric_limits<int8_t>::max(); | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int DivInt8CPUKernel::ReSize() { | |||||
| if (broadcast_) { | |||||
| if (tile0_data_ != nullptr) { | |||||
| if (context_ != nullptr && context_->allocator != nullptr) { | |||||
| context_->allocator->Free(tile0_data_); | |||||
| } else { | |||||
| free(tile0_data_); | |||||
| } | |||||
| } | |||||
| if (tile1_data_ != nullptr) { | |||||
| if (context_ != nullptr && context_->allocator != nullptr) { | |||||
| context_->allocator->Free(tile1_data_); | |||||
| } else { | |||||
| free(tile1_data_); | |||||
| } | |||||
| } | |||||
| if (context_ != nullptr && context_->allocator != nullptr) { | |||||
| tile0_data_ = static_cast<int8_t *>(context_->allocator->Malloc(out_tensors_.at(0)->Size())); | |||||
| tile1_data_ = static_cast<int8_t *>(context_->allocator->Malloc(out_tensors_.at(0)->Size())); | |||||
| } else { | |||||
| tile0_data_ = static_cast<int8_t *>(malloc(sizeof(int8_t) * out_tensors_.at(0)->Size())); | |||||
| tile1_data_ = static_cast<int8_t *>(malloc(sizeof(int8_t) * out_tensors_.at(0)->Size())); | |||||
| } | |||||
| if (tile0_data_ == nullptr || tile1_data_ == nullptr) { | |||||
| if (tile0_data_ != nullptr) { | |||||
| if (context_ != nullptr && context_->allocator != nullptr) { | |||||
| context_->allocator->Free(tile0_data_); | |||||
| } else { | |||||
| free(tile0_data_); | |||||
| } | |||||
| } | |||||
| if (tile1_data_ != nullptr) { | |||||
| if (context_ != nullptr && context_->allocator != nullptr) { | |||||
| context_->allocator->Free(tile1_data_); | |||||
| } else { | |||||
| free(tile1_data_); | |||||
| } | |||||
| } | |||||
| MS_LOG(ERROR) << "malloc memroy fail!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int DivInt8CPUKernel::DoExecute(int task_id) { | |||||
| auto input0_data_ = static_cast<int8_t *>(in_tensors_.at(0)->Data()); | |||||
| auto input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->Data()); | |||||
| auto output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->Data()); | |||||
| auto element_num = out_tensors_[0]->ElementsNum(); | |||||
| MS_ASSERT(op_parameter_->thread_num_ != 0); | |||||
| int stride = UP_DIV(element_num, op_parameter_->thread_num_); | |||||
| int count = MSMIN(stride, element_num - stride * task_id); | |||||
| auto ret = RET_OK; | |||||
| if (broadcast_) { | |||||
| ret = DivInt8(tile0_data_ + task_id * count, tile1_data_ + task_id * count, output_data_ + task_id * count, count, | |||||
| ¶m_); | |||||
| } else { | |||||
| ret = DivInt8(input0_data_ + task_id * count, input1_data_ + task_id * count, output_data_ + task_id * count, count, | |||||
| ¶m_); | |||||
| } | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Divint8 function error error_code[" << ret << "]"; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| int DivInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| auto div_kernel = reinterpret_cast<DivInt8CPUKernel *>(cdata); | |||||
| auto ret = div_kernel->DoExecute(task_id); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "DivInt8 DoExecute error task_id[" << task_id << "] error_code[" << ret << "]"; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| int DivInt8CPUKernel::Run() { | |||||
| auto ret = Prepare(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (broadcast_) { | |||||
| ArithmeticParameter tile_para = {0}; | |||||
| tile_para.ndim_ = out_tensors_.at(0)->shape().size(); | |||||
| for (size_t i = 0; i < tile_para.ndim_; i++) { | |||||
| tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i); | |||||
| tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i); | |||||
| tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i); | |||||
| } | |||||
| TileDimensionsUint8(static_cast<uint8_t *>(in_tensors_.at(0)->Data()), | |||||
| static_cast<uint8_t *>(in_tensors_.at(1)->Data()), reinterpret_cast<uint8_t *>(tile0_data_), | |||||
| reinterpret_cast<uint8_t *>(tile1_data_), &tile_para); | |||||
| } | |||||
| ret = LiteBackendParallelLaunch(DivInt8Run, this, op_parameter_->thread_num_); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "DivInt8Run function error error_code[" << ret << "]"; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| kernel::LiteKernel *CpuDivInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *parameter, | |||||
| const lite::Context *ctx, const KernelKey &desc, | |||||
| const lite::Primitive *primitive) { | |||||
| if (parameter == nullptr || ctx == nullptr) { | |||||
| MS_LOG(ERROR) << "parameter or ctx is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| MS_ASSERT(desc.type == PrimitiveType_Div); | |||||
| auto *kernel = new (std::nothrow) DivInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ | |||||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Div, CpuDivInt8KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DIV_INT8_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DIV_INT8_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/int8/div_int8.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| namespace mindspore::kernel { | |||||
| class DivInt8CPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit DivInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | |||||
| const lite::Primitive *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~DivInt8CPUKernel() override {} | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int DoExecute(int task_id); | |||||
| private: | |||||
| DivQuantArg param_; | |||||
| int8_t *tile0_data_ = nullptr; | |||||
| int8_t *tile1_data_ = nullptr; | |||||
| bool broadcast_ = false; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DIV_INT8_H_ | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/int8/div_int8.h" | |||||
| #include "nnacl/quantization/fixed_point.h" | |||||
| #include "nnacl/errorcode.h" | |||||
| int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, DivQuantArg *para) { | |||||
| int index = 0; | |||||
| for (; index < real_dst_count; ++index) { | |||||
| const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; | |||||
| const int32_t input1_val = para->in1_args_.zp_ + input1_data[index]; | |||||
| if (input1_val == 0) { | |||||
| return NNACL_ERRCODE_DIVISOR_ZERO; | |||||
| } | |||||
| int recip_shift; | |||||
| const int32_t input1_inv = (input1_val > 0) ? ComputerReciproal(input1_val, 31, &recip_shift) | |||||
| : -ComputerReciproal(-input1_val, 31, &recip_shift); | |||||
| const int leading_bits = CountLeadingSignBits(input0_val); | |||||
| const int32_t raw_data = | |||||
| SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv); | |||||
| const int total_shift = para->output_shift_ - recip_shift - leading_bits; | |||||
| const int32_t raw_output = | |||||
| RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data, para->output_multiplier_), -total_shift) + | |||||
| para->out_args_.zp_; | |||||
| output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DIV_INT8_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DIV_INT8_H_ | |||||
| #include "nnacl/op_base.h" | |||||
| #include "nnacl/quantization/quantize.h" | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, DivQuantArg *para); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DIV_INT8_H_ | |||||
| @@ -64,6 +64,114 @@ inline int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int3 | |||||
| return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); | return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); | ||||
| } | } | ||||
| inline int FractionsBits(int kIntegerBits) { | |||||
| int totalBits = 8 * sizeof(int32_t) - 1; | |||||
| return totalBits - kIntegerBits; | |||||
| } | |||||
| inline int FixedPoint_One(int kIntegerBits, int kFractionsBits) { | |||||
| return (kIntegerBits == 0 ? INT32_MAX : ((1) << (uint32_t)(kIntegerBits == 0 ? 0 : kFractionsBits))); | |||||
| } | |||||
| inline int RoundingHalfSum(int a, int b) { | |||||
| int64_t a64 = a; | |||||
| int64_t b64 = b; | |||||
| int64_t sum = a64 + b64; | |||||
| int64_t sign = sum > 0 ? 1 : -1; | |||||
| return (int32_t)((sum + sign) / 2); | |||||
| } | |||||
| inline int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; } | |||||
| inline int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; } | |||||
| inline int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; } | |||||
| inline int32_t BitNot(int32_t a) { return ~(uint32_t)a; } | |||||
| inline int SelectUsingMask(int mask, int bound, int val) { | |||||
| return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val)); | |||||
| } | |||||
| inline int32_t MaskNonZero(int32_t a) { | |||||
| int32_t zreo = 0; | |||||
| return a ? BitNot(zreo) : zreo; | |||||
| } | |||||
| inline int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent) { | |||||
| int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0); | |||||
| if (ExponentSign == 0) { | |||||
| return x; | |||||
| } else if (ExponentSign == 1) { | |||||
| const int min = INT32_MIN; | |||||
| const int max = INT32_MAX; | |||||
| const int thresold = ((1 << (uint32_t)(31 - Exponent)) - 1); | |||||
| const int postive_mask = MaskNonZero(x > thresold); | |||||
| const int negative_mask = MaskNonZero(x < -thresold); | |||||
| int result = x << Exponent; | |||||
| result = SelectUsingMask(postive_mask, max, result); | |||||
| result = SelectUsingMask(negative_mask, min, result); | |||||
| return result; | |||||
| } else if (ExponentSign == -1) { | |||||
| return RoundingDivideByPOT(x, -Exponent); | |||||
| } else { | |||||
| return 0; | |||||
| } | |||||
| } | |||||
| inline int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst) { | |||||
| int kExponent = kIntegerBitsSrc - kIntegerBitsDst; | |||||
| int result = SaturatingRoundingMultiplyByPOT(x, kExponent); | |||||
| return result; | |||||
| } | |||||
| static inline int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a) { | |||||
| int one = FixedPoint_One(0, FractionsBits(0)); | |||||
| int half_denominator = RoundingHalfSum(a, one); | |||||
| const int constant_48_over_17 = 1515870810; | |||||
| const int constant_neg_32_over_17 = -1010580540; | |||||
| int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_denominator, constant_neg_32_over_17); | |||||
| for (int i = 0; i < 3; i++) { | |||||
| int half_denominator_times_x = SaturatingRoundingDoublingHighMul(half_denominator, x); | |||||
| int one_minus_half_denominator_times_x = FixedPoint_One(2, FractionsBits(2)) - half_denominator_times_x; | |||||
| x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_denominator_times_x), 2 + 2, 2); | |||||
| } | |||||
| return Rescale(x, 2 - 1, 0); | |||||
| } | |||||
| inline int CountLeadingZeroBits(uint32_t x) { | |||||
| #if defined(__GUNC__) | |||||
| return x ? __builtin_clz(x) : 8 * sizeof(uint32_t); | |||||
| #else | |||||
| if (x == 0) { | |||||
| return 8 * sizeof(uint32_t); | |||||
| } | |||||
| const int32_t leading_positive = (int32_t)(1) << (8 * sizeof(uint32_t) - 1); | |||||
| int leading_zeros = 0; | |||||
| while (x < leading_positive) { | |||||
| x <<= 1; | |||||
| leading_zeros++; | |||||
| } | |||||
| return leading_zeros; | |||||
| #endif | |||||
| } | |||||
| inline int CountLeadingSignBits(int32_t x) { | |||||
| #if defined(__GUNC__) && !defined(__clang__) | |||||
| return x ? __builtin_clrsb(x) : 8 * sizeof(int32_t); | |||||
| #else | |||||
| return x >= 0 ? CountLeadingZeroBits((uint32_t)x) - 1 : x != INT32_MIN ? CountLeadingZeroBits(2 * (uint32_t)(-x)) : 0; | |||||
| #endif | |||||
| } | |||||
| static inline int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift) { | |||||
| int leading_zreos_plus_one = CountLeadingZeroBits((uint32_t)x); | |||||
| *recip_shift = x_digits - leading_zreos_plus_one; | |||||
| const int32_t shifted_minus_one = (int32_t)(((uint32_t)x << leading_zreos_plus_one) - ((uint32_t)(1) << 31)); | |||||
| const int32_t shifted_scaled = one_over_one_plus_x_for_x_in_0_1(shifted_minus_one); | |||||
| return shifted_scaled; | |||||
| } | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -197,6 +197,16 @@ typedef struct ArithmeticQuantArg { | |||||
| QuantArg in1_args_; | QuantArg in1_args_; | ||||
| QuantArg out_args_; | QuantArg out_args_; | ||||
| } ArithmeticQuantArg; | } ArithmeticQuantArg; | ||||
| typedef struct DivQuantArg { | |||||
| QuantArg in0_args_; | |||||
| QuantArg in1_args_; | |||||
| QuantArg out_args_; | |||||
| int output_activation_min_; | |||||
| int output_activation_max_; | |||||
| int output_multiplier_; | |||||
| int output_shift_; | |||||
| } DivQuantArg; | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| @@ -0,0 +1,74 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/int8/div_int8.h" | |||||
| #include "mindspore/lite/src/kernel_registry.h" | |||||
| #include "mindspore/lite/include/context.h" | |||||
| namespace mindspore { | |||||
| class TestDivInt8 : public mindspore::CommonTest { | |||||
| public: | |||||
| TestDivInt8() {} | |||||
| }; | |||||
| TEST_F(TestDivInt8, DivInt8) { | |||||
| lite::tensor::Tensor in_tensor0(kNumberTypeInt8, {1, 1, 2, 5}); | |||||
| lite::tensor::Tensor in_tensor1(kNumberTypeInt8, {1, 1, 2, 5}); | |||||
| lite::tensor::Tensor out_tensor(kNumberTypeInt8, {1, 1, 2, 5}); | |||||
| int8_t input_data0[] = {105, 35, -27, 0, -63, 99, 16, 45, 67, -49}; | |||||
| int8_t input_data1[] = {126, -38, -115, 106, -98, 119, 103, 81, -114, 68}; | |||||
| int8_t output_data[10] = {0}; | |||||
| in_tensor0.SetData(input_data0); | |||||
| in_tensor1.SetData(input_data1); | |||||
| out_tensor.SetData(output_data); | |||||
| const lite::tensor::QuantArg quant_in0 = {0.00784314f, 0}; // -1.0--1.0 -> 0--255 | |||||
| const lite::tensor::QuantArg quant_in1 = {0.00784314f, 0}; | |||||
| const lite::tensor::QuantArg quant_out = {0.00784314f, 0}; | |||||
| in_tensor0.AddQuantParam(quant_in0); | |||||
| in_tensor1.AddQuantParam(quant_in1); | |||||
| out_tensor.AddQuantParam(quant_out); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {&in_tensor0, &in_tensor1}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {&out_tensor}; | |||||
| OpParameter parameter = {}; | |||||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Div}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| ASSERT_NE(creator, nullptr); | |||||
| auto ctx = std::make_shared<lite::Context>(); | |||||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(¶meter), ctx.get(), desc, nullptr); | |||||
| ASSERT_NE(kernel, nullptr); | |||||
| auto ret = kernel->Run(); | |||||
| EXPECT_EQ(0, ret); | |||||
| int8_t expect0[10] = {106, -117, 30, 0, 82, 106, 20, 71, -75, -92}; | |||||
| for (int i = 0; i < 10; ++i) { | |||||
| EXPECT_EQ(output_data[i], expect0[i]); | |||||
| } | |||||
| in_tensor0.SetData(nullptr); | |||||
| in_tensor1.SetData(nullptr); | |||||
| out_tensor.SetData(nullptr); | |||||
| } | |||||
| } // namespace mindspore | |||||