Browse Source

bias correction

tags/v1.1.0
xutianchun 5 years ago
parent
commit
46d6c6f197
3 changed files with 491 additions and 44 deletions
  1. +4
    -16
      mindspore/lite/tools/converter/anf_transform.cc
  2. +466
    -24
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc
  3. +21
    -4
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.h

+ 4
- 16
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -58,12 +58,10 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
schema::ActivationType_RELU));
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation,
schema::ActivationType_RELU6));
pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(true, "conv_tuple_relu",
schema::PrimitiveType_Activation,
schema::ActivationType_RELU));
pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(true, "conv_tuple_relu6",
schema::PrimitiveType_Activation,
schema::ActivationType_RELU6));
pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(
true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU));
pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(
true, "conv_tuple_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6));
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
weight_format_hardcode_pass->SetFmkType(config->fmk);
weight_format_hardcode_pass->SetQuantType(config->quantType);
@@ -114,16 +112,6 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
if (config->quantType == schema::QuantType_PostTraining) {
quant::QuantCast quant_cast;
quant_cast.SetInputDataDType(kNumberTypeFloat32);
status = quant_cast.Run(new_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
}
}

return new_graph;


+ 466
- 24
mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc View File

@@ -17,23 +17,28 @@
#include "tools/converter/quantizer/post_training_quantizer.h"
#include <dirent.h>
#include <sys/stat.h>
#include <future>
#include <map>
#include <unordered_map>
#include <algorithm>
#include <functional>
#include <memory>
#include <numeric>
#include <utility>
#include <string>
#include <thread>
#include <vector>
#include <fstream>
#include "schema/inner/model_generated.h"
#include "src/tensor.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "utils/log_adapter.h"
#include "securec/include/securec.h"
#include "tools/common/tensor_util.h"
#include "src/common/file_utils.h"
#include "src/common/utils.h"

using std::string;
using std::vector;
@@ -431,6 +436,8 @@ STATUS Calibrator::ReadConfig() {
}
auto key = line.substr(0, index);
auto value = line.substr(index + 1);
Trim(&key);
Trim(&value);
if (key == "image_path") {
config_param_.image_path = value;
} else if (key == "batch_count") {
@@ -443,6 +450,11 @@ STATUS Calibrator::ReadConfig() {
} else {
config_param_.method_x = value;
}
} else if (key == "bias_correction") {
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
config_param_.bias_correction = true;
}
} else {
MS_LOG(WARNING) << "unsupported parameter";
}
@@ -450,7 +462,8 @@ STATUS Calibrator::ReadConfig() {
MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " "
<< "batch_count: " << config_param_.batch_count << " "
<< "method_x: " << config_param_.method_x << " "
<< "thread_num: " << config_param_.thread_num;
<< "thread_num: " << config_param_.thread_num << " "
<< "bias_correction: " << config_param_.bias_correction;

delete[] resolved_path;
fs.close();
@@ -533,8 +546,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
return RET_ERROR;
}
auto status = QuantFilter<int8_t>(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num,
perchanel);
auto status =
QuantFilter<int8_t>(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num, perchanel);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;
@@ -637,8 +650,8 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
quant_params[i].scale = bias_scale_tmp;
MS_LOG(DEBUG) << "new filter scale: " << filter_scale;
} else {
MS_LOG(WARNING) << "unexpected input_scales size: " << input_scales.size() << " weight_scales size: "
<< active_weight_quant_params[1].size();
MS_LOG(WARNING) << "unexpected input_scales size: " << input_scales.size()
<< " weight_scales size: " << active_weight_quant_params[1].size();
}
}
auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp);
@@ -694,7 +707,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
}

auto op_type = (schema::PrimitiveType)primitive_c->Type();
MS_LOG(INFO) << "OpName: " << op_name;
MS_LOG(DEBUG) << "OpName: " << op_name;
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
op_type != PrimitiveType_FullConnection) {
for (size_t i = 1; i < cnode->inputs().size(); i++) {
@@ -811,16 +824,16 @@ STATUS PostTrainingQuantizer::PreProcess() {
return RET_OK;
}

STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &node_name,
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) const {
STATUS PostTrainingQuantizer::CheckFp32TensorVec(const std::string &node_name,
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) const {
if (tensor_vec.size() < 1) {
MS_LOG(ERROR) << "node: " << node_name << " input tensors is 0";
return RET_ERROR;
}
auto *tensor = tensor_vec[0];
if (tensor->data_type() != kNumberTypeFloat32) {
MS_LOG(DEBUG) << "node: " << node_name << " will not quantize"
<< " tensor data_type: " << tensor->data_type();
MS_LOG(WARNING) << "node: " << node_name << " will not quantize"
<< " tensor data_type: " << tensor->data_type();
return RET_ERROR;
}
return RET_OK;
@@ -834,7 +847,7 @@ STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &node_name,
STATUS PostTrainingQuantizer::DoInference() {
for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
// get input tensor
vector<mindspore::tensor::MSTensor *> inputs = session_->GetInputs();
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
if (inputs.size() > 1) {
MS_LOG(ERROR) << "model's input tensor size: " << inputs.size() << " >1";
return RET_ERROR;
@@ -848,7 +861,7 @@ STATUS PostTrainingQuantizer::DoInference() {
[&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
const mindspore::session::CallBackParam &callParam) -> bool {
if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) {
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) {
return false;
}
auto tensor = beforeInputs[0];
@@ -863,7 +876,7 @@ STATUS PostTrainingQuantizer::DoInference() {
const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
const mindspore::session::CallBackParam &callParam) -> bool {
if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_param, afterOutputs) != RET_OK) {
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, afterOutputs) != RET_OK) {
return false;
}
auto tensor = afterOutputs[0];
@@ -873,7 +886,7 @@ STATUS PostTrainingQuantizer::DoInference() {
this->calibrator_->RecordMaxValue(callParam.name_callback_param, data, this->calibrator_->GetOutputDivergInfo());
return true;
};
status = session_->RunGraph(beforeCallBack, afterCallBack);
status = fp32_session_->RunGraph(beforeCallBack, afterCallBack);
if (status != RET_OK) {
MS_LOG(ERROR) << "run model failed!";
return RET_ERROR;
@@ -882,10 +895,376 @@ STATUS PostTrainingQuantizer::DoInference() {
return RET_OK;
}

STATUS PostTrainingQuantizer::Int8Inference() {
// fp32 inference
vector<mindspore::tensor::MSTensor *> inputs = int8_session_->GetInputs();
// get input tensor
if (inputs.size() != 1) {
MS_LOG(ERROR) << "model's input tensor size: " << inputs.size();
return RET_ERROR;
}
auto elem_count = inputs.front()->ElementsNum();
vector<float> dummy_data(elem_count);
std::fill(dummy_data.begin(), dummy_data.end(), 0.1);
auto ret = memcpy_s(inputs.front()->MutableData(), inputs.front()->Size(), dummy_data.data(),
sizeof(float) * dummy_data.size());
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s error: " << ret;
return RET_ERROR;
}

for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
mindspore::session::KernelCallBack beforeCallBack =
[this](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
const mindspore::session::CallBackParam &callParam) -> bool {
if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) {
while (!fp32_op_input_ready) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
if (callParam.name_callback_param != fp32_op_input_name) {
MS_LOG(ERROR) << "current int8 op name: " << callParam.name_callback_param
<< " ready fp32 op name: " << fp32_op_input_name;
return false;
}
auto tensor = beforeInputs[0];
auto lite_tensor = dynamic_cast<mindspore::lite::Tensor *>(tensor);

if (tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
return false;
}

// do quantization: activation is always per layer quantized
std::vector<int8_t> quant_datas;
auto quant_params = lite_tensor->GetQuantParams();
if (quant_params.size() != 1) {
MS_LOG(ERROR) << "unexpected quant_params size: " << quant_params.size();
return false;
}
schema::QuantParamT quant_param_t;
quant_param_t.scale = quant_params[0].scale;
quant_param_t.zeroPoint = quant_params[0].zeroPoint;
for (auto float_data : fp32_op_input) {
auto quant_data = QuantizeData<int8_t>(float_data, quant_param_t, quant_max, quant_min);
quant_datas.push_back(quant_data);
}

if (tensor->Size() != quant_datas.size() * sizeof(int8_t)) {
MS_LOG(ERROR) << "unexpected tensor size: " << quant_datas.size()
<< " not the same with: " << quant_datas.size() * sizeof(int8_t);
return false;
}

auto ret =
memcpy_s(tensor->MutableData(), tensor->Size(), quant_datas.data(), quant_datas.size() * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return false;
}
fp32_op_input_ready = false;
}
return true;
};
// func
mindspore::session::KernelCallBack afterCallBack = [this](
const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
const mindspore::session::CallBackParam &callParam) -> bool {
if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) {
while (!fp32_op_output_ch_mean_ready) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
if (callParam.name_callback_param != fp32_op_output_name) {
MS_LOG(ERROR) << "current int8 op name: " << callParam.name_callback_param
<< " ready fp32 op name: " << fp32_op_output_name;
return false;
}
auto tensor = afterOutputs[0];
auto lite_tensor = dynamic_cast<mindspore::lite::Tensor *>(tensor);

if (tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
return false;
}

const int8_t *tensor_data = static_cast<int8_t *>(tensor->MutableData());
size_t elem_count = tensor->ElementsNum();
auto shapes = tensor->shape();
if (shapes.size() != 4) {
MS_LOG(ERROR) << "unexpected shape size: " << shapes.size();
return false;
}
// suppose the the format is NHWC
auto channels = shapes[3];
if (channels == 0) {
MS_LOG(ERROR) << "unexpected channels: 0";
return false;
}
auto quant_params = lite_tensor->GetQuantParams();
if (quant_params.size() != 1) {
MS_LOG(ERROR) << "unexpected activatation quant_params size: " << quant_params.size();
return false;
}
auto scale = quant_params[0].scale;
auto zp = quant_params[0].zeroPoint;

std::vector<float> dequant_op_output_ch_mean(channels);
auto one_filter_size = elem_count / channels;
for (int i = 0; i < channels; i++) {
float sum = 0;
for (size_t j = 0; j < one_filter_size; j++) {
auto index = j * channels + i;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
// deuqant activation
auto float_data = scale * (tensor_data[index] - zp);
sum += float_data;
}
sum = sum / one_filter_size;
dequant_op_output_ch_mean[i] = sum;
}
std::transform(fp32_op_output_ch_mean.begin(), fp32_op_output_ch_mean.end(), dequant_op_output_ch_mean.begin(),
dequant_op_output_ch_mean.begin(), std::minus<>());

if (op_bias_diff_map.find(callParam.name_callback_param) != op_bias_diff_map.end()) {
auto &bias_diff = op_bias_diff_map[callParam.name_callback_param];
std::transform(bias_diff.begin(), bias_diff.end(), dequant_op_output_ch_mean.begin(), bias_diff.begin(),
std::plus<>());
} else {
op_bias_diff_map[callParam.name_callback_param] = dequant_op_output_ch_mean;
}
fp32_op_output_ch_mean_ready = false;
}

return true;
};
ret = int8_session_->RunGraph(beforeCallBack, afterCallBack);
if (ret != RET_OK) {
MS_LOG(ERROR) << "run model failed!";
return RET_ERROR;
}
} // end for images
return RET_OK;
}

STATUS PostTrainingQuantizer::BiasCorrection(FuncGraphPtr func_graph) {
auto ret = RET_OK;
std::future<STATUS> int8_inference = std::async(std::launch::async, &PostTrainingQuantizer::Int8Inference, this);

// fp32 inference
for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
// get input tensor
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
if (inputs.size() != 1) {
MS_LOG(ERROR) << "model's input tensor size: " << inputs.size();
return RET_ERROR;
}
STATUS status = calibrator_->GenerateInputData(i, inputs.front());
if (status != RET_OK) {
MS_LOG(ERROR) << "generate input data from images failed!";
return RET_ERROR;
}
mindspore::session::KernelCallBack beforeCallBack =
[this](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
const mindspore::session::CallBackParam &callParam) -> bool {
if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) {
while (fp32_op_input_ready) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) {
return false;
}
auto tensor = beforeInputs[0];
size_t elem_count = tensor->ElementsNum();
fp32_op_input.resize(elem_count);
auto ret =
memcpy_s(fp32_op_input.data(), fp32_op_input.size() * sizeof(float), tensor->MutableData(), tensor->Size());
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return false;
}
fp32_op_input_name = callParam.name_callback_param;
fp32_op_input_ready = true;
}
return true;
};
// func
mindspore::session::KernelCallBack afterCallBack = [this](
const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
const mindspore::session::CallBackParam &callParam) -> bool {
if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) {
while (fp32_op_output_ch_mean_ready) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, afterOutputs) != RET_OK) {
return false;
}
auto tensor = afterOutputs[0];
const float *tensor_data = static_cast<const float *>(tensor->MutableData());
size_t elem_count = tensor->ElementsNum();
auto shapes = tensor->shape();
if (shapes.size() != 4) {
MS_LOG(ERROR) << "unexpected shape size: " << shapes.size();
return false;
}
// suppose the activation format: NHWC
auto channels = shapes[3];
if (channels == 0) {
MS_LOG(ERROR) << "unexpected channels: 0";
return false;
}
fp32_op_output_ch_mean.resize(channels);
auto one_filter_size = elem_count / channels;
for (int i = 0; i < channels; i++) {
float sum = 0;
for (size_t j = 0; j < one_filter_size; j++) {
auto index = j * channels + i;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
sum += tensor_data[index];
}
sum = sum / one_filter_size;
fp32_op_output_ch_mean[i] = sum;
}
fp32_op_output_name = callParam.name_callback_param;
fp32_op_output_ch_mean_ready = true;
}

return true;
};
status = fp32_session_->RunGraph(beforeCallBack, afterCallBack);
if (status != RET_OK) {
MS_LOG(ERROR) << "run model failed!";
return RET_ERROR;
}
} // end for images

ret = int8_inference.get();
if (ret != RET_OK) {
MS_LOG(ERROR) << "int8 inference failed!";
return RET_ERROR;
}
for (auto &key_value : op_bias_diff_map) {
std::for_each(key_value.second.begin(), key_value.second.end(),
[this](float &data) { data = data / calibrator_->GetBatchNum(); });
}
auto cnodes = func_graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto op_name = cnode->fullname_with_scope();
if (op_bias_diff_map.find(op_name) != op_bias_diff_map.end()) {
const auto &bias_diff = op_bias_diff_map[op_name];
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
continue;
}
auto input_quant_params = primitive_c->GetInputQuantParams();

if (input_quant_params.size() == 3) {
// compensate the existed
auto bias_quant_params = input_quant_params[2];
auto bias = cnode->input(3);
auto bias_parameter_ptr = std::dynamic_pointer_cast<Parameter>(bias);
auto bias_default_param = bias_parameter_ptr->default_param();
auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param);
int *bias_datas = static_cast<int *>(bias_param->tensor_addr());

if (static_cast<size_t>(bias_param->tensor_shape_size()) != bias_diff.size()) {
MS_LOG(ERROR) << "unexpected bias data count: " << bias_param->tensor_shape_size()
<< " not the same as bias_diff: " << bias_diff.size();
continue;
}
if (bias_quant_params.size() != bias_diff.size()) {
MS_LOG(ERROR) << "unexpected bias quant params size: " << bias_quant_params.size()
<< " not the same as bias_diff: " << bias_diff.size();
}

for (int i = 0; i < bias_param->tensor_shape_size(); i++) {
auto scale = bias_quant_params[i].scale;
double after_correct = std::round(bias_diff[i] / scale) + bias_datas[i];
constexpr int32_t corrected_bias_abs_limit = 0.6 * INT32_MAX;
if (after_correct > corrected_bias_abs_limit) {
MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too large: " << after_correct
<< " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i]
<< " scale: " << scale;
bias_datas[i] = static_cast<int>(corrected_bias_abs_limit);
} else if (after_correct < -corrected_bias_abs_limit) {
MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too small: " << after_correct
<< " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i]
<< " scale: " << scale;
bias_datas[i] = static_cast<int>(-corrected_bias_abs_limit);
} else {
auto diff = static_cast<int>(std::round(bias_diff[i] / scale));
bias_datas[i] += diff;
}
}
} else if (input_quant_params.size() == 2) {
MS_LOG(INFO) << op_name << " add bias input";
// need to add bias input
auto parameter = func_graph->add_parameter();
ShapeVector shape;
shape.push_back(bias_diff.size());
auto type_ptr = TypeIdToType(kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
parameter->set_abstract(abstract_tensor);
parameter->set_name("added_" + op_name + "_bias");

ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
MS_ASSERT(param_value != nullptr);
param_value->set_tensor_shape(shape);
param_value->set_tensor_type(kNumberTypeFloat32);
// param_value->set_format(tensor->format);

auto size = sizeof(float) * bias_diff.size();
char *tensor_data = new (std::nothrow) char[size];
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "new char[] failed";
return RET_MEMORY_FAILED;
}
std::memcpy(tensor_data, bias_diff.data(), size);
param_value->set_tensor_addr(tensor_data);
param_value->set_tensor_size(size);
parameter->set_default_param(param_value);
cnode->add_input(parameter);
DoBiasQuant(parameter, primitive_c);

auto op_type = (schema::PrimitiveType)primitive_c->Type();
if (op_type == schema::PrimitiveType_Conv2D) {
auto conv2d = primitive_c->GetPrimitiveT()->value.AsConv2D();
if (conv2d == nullptr) {
MS_LOG(ERROR) << "conv2d is null";
return RET_ERROR;
}
conv2d->hasBias = true;
} else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
auto depthwise_conv2d = primitive_c->GetPrimitiveT()->value.AsDepthwiseConv2D();
if (depthwise_conv2d == nullptr) {
MS_LOG(ERROR) << "conv2d is null";
return RET_ERROR;
}
depthwise_conv2d->hasBias = true;
}
} else {
MS_LOG(ERROR) << "unexpected input_quant_params size: " << input_quant_params.size();
continue;
}
} // end fine op_name
}

return ret;
}

STATUS PostTrainingQuantizer::CollectDataFrequency() {
for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
// get input tensor
vector<mindspore::tensor::MSTensor *> inputs = session_->GetInputs();
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
if (inputs.size() > 1) {
MS_LOG(ERROR) << "model's input tensor size: " << inputs.size() << " > 1";
return RET_ERROR;
@@ -900,7 +1279,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
[&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
const mindspore::session::CallBackParam &callParam) {
if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) {
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) {
return false;
}
auto tensor = beforeInputs[0];
@@ -916,7 +1295,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
[&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
const mindspore::session::CallBackParam &call_param) {
if (PostTrainingQuantizer::CheckTensorVec(call_param.name_callback_param, after_outputs) != RET_OK) {
if (PostTrainingQuantizer::CheckFp32TensorVec(call_param.name_callback_param, after_outputs) != RET_OK) {
return false;
}
auto tensor = after_outputs[0];
@@ -927,7 +1306,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
this->calibrator_->GetOutputDivergInfo());
return true;
};
status = session_->RunGraph(beforeCallBack, afterCallBack);
status = fp32_session_->RunGraph(beforeCallBack, afterCallBack);
if (status != RET_OK) {
MS_LOG(ERROR) << "run model failed!";
return RET_ERROR;
@@ -939,16 +1318,15 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {

STATUS PostTrainingQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); }

STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_LOG(INFO) << "start to parse config file";
STATUS status = PreProcess();
if (status != RET_OK) {
MS_LOG(ERROR) << "do pre process failed!";
return status;
}

// anf -- fb
auto meta_graph = Export(funcGraph, true);
auto meta_graph = Export(func_graph, true);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph return nullptr";
return RET_ERROR;
@@ -980,13 +1358,13 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
ctx.thread_num_ = calibrator_->GetThreadNum();
ctx.cpu_bind_mode_ = MID_CPU;

session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&ctx));
if (session_ == nullptr) {
fp32_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&ctx));
if (fp32_session_ == nullptr) {
MS_LOG(ERROR) << "create session failed!";
return RET_ERROR;
}

auto ret = session_->CompileGraph(model);
auto ret = fp32_session_->CompileGraph(model);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "compile graph error";
return RET_ERROR;
@@ -1017,6 +1395,70 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
if (status != RET_OK) {
return status;
}

// add quant_cast
quant::QuantCast quant_cast;
quant_cast.SetInputDataDType(kNumberTypeFloat32);
status = quant_cast.Run(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}

if (calibrator_->GetBiasCorrection()) {
// init in8 session
// anf -- fb
auto int8_meta_graph = Export(func_graph, true);
if (int8_meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to int8_meta_graph return nullptr";
return RET_ERROR;
}

// transform
GraphDefTransform fb_transform;
fb_transform.SetGraphDef(int8_meta_graph);
flags.quantType = schema::QuantType_PostTraining;
status = fb_transform.Transform(flags);
if (status != RET_OK) {
MS_LOG(ERROR) << "FBTransform model failed " << status;
return RET_ERROR;
}
MS_LOG(INFO) << "start create quantized session";
flatbuffers::FlatBufferBuilder int8_builder(1024);
auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph);
int8_builder.Finish(int8_offset);
size = int8_builder.GetSize();
auto *int8_content = reinterpret_cast<const char *>(int8_builder.GetBufferPointer());
if (int8_content == nullptr) {
MS_LOG(ERROR) << "GetBufferPointer nullptr";
return RET_ERROR;
}
auto int8_model = lite::Model::Import(int8_content, size);

Context int8_ctx;
int8_ctx.device_type_ = DT_CPU;
int8_ctx.thread_num_ = calibrator_->GetThreadNum();
int8_ctx.cpu_bind_mode_ = HIGHER_CPU;

int8_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&int8_ctx));
if (int8_session_ == nullptr) {
MS_LOG(ERROR) << "create session failed!";
return RET_ERROR;
}
ret = int8_session_->CompileGraph(int8_model);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "compile graph error";
return RET_ERROR;
}

MS_LOG(INFO) << "do bias correction";
status = BiasCorrection(func_graph);
if (status != RET_OK) {
MS_LOG(WARNING) << "BiasCorrection failed.";
}
}

return RET_OK;
}
} // namespace quant


+ 21
- 4
mindspore/lite/tools/converter/quantizer/post_training_quantizer.h View File

@@ -49,6 +49,7 @@ struct ConfigParam {
uint32_t batch_count{100};
std::string method_x{kMethodKL};
uint32_t thread_num{1};
bool bias_correction{false};
};

class PostTrainingQuantizer : public Quantizer {
@@ -56,7 +57,7 @@ class PostTrainingQuantizer : public Quantizer {
PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8,
bool per_channel = true);

STATUS DoQuantize(FuncGraphPtr funcGraph) override;
STATUS DoQuantize(FuncGraphPtr func_graph) override;

size_t bit_num;
int quant_max{INT8_MAX};
@@ -69,12 +70,24 @@ class PostTrainingQuantizer : public Quantizer {

std::unique_ptr<Calibrator> calibrator_;

mindspore::lite::LiteSession *session_;
mindspore::lite::LiteSession *fp32_session_;
mindspore::lite::LiteSession *int8_session_;

std::string fp32_op_input_name;
std::string fp32_op_output_name;
std::vector<float> fp32_op_input;
std::vector<float> fp32_op_output_ch_mean;
std::map<std::string, std::vector<float>> op_bias_diff_map;
std::atomic<bool> fp32_op_input_ready{false};
std::atomic<bool> fp32_op_output_ch_mean_ready{false};

const std::string kTypeConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_Conv2D);
const std::string kTypeDepthwiseConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_DepthwiseConv2D);

STATUS PreProcess();

STATUS CheckTensorVec(const std::string &node_name,
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) const;
STATUS CheckFp32TensorVec(const std::string &node_name,
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) const;

STATUS DoInference();

@@ -92,6 +105,8 @@ class PostTrainingQuantizer : public Quantizer {
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel);

STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c);
STATUS Int8Inference();
STATUS BiasCorrection(FuncGraphPtr func_graph);
};

struct DivergInfo {
@@ -153,6 +168,8 @@ class Calibrator {

std::string GetMethodX() const { return config_param_.method_x; }

bool GetBiasCorrection() const { return config_param_.bias_correction; }

STATUS AddQuantizedOp(CNodePtr node);

STATUS RecordMaxValue(const std::string &op_name, const std::vector<float> &data,


Loading…
Cancel
Save