From adb7d7c1d544f5cca080fccdcf8860e05c094c00 Mon Sep 17 00:00:00 2001 From: yankai Date: Wed, 26 Aug 2020 15:18:24 +0800 Subject: [PATCH] MS quant --- mindspore/lite/src/ir/tensor.cc | 10 ++ mindspore/lite/src/ops/conv2d.cc | 110 +++----------- mindspore/lite/src/ops/conv2d.h | 3 - mindspore/lite/src/ops/depthwise_conv2d.cc | 100 ++----------- mindspore/lite/src/ops/depthwise_conv2d.h | 4 - mindspore/lite/src/ops/matmul.cc | 77 ---------- mindspore/lite/src/ops/matmul.h | 4 - mindspore/lite/src/ops/primitive_c.cc | 139 +++++++++++++++--- mindspore/lite/src/ops/primitive_c.h | 12 +- .../lite/tools/anf_exporter/anf_exporter.cc | 4 +- .../anf_importer/import_from_protobuf.cc | 2 +- mindspore/lite/tools/common/node_util.cc | 3 +- .../graph/weight_format_hardcode_pass.cc | 4 +- .../converter/quantizer/aware_quantizer.cc | 99 +++++-------- .../converter/quantizer/quantize_util.cc | 24 +-- .../tools/converter/quantizer/quantize_util.h | 14 +- 16 files changed, 227 insertions(+), 382 deletions(-) diff --git a/mindspore/lite/src/ir/tensor.cc b/mindspore/lite/src/ir/tensor.cc index 373223a0f3..9698d9d632 100644 --- a/mindspore/lite/src/ir/tensor.cc +++ b/mindspore/lite/src/ir/tensor.cc @@ -253,6 +253,16 @@ std::string Tensor::ToString() const { } } } break; + case kNumberTypeInt8: { + auto data = static_cast(this->data_); + if (data == nullptr) { + return "Data of tensor is nullptr"; + } else { + for (int i = 0; i < 40 && i < this->ElementsNum(); i++) { + oss << " " << static_cast(data[i]); + } + } + } break; default: oss << "Unsupported data type to print"; break; diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index 7c9c32058f..e2b72536c7 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -15,11 +15,16 @@ */ #include "src/ops/conv2d.h" -#include + +#include #include +#include + #include "include/errorcode.h" #include "utils/log_adapter.h" #ifdef PRIMITIVE_WRITEABLE +#include + #include "tools/converter/quantizer/quantize_util.h" #endif @@ -156,6 +161,13 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT attr->padMode = schema::PadMode_NOTSET; } + if (prim.GetAttr("activation_name") != nullptr) { + std::string activate_name = GetValue(prim.GetAttr("activation_name")); + attr->activationType = kActivationTypeMap[activate_name]; + } else { + attr->activationType = schema::ActivationType_NO_ACTIVATION; + } + int channel_mutiplier = 1; if (prim.GetAttr("channel_mutiplier") != nullptr) { channel_mutiplier = GetValue(prim.GetAttr("channel_multiplier")); @@ -213,98 +225,18 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive } else { attr->padMode = schema::PadMode_NOTSET; } - primitive->value.type = schema::PrimitiveType_Conv2D; - primitive->value.value = attr.release(); -} - -void Conv2D::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { - const float qmin = 0; - const float qmax = 255; - *mMin = static_cast((qmin - mean) / stdDev); - *mMax = static_cast((qmax - mean) / stdDev); -} -void Conv2D::PopulaterQuantParam(const Primitive &prim, - std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam) { - auto narrow_range = prim.GetAttr("narrow_range"); - bool narrowRangeQuantParam = GetValue(narrow_range); - auto num_bits = prim.GetAttr("num_bits"); - int32_t numbitsRangeQuantParam = GetValue(num_bits); - - std::vector quants; - schema::QuantParamT quantParam; - auto mean = prim.GetAttr("mean"); - auto std_dev = prim.GetAttr("std_dev"); - if (mean != nullptr && std_dev != nullptr) { - auto meanQuantOaram = GetValue(mean); - double stddevQuantOaram = GetValue(std_dev); - float mMin = 0.0; - float mMax = 0.0; - CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); - quantParam.min = mMin; - quantParam.max = mMax; + if (prim.GetAttr("activation_name") != nullptr) { + std::string activate_name = GetValue(prim.GetAttr("activation_name")); + attr->activationType = kActivationTypeMap[activate_name]; } else { - auto inputMin = prim.GetAttr("input_minq"); - auto inputMax = prim.GetAttr("input_maxq"); - auto inputMinPtr = inputMin->cast(); - auto inputMaxPtr = inputMax->cast(); - float *minBuf = static_cast(inputMinPtr->Data()); - float *maxBuf = static_cast(inputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - } - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecInputQuantParam->emplace_back(quants); - - quants.clear(); - int biasQuantSize = 0; - auto filterMin = prim.GetAttr("filter_minq"); - auto filterMax = prim.GetAttr("filter_maxq"); - if (filterMin != nullptr && filterMax != nullptr) { - auto filterMinPtr = filterMin->cast(); - auto filterMaxPtr = filterMax->cast(); - float *minBuf = static_cast(filterMinPtr->Data()); - float *maxBuf = static_cast(filterMaxPtr->Data()); - biasQuantSize = filterMinPtr->DataSize(); - for (int i = 0; i < biasQuantSize; ++i) { - quantParam.min = *(minBuf++); - quantParam.max = *(maxBuf++); - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - } - vecInputQuantParam->emplace_back(quants); + attr->activationType = schema::ActivationType_NO_ACTIVATION; } - quants.clear(); - for (int i = 0; i < biasQuantSize; ++i) { - quantParam.min = 0.0; - quantParam.max = 0.0; - quantParam.zeroPoint = 0; - - quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale; - quants.emplace_back(quantParam); - } - vecInputQuantParam->emplace_back(quants); - - quants.clear(); - auto outputMin = prim.GetAttr("output_minq"); - auto outputMax = prim.GetAttr("output_maxq"); - if (outputMin != nullptr && outputMax != nullptr) { - auto outputMinPtr = outputMin->cast(); - auto outputMaxPtr = outputMax->cast(); - float *minBuf = static_cast(outputMinPtr->Data()); - float *maxBuf = static_cast(outputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecOutputQuantParam->emplace_back(quants); - } + // attr->padMode = schema::PadMode_SAME; + // attr->activationType = schema::ActivationType_RELU; + primitive->value.type = schema::PrimitiveType_Conv2D; + primitive->value.value = attr.release(); } int Conv2D::UnPackAttr(const Primitive &prim, const std::vector &inputs) { diff --git a/mindspore/lite/src/ops/conv2d.h b/mindspore/lite/src/ops/conv2d.h index 21367dcdc5..3b85a20e85 100644 --- a/mindspore/lite/src/ops/conv2d.h +++ b/mindspore/lite/src/ops/conv2d.h @@ -57,9 +57,6 @@ class Conv2D : public PrimitiveC { void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, const std::vector &inputs); void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); - void PopulaterQuantParam(const Primitive &prim, std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); #else public: diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index ba318a1bb1..99d78dc150 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -15,6 +15,7 @@ */ #include "src/ops/depthwise_conv2d.h" + #include #include #ifdef PRIMITIVE_WRITEABLE @@ -69,96 +70,6 @@ void DepthwiseConv2D::SetActivationType(int activation_type) { this->primitive_->value.AsDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type; } -void DepthwiseConv2D::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { - const float qmin = 0; - const float qmax = 255; - *mMin = static_cast((qmin - mean) / stdDev); - *mMax = static_cast((qmax - mean) / stdDev); -} - -void DepthwiseConv2D::PopulaterQuantParam(const Primitive &prim, - std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam) { - auto narrow_range = prim.GetAttr("narrow_range"); - bool narrowRangeQuantParam = GetValue(narrow_range); - auto num_bits = prim.GetAttr("num_bits"); - int32_t numbitsRangeQuantParam = GetValue(num_bits); - - std::vector quants; - schema::QuantParamT quantParam; - auto mean = prim.GetAttr("mean"); - auto std_dev = prim.GetAttr("std_dev"); - if (mean != nullptr && std_dev != nullptr) { - auto meanQuantOaram = GetValue(mean); - double stddevQuantOaram = GetValue(std_dev); - float mMin = 0.0; - float mMax = 0.0; - CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); - quantParam.min = mMin; - quantParam.max = mMax; - } else { - auto inputMin = prim.GetAttr("input_minq"); - auto inputMax = prim.GetAttr("input_maxq"); - auto inputMinPtr = inputMin->cast(); - auto inputMaxPtr = inputMax->cast(); - float *minBuf = static_cast(inputMinPtr->Data()); - float *maxBuf = static_cast(inputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - } - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecInputQuantParam->emplace_back(quants); - - quants.clear(); - int biasQuantSize = 0; - auto filterMin = prim.GetAttr("filter_minq"); - auto filterMax = prim.GetAttr("filter_maxq"); - if (filterMin != nullptr && filterMax != nullptr) { - auto filterMinPtr = filterMin->cast(); - auto filterMaxPtr = filterMax->cast(); - float *minBuf = static_cast(filterMinPtr->Data()); - float *maxBuf = static_cast(filterMaxPtr->Data()); - biasQuantSize = filterMinPtr->DataSize(); - for (int i = 0; i < biasQuantSize; ++i) { - quantParam.min = *(minBuf++); - quantParam.max = *(maxBuf++); - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - } - vecInputQuantParam->emplace_back(quants); - } - - quants.clear(); - for (int i = 0; i < biasQuantSize; ++i) { - quantParam.min = 0.0; - quantParam.max = 0.0; - quantParam.zeroPoint = 0; - - quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale; - quants.emplace_back(quantParam); - } - vecInputQuantParam->emplace_back(quants); - - quants.clear(); - auto outputMin = prim.GetAttr("output_minq"); - auto outputMax = prim.GetAttr("output_maxq"); - if (outputMin != nullptr && outputMax != nullptr) { - auto outputMinPtr = outputMin->cast(); - auto outputMaxPtr = outputMax->cast(); - float *minBuf = static_cast(outputMinPtr->Data()); - float *maxBuf = static_cast(outputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecOutputQuantParam->emplace_back(quants); - } -} - int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector &inputs) { this->primitive_ = new (schema::PrimitiveT); auto attr = std::make_unique(); @@ -197,7 +108,14 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vectorpadMode = schema::PadMode_NOTSET; } - + if (prim.GetAttr("activation_name") != nullptr) { + std::string activate_name = GetValue(prim.GetAttr("activation_name")); + attr->activationType = kActivationTypeMap[activate_name]; + } else { + attr->activationType = schema::ActivationType_NO_ACTIVATION; + } + // attr->padMode = schema::PadMode_SAME; + // attr->activationType = schema::ActivationType_RELU; auto channel_multiplier = GetValue(prim.GetAttr("channel_multiplier")); attr->channelMultiplier = channel_multiplier; diff --git a/mindspore/lite/src/ops/depthwise_conv2d.h b/mindspore/lite/src/ops/depthwise_conv2d.h index aada41f542..877f083f6f 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.h +++ b/mindspore/lite/src/ops/depthwise_conv2d.h @@ -51,10 +51,6 @@ class DepthwiseConv2D : public PrimitiveC { void SetHasBias(bool has_bias); void SetActivationType(int activation_type); - private: - void PopulaterQuantParam(const Primitive &prim, std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); #else public: diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index c12038312c..25e437366c 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -30,83 +30,6 @@ bool MatMul::GetTransposeB() const { return this->primitive_->value.AsMatMul()-> void MatMul::SetTransposeA(bool transpose_a) { this->primitive_->value.AsMatMul()->transposeA = transpose_a; } void MatMul::SetTransposeB(bool transpose_b) { this->primitive_->value.AsMatMul()->transposeB = transpose_b; } -void MatMul::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { - const float qmin = 0; - const float qmax = 255; - *mMin = static_cast((qmin - mean) / stdDev); - *mMax = static_cast((qmax - mean) / stdDev); -} - -void MatMul::PopulaterQuantParam(const Primitive &prim, - std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam) { - auto narrow_range = prim.GetAttr("narrow_range"); - bool narrowRangeQuantParam = GetValue(narrow_range); - auto num_bits = prim.GetAttr("num_bits"); - int32_t numbitsRangeQuantParam = GetValue(num_bits); - - std::vector quants; - schema::QuantParamT quantParam; - auto mean = prim.GetAttr("mean"); - auto std_dev = prim.GetAttr("std_dev"); - if (mean != nullptr && std_dev != nullptr) { - auto meanQuantOaram = GetValue(mean); - double stddevQuantOaram = GetValue(std_dev); - float mMin = 0.0; - float mMax = 0.0; - CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); - quantParam.min = mMin; - quantParam.max = mMax; - } else { - auto inputMin = prim.GetAttr("input_minq"); - auto inputMax = prim.GetAttr("input_maxq"); - auto inputMinPtr = inputMin->cast(); - auto inputMaxPtr = inputMax->cast(); - float *minBuf = static_cast(inputMinPtr->Data()); - float *maxBuf = static_cast(inputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - } - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecInputQuantParam->emplace_back(quants); - - quants.clear(); - auto filterMin = prim.GetAttr("filter_minq"); - auto filterMax = prim.GetAttr("filter_maxq"); - if (filterMin != nullptr && filterMax != nullptr) { - auto filterMinPtr = filterMin->cast(); - auto filterMaxPtr = filterMax->cast(); - float *minBuf = static_cast(filterMinPtr->Data()); - float *maxBuf = static_cast(filterMaxPtr->Data()); - for (int i = 0; i < filterMinPtr->DataSize(); ++i) { - quantParam.min = *(minBuf++); - quantParam.max = *(maxBuf++); - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - } - vecInputQuantParam->emplace_back(quants); - } - - quants.clear(); - auto outputMin = prim.GetAttr("output_minq"); - auto outputMax = prim.GetAttr("output_maxq"); - if (outputMin != nullptr && outputMax != nullptr) { - auto outputMinPtr = outputMin->cast(); - auto outputMaxPtr = outputMax->cast(); - float *minBuf = static_cast(outputMinPtr->Data()); - float *maxBuf = static_cast(outputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecOutputQuantParam->emplace_back(quants); - } -} - int MatMul::UnPackAttr(const Primitive &prim, const std::vector &inputs) { if (this->primitive_ == nullptr) { this->primitive_ = new (std::nothrow) schema::PrimitiveT; diff --git a/mindspore/lite/src/ops/matmul.h b/mindspore/lite/src/ops/matmul.h index edbd61b252..d943b4c419 100644 --- a/mindspore/lite/src/ops/matmul.h +++ b/mindspore/lite/src/ops/matmul.h @@ -36,10 +36,6 @@ class MatMul : public PrimitiveC { void SetTransposeA(bool transpose_a); void SetTransposeB(bool transpose_b); - private: - void PopulaterQuantParam(const Primitive &prim, std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); #else public: diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index d881320442..13033708dc 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -16,6 +16,7 @@ #include "src/ops/primitive_c.h" #include +#include #include "src/ops/space_to_batch.h" #include "src/ops/space_to_batch_nd.h" #include "src/ops/conv2d.h" @@ -121,10 +122,99 @@ #include "src/ops/l2_norm.h" #include "src/ops/sparse_to_dense.h" #include "src/ops/detection_post_process.h" - +#ifdef PRIMITIVE_WRITEABLE +#include "tools/converter/quantizer/quantize_util.h" +#endif namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE +void PrimitiveC::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { + const float qmin = 0; + const float qmax = 255; + *mMin = static_cast((qmin - mean) / stdDev); + *mMax = static_cast((qmax - mean) / stdDev); +} + +void PrimitiveC::PopulaterQuantParam(const Primitive &prim, + std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam) { + auto narrow_range = prim.GetAttr("narrow_range"); + bool narrowRangeQuantParam = GetValue(narrow_range); + auto num_bits = prim.GetAttr("num_bits"); + int32_t numbitsRangeQuantParam = GetValue(num_bits); + + std::vector quants; + schema::QuantParamT quantParam; + auto mean = prim.GetAttr("mean"); + auto std_dev = prim.GetAttr("std_dev"); + if (mean != nullptr && std_dev != nullptr) { + auto meanQuantOaram = GetValue(mean); + double stddevQuantOaram = GetValue(std_dev); + float mMin = 0.0; + float mMax = 0.0; + CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); + quantParam.min = mMin; + quantParam.max = mMax; + } else { + auto inputMin = prim.GetAttr("input_minq"); + auto inputMax = prim.GetAttr("input_maxq"); + auto inputMinPtr = inputMin->cast(); + auto inputMaxPtr = inputMax->cast(); + float *minBuf = static_cast(inputMinPtr->Data()); + float *maxBuf = static_cast(inputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + } + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecInputQuantParam->emplace_back(quants); + + quants.clear(); + auto filterMin = prim.GetAttr("filter_minq"); + auto filterMax = prim.GetAttr("filter_maxq"); + if (filterMin != nullptr && filterMax != nullptr) { + auto filterMinPtr = filterMin->cast(); + auto filterMaxPtr = filterMax->cast(); + float *minBuf = static_cast(filterMinPtr->Data()); + float *maxBuf = static_cast(filterMaxPtr->Data()); + quantParam.min = FLT_MAX; + quantParam.max = FLT_MIN; + for (int i = 0; i < filterMinPtr->DataSize(); ++i) { + quantParam.min = (*(minBuf) < quantParam.min) ? (*minBuf) : quantParam.min; + quantParam.max = (*(maxBuf) > quantParam.max) ? (*maxBuf) : quantParam.max; + minBuf++; + maxBuf++; + } + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, true, numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecInputQuantParam->emplace_back(quants); + } + + quants.clear(); + quantParam.min = 0.0; + quantParam.max = 0.0; + quantParam.zeroPoint = 0; + quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale; + quants.emplace_back(quantParam); + vecInputQuantParam->emplace_back(quants); + + quants.clear(); + auto outputMin = prim.GetAttr("output_minq"); + auto outputMax = prim.GetAttr("output_maxq"); + if (outputMin != nullptr && outputMax != nullptr) { + auto outputMinPtr = outputMin->cast(); + auto outputMaxPtr = outputMax->cast(); + float *minBuf = static_cast(outputMinPtr->Data()); + float *maxBuf = static_cast(outputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecOutputQuantParam->emplace_back(quants); + } +} schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; } void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; } @@ -152,7 +242,7 @@ void PrimitiveC::AddOutputQuantParam(std::vector quant_para } std::vector> PrimitiveC::GetOutputQuantParams() const { return output_quant_param_; } -void PrimitiveC::SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; } +void PrimitiveC::SetQuantType(const schema::QuantType &quant_type) { this->quant_type_ = quant_type; } schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_; } @@ -205,12 +295,14 @@ std::shared_ptr GetTupleGetItemPrim() { } template ::value>> -std::shared_ptr NewPrimitiveC(const Primitive &prim, const std::vector &inputs) { +std::shared_ptr NewPrimitiveC(const Primitive &prim, const std::vector &inputs, + const schema::QuantType &quantType) { auto primc = std::make_shared(); if (primc == nullptr) { MS_LOG(ERROR) << "make_shared PrimitiveC failed"; return nullptr; } + primc->SetQuantType(quantType); auto ret = primc->UnPackAttr(prim, inputs); if (ret != RET_OK) { MS_LOG(ERROR) << "UnPackAttr failed"; @@ -220,46 +312,47 @@ std::shared_ptr NewPrimitiveC(const Primitive &prim, const std::vect } std::shared_ptr PrimitiveC::UnPackFromPrimitive(const Primitive &prim, - const std::vector &inputs) { + const std::vector &inputs, + const schema::QuantType &quantType) { const auto &op_type = prim.name(); if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "BatchNorm") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "BiasAdd") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Concat") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Conv2D") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Dequant") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Flatten") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "make_tuple") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "MatMul") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Mul") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "MaxPool") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Quant") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "ReduceMean") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Reshape") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "TensorAdd") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Transpose") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "tuple_getitem") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Softmax") { - return NewPrimitiveC(prim, inputs); + return NewPrimitiveC(prim, inputs, quantType); } else { MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type; return nullptr; diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 73ba1f4d87..bb19204a56 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -20,6 +20,7 @@ #include #include #include +#include #ifdef PRIMITIVE_WRITEABLE #include "ir/primitive.h" #include "schema/inner/model_generated.h" @@ -44,6 +45,9 @@ const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNum constexpr int kAnfPopulaterOne = 1; constexpr int kAnfPopulaterTwo = 2; constexpr int kAnfPopulaterThree = 3; +static std::map kActivationTypeMap{{"ReLU", schema::ActivationType_RELU}, + {"ReLU6", schema::ActivationType_RELU6}, + {"Sigmoid", schema::ActivationType_SIGMOID}}; class PrimitiveC : public mindspore::Primitive { public: // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). @@ -94,7 +98,7 @@ class PrimitiveC : public mindspore::Primitive { std::vector> GetOutputQuantParams() const; - void SetQuantType(schema::QuantType quant_type); + void SetQuantType(const schema::QuantType &quant_type); schema::QuantType GetQuantType() const; @@ -110,7 +114,11 @@ class PrimitiveC : public mindspore::Primitive { static PrimitiveC *UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT *primitive); - static std::shared_ptr UnPackFromPrimitive(const Primitive &prim, const std::vector &inputs); + static std::shared_ptr UnPackFromPrimitive(const Primitive &prim, const std::vector &inputs, + const schema::QuantType &quantType); + void PopulaterQuantParam(const Primitive &prim, std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); protected: virtual int UnPackAttr(const Primitive &prim, const std::vector &inputs) { return RET_ERROR; } diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 5fd289af2f..1006b4fc2d 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -71,8 +71,8 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me auto input_quant_params = primitive->GetInputQuantParams(); auto node_type = (schema::PrimitiveType)primitive->Type(); if (input_quant_params.empty()) { - MS_LOG(ERROR) << "node: " << dst_node->name << " input quant params is empty"; - return RET_ERROR; + MS_LOG(WARNING) << "node: " << dst_node->name << " input quant params is empty"; + return RET_OK; } for (size_t i = 0; i < input_quant_params.size(); i++) { if (i >= dst_node->inputIndex.size()) { diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc index 85f8465c49..3f0a0667ac 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc @@ -473,7 +473,7 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out } inputs.push_back(anfnode_build_map_[input_name]); } - auto primitivec_ptr = PrimitiveC::UnPackFromPrimitive(*prim, inputs); + auto primitivec_ptr = PrimitiveC::UnPackFromPrimitive(*prim, inputs, quantType); if (primitivec_ptr == nullptr) { MS_LOG(ERROR) << "Create PrimitiveC return nullptr, " << prim->name(); return nullptr; diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index de8cf2c305..4148217e70 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -50,7 +50,8 @@ static const std::vector int8OpList = { schema::PrimitiveType_Mul, schema::PrimitiveType_Slice, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split, schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub, - schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze}; + schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze, + schema::PrimitiveType_MatMul}; static const std::vector needInsertOpList = { schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc index c8d9812b43..c23e2681d0 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc @@ -155,11 +155,11 @@ STATUS WeightFormatHardCodePass::HardCodeMS(const std::unique_ptr &node, switch (this->quantType) { case QuantType_AwareTraining: { if (opType == schema::PrimitiveType_Conv2D) { - weightTensor->format = schema::Format_HWCK; + weightTensor->format = schema::Format_KCHW; } else if (opType == PrimitiveType_DepthwiseConv2D) { weightTensor->format = Format_CKHW; } else { - weightTensor->format = schema::Format_HWKC; + weightTensor->format = schema::Format_KCHW; } } break; case QuantType_QUANT_NONE: { diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index c921add7d3..881398cabb 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -36,15 +36,13 @@ using std::vector; namespace mindspore::lite::quant { const std::array AwareQuantizer::propagatedOps = { - {schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, - schema::PrimitiveType_Reshape, schema::PrimitiveType_Squeeze, - schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation, - schema::PrimitiveType_DetectionPostProcess}}; + {schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape, + schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation, + schema::PrimitiveType_DetectionPostProcess}}; STATUS InputArray::InitQuantParam() { this->quantParam = std::make_unique(); - auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, - narrowRange, numBits); + auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, narrowRange, numBits); if (status != RET_OK) { return status; } @@ -58,8 +56,7 @@ STATUS InputArray::SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensor if (!tensor->quantParams.empty()) { auto param = GetTensorQuantParam(tensor); if (param != nullptr && param->inited) { - MS_LOG(DEBUG) << "tensor " << inputTensorIdx - << " already has quantParam"; + MS_LOG(DEBUG) << "tensor " << inputTensorIdx << " already has quantParam"; return RET_OK; } tensor->quantParams.clear(); @@ -74,9 +71,7 @@ STATUS InputArray::SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensor return RET_OK; } -AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, - const string &inputInferType, - const string &stdValues, +AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues, const string &meanValues) : FbQuantizer(graph) { MS_ASSERT(graph != nullptr); @@ -94,12 +89,9 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, mInputArray->InitQuantParam(); } -STATUS AwareQuantizer::RemoveFakeQuant() { - return RET_OK; -} +STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; } -STATUS AwareQuantizer::GenerateDefaultQuantParam( - const schema::MetaGraphT *subGraph) { +STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) { MS_ASSERT(subGraph != nullptr); for (const auto &tensor : subGraph->allTensors) { if (!tensor->quantParams.empty()) { @@ -111,8 +103,7 @@ STATUS AwareQuantizer::GenerateDefaultQuantParam( return RET_OK; } -STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, - schema::CNodeT *node) { +STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { // MS_ASSERT(subGraph != nullptr); // MS_ASSERT(node != nullptr); // auto inputIndexes = node->inputIndex; @@ -193,19 +184,15 @@ STATUS AwareQuantizer::GenerateQuantParam() { GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { MS_ASSERT(false); } - auto *quantParamCalcer = - quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); + auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); if (quantParamCalcer == nullptr) { - MS_LOG(ERROR) << "Can not find QuantParamCalcer for " - << node->name.c_str() - << ", type: " << GetCNodeTTypeName(*node).c_str() - << " set node to QuantNone and skip"; + MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str() + << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; node->quantType = static_cast(QuantType_QUANT_NONE); } else { status = quantParamCalcer->Calc(graph, *node); if (status != RET_OK) { - MS_LOG(ERROR) << "quantParamCalcer failed: " << status - << " node: " << node->name.c_str(); + MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); node->quantType = schema::QuantType_QUANT_NONE; } else { node->quantType = schema::QuantType_AwareTraining; @@ -227,11 +214,11 @@ STATUS AwareQuantizer::DoQuantize() { STATUS status; if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D || GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D || - GetCNodeTType(*node) == schema::PrimitiveType_FullConnection) { + GetCNodeTType(*node) == schema::PrimitiveType_FullConnection || + GetCNodeTType(*node) == schema::PrimitiveType_MatMul) { auto inputIndexes = node->inputIndex; if (inputIndexes.size() < 2) { - MS_LOG(ERROR) << node->name.c_str() - << " node input has invalid inputs tensor count"; + MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count"; return RET_ERROR; } // quant weight @@ -248,8 +235,7 @@ STATUS AwareQuantizer::DoQuantize() { return RET_ERROR; } } - } else if (GetCNodeTType(*node) == - schema::PrimitiveType_DetectionPostProcess) { + } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { status = QuantDetectionPostProcessConstTensor(graph, node.get()); if (status != RET_OK) { MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; @@ -275,8 +261,7 @@ STATUS AwareQuantizer::DoQuantize() { return RET_OK; } -STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, - schema::CNodeT *node) { +STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { MS_ASSERT(graph != nullptr); MS_ASSERT(node != nullptr); for (size_t i = 0; i < node->inputIndex.size(); i++) { @@ -295,8 +280,7 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, void *inData = inTensor->data.data(); auto *castedInData = static_cast(inData); for (size_t j = 0; j < constTensorShapeSize; j++) { - qDatas[j] = - QuantizeData(castedInData[j], quantParam.get()); + qDatas[j] = QuantizeData(castedInData[j], quantParam.get()); } inTensor->data = std::move(qDatas); inTensor->dataType = kNumberTypeUInt8; @@ -312,17 +296,14 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, return RET_OK; } -STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor( - const schema::MetaGraphT *subGraph, schema::CNodeT *node) { +STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { MS_ASSERT(subGraph != nullptr); MS_ASSERT(node != nullptr); auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]); MS_ASSERT(constTensor != nullptr); - const auto *constData = - reinterpret_cast(constTensor->data.data()); + const auto *constData = reinterpret_cast(constTensor->data.data()); - if (constTensor->refCount == 999 && - constTensor->dataType == TypeId::kNumberTypeFloat) { + if (constTensor->nodeType == schema::NodeType_ValueNode && constTensor->dataType == TypeId::kNumberTypeFloat) { size_t constTensorShapeSize = GetShapeSize(*constTensor); std::unique_ptr quantParam = GetTensorQuantParam(constTensor); if (quantParam == nullptr) { @@ -340,8 +321,7 @@ STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor( return RET_OK; } -STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, - mindspore::schema::CNodeT *node) { +STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) { MS_ASSERT(graph != nullptr); MS_ASSERT(node != nullptr); auto inputIndexes = node->inputIndex; @@ -351,15 +331,10 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, MS_ASSERT(graph->allTensors.size() > inputIndexes.at(2)); auto &biasTensor = graph->allTensors.at(inputIndexes.at(2)); MS_ASSERT(biasTensor != nullptr); - if (biasTensor->dataType != TypeId::kNumberTypeFloat) { - // MS_LOGD("conv %s's bias data is not float", node->name.c_str()); - return RET_OK; - } - if (biasTensor->dataType == TypeId::kNumberTypeInt32) { return RET_OK; } - if (biasTensor->dataType != TypeId::kNumberTypeFloat) { + if (biasTensor->dataType != TypeId::kNumberTypeFloat && biasTensor->dataType != TypeId::kNumberTypeFloat32) { // MS_LOGE("conv %s's bias data is not float", node->name.c_str()); return RET_ERROR; } @@ -400,8 +375,8 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, biasTensor->dataType = TypeId::kNumberTypeInt32; biasTensor->data.clear(); biasTensor->data.resize(bShapeSize * sizeof(int32_t)); - auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), - qDatas.get(), bShapeSize * sizeof(int32_t)); + auto ret = + memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s failed: " << ret; return RET_ERROR; @@ -409,12 +384,10 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, return RET_OK; } -STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, - schema::CNodeT *node) { +STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { MS_ASSERT(subGraph != nullptr); MS_ASSERT(node != nullptr); - MS_ASSERT(node->quantParam.size() == - node->inputIndex.size() + node->outputIndex.size()); + MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size()); auto inputIndexes = node->inputIndex; MS_ASSERT(inputIndexes.size() >= 2); MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1)); @@ -422,11 +395,9 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, if (weightTensor->dataType == TypeId::kNumberTypeInt8) { return RET_OK; } - if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && - weightTensor->dataType != TypeId::kNumberTypeFloat && + if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat && weightTensor->dataType != TypeId::kNumberTypeUInt8) { - MS_LOG(ERROR) << "conv " << node->name.c_str() - << "'s weight data is not float or uint8"; + MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; return RET_ERROR; } size_t wShapeSize = GetShapeSize(*(weightTensor.get())); @@ -434,8 +405,8 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr); vector qDatas(wShapeSize); auto weightQauntParam = GetTensorQuantParam(weightTensor); - if (weightTensor->dataType == - TypeId::kNumberTypeFloat) { // normal awareing quant + if (weightTensor->dataType == TypeId::kNumberTypeFloat || + weightTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant auto *weightData = static_cast(oriWeightData); for (size_t j = 0; j < wShapeSize; j++) { qDatas[j] = QuantizeData(weightData[j], weightQauntParam.get()); @@ -463,8 +434,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { MS_ASSERT(graph->allTensors.size() > inTensorIdx); auto &inTensor = graph->allTensors.at(inTensorIdx); MS_ASSERT(inTensor != nullptr); - if (inTensor->quantParams.empty() || - inTensor->quantParams.front() == nullptr || + if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr || !inTensor->quantParams.front()->inited) { canQuant = false; break; @@ -476,8 +446,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { MS_ASSERT(graph->allTensors.size() > outTensorIdx); auto &outTensor = graph->allTensors.at(outTensorIdx); MS_ASSERT(outTensor != nullptr); - if (outTensor->quantParams.empty() || - outTensor->quantParams.front() == nullptr || + if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || !outTensor->quantParams.front()->inited) { canQuant = false; break; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 6e07cd0710..d527aea453 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include "mindspore/lite/tools/converter/quantizer/quantize_util.h" #include #include #include #include #include #include "src/ops/primitive_c.h" -#include "mindspore/lite/tools/converter/quantizer/quantize_util.h" #include "mindspore/lite/tools/converter/quantizer/general_bitpacking.h" #include "src/common/utils.h" #include "abstract/abstract_value.h" @@ -32,7 +33,7 @@ namespace mindspore { namespace lite { namespace quant { const std::array QuantStrategy::mConvTypes = { - {"Conv2D", "DeConv2D", "DepthwiseConv2D", "DeDepthwiseConv2D"}}; + {"Conv2D", "DeConv2D", "DepthwiseConv2D", "DeDepthwiseConv2D"}}; const std::array QuantStrategy::mMulTypes = {{"Mul", "MatMul", "BatchMatMul", "FullConnection"}}; QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold) @@ -99,10 +100,9 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, - schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ + schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, - schema::PrimitiveType_MatMul, - schema::PrimitiveType_Activation}; + schema::PrimitiveType_MatMul, schema::PrimitiveType_Activation}; return IsContain(uint8OpList, type); } @@ -164,8 +164,8 @@ bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const { return true; } -STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, - int quant_max, int quant_min, int num_bits) { +STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max, + int quant_min, int num_bits) { MS_ASSERT(quantParam != nullptr); if (mMin > 0.0f) { MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; @@ -216,8 +216,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl return RET_OK; } -STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, - bool narrowRange, int numBits) { +STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int numBits) { MS_ASSERT(quantParam != nullptr); if (mMin > 0.0f) { MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; @@ -246,8 +245,8 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl return RET_OK; } - int quantMin = narrowRange ? 1 : 0 - 128; - int quantMax = (1 << (unsigned int) numBits) - 1 - 128; + const int8_t quantMin = std::numeric_limits::min() + (narrowRange ? 1 : 0); + const int8_t quantMax = std::numeric_limits::max(); auto quantMinFloat = static_cast(quantMin); auto quantMaxFloat = static_cast(quantMax); double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); @@ -264,6 +263,9 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl } else { zeroPoint = static_cast(std::round(zpDouble)); } + if (std::abs(mMin) == std::abs(mMax)) { + zeroPoint = 0; + } // The zero point should always be in the range of quantized value, // [qmin, qmax]. MS_ASSERT(zeroPoint >= quantMin); diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index b955ab45f7..d5e1dd5ee3 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -23,6 +23,7 @@ #include #include #include +#include #include "tools/converter/quantizer/quantizer.h" #include "src/ops/primitive_c.h" #include "include/errorcode.h" @@ -75,13 +76,15 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { const auto zeroPoint = quantParam->zeroPoint; const auto numBit = quantParam->numBits; const auto narrowRange = quantParam->narrowRange; - const double maxLimit = static_cast((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale; + double maxLimitTemp = static_cast((1 << (unsigned int)numBit) - 1); + const double maxLimit = static_cast(maxLimitTemp - zeroPoint + std::numeric_limits::min()) * scale; double minLimit; if (narrowRange) { - minLimit = static_cast(1 - zeroPoint) * scale; + minLimit = static_cast(std::numeric_limits::min() + 1 - zeroPoint) * scale; } else { - minLimit = static_cast(0 - zeroPoint) * scale; + minLimit = static_cast(std::numeric_limits::min() - zeroPoint) * scale; } + return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { double tmp = 0.0f; if (originData > maxLimit) { @@ -91,10 +94,7 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { } else { tmp = originData; } - auto quantData = static_cast(std::round(tmp / scale + zeroPoint)); - if (quantData == 0 && narrowRange) { - quantData++; - } + auto quantData = static_cast(std::round(zeroPoint + tmp / scale)); return quantData; }(); }