Merge pull request !8148 from fuzhiye/tmptags/v1.1.0
| @@ -53,6 +53,8 @@ void ComputeStrides(const int *shape, int *strides, const int ndim); | |||
| void CalcMultiplesAndStrides(ArithmeticParameter *param); | |||
| void TileOneDimensionUint8(uint8_t *inData, uint8_t *outData, int dim, size_t ndim, int *inShape, int *inStrides, | |||
| int *outStrides, int *multiple); | |||
| void TileDimensions(float *data0, float *data1, float *tile_data0, float *tile_data1, ArithmeticParameter *param); | |||
| void TileDimensionsUint8(uint8_t *data0, uint8_t *data1, uint8_t *tile_data0, uint8_t *tile_data1, | |||
| ArithmeticParameter *param); | |||
| @@ -17,78 +17,148 @@ | |||
| #include "nnacl/int8/scale_int8.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| void ScaleInnerInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int outer_start, int outer_end, | |||
| int axis_size, int inner_size, const ScaleParameter *scale_param, int max, int min) { | |||
| for (int out = outer_start; out < outer_end; out++) { | |||
| int out_offset = out * axis_size * inner_size; | |||
| for (int i = 0; i < axis_size; i++) { | |||
| int axis_offset = out_offset + i * inner_size; | |||
| int in_index = 0; | |||
| for (; in_index < inner_size; in_index++) { | |||
| int in_offset = axis_offset + in_index; | |||
| int tmp_input_scale = (in_data[in_offset] - scale_param->input_zp_) * (scale[i] - scale_param->scale_zp_); | |||
| int input_mul_scale = | |||
| RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( | |||
| tmp_input_scale * (1 << (unsigned int)scale_param->scale_mul_arg_.left_shift_), | |||
| scale_param->scale_mul_arg_.multiplier_), | |||
| scale_param->scale_mul_arg_.right_shift_); | |||
| int tmp = input_mul_scale + scale_param->output_zp_; | |||
| tmp = tmp > max ? max : tmp; | |||
| tmp = tmp < min ? min : tmp; | |||
| out_data[in_offset] = tmp; | |||
| } | |||
| } | |||
| } | |||
| #ifdef ENABLE_NEON | |||
| int16x4_t ClacSumHalfWordMul2(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, | |||
| int32x4_t output_multiplier_vec, const ScaleParameter *scale_param) { | |||
| int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); | |||
| int32x4_t raw_sum = RoundingDivideByPOTInt32x4( | |||
| SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), | |||
| scale_param->scale_mul_arg_.right_shift_); | |||
| raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_)); | |||
| raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_)); | |||
| raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_)); | |||
| return vqmovn_s32(raw_sum); | |||
| } | |||
| void ScaleInnerWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, | |||
| int outer_start, int outer_end, int axis_size, int inner_size, | |||
| const ScaleParameter *scale_param, int max, int min) { | |||
| for (int out = outer_start; out < outer_end; out++) { | |||
| int out_offset = out * axis_size * inner_size; | |||
| for (int i = 0; i < axis_size; i++) { | |||
| int axis_offset = out_offset + i * inner_size; | |||
| int in_index = 0; | |||
| for (; in_index < inner_size; in_index++) { | |||
| int in_offset = axis_offset + in_index; | |||
| int tmp_input_scale = (in_data[in_offset] - scale_param->input_zp_) * (scale[i] - scale_param->scale_zp_); | |||
| int input_mul_scale = | |||
| RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( | |||
| tmp_input_scale * (1 << (unsigned int)scale_param->scale_mul_arg_.left_shift_), | |||
| scale_param->scale_mul_arg_.multiplier_), | |||
| scale_param->scale_mul_arg_.right_shift_); | |||
| int tmp_bias = offset[i] - scale_param->offset_zp_; | |||
| int bias = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(tmp_bias * (1 << (unsigned int)scale_param->offset_mul_arg_.left_shift_), | |||
| scale_param->offset_mul_arg_.multiplier_), | |||
| scale_param->offset_mul_arg_.right_shift_); | |||
| int tmp = input_mul_scale + bias + scale_param->output_zp_; | |||
| tmp = tmp > max ? max : tmp; | |||
| tmp = tmp < min ? min : tmp; | |||
| out_data[in_offset] = tmp; | |||
| } | |||
| } | |||
| } | |||
| int16x4_t ClacSumHalfWordMul3(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t scaled_input2, | |||
| const ScaleParameter *scale_param) { | |||
| int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_); | |||
| int32x4_t output_multiplier_vec2 = vdupq_n_s32(scale_param->offset_mul_arg_.multiplier_); | |||
| int32x4_t left_shift_out_vec = vdupq_n_s32(1 << scale_param->scale_mul_arg_.left_shift_); | |||
| int32x4_t left_shift_out_vec2 = vdupq_n_s32(1 << scale_param->offset_mul_arg_.left_shift_); | |||
| int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); | |||
| int32x4_t raw_sum = RoundingDivideByPOTInt32x4( | |||
| SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), | |||
| scale_param->scale_mul_arg_.right_shift_); | |||
| int32x4_t raw_sum2 = RoundingDivideByPOTInt32x4( | |||
| SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(scaled_input2, left_shift_out_vec2), output_multiplier_vec2), | |||
| scale_param->offset_mul_arg_.right_shift_); | |||
| raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_)); | |||
| raw_sum = vaddq_s32(raw_sum, raw_sum2); | |||
| raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_)); | |||
| raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_)); | |||
| return vqmovn_s32(raw_sum); | |||
| } | |||
| #endif | |||
| void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleParameter *scale_param, | |||
| int real_dst_count) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_); | |||
| int32x4_t left_shift_out_vec = vdupq_n_s32(1 << scale_param->scale_mul_arg_.left_shift_); | |||
| for (; index <= real_dst_count - 8; index += 8) { | |||
| int8x8_t input_s8 = vld1_s8(in_data + index); | |||
| int16x8_t input_s16 = vmovl_s8(input_s8); | |||
| int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_)); | |||
| int8x8_t input1_s8 = vld1_s8(scale + index); | |||
| int16x8_t input1_s16 = vmovl_s8(input1_s8); | |||
| int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_)); | |||
| int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); | |||
| int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); | |||
| int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); | |||
| int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); | |||
| int16x4_t sum_low = | |||
| ClacSumHalfWordMul2(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, scale_param); | |||
| int16x4_t sum_high = | |||
| ClacSumHalfWordMul2(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, scale_param); | |||
| int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); | |||
| int8x8_t res_u8_n0 = vqmovn_s16(res_s16); | |||
| vst1_s8(out_data, res_u8_n0); | |||
| out_data += 8; | |||
| } | |||
| #endif | |||
| for (; index < real_dst_count; ++index) { | |||
| const int32_t input0_val = scale_param->input_zp_ + in_data[index]; | |||
| const int32_t input1_val = scale_param->scale_zp_ + scale[index]; | |||
| int32_t mul_result = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_), | |||
| scale_param->scale_mul_arg_.multiplier_), | |||
| scale_param->scale_mul_arg_.right_shift_); | |||
| void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int task_id, | |||
| const ScaleParameter *scale_param, int max, int min) { | |||
| int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_); | |||
| int outer_start = task_id * outer_step; | |||
| int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); | |||
| mul_result += scale_param->output_zp_; | |||
| ScaleInnerInt8(in_data, out_data, scale, outer_start, outer_end, scale_param->axis_size_, scale_param->inner_size_, | |||
| scale_param, max, min); | |||
| if (mul_result > scale_param->output_activation_max_) { | |||
| out_data[index] = scale_param->output_activation_max_; | |||
| } else if (mul_result < scale_param->output_activation_min_) { | |||
| out_data[index] = scale_param->output_activation_min_; | |||
| } else { | |||
| out_data[index] = (int8_t)mul_result; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, | |||
| int task_id, const ScaleParameter *scale_param, int max, int min) { | |||
| int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_); | |||
| int outer_start = task_id * outer_step; | |||
| int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); | |||
| const ScaleParameter *scale_param, int real_dst_count) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= real_dst_count - 8; index += 8) { | |||
| int8x8_t input_s8 = vld1_s8(in_data + index); | |||
| int16x8_t input_s16 = vmovl_s8(input_s8); | |||
| int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_)); | |||
| int8x8_t input1_s8 = vld1_s8(scale + index); | |||
| int16x8_t input1_s16 = vmovl_s8(input1_s8); | |||
| int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_)); | |||
| int8x8_t input2_s8 = vld1_s8(offset + index); | |||
| int16x8_t input2_s16 = vmovl_s8(input2_s8); | |||
| int16x8_t input2_val = vaddq_s16(input2_s16, vdupq_n_s16(scale_param->offset_zp_)); | |||
| int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); | |||
| int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); | |||
| int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); | |||
| int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); | |||
| int32x4_t input2_low = vmovl_s16(vget_low_s16(input2_val)); | |||
| int32x4_t input2_high = vmovl_s16(vget_high_s16(input2_val)); | |||
| int16x4_t sum_low = ClacSumHalfWordMul3(input0_low, input1_low, input2_low, scale_param); | |||
| int16x4_t sum_high = ClacSumHalfWordMul3(input0_high, input1_high, input2_high, scale_param); | |||
| ScaleInnerWithBiasInt8(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, | |||
| scale_param->inner_size_, scale_param, max, min); | |||
| int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); | |||
| int8x8_t res_u8_n0 = vqmovn_s16(res_s16); | |||
| vst1_s8(out_data, res_u8_n0); | |||
| out_data += 8; | |||
| } | |||
| #endif | |||
| for (; index < real_dst_count; ++index) { | |||
| const int32_t input0_val = in_data[index] - scale_param->input_zp_; | |||
| const int32_t input1_val = scale[index] - scale_param->scale_zp_; | |||
| int32_t mul_result = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_), | |||
| scale_param->scale_mul_arg_.multiplier_), | |||
| scale_param->scale_mul_arg_.right_shift_); | |||
| int tmp_bias = offset[index] - scale_param->offset_zp_; | |||
| int bias = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(tmp_bias * (1 << (unsigned int)scale_param->offset_mul_arg_.left_shift_), | |||
| scale_param->offset_mul_arg_.multiplier_), | |||
| scale_param->offset_mul_arg_.right_shift_); | |||
| mul_result += bias + scale_param->output_zp_; | |||
| if (mul_result > scale_param->output_activation_max_) { | |||
| out_data[index] = scale_param->output_activation_max_; | |||
| } else if (mul_result < scale_param->output_activation_min_) { | |||
| out_data[index] = scale_param->output_activation_min_; | |||
| } else { | |||
| out_data[index] = (int8_t)mul_result; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| @@ -22,10 +22,10 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int task_id, | |||
| const ScaleParameter *scale_param, int max, int min); | |||
| void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleParameter *scale_param, | |||
| int real_dst_count); | |||
| void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, | |||
| int task_id, const ScaleParameter *scale_param, int max, int min); | |||
| const ScaleParameter *scale_param, int real_dst_count); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -34,6 +34,8 @@ typedef struct ScaleParameter { | |||
| int offset_zp_; | |||
| int output_zp_; | |||
| int activation_type_; | |||
| int output_activation_min_; | |||
| int output_activation_max_; | |||
| } ScaleParameter; | |||
| #endif // MINDSPORE_LITE_NNACL_SCALE_H_ | |||
| @@ -19,6 +19,7 @@ | |||
| #include <string.h> | |||
| #include <vector> | |||
| #include "nnacl/int8/scale_int8.h" | |||
| #include "nnacl/arithmetic_common.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| @@ -35,63 +36,65 @@ constexpr size_t kScaleInputsSize = 2; | |||
| constexpr size_t kScaleBiasInputsSize = 3; | |||
| } // namespace | |||
| ScaleInt8CPUKernel::~ScaleInt8CPUKernel() { | |||
| if (scale_param_->const_scale_) { | |||
| if (scale_ != nullptr) { | |||
| free(scale_); | |||
| scale_ = nullptr; | |||
| } | |||
| if (tile_para != nullptr) { | |||
| free(tile_para); | |||
| tile_para = nullptr; | |||
| } | |||
| if (has_bias_ && scale_param_->const_offset_) { | |||
| if (offset_ != nullptr) { | |||
| free(offset_); | |||
| offset_ = nullptr; | |||
| } | |||
| if (input1_data_ != nullptr && malloced_scale_) { | |||
| free(input1_data_); | |||
| } | |||
| if (input2_data_ != nullptr && malloced_offset_) { | |||
| free(input2_data_); | |||
| } | |||
| } | |||
| int ScaleInt8CPUKernel::InitScaleOffset() { | |||
| auto scale_tensor = in_tensors_.at(1); | |||
| int8_t *scale_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c()); | |||
| CalcMultiplesAndStrides(tile_para); | |||
| scale_param_->const_scale_ = false; | |||
| auto *scale_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c()); | |||
| // scale may be const value ,can be processed in prepare stage | |||
| if (scale_ptr != nullptr) { | |||
| scale_param_->const_scale_ = true; | |||
| if (scale_ != nullptr) { | |||
| free(scale_); | |||
| scale_ = nullptr; | |||
| } | |||
| scale_ = reinterpret_cast<int8_t *>(malloc(scale_tensor->ElementsNum() * sizeof(int8_t))); | |||
| if (scale_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| input1_data_ = scale_ptr; | |||
| // need broadcasting | |||
| if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { | |||
| input1_data_ = reinterpret_cast<int8_t *>(malloc(out_tensors_.at(0)->Size())); | |||
| if (input1_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc input1_data_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| malloced_scale_ = true; | |||
| TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(1)->data_c()), | |||
| reinterpret_cast<uint8_t *>(input1_data_), 0, tile_para->ndim_, tile_para->in_shape1_, | |||
| tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_); | |||
| } | |||
| memcpy(scale_, scale_ptr, scale_tensor->ElementsNum() * sizeof(int8_t)); | |||
| } else { | |||
| scale_param_->const_scale_ = false; | |||
| scale_ = nullptr; | |||
| } | |||
| scale_param_->const_offset_ = false; | |||
| if (in_tensors_.size() == 3) { | |||
| has_bias_ = true; | |||
| auto offset_tensor = in_tensors_.at(2); | |||
| int8_t *offset_ptr = reinterpret_cast<int8_t *>(offset_tensor->data_c()); | |||
| auto *offset_ptr = reinterpret_cast<int8_t *>(offset_tensor->data_c()); | |||
| // offset may be const value ,can be processed in prepare stage | |||
| if (offset_ptr != nullptr) { | |||
| scale_param_->const_offset_ = true; | |||
| if (offset_ != nullptr) { | |||
| free(offset_); | |||
| offset_ = nullptr; | |||
| } | |||
| offset_ = reinterpret_cast<int8_t *>(malloc(offset_tensor->ElementsNum() * sizeof(int8_t))); | |||
| if (offset_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| input2_data_ = offset_ptr; | |||
| // need broadcasting | |||
| if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(2)->ElementsNum()) { | |||
| input2_data_ = reinterpret_cast<int8_t *>(malloc(out_tensors_.at(0)->Size())); | |||
| if (input2_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc input2_data_ failed."; | |||
| free(input1_data_); | |||
| return RET_ERROR; | |||
| } | |||
| malloced_offset_ = true; | |||
| TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(2)->data_c()), | |||
| reinterpret_cast<uint8_t *>(input2_data_), 0, tile_para->ndim_, tile_para->in_shape1_, | |||
| tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_); | |||
| } | |||
| memcpy(offset_, offset_ptr, offset_tensor->ElementsNum() * sizeof(int8_t)); | |||
| } else { | |||
| scale_param_->const_offset_ = false; | |||
| offset_ = nullptr; | |||
| } | |||
| } else { | |||
| has_bias_ = false; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -102,29 +105,66 @@ int ScaleInt8CPUKernel::InitParameter() { | |||
| auto scale_shape = scale_tensor->shape(); | |||
| if (scale_param_->axis_ < 0) { | |||
| scale_param_->axis_ = scale_param_->axis_ + in_shape.size(); | |||
| scale_param_->axis_ += in_shape.size(); | |||
| } | |||
| if (scale_shape.size() + scale_param_->axis_ > in_shape.size()) { | |||
| MS_LOG(ERROR) << "Scale tensor shape is incorrect."; | |||
| return RET_ERROR; | |||
| } | |||
| scale_param_->outer_size_ = 1; | |||
| scale_param_->axis_size_ = 1; | |||
| scale_param_->inner_size_ = 1; | |||
| for (int i = 0; i < scale_param_->axis_; i++) { | |||
| scale_param_->outer_size_ *= in_shape[i]; | |||
| } | |||
| for (size_t i = 0; i < scale_shape.size(); i++) { | |||
| if (in_shape[i + scale_param_->axis_] != scale_shape[i]) { | |||
| MS_LOG(ERROR) << "Scale tensor shape is incorrect."; | |||
| return RET_ERROR; | |||
| } | |||
| scale_param_->axis_size_ *= in_shape[i + scale_param_->axis_]; | |||
| } | |||
| for (size_t i = scale_param_->axis_ + scale_shape.size(); i < in_shape.size(); i++) { | |||
| scale_param_->inner_size_ *= in_shape[i]; | |||
| tile_para = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||
| if (tile_para == nullptr) { | |||
| MS_LOG(ERROR) << "malloc tile parameter failed."; | |||
| return RET_ERROR; | |||
| } | |||
| scale_param_->op_parameter_.thread_num_ = MSMIN(scale_param_->op_parameter_.thread_num_, scale_param_->outer_size_); | |||
| size_t input0_size = in_tensors_.at(0)->shape().size(); | |||
| size_t input1_size = in_tensors_.at(1)->shape().size(); | |||
| size_t output_size = out_tensors_.at(0)->shape().size(); | |||
| auto input1_shape = in_tensors_.at(1)->shape(); | |||
| tile_para->ndim_ = output_size; | |||
| // supplement shape of scale tensor with number 1 | |||
| size_t len = input0_size - scale_param_->axis_; | |||
| second_in_shape_ = input1_shape; | |||
| if (len != input1_size) { | |||
| second_in_shape_.resize(len); | |||
| size_t i = 0; | |||
| for (; i < input1_size; ++i) { | |||
| second_in_shape_[i] = input1_shape[i]; | |||
| } | |||
| for (; i < len; ++i) { | |||
| second_in_shape_[i] = 1; | |||
| } | |||
| input1_size = len; | |||
| } | |||
| if (input0_size == input1_size) { | |||
| for (size_t i = 0; i < output_size; i++) { | |||
| tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i); | |||
| tile_para->in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i); | |||
| tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i); | |||
| } | |||
| } else { | |||
| MS_ASSERT(input0_size > input1_size); | |||
| size_t fill_dim_num = input0_size - input1_size; | |||
| int j = 0; | |||
| for (size_t i = 0; i < output_size; i++) { | |||
| tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i); | |||
| if (i < fill_dim_num) { | |||
| tile_para->in_shape1_[i] = 1; | |||
| } else { | |||
| tile_para->in_shape1_[i] = second_in_shape_[j++]; | |||
| } | |||
| tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -156,6 +196,24 @@ int ScaleInt8CPUKernel::InitQuantArgs() { | |||
| scale_param_->offset_mul_arg_.left_shift_ = shift > 0 ? shift : 0; | |||
| scale_param_->offset_mul_arg_.right_shift_ = shift < 0 ? -shift : 0; | |||
| } | |||
| switch (scale_param_->activation_type_) { | |||
| case schema::ActivationType_RELU: | |||
| scale_param_->output_activation_min_ = 0; | |||
| scale_param_->output_activation_max_ = INT8_MAX; | |||
| break; | |||
| case schema::ActivationType_RELU6: | |||
| scale_param_->output_activation_min_ = 0; | |||
| scale_param_->output_activation_max_ = 6; | |||
| break; | |||
| case schema::ActivationType_NO_ACTIVATION: | |||
| scale_param_->output_activation_min_ = INT8_MIN; | |||
| scale_param_->output_activation_max_ = INT8_MAX; | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -176,13 +234,13 @@ int ScaleInt8CPUKernel::Init() { | |||
| int ScaleInt8CPUKernel::ReSize() { | |||
| auto ret = InitParameter(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Scale fp32 InitParameter failed."; | |||
| MS_LOG(ERROR) << "Scale int8 InitParameter failed."; | |||
| return RET_ERROR; | |||
| } | |||
| ret = InitScaleOffset(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Scale fp32 InitScaleOffset failed."; | |||
| MS_LOG(ERROR) << "Scale int8 InitScaleOffset failed."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -195,38 +253,21 @@ int ScaleInt8CPUKernel::ReSize() { | |||
| } | |||
| int ScaleInt8CPUKernel::Scale(int task_id) { | |||
| int real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_); | |||
| if (real_dst_count <= 0) { | |||
| return lite::RET_OK; | |||
| } | |||
| int8_t *cur_input0_data = input0_data_ + task_id * count_unit_; | |||
| int8_t *cur_input1_data = input1_data_ + task_id * count_unit_; | |||
| int8_t *cur_output_data = output_data_ + task_id * count_unit_; | |||
| if (has_bias_) { | |||
| switch (scale_param_->activation_type_) { | |||
| case schema::ActivationType_RELU: | |||
| DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, INT8_MAX, 0); | |||
| break; | |||
| case schema::ActivationType_RELU6: | |||
| DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, 6, 0); | |||
| break; | |||
| case schema::ActivationType_NO_ACTIVATION: | |||
| DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, INT8_MAX, INT8_MIN); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_; | |||
| return RET_ERROR; | |||
| } | |||
| int8_t *cur_input2_data = input2_data_ + task_id * count_unit_; | |||
| DoScaleWithBiasInt8(cur_input0_data, cur_output_data, cur_input1_data, cur_input2_data, scale_param_, | |||
| real_dst_count); | |||
| } else { | |||
| switch (scale_param_->activation_type_) { | |||
| case schema::ActivationType_RELU: | |||
| DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, INT8_MAX, 0); | |||
| break; | |||
| case schema::ActivationType_RELU6: | |||
| DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, 6, 0); | |||
| break; | |||
| case schema::ActivationType_NO_ACTIVATION: | |||
| DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, INT8_MAX, INT8_MIN); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_; | |||
| return RET_ERROR; | |||
| } | |||
| DoScaleInt8(cur_input0_data, cur_output_data, cur_input1_data, scale_param_, real_dst_count); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -241,18 +282,59 @@ int ScaleRunInt8(void *cdata, int task_id) { | |||
| } | |||
| int ScaleInt8CPUKernel::Run() { | |||
| auto in_tensor = in_tensors_.front(); | |||
| input_ptr_ = reinterpret_cast<int8_t *>(in_tensor->data_c()); | |||
| if (scale_ == nullptr) { | |||
| auto scale_tensor = in_tensors_[1]; | |||
| scale_ = reinterpret_cast<int8_t *>(scale_tensor->data_c()); | |||
| elements_num_ = out_tensors_.at(0)->ElementsNum(); | |||
| count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; | |||
| input0_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data_c()); | |||
| output_data_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data_c()); | |||
| // need broadcasting | |||
| if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { | |||
| // scale is passed by previous node, need do broadcasting online | |||
| if (!scale_param_->const_scale_) { | |||
| input1_data_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size())); | |||
| if (input1_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc input1_data_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(1)->data_c()), | |||
| reinterpret_cast<uint8_t *>(input1_data_), 0, tile_para->ndim_, tile_para->in_shape1_, | |||
| tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_); | |||
| } | |||
| // If has bias, bias is passed by previous node case, need do broadcasting online | |||
| if (has_bias_ && !scale_param_->const_offset_) { | |||
| input2_data_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size())); | |||
| if (input2_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc input2_data_ failed."; | |||
| ctx_->allocator->Free(input1_data_); | |||
| input1_data_ = nullptr; | |||
| return RET_ERROR; | |||
| } | |||
| TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(2)->data_c()), | |||
| reinterpret_cast<uint8_t *>(input2_data_), 0, tile_para->ndim_, tile_para->in_shape1_, | |||
| tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_); | |||
| } | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ScaleRunInt8, this, op_parameter_->thread_num_); | |||
| // free memory malloced from memory pool | |||
| if (!scale_param_->const_scale_) { | |||
| ctx_->allocator->Free(input1_data_); | |||
| input1_data_ = nullptr; | |||
| } | |||
| if (has_bias_ && !scale_param_->const_offset_) { | |||
| ctx_->allocator->Free(input2_data_); | |||
| input2_data_ = nullptr; | |||
| } | |||
| return ret; | |||
| } | |||
| // input1 has the same shape with input0 situation | |||
| if (input1_data_ == nullptr) { | |||
| input1_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c()); | |||
| } | |||
| if (has_bias_ && !scale_param_->const_offset_) { | |||
| offset_ = reinterpret_cast<int8_t *>(in_tensors_.at(2)->data_c()); | |||
| input2_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(2)->data_c()); | |||
| } | |||
| auto out_tensor = out_tensors_.front(); | |||
| output_ptr_ = reinterpret_cast<int8_t *>(out_tensor->data_c()); | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ScaleRunInt8, this, op_parameter_->thread_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; | |||
| @@ -260,6 +342,7 @@ int ScaleInt8CPUKernel::Run() { | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuScaleInt8KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| @@ -21,6 +21,7 @@ | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/scale.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/arithmetic_common.h" | |||
| namespace mindspore::kernel { | |||
| @@ -29,7 +30,7 @@ class ScaleInt8CPUKernel : public LiteKernel { | |||
| ScaleInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) { | |||
| scale_param_ = reinterpret_cast<ScaleParameter *>(op_parameter_); | |||
| } | |||
| ~ScaleInt8CPUKernel() override; | |||
| @@ -42,12 +43,20 @@ class ScaleInt8CPUKernel : public LiteKernel { | |||
| int Scale(int task_id); | |||
| private: | |||
| int8_t *input_ptr_ = nullptr; | |||
| int8_t *scale_ = nullptr; | |||
| int8_t *offset_ = nullptr; | |||
| int8_t *output_ptr_ = nullptr; | |||
| bool has_bias_ = false; | |||
| int8_t *input0_data_ = nullptr; | |||
| int8_t *input1_data_ = nullptr; | |||
| int8_t *input2_data_ = nullptr; | |||
| int8_t *output_data_ = nullptr; | |||
| const lite::InnerContext *ctx_; | |||
| ScaleParameter *scale_param_; | |||
| ArithmeticParameter *tile_para = nullptr; | |||
| std::vector<int> second_in_shape_; | |||
| int thread_count_; | |||
| int64_t elements_num_; | |||
| int64_t count_unit_; | |||
| bool has_bias_ = false; | |||
| bool malloced_scale_ = false; | |||
| bool malloced_offset_ = false; | |||
| int InitQuantArgs(); | |||
| }; | |||