diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc index 5124980a68..446495ca98 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc @@ -37,17 +37,24 @@ int SoftmaxInt8CPUKernel::Init() { auto in_quant_args = input_tensor->GetQuantParams(); quant_params_.in_quant_args_.scale_ = in_quant_args.front().scale; - quant_params_.in_quant_args_.zp_ = in_quant_args.front().zeroPoint; + quant_params_.in_quant_args_.zp_ = -in_quant_args.front().zeroPoint; auto *out_tensor = out_tensors_.at(kOutputIndex); MS_ASSERT(out_tensor); auto out_quant_args = out_tensor->GetQuantParams(); quant_params_.out_quant_arg_.scale_ = out_quant_args.front().scale; - quant_params_.out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; + quant_params_.out_quant_arg_.zp_ = -out_quant_args.front().zeroPoint; quant_params_.output_activation_min_ = std::numeric_limits::min(); quant_params_.output_activation_max_ = std::numeric_limits::max(); + const double input_real_multiplier = + MSMIN(quant_params_.in_quant_args_.scale_ * (1 << (unsigned int)(31 - 5)), (1ll << 31) - 1.0); + int right_shift = 0; + QuantizeMultiplierSmallerThanOne(input_real_multiplier, &quant_params_.output_multiplier_, &right_shift); + quant_params_.shift_left_ = right_shift < 0 ? -right_shift : 0; + quant_params_.shift_right_ = right_shift > 0 ? right_shift : 0; + if (!InferShapeDone()) { return RET_OK; } @@ -72,12 +79,12 @@ int SoftmaxInt8CPUKernel::ReSize() { return ret; } FreeTmpBuffer(); - exp_data_ = reinterpret_cast(malloc(softmax_param_->element_size_ * sizeof(float))); + exp_data_ = reinterpret_cast(malloc(softmax_param_->element_size_ * sizeof(int))); int inner_size = 1; for (int i = softmax_param_->axis_ + 1; i < softmax_param_->n_dim_; i++) { inner_size *= softmax_param_->input_shape_[i]; } - sum_data_ = reinterpret_cast(malloc(inner_size * sizeof(float))); + sum_data_ = reinterpret_cast(malloc(inner_size * sizeof(int))); return RET_OK; } @@ -125,12 +132,7 @@ int SoftmaxInt8CPUKernel::Run() { MS_LOG(ERROR) << "Prepare fail!ret: " << ret; return RET_ERROR; } - auto input_ptr = reinterpret_cast(in_tensors_.at(0)->Data()); - int ele_size = softmax_param_->element_size_; - for (int i = 0; i < ele_size; i++) { - float input_scaled = ((input_ptr[i] - quant_params_.in_quant_args_.zp_) * quant_params_.in_quant_args_.scale_); - exp_data_[i] = exp(input_scaled); - } + int error_code = LiteBackendParallelLaunch(SoftmaxRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Softmax function error error_code[" << error_code << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h index c0aaea1b14..acb7fb895d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h @@ -37,8 +37,8 @@ class SoftmaxInt8CPUKernel : public SoftmaxBaseCPUKernel { private: void FreeTmpBuffer(); - float *sum_data_ = nullptr; - float *exp_data_ = nullptr; + int *sum_data_ = nullptr; + int *exp_data_ = nullptr; SoftmaxQuantArg quant_params_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.c index ed79bb26a5..0ffa437d8b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.c @@ -16,17 +16,17 @@ #include "nnacl/int8/softmax_int8.h" #include +#include "nnacl/quantization/fixed_point.h" +#include "nnacl/quantization/quantize.h" +#include "nnacl/errorcode.h" -int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data, +int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp_data, int *sum_data, SoftmaxQuantArg quant_param, SoftmaxParameter *parameter) { int32_t axis = parameter->axis_; int n_dim = parameter->n_dim_; int *input_shape = parameter->input_shape_; int axis_shape_size = input_shape[axis]; - double output_scale = quant_param.out_quant_arg_.scale_; - int32_t output_zp = quant_param.out_quant_arg_.zp_; - int inner_size = 1; for (int i = axis + 1; i < n_dim; i++) { inner_size *= input_shape[i]; @@ -34,22 +34,37 @@ int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, float *e for (int o = 0; o < count; o++) { int outter_offset = o * axis_shape_size * inner_size; - for (int i = 0; i < inner_size; i++) { - float sum = 0; - for (int j = 0; j < axis_shape_size; j++) { - int axis_offset = outter_offset + i + j * inner_size; - sum += exp_data[axis_offset]; + + for (int c = 0; c < inner_size; c++) { + int8_t max_row = quant_param.output_activation_min_; + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + c + i * inner_size; + max_row = MSMAX(max_row, input_ptr[axis_offset]); } - sum_data[i] = sum; + + int32_t exp_sum = 0; + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + c + i * inner_size; + const int32_t input_val = input_ptr[axis_offset] - max_row; + const int32_t input_scaled = SaturatingRoundingDoublingHighMul( + input_val * (1 << (unsigned int)quant_param.shift_left_), quant_param.output_multiplier_); + int exp_val = exp_on_negative_values(input_scaled, 5); + exp_data[axis_offset] = exp_val; + exp_sum = exp_sum + Rescale(exp_val, 0, 12); + } + sum_data[c] = exp_sum; } - for (int j = 0; j < axis_shape_size; j++) { - int axis_offset = outter_offset + j * inner_size; - for (int i = 0; i < inner_size; i++) { - int inner_offset = axis_offset + i; - float real_output = exp_data[inner_offset] / sum_data[i]; - int32_t output_scaled = round(real_output / output_scale) + output_zp; - output_ptr[inner_offset] = - MSMAX(quant_param.output_activation_min_, MSMIN(quant_param.output_activation_max_, output_scaled)); + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + i * inner_size; + for (int c = 0; c < inner_size; ++c) { + int num_bits_over_unit; + int shifted_scale = ComputerReciproal(sum_data[c], 12, &num_bits_over_unit); + int unsat_output = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_scale, exp_data[axis_offset + c]), num_bits_over_unit + 31 - 8); + + int raw_output = unsat_output + quant_param.output_activation_min_; + output_ptr[axis_offset + c] = + (int8_t)MSMAX(quant_param.output_activation_min_, MSMIN(raw_output, quant_param.output_activation_max_)); } } } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.h index e824ad96e4..5d517df14d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.h @@ -24,7 +24,7 @@ #ifdef __cplusplus extern "C" { #endif -int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data, +int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp_data, int *sum_data, SoftmaxQuantArg quant_param, SoftmaxParameter *parameter); #ifdef __cplusplus } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.c index 3e2010ab2c..c12bac9111 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.c @@ -86,24 +86,22 @@ int32_t MaskNonZero(int32_t a) { return a ? BitNot(zreo) : zreo; } -int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent) { - int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0); - if (ExponentSign == 0) { - return x; - } else if (ExponentSign == 1) { +static inline int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent) { + if (Exponent > 0) { const int min = INT32_MIN; const int max = INT32_MAX; - const int thresold = ((1 << (uint32_t)(31 - Exponent)) - 1); + const int scalar_int_bits = 8 * sizeof(int32_t); + const int thresold = ((1 << (uint32_t)(scalar_int_bits - 1 - Exponent)) - 1); const int postive_mask = MaskNonZero(x > thresold); const int negative_mask = MaskNonZero(x < -thresold); - int result = x << Exponent; + int result = x * ((int32_t)(1) << (uint32_t)Exponent); result = SelectUsingMask(postive_mask, max, result); result = SelectUsingMask(negative_mask, min, result); return result; - } else if (ExponentSign == -1) { + } else if (Exponent < 0) { return RoundingDivideByPOT(x, -Exponent); } else { - return 0; + return x; } } @@ -113,7 +111,7 @@ int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst) { return result; } -static int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a) { +int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a) { int one = FixedPoint_One(0, FractionsBits(0)); int half_denominator = RoundingHalfSum(a, one); const int constant_48_over_17 = 1515870810; @@ -159,6 +157,71 @@ int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift) { const int32_t shifted_scaled = one_over_one_plus_x_for_x_in_0_1(shifted_minus_one); return shifted_scaled; } +int ConstantPOT(int fractional_bits, int exponent) { + int offset = fractional_bits + exponent; + return (1 << (uint32_t)offset); +} + +int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; } + +int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); } + +int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); } + +int exp_on_interval_between_negative_one_quarter_and_0_excl(int a) { + const int constant_term = 1895147668; + const int constant_1_over_3 = 715827883; + // We're evaluating a Taylor expansion around -1/8, so we do the change of + // variable: x = a + 1/8. + // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. + int kFractionalBits = FractionsBits(0); + int x = a + ConstantPOT(kFractionalBits, -3); + int x2 = SaturatingRoundingDoublingHighMul(x, x); + int x3 = SaturatingRoundingDoublingHighMul(x2, x); + int x4 = SaturatingRoundingDoublingHighMul(x2, x2); + int x4_over_4 = SaturatingRoundingMultiplyByPOT(x4, -2); + int x4_over_24_plus_x3_over_6_plus_x2_over_2 = + SaturatingRoundingMultiplyByPOT((SaturatingRoundingDoublingHighMul((x4_over_4 + x3), constant_1_over_3) + x2), -1); + return constant_term + + SaturatingRoundingDoublingHighMul(constant_term, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); +} + +int exp_on_negative_values(int a, const int tIntegerBits) { + int kIntegerBits = tIntegerBits; + int kFractionalBits = FractionsBits(tIntegerBits); + const int kOneQuarter = ConstantPOT(kFractionalBits, -2); + int mask = kOneQuarter - 1; + int a_mod_quarter_minus_one_quarter = ((unsigned)(a)&mask) - kOneQuarter; + int result = + exp_on_interval_between_negative_one_quarter_and_0_excl(Rescale(a_mod_quarter_minus_one_quarter, tIntegerBits, 0)); + int remainder = a_mod_quarter_minus_one_quarter - a; + +#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ + if (kIntegerBits > Exponent) { \ + const int kMultiplier = FixedPointMultiplier; \ + int kShiftAmount = kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ + result = SelectUsingMask(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)kShiftAmount))), \ + SaturatingRoundingDoublingHighMul(result, kMultiplier), result); \ + } + GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); + GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); + GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); + GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); + GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); + GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); + GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); +#undef GEMMLOWP_EXP_BARREL_SHIFTER + + int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0; + if (kIntegerBits > 5) { + const int clamp = -(1 << (uint32_t)clampB); + result = SelectUsingMask(MaskIfLessThan(a, clamp), 0, result); + } + + result = SelectUsingMask(MaskIfZero(a), FixedPoint_One(0, kFractionalBits), result); + return result; +} + #ifdef ENABLE_NEON int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) { const int32x4_t shift_vec = vdupq_n_s32(-exponent); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.h index 3cd47a3f2b..f3425de514 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.h @@ -60,11 +60,9 @@ int SelectUsingMask(int mask, int bound, int val); int32_t MaskNonZero(int32_t a); -int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent); - int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst); -static int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a); +int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a); int CountLeadingZeroBits(uint32_t x); @@ -72,6 +70,18 @@ int CountLeadingSignBits(int32_t x); int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift); +int exp_on_negative_values(int a, const int tIntegerBits); + +int ConstantPOT(int fractional_bits, int exponent); + +int32_t MaskIfNonZero(int32_t a); + +int32_t MaskIfZero(int32_t a); + +int32_t MaskIfLessThan(int32_t a, int32_t b); + +int exp_on_interval_between_negative_one_quarter_and_0_excl(int a); + #ifdef __cplusplus } #endif diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc index ff9a1e7528..c87a3ede50 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc @@ -80,9 +80,8 @@ TEST_F(TestSoftmaxInt8, SoftmaxInt8) { auto output_tensor_shape = output0_tensor.shape(); kernel->Run(); - std::vector except_result = {-126, -126, -124, -124, -123, -124, -116, -116, 121, 121, 111, 111, - -127, -127, -127, -127, -59, -59, -61, -59, 57, 57, 59, 57}; - + std::vector except_result = {-126, -126, -124, -124, -123, -124, -116, -116, 122, 122, 112, 112, + -127, -127, -127, -127, -59, -59, -61, -59, 58, 58, 59, 58}; CompareOutputData(output.data(), except_result.data(), input.size(), 0.000001); input0_tensor.SetData(nullptr);