Browse Source

mindspore lite:solve argmax int8 problem

r1.7
liu lili 4 years ago
parent
commit
e2a79ab004
3 changed files with 70 additions and 40 deletions
  1. +50
    -24
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/arg_min_max_int8.c
  2. +10
    -10
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/arg_min_max_int8.h
  3. +10
    -6
      mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc

+ 50
- 24
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/arg_min_max_int8.c View File

@@ -31,8 +31,24 @@ void CalcParameter(const int *shape, int dims_number, int axis, int *pre_axis_co
}
}

void DoArgMinMaxQuant(const int8_t *input, int8_t *output, const ArgMinMaxParameter *param, int pre_axis_count,
int axis_count, int after_axis_count, const QuantArg *in_quant_arg,
void SetOutputValue(float value, int32_t index, int8_t *output1, int8_t *output2, int offset,
float output_inverse_scale, float output_zp, bool out_value) {
if (output2) {
int32_t *output1_index = (int32_t *)output1;
output1_index[offset] = index;
output2[offset] = value * output_inverse_scale + output_zp;
} else {
if (out_value) {
output1[offset] = value * output_inverse_scale + output_zp;
} else {
int32_t *output1_index = (int32_t *)output1;
output1_index[offset] = index;
}
}
}

void DoArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const ArgMinMaxParameter *param,
int pre_axis_count, int axis_count, int after_axis_count, const QuantArg *in_quant_arg,
const QuantArg *out_quant_arg) {
bool out_value = param->out_value_;
const float output_inverse_scale = 1.f / out_quant_arg->scale_;
@@ -46,7 +62,7 @@ void DoArgMinMaxQuant(const int8_t *input, int8_t *output, const ArgMinMaxParame
if (!param->get_max_) {
value = FLT_MAX;
}
float index = 0.0f;
int32_t index = 0.0f;
for (int k = 0; k < axis_count; ++k) {
float value_tmp = input[input_offset + k * after_axis_count + j] * in_quant_arg->scale_ + bias;
if (param->get_max_) {
@@ -61,19 +77,21 @@ void DoArgMinMaxQuant(const int8_t *input, int8_t *output, const ArgMinMaxParame
}
}
}
float real_out = out_value ? value : index;
output[output_offset + j] = real_out * output_inverse_scale + output_zp;
SetOutputValue(value, index, output1, output2, output_offset + j, output_inverse_scale, output_zp, out_value);
// float real_out = out_value ? value : index;
// output[output_offset + j] = real_out * output_inverse_scale + output_zp;
}
}
}

void Int8ArgMinMaxQuant(const int8_t *input, int8_t *output, const int *in_shape, const ArgMinMaxParameter *param,
const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) {
void Int8ArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const int *in_shape,
const ArgMinMaxParameter *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) {
int pre_axis_count = 1;
int axis_count = 1;
int after_axis_count = 1;
CalcParameter(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count);
DoArgMinMaxQuant(input, output, param, pre_axis_count, axis_count, after_axis_count, in_quant_arg, out_quant_arg);
DoArgMinMaxQuant(input, output1, output2, param, pre_axis_count, axis_count, after_axis_count, in_quant_arg,
out_quant_arg);
return;
}

@@ -89,8 +107,8 @@ int8_t GetInt8Output(float real_out, float output_inverse_scale, int32_t output_
return real_out * output_inverse_scale + output_zp;
}

void Int8ArgMinMaxDim0(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param,
const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) {
void Int8ArgMinMaxDim0(const int8_t *input, int8_t *output1, int8_t *output2, const int *in_shape,
ArgMinMaxParameter *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) {
bool out_value = param->out_value_;
const float output_inverse_scale = 1.f / out_quant_arg->scale_;
float bias = -in_quant_arg->zp_ * in_quant_arg->scale_;
@@ -109,14 +127,16 @@ void Int8ArgMinMaxDim0(const int8_t *input, int8_t *output, const int *in_shape,

for (int j = 0; j < param->topk_; ++j) {
int out_offset = j * param->out_strides_[0] + i;
float real_out = out_value ? param->arg_elements_[j].data_.f_data_ : param->arg_elements_[j].index_;
output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp);
SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2,
out_offset, output_inverse_scale, output_zp, out_value);
// float real_out = out_value ? param->arg_elements_[j].data_.f_data_ : param->arg_elements_[j].index_;
// output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp);
}
}
}

void Int8ArgMinMaxDim1(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param,
const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) {
void Int8ArgMinMaxDim1(const int8_t *input, int8_t *output1, int8_t *output2, const int *in_shape,
ArgMinMaxParameter *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) {
bool out_value = param->out_value_;
const float output_inverse_scale = 1.f / out_quant_arg->scale_;
float bias = -in_quant_arg->zp_ * in_quant_arg->scale_;
@@ -139,15 +159,17 @@ void Int8ArgMinMaxDim1(const int8_t *input, int8_t *output, const int *in_shape,

for (int k = 0; k < param->topk_; ++k) {
int out_offset = out_dim0_offset + j + k * param->out_strides_[1];
float real_out = out_value ? param->arg_elements_[k].data_.f_data_ : param->arg_elements_[k].index_;
output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp);
SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2,
out_offset, output_inverse_scale, output_zp, out_value);
// float real_out = out_value ? param->arg_elements_[k].data_.f_data_ : param->arg_elements_[k].index_;
// output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp);
}
}
}
}

void Int8ArgMinMaxDim2(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param,
const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) {
void Int8ArgMinMaxDim2(const int8_t *input, int8_t *output1, int8_t *output2, const int *in_shape,
ArgMinMaxParameter *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) {
bool out_value = param->out_value_;
const float output_inverse_scale = 1.f / out_quant_arg->scale_;
float bias = -in_quant_arg->zp_ * in_quant_arg->scale_;
@@ -173,16 +195,18 @@ void Int8ArgMinMaxDim2(const int8_t *input, int8_t *output, const int *in_shape,
}
for (int l = 0; l < param->topk_; ++l) {
int out_offset = out_dim1_offset + k + l * param->out_strides_[2];
float real_out = out_value ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_;
output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp);
SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2,
out_offset, output_inverse_scale, output_zp, out_value);
// float real_out = out_value ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_;
// output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp);
}
}
}
}
}

void Int8ArgMinMaxDim3(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param,
const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) {
void Int8ArgMinMaxDim3(const int8_t *input, int8_t *output1, int8_t *output2, const int *in_shape,
ArgMinMaxParameter *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) {
bool out_value = param->out_value_;
const float output_inverse_scale = 1.f / out_quant_arg->scale_;
float bias = -in_quant_arg->zp_ * in_quant_arg->scale_;
@@ -211,8 +235,10 @@ void Int8ArgMinMaxDim3(const int8_t *input, int8_t *output, const int *in_shape,
}
for (int l = 0; l < param->topk_; ++l) {
int out_offset = out_dim2_offset + l;
float real_out = out_value ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_;
output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp);
SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2,
out_offset, output_inverse_scale, output_zp, out_value);
// float real_out = out_value ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_;
// output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp);
}
}
}


+ 10
- 10
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/arg_min_max_int8.h View File

@@ -23,16 +23,16 @@
extern "C" {
#endif

void Int8ArgMinMaxQuant(const int8_t *input, int8_t *output, const int *in_shape, const ArgMinMaxParameter *param,
const QuantArg *in_quant, const QuantArg *out_quant);
void Int8ArgMinMaxDim0(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param,
const QuantArg *in_quant, const QuantArg *out_quant);
void Int8ArgMinMaxDim1(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param,
const QuantArg *in_quant, const QuantArg *out_quant);
void Int8ArgMinMaxDim2(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param,
const QuantArg *in_quant, const QuantArg *out_quant);
void Int8ArgMinMaxDim3(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param,
const QuantArg *in_quant, const QuantArg *out_quant);
void Int8ArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const int *in_shape,
const ArgMinMaxParameter *param, const QuantArg *in_quant, const QuantArg *out_quant);
void Int8ArgMinMaxDim0(const int8_t *input, int8_t *output1, int8_t *output2, const int *in_shape,
ArgMinMaxParameter *param, const QuantArg *in_quant, const QuantArg *out_quant);
void Int8ArgMinMaxDim1(const int8_t *input, int8_t *output1, int8_t *output2, const int *in_shape,
ArgMinMaxParameter *param, const QuantArg *in_quant, const QuantArg *out_quant);
void Int8ArgMinMaxDim2(const int8_t *input, int8_t *output1, int8_t *output2, const int *in_shape,
ArgMinMaxParameter *param, const QuantArg *in_quant, const QuantArg *out_quant);
void Int8ArgMinMaxDim3(const int8_t *input, int8_t *output1, int8_t *output2, const int *in_shape,
ArgMinMaxParameter *param, const QuantArg *in_quant, const QuantArg *out_quant);

#ifdef __cplusplus
}


+ 10
- 6
mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc View File

@@ -58,9 +58,9 @@ int ArgMinMaxInt8CPUKernel::Prepare() {
auto *out_tensor = out_tensors_.at(kOutputIndex);
auto out_quant_args = out_tensor->quant_params();
CHECK_LESS_RETURN(out_quant_args.size(), 1);
out_quant_arg_ = reinterpret_cast<QuantArg *>(malloc(sizeof(QuantArg)));
out_quant_arg_->scale_ = out_quant_args.front().scale;
out_quant_arg_->zp_ = out_quant_args.front().zeroPoint;
out_quant_arg_ = reinterpret_cast<QuantArg *>(malloc(sizeof(QuantArg)));
if (out_quant_arg_ == nullptr) {
MS_LOG(ERROR) << "Malloc QuantArg for argmin or argmax int8 op failed!";
return RET_ERROR;
@@ -98,6 +98,10 @@ int ArgMinMaxInt8CPUKernel::Run() {

const int8_t *input_data = reinterpret_cast<const int8_t *>(in_tensors_.at(0)->MutableData());
int8_t *output_data = reinterpret_cast<int8_t *>(out_tensors_.at(0)->MutableData());
int8_t *output_value = nullptr;
if (out_tensors_.size() == C2NUM) {
output_value = reinterpret_cast<int8_t *>(out_tensors_.at(C1NUM)->MallocData());
}
CHECK_NULL_RETURN(input_data);
CHECK_NULL_RETURN(output_data);
auto in_shape = input->shape();
@@ -105,23 +109,23 @@ int ArgMinMaxInt8CPUKernel::Run() {
CHECK_NULL_RETURN(in_shape.data());
CHECK_NULL_RETURN(param);
if (param->topk_ == 1) {
Int8ArgMinMaxQuant(input_data, output_data, in_shape.data(), param, in_quant_arg_, out_quant_arg_);
Int8ArgMinMaxQuant(input_data, output_data, output_value, in_shape.data(), param, in_quant_arg_, out_quant_arg_);
return RET_OK;
}
CHECK_NULL_RETURN(in_quant_arg_);
CHECK_NULL_RETURN(out_quant_arg_);
switch (param->axis_) {
case 0:
Int8ArgMinMaxDim0(input_data, output_data, in_shape.data(), param, in_quant_arg_, out_quant_arg_);
Int8ArgMinMaxDim0(input_data, output_data, output_value, in_shape.data(), param, in_quant_arg_, out_quant_arg_);
break;
case 1:
Int8ArgMinMaxDim1(input_data, output_data, in_shape.data(), param, in_quant_arg_, out_quant_arg_);
Int8ArgMinMaxDim1(input_data, output_data, output_value, in_shape.data(), param, in_quant_arg_, out_quant_arg_);
break;
case 2:
Int8ArgMinMaxDim2(input_data, output_data, in_shape.data(), param, in_quant_arg_, out_quant_arg_);
Int8ArgMinMaxDim2(input_data, output_data, output_value, in_shape.data(), param, in_quant_arg_, out_quant_arg_);
break;
case 3:
Int8ArgMinMaxDim3(input_data, output_data, in_shape.data(), param, in_quant_arg_, out_quant_arg_);
Int8ArgMinMaxDim3(input_data, output_data, output_value, in_shape.data(), param, in_quant_arg_, out_quant_arg_);
break;
default:
MS_LOG(ERROR) << "axis is invalid";


Loading…
Cancel
Save