diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index ac8b1b2212..c39e8d94b3 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -58,12 +58,10 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver schema::ActivationType_RELU)); pm->AddPass(std::make_shared(true, "conv_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6)); - pm->AddPass(std::make_shared(true, "conv_tuple_relu", - schema::PrimitiveType_Activation, - schema::ActivationType_RELU)); - pm->AddPass(std::make_shared(true, "conv_tuple_relu6", - schema::PrimitiveType_Activation, - schema::ActivationType_RELU6)); + pm->AddPass(std::make_shared( + true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU)); + pm->AddPass(std::make_shared( + true, "conv_tuple_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6)); auto weight_format_hardcode_pass = std::make_shared(); weight_format_hardcode_pass->SetFmkType(config->fmk); weight_format_hardcode_pass->SetQuantType(config->quantType); @@ -114,16 +112,6 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - if (config->quantType == schema::QuantType_PostTraining) { - quant::QuantCast quant_cast; - quant_cast.SetInputDataDType(kNumberTypeFloat32); - status = quant_cast.Run(new_graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "add QuantCast error"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } - } } return new_graph; diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 9508f621da..e70c873a13 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -17,23 +17,28 @@ #include "tools/converter/quantizer/post_training_quantizer.h" #include #include +#include #include #include #include +#include #include #include #include #include +#include #include #include #include "schema/inner/model_generated.h" #include "src/tensor.h" #include "tools/anf_exporter/anf_exporter.h" +#include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/quantize_util.h" #include "utils/log_adapter.h" #include "securec/include/securec.h" #include "tools/common/tensor_util.h" #include "src/common/file_utils.h" +#include "src/common/utils.h" using std::string; using std::vector; @@ -431,6 +436,8 @@ STATUS Calibrator::ReadConfig() { } auto key = line.substr(0, index); auto value = line.substr(index + 1); + Trim(&key); + Trim(&value); if (key == "image_path") { config_param_.image_path = value; } else if (key == "batch_count") { @@ -443,6 +450,11 @@ STATUS Calibrator::ReadConfig() { } else { config_param_.method_x = value; } + } else if (key == "bias_correction") { + std::for_each(value.begin(), value.end(), ::tolower); + if (value == "true") { + config_param_.bias_correction = true; + } } else { MS_LOG(WARNING) << "unsupported parameter"; } @@ -450,7 +462,8 @@ STATUS Calibrator::ReadConfig() { MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " " << "batch_count: " << config_param_.batch_count << " " << "method_x: " << config_param_.method_x << " " - << "thread_num: " << config_param_.thread_num; + << "thread_num: " << config_param_.thread_num << " " + << "bias_correction: " << config_param_.bias_correction; delete[] resolved_path; fs.close(); @@ -533,8 +546,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr

fullname_with_scope() << " can not get value"; return RET_ERROR; } - auto status = QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num, - perchanel); + auto status = + QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num, perchanel); if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed: " << status; return status; @@ -637,8 +650,8 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptrType(); - MS_LOG(INFO) << "OpName: " << op_name; + MS_LOG(DEBUG) << "OpName: " << op_name; if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D && op_type != PrimitiveType_FullConnection) { for (size_t i = 1; i < cnode->inputs().size(); i++) { @@ -811,16 +824,16 @@ STATUS PostTrainingQuantizer::PreProcess() { return RET_OK; } -STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &node_name, - const std::vector &tensor_vec) const { +STATUS PostTrainingQuantizer::CheckFp32TensorVec(const std::string &node_name, + const std::vector &tensor_vec) const { if (tensor_vec.size() < 1) { MS_LOG(ERROR) << "node: " << node_name << " input tensors is 0"; return RET_ERROR; } auto *tensor = tensor_vec[0]; if (tensor->data_type() != kNumberTypeFloat32) { - MS_LOG(DEBUG) << "node: " << node_name << " will not quantize" - << " tensor data_type: " << tensor->data_type(); + MS_LOG(WARNING) << "node: " << node_name << " will not quantize" + << " tensor data_type: " << tensor->data_type(); return RET_ERROR; } return RET_OK; @@ -834,7 +847,7 @@ STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &node_name, STATUS PostTrainingQuantizer::DoInference() { for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) { // get input tensor - vector inputs = session_->GetInputs(); + vector inputs = fp32_session_->GetInputs(); if (inputs.size() > 1) { MS_LOG(ERROR) << "model's input tensor size: " << inputs.size() << " >1"; return RET_ERROR; @@ -848,7 +861,7 @@ STATUS PostTrainingQuantizer::DoInference() { [&](const std::vector &beforeInputs, const std::vector &beforeOutputs, const mindspore::session::CallBackParam &callParam) -> bool { - if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) { + if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) { return false; } auto tensor = beforeInputs[0]; @@ -863,7 +876,7 @@ STATUS PostTrainingQuantizer::DoInference() { const std::vector &afterInputs, const std::vector &afterOutputs, const mindspore::session::CallBackParam &callParam) -> bool { - if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_param, afterOutputs) != RET_OK) { + if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, afterOutputs) != RET_OK) { return false; } auto tensor = afterOutputs[0]; @@ -873,7 +886,7 @@ STATUS PostTrainingQuantizer::DoInference() { this->calibrator_->RecordMaxValue(callParam.name_callback_param, data, this->calibrator_->GetOutputDivergInfo()); return true; }; - status = session_->RunGraph(beforeCallBack, afterCallBack); + status = fp32_session_->RunGraph(beforeCallBack, afterCallBack); if (status != RET_OK) { MS_LOG(ERROR) << "run model failed!"; return RET_ERROR; @@ -882,10 +895,376 @@ STATUS PostTrainingQuantizer::DoInference() { return RET_OK; } +STATUS PostTrainingQuantizer::Int8Inference() { + // fp32 inference + vector inputs = int8_session_->GetInputs(); + // get input tensor + if (inputs.size() != 1) { + MS_LOG(ERROR) << "model's input tensor size: " << inputs.size(); + return RET_ERROR; + } + auto elem_count = inputs.front()->ElementsNum(); + vector dummy_data(elem_count); + std::fill(dummy_data.begin(), dummy_data.end(), 0.1); + auto ret = memcpy_s(inputs.front()->MutableData(), inputs.front()->Size(), dummy_data.data(), + sizeof(float) * dummy_data.size()); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s error: " << ret; + return RET_ERROR; + } + + for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) { + mindspore::session::KernelCallBack beforeCallBack = + [this](const std::vector &beforeInputs, + const std::vector &beforeOutputs, + const mindspore::session::CallBackParam &callParam) -> bool { + if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { + while (!fp32_op_input_ready) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + if (callParam.name_callback_param != fp32_op_input_name) { + MS_LOG(ERROR) << "current int8 op name: " << callParam.name_callback_param + << " ready fp32 op name: " << fp32_op_input_name; + return false; + } + auto tensor = beforeInputs[0]; + auto lite_tensor = dynamic_cast(tensor); + + if (tensor->data_type() != kNumberTypeInt8) { + MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type(); + return false; + } + + // do quantization: activation is always per layer quantized + std::vector quant_datas; + auto quant_params = lite_tensor->GetQuantParams(); + if (quant_params.size() != 1) { + MS_LOG(ERROR) << "unexpected quant_params size: " << quant_params.size(); + return false; + } + schema::QuantParamT quant_param_t; + quant_param_t.scale = quant_params[0].scale; + quant_param_t.zeroPoint = quant_params[0].zeroPoint; + for (auto float_data : fp32_op_input) { + auto quant_data = QuantizeData(float_data, quant_param_t, quant_max, quant_min); + quant_datas.push_back(quant_data); + } + + if (tensor->Size() != quant_datas.size() * sizeof(int8_t)) { + MS_LOG(ERROR) << "unexpected tensor size: " << quant_datas.size() + << " not the same with: " << quant_datas.size() * sizeof(int8_t); + return false; + } + + auto ret = + memcpy_s(tensor->MutableData(), tensor->Size(), quant_datas.data(), quant_datas.size() * sizeof(int8_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy error: " << ret; + return false; + } + fp32_op_input_ready = false; + } + return true; + }; + // func + mindspore::session::KernelCallBack afterCallBack = [this]( + const std::vector &afterInputs, + const std::vector &afterOutputs, + const mindspore::session::CallBackParam &callParam) -> bool { + if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { + while (!fp32_op_output_ch_mean_ready) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + if (callParam.name_callback_param != fp32_op_output_name) { + MS_LOG(ERROR) << "current int8 op name: " << callParam.name_callback_param + << " ready fp32 op name: " << fp32_op_output_name; + return false; + } + auto tensor = afterOutputs[0]; + auto lite_tensor = dynamic_cast(tensor); + + if (tensor->data_type() != kNumberTypeInt8) { + MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type(); + return false; + } + + const int8_t *tensor_data = static_cast(tensor->MutableData()); + size_t elem_count = tensor->ElementsNum(); + auto shapes = tensor->shape(); + if (shapes.size() != 4) { + MS_LOG(ERROR) << "unexpected shape size: " << shapes.size(); + return false; + } + // suppose the the format is NHWC + auto channels = shapes[3]; + if (channels == 0) { + MS_LOG(ERROR) << "unexpected channels: 0"; + return false; + } + auto quant_params = lite_tensor->GetQuantParams(); + if (quant_params.size() != 1) { + MS_LOG(ERROR) << "unexpected activatation quant_params size: " << quant_params.size(); + return false; + } + auto scale = quant_params[0].scale; + auto zp = quant_params[0].zeroPoint; + + std::vector dequant_op_output_ch_mean(channels); + auto one_filter_size = elem_count / channels; + for (int i = 0; i < channels; i++) { + float sum = 0; + for (size_t j = 0; j < one_filter_size; j++) { + auto index = j * channels + i; + if (index >= elem_count) { + MS_LOG(ERROR) << "over flow!"; + return RET_ERROR; + } + // deuqant activation + auto float_data = scale * (tensor_data[index] - zp); + sum += float_data; + } + sum = sum / one_filter_size; + dequant_op_output_ch_mean[i] = sum; + } + std::transform(fp32_op_output_ch_mean.begin(), fp32_op_output_ch_mean.end(), dequant_op_output_ch_mean.begin(), + dequant_op_output_ch_mean.begin(), std::minus<>()); + + if (op_bias_diff_map.find(callParam.name_callback_param) != op_bias_diff_map.end()) { + auto &bias_diff = op_bias_diff_map[callParam.name_callback_param]; + std::transform(bias_diff.begin(), bias_diff.end(), dequant_op_output_ch_mean.begin(), bias_diff.begin(), + std::plus<>()); + } else { + op_bias_diff_map[callParam.name_callback_param] = dequant_op_output_ch_mean; + } + fp32_op_output_ch_mean_ready = false; + } + + return true; + }; + ret = int8_session_->RunGraph(beforeCallBack, afterCallBack); + if (ret != RET_OK) { + MS_LOG(ERROR) << "run model failed!"; + return RET_ERROR; + } + } // end for images + return RET_OK; +} + +STATUS PostTrainingQuantizer::BiasCorrection(FuncGraphPtr func_graph) { + auto ret = RET_OK; + std::future int8_inference = std::async(std::launch::async, &PostTrainingQuantizer::Int8Inference, this); + + // fp32 inference + for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) { + // get input tensor + vector inputs = fp32_session_->GetInputs(); + if (inputs.size() != 1) { + MS_LOG(ERROR) << "model's input tensor size: " << inputs.size(); + return RET_ERROR; + } + STATUS status = calibrator_->GenerateInputData(i, inputs.front()); + if (status != RET_OK) { + MS_LOG(ERROR) << "generate input data from images failed!"; + return RET_ERROR; + } + mindspore::session::KernelCallBack beforeCallBack = + [this](const std::vector &beforeInputs, + const std::vector &beforeOutputs, + const mindspore::session::CallBackParam &callParam) -> bool { + if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { + while (fp32_op_input_ready) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) { + return false; + } + auto tensor = beforeInputs[0]; + size_t elem_count = tensor->ElementsNum(); + fp32_op_input.resize(elem_count); + auto ret = + memcpy_s(fp32_op_input.data(), fp32_op_input.size() * sizeof(float), tensor->MutableData(), tensor->Size()); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy error: " << ret; + return false; + } + fp32_op_input_name = callParam.name_callback_param; + fp32_op_input_ready = true; + } + return true; + }; + // func + mindspore::session::KernelCallBack afterCallBack = [this]( + const std::vector &afterInputs, + const std::vector &afterOutputs, + const mindspore::session::CallBackParam &callParam) -> bool { + if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { + while (fp32_op_output_ch_mean_ready) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, afterOutputs) != RET_OK) { + return false; + } + auto tensor = afterOutputs[0]; + const float *tensor_data = static_cast(tensor->MutableData()); + size_t elem_count = tensor->ElementsNum(); + auto shapes = tensor->shape(); + if (shapes.size() != 4) { + MS_LOG(ERROR) << "unexpected shape size: " << shapes.size(); + return false; + } + // suppose the activation format: NHWC + auto channels = shapes[3]; + if (channels == 0) { + MS_LOG(ERROR) << "unexpected channels: 0"; + return false; + } + fp32_op_output_ch_mean.resize(channels); + auto one_filter_size = elem_count / channels; + for (int i = 0; i < channels; i++) { + float sum = 0; + for (size_t j = 0; j < one_filter_size; j++) { + auto index = j * channels + i; + if (index >= elem_count) { + MS_LOG(ERROR) << "over flow!"; + return RET_ERROR; + } + sum += tensor_data[index]; + } + sum = sum / one_filter_size; + fp32_op_output_ch_mean[i] = sum; + } + fp32_op_output_name = callParam.name_callback_param; + fp32_op_output_ch_mean_ready = true; + } + + return true; + }; + status = fp32_session_->RunGraph(beforeCallBack, afterCallBack); + if (status != RET_OK) { + MS_LOG(ERROR) << "run model failed!"; + return RET_ERROR; + } + } // end for images + + ret = int8_inference.get(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "int8 inference failed!"; + return RET_ERROR; + } + for (auto &key_value : op_bias_diff_map) { + std::for_each(key_value.second.begin(), key_value.second.end(), + [this](float &data) { data = data / calibrator_->GetBatchNum(); }); + } + auto cnodes = func_graph->GetOrderedCnodes(); + for (auto &cnode : cnodes) { + auto op_name = cnode->fullname_with_scope(); + if (op_bias_diff_map.find(op_name) != op_bias_diff_map.end()) { + const auto &bias_diff = op_bias_diff_map[op_name]; + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + continue; + } + auto input_quant_params = primitive_c->GetInputQuantParams(); + + if (input_quant_params.size() == 3) { + // compensate the existed + auto bias_quant_params = input_quant_params[2]; + auto bias = cnode->input(3); + auto bias_parameter_ptr = std::dynamic_pointer_cast(bias); + auto bias_default_param = bias_parameter_ptr->default_param(); + auto bias_param = std::dynamic_pointer_cast(bias_default_param); + int *bias_datas = static_cast(bias_param->tensor_addr()); + + if (static_cast(bias_param->tensor_shape_size()) != bias_diff.size()) { + MS_LOG(ERROR) << "unexpected bias data count: " << bias_param->tensor_shape_size() + << " not the same as bias_diff: " << bias_diff.size(); + continue; + } + if (bias_quant_params.size() != bias_diff.size()) { + MS_LOG(ERROR) << "unexpected bias quant params size: " << bias_quant_params.size() + << " not the same as bias_diff: " << bias_diff.size(); + } + + for (int i = 0; i < bias_param->tensor_shape_size(); i++) { + auto scale = bias_quant_params[i].scale; + double after_correct = std::round(bias_diff[i] / scale) + bias_datas[i]; + constexpr int32_t corrected_bias_abs_limit = 0.6 * INT32_MAX; + if (after_correct > corrected_bias_abs_limit) { + MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too large: " << after_correct + << " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i] + << " scale: " << scale; + bias_datas[i] = static_cast(corrected_bias_abs_limit); + } else if (after_correct < -corrected_bias_abs_limit) { + MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too small: " << after_correct + << " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i] + << " scale: " << scale; + bias_datas[i] = static_cast(-corrected_bias_abs_limit); + } else { + auto diff = static_cast(std::round(bias_diff[i] / scale)); + bias_datas[i] += diff; + } + } + } else if (input_quant_params.size() == 2) { + MS_LOG(INFO) << op_name << " add bias input"; + // need to add bias input + auto parameter = func_graph->add_parameter(); + ShapeVector shape; + shape.push_back(bias_diff.size()); + auto type_ptr = TypeIdToType(kNumberTypeFloat32); + auto abstract_tensor = std::make_shared(type_ptr, shape); + parameter->set_abstract(abstract_tensor); + parameter->set_name("added_" + op_name + "_bias"); + + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(param_value != nullptr); + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(kNumberTypeFloat32); + // param_value->set_format(tensor->format); + + auto size = sizeof(float) * bias_diff.size(); + char *tensor_data = new (std::nothrow) char[size]; + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new char[] failed"; + return RET_MEMORY_FAILED; + } + std::memcpy(tensor_data, bias_diff.data(), size); + param_value->set_tensor_addr(tensor_data); + param_value->set_tensor_size(size); + parameter->set_default_param(param_value); + cnode->add_input(parameter); + DoBiasQuant(parameter, primitive_c); + + auto op_type = (schema::PrimitiveType)primitive_c->Type(); + if (op_type == schema::PrimitiveType_Conv2D) { + auto conv2d = primitive_c->GetPrimitiveT()->value.AsConv2D(); + if (conv2d == nullptr) { + MS_LOG(ERROR) << "conv2d is null"; + return RET_ERROR; + } + conv2d->hasBias = true; + } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { + auto depthwise_conv2d = primitive_c->GetPrimitiveT()->value.AsDepthwiseConv2D(); + if (depthwise_conv2d == nullptr) { + MS_LOG(ERROR) << "conv2d is null"; + return RET_ERROR; + } + depthwise_conv2d->hasBias = true; + } + } else { + MS_LOG(ERROR) << "unexpected input_quant_params size: " << input_quant_params.size(); + continue; + } + } // end fine op_name + } + + return ret; +} + STATUS PostTrainingQuantizer::CollectDataFrequency() { for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) { // get input tensor - vector inputs = session_->GetInputs(); + vector inputs = fp32_session_->GetInputs(); if (inputs.size() > 1) { MS_LOG(ERROR) << "model's input tensor size: " << inputs.size() << " > 1"; return RET_ERROR; @@ -900,7 +1279,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() { [&](const std::vector &beforeInputs, const std::vector &beforeOutputs, const mindspore::session::CallBackParam &callParam) { - if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) { + if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) { return false; } auto tensor = beforeInputs[0]; @@ -916,7 +1295,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() { [&](const std::vector &after_inputs, const std::vector &after_outputs, const mindspore::session::CallBackParam &call_param) { - if (PostTrainingQuantizer::CheckTensorVec(call_param.name_callback_param, after_outputs) != RET_OK) { + if (PostTrainingQuantizer::CheckFp32TensorVec(call_param.name_callback_param, after_outputs) != RET_OK) { return false; } auto tensor = after_outputs[0]; @@ -927,7 +1306,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() { this->calibrator_->GetOutputDivergInfo()); return true; }; - status = session_->RunGraph(beforeCallBack, afterCallBack); + status = fp32_session_->RunGraph(beforeCallBack, afterCallBack); if (status != RET_OK) { MS_LOG(ERROR) << "run model failed!"; return RET_ERROR; @@ -939,16 +1318,15 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() { STATUS PostTrainingQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); } -STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) { +STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { MS_LOG(INFO) << "start to parse config file"; STATUS status = PreProcess(); if (status != RET_OK) { MS_LOG(ERROR) << "do pre process failed!"; return status; } - // anf -- fb - auto meta_graph = Export(funcGraph, true); + auto meta_graph = Export(func_graph, true); if (meta_graph == nullptr) { MS_LOG(ERROR) << "Export to meta_graph return nullptr"; return RET_ERROR; @@ -980,13 +1358,13 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) { ctx.thread_num_ = calibrator_->GetThreadNum(); ctx.cpu_bind_mode_ = MID_CPU; - session_ = dynamic_cast(session::LiteSession::CreateSession(&ctx)); - if (session_ == nullptr) { + fp32_session_ = dynamic_cast(session::LiteSession::CreateSession(&ctx)); + if (fp32_session_ == nullptr) { MS_LOG(ERROR) << "create session failed!"; return RET_ERROR; } - auto ret = session_->CompileGraph(model); + auto ret = fp32_session_->CompileGraph(model); if (ret != lite::RET_OK) { MS_LOG(ERROR) << "compile graph error"; return RET_ERROR; @@ -1017,6 +1395,70 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) { if (status != RET_OK) { return status; } + + // add quant_cast + quant::QuantCast quant_cast; + quant_cast.SetInputDataDType(kNumberTypeFloat32); + status = quant_cast.Run(func_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "add QuantCast error"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return RET_ERROR; + } + + if (calibrator_->GetBiasCorrection()) { + // init in8 session + // anf -- fb + auto int8_meta_graph = Export(func_graph, true); + if (int8_meta_graph == nullptr) { + MS_LOG(ERROR) << "Export to int8_meta_graph return nullptr"; + return RET_ERROR; + } + + // transform + GraphDefTransform fb_transform; + fb_transform.SetGraphDef(int8_meta_graph); + flags.quantType = schema::QuantType_PostTraining; + status = fb_transform.Transform(flags); + if (status != RET_OK) { + MS_LOG(ERROR) << "FBTransform model failed " << status; + return RET_ERROR; + } + MS_LOG(INFO) << "start create quantized session"; + flatbuffers::FlatBufferBuilder int8_builder(1024); + auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph); + int8_builder.Finish(int8_offset); + size = int8_builder.GetSize(); + auto *int8_content = reinterpret_cast(int8_builder.GetBufferPointer()); + if (int8_content == nullptr) { + MS_LOG(ERROR) << "GetBufferPointer nullptr"; + return RET_ERROR; + } + auto int8_model = lite::Model::Import(int8_content, size); + + Context int8_ctx; + int8_ctx.device_type_ = DT_CPU; + int8_ctx.thread_num_ = calibrator_->GetThreadNum(); + int8_ctx.cpu_bind_mode_ = HIGHER_CPU; + + int8_session_ = dynamic_cast(session::LiteSession::CreateSession(&int8_ctx)); + if (int8_session_ == nullptr) { + MS_LOG(ERROR) << "create session failed!"; + return RET_ERROR; + } + ret = int8_session_->CompileGraph(int8_model); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "compile graph error"; + return RET_ERROR; + } + + MS_LOG(INFO) << "do bias correction"; + status = BiasCorrection(func_graph); + if (status != RET_OK) { + MS_LOG(WARNING) << "BiasCorrection failed."; + } + } + return RET_OK; } } // namespace quant diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index 55084ff2ad..5d78abe691 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -49,6 +49,7 @@ struct ConfigParam { uint32_t batch_count{100}; std::string method_x{kMethodKL}; uint32_t thread_num{1}; + bool bias_correction{false}; }; class PostTrainingQuantizer : public Quantizer { @@ -56,7 +57,7 @@ class PostTrainingQuantizer : public Quantizer { PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8, bool per_channel = true); - STATUS DoQuantize(FuncGraphPtr funcGraph) override; + STATUS DoQuantize(FuncGraphPtr func_graph) override; size_t bit_num; int quant_max{INT8_MAX}; @@ -69,12 +70,24 @@ class PostTrainingQuantizer : public Quantizer { std::unique_ptr calibrator_; - mindspore::lite::LiteSession *session_; + mindspore::lite::LiteSession *fp32_session_; + mindspore::lite::LiteSession *int8_session_; + + std::string fp32_op_input_name; + std::string fp32_op_output_name; + std::vector fp32_op_input; + std::vector fp32_op_output_ch_mean; + std::map> op_bias_diff_map; + std::atomic fp32_op_input_ready{false}; + std::atomic fp32_op_output_ch_mean_ready{false}; + + const std::string kTypeConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_Conv2D); + const std::string kTypeDepthwiseConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_DepthwiseConv2D); STATUS PreProcess(); - STATUS CheckTensorVec(const std::string &node_name, - const std::vector &tensor_vec) const; + STATUS CheckFp32TensorVec(const std::string &node_name, + const std::vector &tensor_vec) const; STATUS DoInference(); @@ -92,6 +105,8 @@ class PostTrainingQuantizer : public Quantizer { STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitive_c, bool perchannel); STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr primitive_c); + STATUS Int8Inference(); + STATUS BiasCorrection(FuncGraphPtr func_graph); }; struct DivergInfo { @@ -153,6 +168,8 @@ class Calibrator { std::string GetMethodX() const { return config_param_.method_x; } + bool GetBiasCorrection() const { return config_param_.bias_correction; } + STATUS AddQuantizedOp(CNodePtr node); STATUS RecordMaxValue(const std::string &op_name, const std::vector &data,