| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "src/runtime/kernel/arm/int8/arithmetic_int8.h" | |||
| #include "src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h" | |||
| #include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| @@ -42,7 +41,7 @@ int ArithmeticsInt8Launch(int thread_id, LiteParallelGroupEnv *penv, void *cdata | |||
| auto error_code = arithmetic_kernel->DoArithmetic(thread_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "ArithmeticsRun error thread_id[" << thread_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| return error_code; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -79,28 +78,43 @@ ArithmeticInt8CPUKernel::~ArithmeticInt8CPUKernel() { | |||
| int ArithmeticInt8CPUKernel::Init() { | |||
| switch (op_parameter_->type_) { | |||
| case PrimitiveType_Equal: | |||
| arithmetic_run_ = ElementEqual; | |||
| arithmetic_run_ = ElementEqualInt8; | |||
| break; | |||
| case PrimitiveType_NotEqual: | |||
| arithmetic_run_ = ElementNotEqual; | |||
| arithmetic_run_ = ElementNotEqualInt8; | |||
| break; | |||
| case PrimitiveType_Less: | |||
| arithmetic_run_ = ElementLess; | |||
| arithmetic_run_ = ElementLessInt8; | |||
| break; | |||
| case PrimitiveType_LessEqual: | |||
| arithmetic_run_ = ElementLessEqual; | |||
| arithmetic_run_ = ElementLessEqualInt8; | |||
| break; | |||
| case PrimitiveType_Greater: | |||
| arithmetic_run_ = ElementGreater; | |||
| arithmetic_run_ = ElementGreaterInt8; | |||
| break; | |||
| case PrimitiveType_GreaterEqual: | |||
| arithmetic_run_ = ElementGreaterEqual; | |||
| arithmetic_run_ = ElementGreaterEqualInt8; | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_; | |||
| arithmetic_run_ = nullptr; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto *input0_tensor = in_tensors_.at(0); | |||
| auto in0_quant_args = input0_tensor->GetQuantParams(); | |||
| quant_args_.in0_args_.scale_ = in0_quant_args.front().scale; | |||
| quant_args_.in0_args_.zp_ = in0_quant_args.front().zeroPoint; | |||
| auto *input1_tensor = in_tensors_.at(1); | |||
| auto in1_quant_args = input1_tensor->GetQuantParams(); | |||
| quant_args_.in1_args_.scale_ = in1_quant_args.front().scale; | |||
| quant_args_.in1_args_.zp_ = in1_quant_args.front().zeroPoint; | |||
| auto *out_tensor = out_tensors_.at(kOutputIndex); | |||
| auto out_quant_args = out_tensor->GetQuantParams(); | |||
| quant_args_.out_args_.scale_ = out_quant_args.front().scale; | |||
| quant_args_.out_args_.zp_ = out_quant_args.front().zeroPoint; | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| @@ -142,16 +156,16 @@ int ArithmeticInt8CPUKernel::DoArithmetic(int thread_id) { | |||
| } | |||
| int error_code = arithmetic_run_(tile_data0_ + stride * thread_id, tile_data1_ + stride * thread_id, | |||
| output_data + stride * thread_id, count); | |||
| output_data + stride * thread_id, count, &quant_args_); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Arithmetic run fail! ret: " << error_code; | |||
| return RET_ERROR; | |||
| return error_code; | |||
| } | |||
| } else if (arithmetic_run_ != nullptr) { | |||
| int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num); | |||
| int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num, &quant_args_); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Arithmetic run fail!ret: " << error_code; | |||
| return RET_ERROR; | |||
| return error_code; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; | |||
| @@ -20,10 +20,12 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h" | |||
| namespace mindspore::kernel { | |||
| class ArithmeticInt8CPUKernel : public LiteKernel { | |||
| typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, int8_t *output, int element_size); | |||
| typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg); | |||
| public: | |||
| ArithmeticInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| @@ -39,10 +41,10 @@ class ArithmeticInt8CPUKernel : public LiteKernel { | |||
| private: | |||
| void FreeTileData(); | |||
| int thread_count_; | |||
| int8_t *tile_data0_; | |||
| int8_t *tile_data1_; | |||
| ArithmeticRunInt8 arithmetic_run_; | |||
| ArithmeticQuantArg quant_args_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARITHMETIC_INT8_H_ | |||
| @@ -17,6 +17,8 @@ | |||
| #include "nnacl/fp32/arithmetic.h" | |||
| #include <math.h> | |||
| #define ACCURACY_DATA 0.00000001 | |||
| int ElementMul(float *input0, float *input1, float *output, int element_size) { | |||
| int block_mod = element_size % C4NUM; | |||
| int block_c4 = element_size - block_mod; | |||
| @@ -549,6 +551,14 @@ int BroadcastMinimum(float *input0, float *input1, float *tile_input0, float *ti | |||
| return ElementMinimum(tile_input0, tile_input1, output, element_size); | |||
| } | |||
| float FloatNotEqualCheck(float in0, float in1) { | |||
| float minus = in0 - in1; | |||
| if (minus <= ACCURACY_DATA && minus >= -ACCURACY_DATA) { | |||
| return (float)false; | |||
| } | |||
| return (float)true; | |||
| } | |||
| int ElementNotEqual(float *input0, float *input1, float *output, int element_size) { | |||
| int block_mod = element_size % C4NUM; | |||
| int block_c4 = element_size - block_mod; | |||
| @@ -563,10 +573,10 @@ int ElementNotEqual(float *input0, float *input1, float *output, int element_siz | |||
| float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vfalse, vtrue); | |||
| vst1q_f32(output, vout); | |||
| #else | |||
| output[0] = (float)(input0[0] != input1[0]); | |||
| output[1] = (float)(input0[1] != input1[1]); | |||
| output[2] = (float)(input0[2] != input1[2]); | |||
| output[3] = (float)(input0[3] != input1[3]); | |||
| output[0] = FloatNotEqualCheck(input0[0], input1[0]); | |||
| output[1] = FloatNotEqualCheck(input0[1], input1[1]); | |||
| output[2] = FloatNotEqualCheck(input0[2], input1[2]); | |||
| output[3] = FloatNotEqualCheck(input0[3], input1[3]); | |||
| #endif | |||
| input0 += C4NUM; | |||
| input1 += C4NUM; | |||
| @@ -584,6 +594,14 @@ int BroadcastNotEqual(float *input0, float *input1, float *tile_input0, float *t | |||
| return ElementNotEqual(tile_input0, tile_input1, output, element_size); | |||
| } | |||
| float FloatEqualCheck(float in0, float in1) { | |||
| float minus = in0 - in1; | |||
| if (minus <= ACCURACY_DATA && minus >= -ACCURACY_DATA) { | |||
| return (float)true; | |||
| } | |||
| return (float)false; | |||
| } | |||
| int ElementEqual(float *input0, float *input1, float *output, int element_size) { | |||
| int block_mod = element_size % C4NUM; | |||
| int block_c4 = element_size - block_mod; | |||
| @@ -598,10 +616,10 @@ int ElementEqual(float *input0, float *input1, float *output, int element_size) | |||
| float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vtrue, vfalse); | |||
| vst1q_f32(output, vout); | |||
| #else | |||
| output[0] = (float)(input0[0] == input1[0]); | |||
| output[1] = (float)(input0[1] == input1[1]); | |||
| output[2] = (float)(input0[2] == input1[2]); | |||
| output[3] = (float)(input0[3] == input1[3]); | |||
| output[0] = FloatEqualCheck(input0[0], input1[0]); | |||
| output[1] = FloatEqualCheck(input0[1], input1[1]); | |||
| output[2] = FloatEqualCheck(input0[2], input1[2]); | |||
| output[3] = FloatEqualCheck(input0[3], input1[3]); | |||
| #endif | |||
| input0 += C4NUM; | |||
| input1 += C4NUM; | |||
| @@ -758,3 +776,5 @@ int BroadcastGreaterEqual(float *input0, float *input1, float *tile_input0, floa | |||
| TileDimensions(input0, input1, tile_input0, tile_input1, param); | |||
| return ElementGreaterEqual(tile_input0, tile_input1, output, element_size); | |||
| } | |||
| #undef ACCURACY_DATA | |||
| @@ -20,44 +20,102 @@ | |||
| #endif | |||
| #include "nnacl/errorcode.h" | |||
| int ElementNotEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { | |||
| #define ACCURACY_DATA 0.00000001 | |||
| int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| output[index] = (int8_t)(input0[index] != input1[index]); | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float minus_inputs = in0_real - in1_real; | |||
| float out_real = (float)true; | |||
| if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { | |||
| out_real = (float)false; | |||
| } | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { | |||
| int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| output[index] = (int8_t)(input0[index] == input1[index]); | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float minus_inputs = in0_real - in1_real; | |||
| float out_real = (float)false; | |||
| if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { | |||
| out_real = (float)true; | |||
| } | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementLess(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { | |||
| int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| output[index] = (int8_t)(input0[index] < input1[index]); | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float out_real = (float)(in0_real < in1_real); | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementLessEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { | |||
| int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| output[index] = (int8_t)(input0[index] <= input1[index]); | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float out_real = (float)(in0_real <= in1_real); | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementGreater(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { | |||
| int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| output[index] = (int8_t)(input0[index] > input1[index]); | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float out_real = (float)(in0_real > in1_real); | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementGreaterEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { | |||
| int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg) { | |||
| float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; | |||
| float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; | |||
| float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; | |||
| float out_zp = quant_arg->out_args_.zp_; | |||
| for (int index = 0; index < element_size; ++index) { | |||
| output[index] = (int8_t)(input0[index] >= input1[index]); | |||
| float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; | |||
| float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; | |||
| float out_real = (float)(in0_real >= in1_real); | |||
| output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| #undef ACCURACY_DATA | |||
| @@ -17,16 +17,21 @@ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_INT8_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| int ElementNotEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); | |||
| int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg); | |||
| int ElementEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); | |||
| int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg); | |||
| int ElementLess(int8_t *input0, int8_t *input1, int8_t *output, int element_size); | |||
| int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg); | |||
| int ElementLessEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); | |||
| int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg); | |||
| int ElementGreater(int8_t *input0, int8_t *input1, int8_t *output, int element_size); | |||
| int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg); | |||
| int ElementGreaterEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); | |||
| int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, | |||
| ArithmeticQuantArg *quant_arg); | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_INT8_H_ | |||
| @@ -193,6 +193,12 @@ typedef struct SubQuantArg { | |||
| int right_shift_out_; | |||
| } SubQuantArg; | |||
| typedef struct ArithmeticQuantArg { | |||
| QuantArg in0_args_; | |||
| QuantArg in1_args_; | |||
| QuantArg out_args_; | |||
| } ArithmeticQuantArg; | |||
| void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift); | |||
| inline void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, | |||