From ef86629298917c7c7ff784d1b40ce2e59d144856 Mon Sep 17 00:00:00 2001 From: lyvette Date: Wed, 30 Sep 2020 16:35:51 +0800 Subject: [PATCH] rewrite aware train converter --- mindspore/lite/schema/model.fbs | 5 +- mindspore/lite/src/common/utils.h | 35 ++ mindspore/lite/src/lite_session.cc | 5 +- mindspore/lite/src/ops/add.cc | 7 +- mindspore/lite/src/ops/conv2d.cc | 8 +- mindspore/lite/src/ops/deconv2d.cc | 9 +- mindspore/lite/src/ops/depthwise_conv2d.cc | 9 +- mindspore/lite/src/ops/matmul.cc | 9 +- mindspore/lite/src/ops/primitive_c.cc | 83 +++-- mindspore/lite/src/ops/primitive_c.h | 18 +- .../kernel/arm/base/fullconnection_base.cc | 10 +- .../kernel/arm/base/quant_dtype_cast.cc | 5 +- .../arm/fp16/convolution_depthwise_fp16.cc | 4 +- .../kernel/arm/fp16/convolution_fp16.cc | 4 +- .../arm/fp16/deconvolution_depthwise_fp16.cc | 4 +- .../kernel/arm/fp16/deconvolution_fp16.cc | 4 +- .../kernel/arm/fp16/fullconnection_fp16.cc | 10 +- .../runtime/kernel/arm/fp16/matmul_fp16.cc | 10 +- .../runtime/kernel/arm/fp32/convolution.cc | 10 +- .../kernel/arm/fp32/convolution_depthwise.cc | 1 + .../runtime/kernel/arm/fp32/deconvolution.cc | 10 +- .../arm/fp32/deconvolution_depthwise.cc | 10 +- .../runtime/kernel/arm/int8/matmul_int8.cc | 5 +- mindspore/lite/src/sub_graph_kernel.cc | 32 +- mindspore/lite/src/tensor.h | 1 + .../lite/test/models_tflite_awaretraining.cfg | 6 +- mindspore/lite/test/run_benchmark_nets.sh | 4 +- .../lite/tools/anf_exporter/anf_exporter.cc | 1 + .../anf_importer/import_from_meta_graphT.cc | 46 +-- .../lite/tools/converter/anf_transform.cc | 45 ++- mindspore/lite/tools/converter/converter.cc | 1 - .../lite/tools/converter/converter_flags.cc | 49 ++- .../lite/tools/converter/converter_flags.h | 5 +- .../tools/converter/graphdef_transform.cc | 64 ++-- .../lite/tools/converter/graphdef_transform.h | 4 - .../legacy_optimizer/graph/CMakeLists.txt | 2 + .../graph/dtype_trans_pass.cc | 81 +---- .../legacy_optimizer/graph/dtype_trans_pass.h | 2 - .../graph/infer_quant_param_pass.cc | 87 +++++ .../graph/infer_quant_param_pass.h | 39 +++ .../graph/tensor_quant_pass.cc | 87 +++++ .../graph/tensor_quant_pass.h | 36 +++ .../parser/tflite/tflite_dequantize_parser.cc | 2 +- .../parser/tflite/tflite_quantize_parser.cc | 2 +- .../converter/quantizer/aware_quantizer.cc | 304 +++--------------- .../converter/quantizer/aware_quantizer.h | 18 -- .../converter/quantizer/calc_quant_param.cc | 47 ++- .../converter/quantizer/calc_quant_param.h | 14 + .../quantizer/post_training_quantizer.cc | 32 +- .../quantizer/post_training_quantizer.h | 3 +- .../converter/quantizer/quantize_util.cc | 2 +- .../tools/converter/quantizer/quantize_util.h | 9 +- .../converter/quantizer/weight_quantizer.cc | 8 +- .../lite/tools/optimizer/common/gllo_utils.cc | 136 ++++---- .../optimizer/fusion/batchmatmul_fusion.cc | 4 +- 55 files changed, 751 insertions(+), 697 deletions(-) create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.h diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index df9889c96b..8f34e21dcf 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -32,8 +32,9 @@ table QuantParam { narrowRange: bool = true; numBits: int = 8; inited: bool = false; - var_corr: double = 1; - mean_corr: double = 0; + varCorr: double = 1; + meanCorr: double = 0; + dstDtype: int = 32; clusters: [float]; } diff --git a/mindspore/lite/src/common/utils.h b/mindspore/lite/src/common/utils.h index 7ef2389368..fbfb9ab04b 100644 --- a/mindspore/lite/src/common/utils.h +++ b/mindspore/lite/src/common/utils.h @@ -27,6 +27,9 @@ #include "src/common/log_adapter.h" #include "tools/common/option.h" #include "include/errorcode.h" +#ifdef ENABLE_ARM64 +#include "nnacl/optimized_kernel.h" +#endif namespace mindspore { namespace lite { @@ -186,6 +189,38 @@ inline Option GenericParseValue(const std::string &value) { return Option(None()); } + +using Float16CastFunc = void (*)(const void *, void *, int); + +class Float16CastUtil { + public: + static Float16CastUtil *GetInstance() { + static Float16CastUtil float16_cast_util; + return &float16_cast_util; + } + + private: + Float16CastUtil() { +#ifdef ENABLE_ARM64 + void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_; + if (fp16_op_handler != nullptr) { + dlerror(); + *(reinterpret_cast(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler"); + *(reinterpret_cast(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler"); + auto dlopen_error = dlerror(); + if (dlopen_error != nullptr) { + MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << "."; + } + } +#endif + } + ~Float16CastUtil() = default; + + public: + Float16CastFunc float16_to_float32_func_ = nullptr; + Float16CastFunc float32_to_float16_func_ = nullptr; +}; + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index e2ce2ff400..cadf01bdd8 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -108,8 +108,9 @@ int LiteSession::ConvertTensors(const lite::Model *model) { QuantArg quant_arg{}; quant_arg.scale = quant_params->Get(j)->scale(); quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); - quant_arg.var_corr = quant_params->Get(j)->var_corr(); - quant_arg.mean_corr = quant_params->Get(j)->mean_corr(); + quant_arg.var_corr = quant_params->Get(j)->varCorr(); + quant_arg.mean_corr = quant_params->Get(j)->meanCorr(); + quant_arg.inited = quant_params->Get(j)->inited(); auto quant_clusters = quant_params->Get(j)->clusters(); if (quant_clusters != nullptr) { for (size_t k = 0; k < quant_clusters->size(); k++) { diff --git a/mindspore/lite/src/ops/add.cc b/mindspore/lite/src/ops/add.cc index 0d868ac499..8661180456 100644 --- a/mindspore/lite/src/ops/add.cc +++ b/mindspore/lite/src/ops/add.cc @@ -49,12 +49,7 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector &inputs return RET_ERROR; } } - if (GetQuantType() == schema::QuantType_AwareTraining) { - std::vector> vecInputQuantParam; - std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); - SetOutputQuantParam(vecOutputQuantParam); - } + PopulaterQuantParam(prim, inputs); return RET_OK; } diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index ed9148b168..7cd97f7a52 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -277,13 +277,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector &inp PopulaterConv2DSingleGroup(prim, this->primitive_, group); } - if (GetQuantType() == schema::QuantType_AwareTraining) { - std::vector> vecInputQuantParam; - std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); - SetInputQuantParam(vecInputQuantParam); - SetOutputQuantParam(vecOutputQuantParam); - } + PopulaterQuantParam(prim, inputs); return RET_OK; } diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index 5d3f6fbc9a..641372f93a 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -254,14 +254,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector &i } else if (group > 1) { PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); } - - if (GetQuantType() == schema::QuantType_AwareTraining) { - std::vector> vecInputQuantParam; - std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); - SetInputQuantParam(vecInputQuantParam); - SetOutputQuantParam(vecOutputQuantParam); - } + PopulaterQuantParam(prim, inputs); return RET_OK; } #else diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index f7bb3222dd..f54bd3141f 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -146,14 +146,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vectorprimitive_->value.type = schema::PrimitiveType_DepthwiseConv2D; this->primitive_->value.value = attr.release(); - - if (GetQuantType() == schema::QuantType_AwareTraining) { - std::vector> vecInputQuantParam; - std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); - SetInputQuantParam(vecInputQuantParam); - SetOutputQuantParam(vecOutputQuantParam); - } + PopulaterQuantParam(prim, inputs); return RET_OK; } diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index 1e72733d8e..e07fa9b70b 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -61,13 +61,8 @@ int MatMul::UnPackAttr(const Primitive &prim, const std::vector &inp return RET_ERROR; } } - if (GetQuantType() == schema::QuantType_AwareTraining) { - std::vector> vecInputQuantParam; - std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); - SetInputQuantParam(vecInputQuantParam); - SetOutputQuantParam(vecOutputQuantParam); - } + + PopulaterQuantParam(prim, inputs); return RET_OK; } diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 60187883c3..e187a41ce2 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -164,32 +164,29 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -void PrimitiveC::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { +void PrimitiveC::CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax) { const float qmin = 0; const float qmax = 255; *mMin = static_cast((qmin - mean) / stdDev); *mMax = static_cast((qmax - mean) / stdDev); } -void PrimitiveC::PopulaterQuantParam(const Primitive &prim, - std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam, - const std::vector &inputs) { +void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector &inputs) { auto narrow_range = prim.GetAttr("narrow_range"); - bool narrowRangeQuantParam = GetValue(narrow_range); + bool narrowRangeQuantParam = narrow_range != nullptr ? GetValue(narrow_range) : false; auto num_bits = prim.GetAttr("num_bits"); - int32_t numbitsRangeQuantParam = GetValue(num_bits); + int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue(num_bits) : 8; std::vector quants; schema::QuantParamT quantParam; auto mean = prim.GetAttr("mean"); auto std_dev = prim.GetAttr("std_dev"); if (mean != nullptr && std_dev != nullptr) { - auto meanQuantOaram = GetValue(mean); - double stddevQuantOaram = GetValue(std_dev); + auto meanValue = GetValue(mean); + auto stddevValue = GetValue(std_dev); float mMin = 0.0; float mMax = 0.0; - CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); + CalFloatScopeByMeanAndStddev(meanValue, stddevValue, &mMin, &mMax); quantParam.min = mMin; quantParam.max = mMax; } else { @@ -198,8 +195,8 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, if (inputMin != nullptr && inputMax != nullptr) { auto inputMinPtr = inputMin->cast(); auto inputMaxPtr = inputMax->cast(); - float *minBuf = static_cast(inputMinPtr->data_c()); - float *maxBuf = static_cast(inputMaxPtr->data_c()); + auto *minBuf = static_cast(inputMinPtr->data_c()); + auto *maxBuf = static_cast(inputMaxPtr->data_c()); quantParam.min = *minBuf; quantParam.max = *maxBuf; } @@ -207,7 +204,7 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); quants.emplace_back(quantParam); - vecInputQuantParam->emplace_back(quants); + input_quant_param_.emplace_back(quants); quants.clear(); auto filterMin = prim.GetAttr("filter_minq"); @@ -227,17 +224,25 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, } quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, true, numbitsRangeQuantParam); quants.emplace_back(quantParam); - vecInputQuantParam->emplace_back(quants); + input_quant_param_.emplace_back(quants); } - if (vecInputQuantParam->size() == kDoubleNum) { + if (input_quant_param_.size() == kDoubleNum) { quants.clear(); quantParam.min = 0.0; quantParam.max = 0.0; quantParam.zeroPoint = 0; - quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale; + quantParam.scale = input_quant_param_.at(0).at(0).scale * input_quant_param_.at(1).at(0).scale; quants.emplace_back(quantParam); - vecInputQuantParam->emplace_back(quants); + input_quant_param_.emplace_back(quants); + } + + // fill input_quant_param_ by not inited quant_parm + if (input_quant_param_.size() < inputs.size()) { + quants.clear(); + schema::QuantParamT tmpQuantParam; + quants.emplace_back(tmpQuantParam); + input_quant_param_.insert(input_quant_param_.end(), inputs.size() - 1 - input_quant_param_.size(), quants); } quants.clear(); @@ -253,7 +258,11 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); quants.emplace_back(quantParam); - vecOutputQuantParam->emplace_back(quants); + output_quant_param_.emplace_back(quants); + } else { + schema::QuantParamT tmpQuantParam; + quants.emplace_back(tmpQuantParam); + output_quant_param_.emplace_back(quants); } } @@ -279,14 +288,48 @@ schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; } -void PrimitiveC::SetInputQuantParam(const std::vector> &input_quant_param) { +void PrimitiveC::SetInputQuantParams(const std::vector> &input_quant_param) { this->input_quant_param_ = input_quant_param; } -void PrimitiveC::SetOutputQuantParam(const std::vector> &output_quant_param) { +void PrimitiveC::SetInputQuantParam(const size_t &index, const std::vector &input_quant_param) { + MS_ASSERT(index < this->input_quant_param_.size()); + this->input_quant_param_[index] = input_quant_param; +} + +void PrimitiveC::SetOutputQuantParams(const std::vector> &output_quant_param) { this->output_quant_param_ = output_quant_param; } +void PrimitiveC::SetOutputQuantParam(const size_t &index, const std::vector &output_quant_param) { + MS_ASSERT(index < this->output_quant_param_.size()); + this->output_quant_param_[index] = output_quant_param; +} + +bool PrimitiveC::IsInputQuantParamsInited() { + if (this->input_quant_param_.empty()) { + return false; + } + for (auto &quant_param : this->input_quant_param_) { + if (!quant_param.front().inited) { + return false; + } + } + return true; +} + +bool PrimitiveC::IsOutputQuantParamsInited() { + if (this->output_quant_param_.empty()) { + return false; + } + for (auto &quant_param : this->output_quant_param_) { + if (!quant_param.front().inited) { + return false; + } + } + return true; +} + void PrimitiveC::ClearInputOutputQuantParam() { input_quant_param_.clear(); output_quant_param_.clear(); diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index df858cd6d9..542157d231 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -88,9 +88,17 @@ class PrimitiveC : public mindspore::Primitive { } } - void SetInputQuantParam(const std::vector> &input_quant_param); + void SetInputQuantParams(const std::vector> &input_quant_param); - void SetOutputQuantParam(const std::vector> &output_quant_param); + void SetInputQuantParam(const size_t &index, const std::vector &input_quant_param); + + void SetOutputQuantParams(const std::vector> &output_quant_param); + + void SetOutputQuantParam(const size_t &index, const std::vector &output_quant_param); + + bool IsInputQuantParamsInited(); + + bool IsOutputQuantParamsInited(); void ClearInputOutputQuantParam(); @@ -120,10 +128,8 @@ class PrimitiveC : public mindspore::Primitive { static std::shared_ptr Create(const Primitive &prim, const std::vector &inputs, const schema::QuantType &quantType); - void PopulaterQuantParam(const Primitive &prim, std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam, - const std::vector &inputs); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); + void PopulaterQuantParam(const Primitive &prim, const std::vector &inputs); + void CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax); protected: virtual int UnPackAttr(const Primitive &prim, const std::vector &inputs) { return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc index a01703c52d..4ac7633844 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc @@ -42,7 +42,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vectordata_c(); - if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && + restore_data != nullptr; + if (dequant_flag) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { MS_LOG(ERROR) << "dequant data is nullptr."; @@ -53,7 +55,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vectorGetQuantParams().empty()) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } @@ -65,13 +67,13 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (!weight_tensor->GetQuantParams().empty()) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } return nullptr; } - if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc index 1562c37d7f..c9053bf79d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc @@ -98,8 +98,9 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found."; return RET_ERROR; } - auto quant_arg = !out_tensors_.front()->GetQuantParams().empty() ? out_tensors_.front()->GetQuantParams().front() - : in_tensors_.front()->GetQuantParams().front(); + auto quant_arg = out_tensors_.front()->GetQuantParams().front().inited + ? out_tensors_.front()->GetQuantParams().front() + : in_tensors_.front()->GetQuantParams().front(); int ret = RET_OK; if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc index af67b1272f..28d38d2ab7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc @@ -140,8 +140,8 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->MutableData(); - auto dequant_flag = - (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; + bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && + restore_data != nullptr; if (dequant_flag) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 9015aadd35..c8c0de972d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -182,8 +182,8 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->MutableData(); - auto dequant_flag = - (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; + bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && + restore_data != nullptr; if (dequant_flag) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc index 0a34c236ca..7c334aace4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc @@ -204,8 +204,8 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vectorMutableData(); - auto dequant_flag = - (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; + auto dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && + restore_data != nullptr; if (dequant_flag) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc index 39b607727b..0b1d4b02d4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc @@ -216,8 +216,8 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->MutableData(); - auto dequant_flag = - (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; + auto dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && + restore_data != nullptr; if (dequant_flag) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc index 407a33cd66..7f40e995e3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc @@ -237,7 +237,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vectordata_c(); - if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && + restore_data != nullptr; + if (dequant_flag) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { MS_LOG(ERROR) << "dequant data is nullptr."; @@ -250,7 +252,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vectorGetQuantParams().empty()) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } @@ -262,13 +264,13 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); delete kernel; - if (!weight_tensor->GetQuantParams().empty()) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } return nullptr; } - if (!weight_tensor->GetQuantParams().empty()) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc index 3f7a3570fd..f8ae9de1cb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc @@ -251,7 +251,9 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector const mindspore::lite::PrimitiveC *primitive) { auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->data_c(); - if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && + restore_data != nullptr; + if (dequant_flag) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { MS_LOG(ERROR) << "dequant data is nullptr."; @@ -263,7 +265,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } @@ -275,13 +277,13 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); delete kernel; - if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } return nullptr; } - if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc index 9bc93235ba..acbd327e46 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -284,7 +284,9 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->MutableData(); - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && + restore_data != nullptr; + if (dequant_flag) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { MS_LOG(ERROR) << "dequant data is nullptr."; @@ -303,7 +305,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } @@ -315,14 +317,14 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & delete kernel; MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } return nullptr; } - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc index 4469fc02dd..33b8156b74 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc @@ -124,6 +124,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->MutableData(); + if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc index 6d0be356a5..cd4f792171 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc @@ -234,7 +234,9 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->MutableData(); - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && + restore_data != nullptr; + if (dequant_flag) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { MS_LOG(ERROR) << "dequant data is nullptr."; @@ -255,7 +257,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } @@ -267,14 +269,14 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector delete kernel; MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } return nullptr; } - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc index e80f05f4b3..d60c000a46 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc @@ -196,7 +196,9 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vectorMutableData(); - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && + restore_data != nullptr; + if (dequant_flag) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { MS_LOG(ERROR) << "dequant data is nullptr."; @@ -209,7 +211,7 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vectordata_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } @@ -221,13 +223,13 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } return nullptr; } - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { + if (dequant_flag) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index 92e70ad84c..8b1f80f8e5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -208,9 +208,8 @@ kernel::LiteKernel *CpuMatmulInt8KernelCreator(const std::vector auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->data_c(); - auto is_const_quant_weight = - (restore_data != nullptr) && - ((weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16)); + bool is_const_quant_weight = !weight_tensor->GetQuantParams().empty() && + weight_tensor->GetQuantParams().front().inited && restore_data != nullptr; if (is_const_quant_weight) { auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { diff --git a/mindspore/lite/src/sub_graph_kernel.cc b/mindspore/lite/src/sub_graph_kernel.cc index 8b328e2965..2018e6d65a 100644 --- a/mindspore/lite/src/sub_graph_kernel.cc +++ b/mindspore/lite/src/sub_graph_kernel.cc @@ -25,36 +25,6 @@ using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_ERR; using mindspore::lite::RET_INFER_INVALID; using mindspore::lite::RET_OK; -using Float16CastFunc = void (*)(const void *, void *, int); - -class Float16CastUtil { - public: - static Float16CastUtil *GetInstance() { - static Float16CastUtil float16_cast_util; - return &float16_cast_util; - } - - private: - Float16CastUtil() { -#ifdef ENABLE_ARM64 - void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_; - if (fp16_op_handler != nullptr) { - dlerror(); - *(reinterpret_cast(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler"); - *(reinterpret_cast(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler"); - auto dlopen_error = dlerror(); - if (dlopen_error != nullptr) { - MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << "."; - } - } -#endif - } - ~Float16CastUtil() = default; - - public: - Float16CastFunc float16_to_float32_func_ = nullptr; - Float16CastFunc float32_to_float16_func_ = nullptr; -}; int SubGraphKernel::Prepare() { for (auto node : this->nodes_) { @@ -208,7 +178,7 @@ int CpuFp16SubGraph::PreProcess() { } int CpuFp16SubGraph::PostProcess() { - auto fp16_to_fp32_cast_func = Float16CastUtil::GetInstance()->float16_to_float32_func_; + auto fp16_to_fp32_cast_func = lite::Float16CastUtil::GetInstance()->float16_to_float32_func_; if (fp16_to_fp32_cast_func == nullptr) { MS_LOG(ERROR) << "Can not find cast fp16 to fp32 func"; return RET_ERROR; diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index fa4ff6f817..8c012526b7 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -35,6 +35,7 @@ struct QuantArg { int32_t zeroPoint; double var_corr{1}; double mean_corr{0}; + bool inited; std::vector clusters{}; }; diff --git a/mindspore/lite/test/models_tflite_awaretraining.cfg b/mindspore/lite/test/models_tflite_awaretraining.cfg index b4645eeff7..dac3fc8b41 100644 --- a/mindspore/lite/test/models_tflite_awaretraining.cfg +++ b/mindspore/lite/test/models_tflite_awaretraining.cfg @@ -33,6 +33,6 @@ lite-model_on_device_vision_classifier_landmarks_classifier_oceania_antarctica_V lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite vision_classifier_fungi_mobile_V1_1_default_1.tflite -detect.tflite -ssd_mobilenet_v1_1_default_1.tflite -object_detection_mobile_object_localizer_v1_1_default_1.tflite +#detect.tflite +#ssd_mobilenet_v1_1_default_1.tflite +#object_detection_mobile_object_localizer_v1_1_default_1.tflite diff --git a/mindspore/lite/test/run_benchmark_nets.sh b/mindspore/lite/test/run_benchmark_nets.sh index 1795b9775c..b6700aa8cf 100644 --- a/mindspore/lite/test/run_benchmark_nets.sh +++ b/mindspore/lite/test/run_benchmark_nets.sh @@ -121,8 +121,8 @@ function Run_Converter() { continue fi echo ${model_name} >> "${run_converter_log_file}" - echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}' --quantType=AwareTraining' >> "${run_converter_log_file}" - ./converter_lite --fmk=TFLITE --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} --quantType=AwareTraining + echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}' --inputDataType=FLOAT --outputDataType=FLOAT' >> "${run_converter_log_file}" + ./converter_lite --fmk=TFLITE --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} --inputDataType=FLOAT --outputDataType=FLOAT if [ $? = 0 ]; then converter_result='converter aware_training '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} else diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index c4a0c7ddec..8f4279102b 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -544,6 +544,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrnodeType = schema::NodeType_CNode; + ms_tensor->dataType = TypeId::kNumberTypeFloat32; fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); node_id_map_[cnode_name] = meta_graphT->allTensors.size(); meta_graphT->allTensors.emplace_back(ms_tensor); diff --git a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc index bb40b536ea..b4aad8bef8 100644 --- a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc @@ -73,30 +73,30 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptrprimitive.release()); cNode->primitive = nullptr; // add quant parameter - if (cNode->quantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) { - primitiveCValue->SetQuantType(cNode->quantType); - for (int index : cNode->inputIndex) { - if (!meta_graph_->allTensors[index]->quantParams.empty()) { - std::vector quant_params(meta_graph_->allTensors[index]->quantParams.size()); - std::transform( - meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), - quant_params.begin(), - [](std::unique_ptr &quant_param) -> schema::QuantParamT { return *quant_param; }); - primitiveCValue->AddInputQuantParam(quant_params); - } else { - std::vector empty_quant_params; - primitiveCValue->AddInputQuantParam(empty_quant_params); - } + for (auto index : cNode->inputIndex) { + if (!meta_graph_->allTensors[index]->quantParams.empty()) { + std::vector quant_params(meta_graph_->allTensors[index]->quantParams.size()); + std::transform( + meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), + quant_params.begin(), + [](std::unique_ptr &quant_param) -> schema::QuantParamT { return *quant_param; }); + primitiveCValue->AddInputQuantParam(quant_params); + } else { + std::vector notinited_quant_params(1); + primitiveCValue->AddInputQuantParam(notinited_quant_params); } - for (int index : cNode->outputIndex) { - if (!meta_graph_->allTensors[index]->quantParams.empty()) { - std::vector quant_params(meta_graph_->allTensors[index]->quantParams.size()); - std::transform( - meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), - quant_params.begin(), - [](std::unique_ptr &quant_param) -> schema::QuantParamT { return *quant_param; }); - primitiveCValue->AddOutputQuantParam(quant_params); - } + } + for (auto index : cNode->outputIndex) { + if (!meta_graph_->allTensors[index]->quantParams.empty()) { + std::vector quant_params(meta_graph_->allTensors[index]->quantParams.size()); + std::transform( + meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), + quant_params.begin(), + [](std::unique_ptr &quant_param) -> schema::QuantParamT { return *quant_param; }); + primitiveCValue->AddOutputQuantParam(quant_params); + } else { + std::vector notinited_quant_params(1); + primitiveCValue->AddOutputQuantParam(notinited_quant_params); } } auto value_node = NewValueNode(std::shared_ptr(primitiveCValue)); diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 93cdaaf47c..7f8e42640d 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -52,16 +52,13 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver auto convert_pm = std::make_shared("anf graph convert pass manager", true); // for now - trainning is not supporting fuse operations - if (config != nullptr && config->trainModel == false) { + if (config != nullptr && !config->trainModel) { // remove quantdtype when awaretraining if (config->fmk == lite::converter::FmkType_ONNX) { auto remove_identity_pass = std::make_shared(); remove_identity_pass->SetFmkType(config->fmk); pm->AddPass(remove_identity_pass); } - if (config->quantType == QuantType_AwareTraining) { - pm->AddPass(std::make_shared()); - } pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); @@ -101,27 +98,25 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver return nullptr; } // quant - if (config != nullptr) { - if (config->quantType == schema::QuantType_PostTraining) { - this->mQuantizer = std::make_unique(new_graph, config->configFile, 8); - if (mQuantizer == nullptr) { - MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); - return nullptr; - } - } else if (config->quantType == schema::QuantType_WeightQuant) { - if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { - MS_LOG(ERROR) << "weight quant input param error"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return nullptr; - } - this->mQuantizer = std::make_unique(new_graph, config->quantWeightSize, - config->quantWeightChannel, config->bitNum); - if (mQuantizer == nullptr) { - MS_LOG(ERROR) << "New WeightQuantizer failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); - return nullptr; - } + if (config->quantType == schema::QuantType_PostTraining) { + this->mQuantizer = std::make_unique(new_graph, config->configFile, 8); + if (mQuantizer == nullptr) { + MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); + return nullptr; + } + } else if (config->quantType == schema::QuantType_WeightQuant) { + if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { + MS_LOG(ERROR) << "weight quant input param error"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return nullptr; + } + this->mQuantizer = std::make_unique(new_graph, config->quantWeightSize, + config->quantWeightChannel, config->bitNum); + if (mQuantizer == nullptr) { + MS_LOG(ERROR) << "New WeightQuantizer failed"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); + return nullptr; } } if (mQuantizer != nullptr) { diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 2ab80994e1..e1e047b935 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -93,7 +93,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { } // transform transform->SetGraphDef(meta_graph); - transform->CreateQuantizer(flag); auto status = transform->Transform(*flag); if (status != RET_OK) { MS_LOG(ERROR) << "Transform meta graph failed " << status; diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 74d4cdafe5..e2fc3f1eb7 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -29,9 +29,14 @@ Flags::Flags() { AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); - AddFlag(&Flags::inferenceTypeIn, "inferenceType", "Data type of input and output tensors. FLOAT | INT8 | UINT8", - "FLOAT"); - AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); + AddFlag(&Flags::inputDataTypeIn, "inputDataType", + "Data type of input tensors, default is same with the type defined in model. FLOAT | INT8 | UINT8 | DEFAULT", + "DEFAULT"); + AddFlag(&Flags::outputDataTypeIn, "outputDataType", + "Data type of output and output tensors, default is same with the type defined in model. FLOAT | INT8 | " + "UINT8 | DEFAULT", + "DEFAULT"); + AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0"); AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16"); @@ -78,15 +83,32 @@ int Flags::Init(int argc, const char **argv) { return RET_INPUT_PARAM_INVALID; } - if (this->inferenceTypeIn == "FLOAT") { - this->inferenceType = TypeId::kNumberTypeFloat; - } else if (this->inferenceTypeIn == "INT8") { - this->inferenceType = TypeId::kNumberTypeInt8; - } else if (this->inferenceTypeIn == "UINT8") { - this->inferenceType = TypeId::kNumberTypeUInt8; + if (this->inputDataTypeIn == "FLOAT") { + this->inputDataType = TypeId::kNumberTypeFloat; + } else if (this->inputDataTypeIn == "INT8") { + this->inputDataType = TypeId::kNumberTypeInt8; + } else if (this->inputDataTypeIn == "UINT8") { + this->inputDataType = TypeId::kNumberTypeUInt8; + } else if (this->inputDataTypeIn == "DEFAULT") { + this->inputDataType = TypeId::kTypeUnknown; } else { - std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | UINT8", - this->inferenceTypeIn.c_str(); + std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT", + this->inputDataTypeIn.c_str(); + return RET_INPUT_PARAM_INVALID; + } + + if (this->outputDataTypeIn == "FLOAT") { + this->outputDataType = TypeId::kNumberTypeFloat; + } else if (this->outputDataTypeIn == "INT8") { + this->outputDataType = TypeId::kNumberTypeInt8; + } else if (this->outputDataTypeIn == "UINT8") { + this->outputDataType = TypeId::kNumberTypeUInt8; + } else if (this->outputDataTypeIn == "DEFAULT") { + this->outputDataType = TypeId::kTypeUnknown; + } else { + std::cerr + << "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT", + this->outputDataTypeIn.c_str(); return RET_INPUT_PARAM_INVALID; } @@ -107,9 +129,8 @@ int Flags::Init(int argc, const char **argv) { std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; return RET_INPUT_PARAM_INVALID; } - if (this->quantTypeIn == "AwareTraining") { - this->quantType = QuantType_AwareTraining; - } else if (this->quantTypeIn == "WeightQuant") { + + if (this->quantTypeIn == "WeightQuant") { this->quantType = QuantType_WeightQuant; } else if (this->quantTypeIn == "PostTraining") { this->quantType = QuantType_PostTraining; diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index ebcb46e2bd..22fb2f0fb8 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -53,8 +53,11 @@ class Flags : public virtual mindspore::lite::FlagParser { std::string quantTypeIn; QuantType quantType; std::string inferenceTypeIn; + std::string inputDataTypeIn; + std::string outputDataTypeIn; // used for parse aware trainning - TypeId inferenceType = TypeId::kNumberTypeFloat; + TypeId inputDataType; + TypeId outputDataType; // used for post-trainning-weight std::string quantWeightSize; std::string bitNum; diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 1ff58281a6..f53b7fdd7c 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -34,6 +34,8 @@ #include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" +#include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" +#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" #include "tools/converter/quantizer/aware_quantizer.h" using std::string; @@ -44,20 +46,6 @@ GraphDefTransform::~GraphDefTransform() = default; void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } -void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { - auto type = flags->quantType; - switch (type) { - case QuantType::QuantType_AwareTraining: { - MS_LOG(INFO) << "create AwareTrainingQuantizer!"; - fbQuantizer = std::make_unique(graphDefT, flags->inferenceType); - break; - } - default: - MS_LOG(INFO) << "will support quantizer type " << flags->quantTypeIn << " in the future"; - break; - } -} - int GraphDefTransform::Transform(const converter::Flags &ctx) { STATUS status; { @@ -84,26 +72,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // generate and infer quant parameters { - if (fbQuantizer != nullptr) { - Optimizer topologicalOptimizer; - topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - status = topologicalOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; - return status; - } - if (ctx.quantType == QuantType_AwareTraining) { - status = fbQuantizer->GenerateQuantParam(); - if (status != RET_OK) { - MS_LOG(ERROR) << "GenerateQuantParam failed"; - return status; - } - status = fbQuantizer->DetermineNodeQuantType(); - if (status != RET_OK) { - MS_LOG(ERROR) << "DetermineNodeQuant failed"; - return status; - } - } + Optimizer inferQuantParamPass; + inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); + inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); + status = inferQuantParamPass.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; + return status; } } @@ -146,12 +121,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } { - Optimizer fusionOptimizer; - fusionOptimizer.AddPass(new (std::nothrow) FormatTransPermuteFusionPass()); - fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - status = fusionOptimizer.Run(graphDefT); + Optimizer inferQuantParamOtimizer; + inferQuantParamOtimizer.AddPass(new (std::nothrow) InferQuantParamPass()); + status = inferQuantParamOtimizer.Run(graphDefT); if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; + MS_LOG(ERROR) << "Run tensorQuantOptimizer graphPasses Failed"; return status; } } @@ -168,8 +142,10 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } // do quantization - if (fbQuantizer != nullptr) { - status = fbQuantizer->DoQuantize(); + { + Optimizer fusionOptimizer; + fusionOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); + status = fusionOptimizer.Run(graphDefT); if (status != RET_OK) { MS_LOG(ERROR) << "DoQuantize failed!"; return status; @@ -177,11 +153,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } // insert quantNode and deQuantNode - if (ctx.quantType == QuantType_AwareTraining) { + { Optimizer quantNodeOptimizer; auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); - dTypeTransPass->SetInputDataDType(ctx.inferenceType); - dTypeTransPass->SetOutputDataDType(ctx.inferenceType); + dTypeTransPass->SetInputDataDType(ctx.inputDataType); + dTypeTransPass->SetOutputDataDType(ctx.outputDataType); quantNodeOptimizer.AddPass(dTypeTransPass); quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); diff --git a/mindspore/lite/tools/converter/graphdef_transform.h b/mindspore/lite/tools/converter/graphdef_transform.h index 822754d925..358383cb76 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.h +++ b/mindspore/lite/tools/converter/graphdef_transform.h @@ -37,14 +37,10 @@ class GraphDefTransform { virtual int Transform(const converter::Flags &ctx); void SetGraphDef(schema::MetaGraphT *dstDef); inline schema::MetaGraphT *GetOutput() { return graphDefT; } - void CreateQuantizer(const converter::Flags *flags); protected: schema::MetaGraphT *graphDefT = nullptr; Optimizer *optimizer = nullptr; - - std::unique_ptr mQuantizer; - std::unique_ptr fbQuantizer; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt index b5a2895590..38aa6a9fca 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt @@ -10,6 +10,8 @@ file(GLOB GRAPH_PASS ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/trans_format_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc ) set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index b84d57d6df..20852c670f 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -27,9 +27,6 @@ namespace lite { #define kMinInputNum 1 #define kOutputNum 1 -static const std::set NoNeedDtypeTransList = { - PrimitiveType_QuantDTypeCast, PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; - STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); @@ -44,12 +41,6 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; return status; } - - status = DoNodeInoutDTypeTrans(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; - return status; - } return RET_OK; } @@ -57,7 +48,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); auto &graphInIdxes = graph->inputIndex; - if (this->inputDataDType == TypeId::kNumberTypeInt8) { + if (this->inputDataDType == TypeId::kNumberTypeInt8 || this->inputDataDType == TypeId::kTypeUnknown) { return RET_OK; } if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) { @@ -68,7 +59,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { for (auto graphInIdx : graphInIdxes) { MS_ASSERT(graphInIdx < graph->allTensors.size()); auto &tensor = graph->allTensors.at(graphInIdx); - if (tensor->dims.size() != kNHWCDimNumber || tensor->dataType != kNumberTypeInt8) { + if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { continue; } @@ -98,7 +89,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); - if (outputDataDType == TypeId::kNumberTypeInt8) { + if (outputDataDType == TypeId::kNumberTypeInt8 || outputDataDType == TypeId::kTypeUnknown) { return RET_OK; } if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) { @@ -107,6 +98,11 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { } auto &graphOutIdxes = graph->outputIndex; for (auto graphOutIdx : graphOutIdxes) { + MS_ASSERT(graphOutIdx < graph->allTensors.size()); + auto &tensor = graph->allTensors.at(graphOutIdx); + if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { + continue; + } for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto nodeName = (*iter)->name; MS_ASSERT(node != nullptr); @@ -131,67 +127,6 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { return RET_OK; } -STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - // insert transNode before and after existNode - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { - continue; - } - auto iterType = GetCNodeTType(**iter); - if (NoNeedDtypeTransList.find(iterType) != NoNeedDtypeTransList.end()) { - continue; - } - bool needInsertPost = true; - if (GetCNodeTType(**iter) == PrimitiveType_Shape) { - needInsertPost = false; - } - auto nodeName = (*iter)->name; - if ((*iter)->inputIndex.size() < kMinInputNum) { - MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least"; - return RET_ERROR; - } - STATUS status; - // insert pre - for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { - MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); - auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); - if (preTensor->dataType == TypeId::kNumberTypeInt || preTensor->dataType == TypeId::kNumberTypeInt32) { - continue; - } - auto &graphInIdxes = graph->inputIndex; - if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { - continue; - } - if ((preTensor->dataType != TypeId::kNumberTypeInt8) && (IsContain(graphInIdxes, (*iter)->inputIndex.at(i)))) { - continue; - } - iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; - return RET_ERROR; - } - } - - if (needInsertPost) { - for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { - auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i)); - if (postTensor->dataType == TypeId::kNumberTypeInt || postTensor->dataType == TypeId::kNumberTypeInt32) { - continue; - } - iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; - return RET_ERROR; - } - } - } - (*iter)->quantType = QuantType_QUANT_NONE; - } - - return RET_OK; -} - NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) { MS_ASSERT((*existNodeIter) != nullptr); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h index a3fc5490a3..d898c10eb3 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h @@ -45,8 +45,6 @@ class DTypeTransPass : public GraphPass { STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph); - STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph); - NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc new file mode 100644 index 0000000000..7a83ac517f --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "src/common/utils.h" +#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" +#include "tools/converter/quantizer/calc_quant_param.h" +#include "tools/common/node_util.h" +#include "tools/common/converter_op_utils.h" + +namespace mindspore::lite { +STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) { + auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); + + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto &node = *iter; + MS_ASSERT(node != nullptr); + if (node->quantType == schema::QuantType_WeightQuant) { + continue; + } + if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax || + GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { + MS_ASSERT(false); + } + auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); + if (quantParamCalcer == nullptr) { + MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() + << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; + node->quantType = static_cast(schema::QuantType_QUANT_NONE); + } else { + auto status = quantParamCalcer->Calc(graph, *node); + if (status != RET_OK) { + MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); + node->quantType = schema::QuantType_QUANT_NONE; + } else { + DetermineNodeQuantType(*graph, node.get()); + } + } + } + return RET_OK; +} + +void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode) { + MS_ASSERT(cnode != nullptr); + bool canQuant = true; + for (auto &inputTensorIdx : cnode->inputIndex) { + MS_ASSERT(graph.allTensors.size() > inputTensorIdx); + auto &inTensor = graph.allTensors.at(inputTensorIdx); + MS_ASSERT(inTensor != nullptr); + if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr || + !inTensor->quantParams.front()->inited) { + canQuant = false; + break; + } + } + + for (auto &outTensorIdx : cnode->outputIndex) { + MS_ASSERT(graph.allTensors.size() > outTensorIdx); + auto &outTensor = graph.allTensors.at(outTensorIdx); + MS_ASSERT(outTensor != nullptr); + if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || + !outTensor->quantParams.front()->inited) { + canQuant = false; + break; + } + } + + if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*cnode))) { + cnode->quantType = schema::QuantType_AwareTraining; + } else { + cnode->quantType = schema::QuantType_QUANT_NONE; + } +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h new file mode 100644 index 0000000000..d2e3c634ba --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_INFER_QUANT_PARAM_PASS_H +#define LITE_INFER_QUANT_PARAM_PASS_H + +#include +#include "tools/converter/optimizer.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +class InferQuantParamPass : public GraphPass { + public: + InferQuantParamPass() {} + + ~InferQuantParamPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; + + private: + void DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode); +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_INFER_QUANT_PARAM_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc new file mode 100644 index 0000000000..63b9bee8c6 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "tools/common/tensor_util.h" + +namespace mindspore::lite { +STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { + for (auto &tensor : graph->allTensors) { + if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { + continue; + } + if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && + tensor->dataType != TypeId::kNumberTypeUInt8) { + continue; + } + // perlayer + if (tensor->quantParams.size() == 1) { + auto &quantParam = tensor->quantParams.front(); + size_t wShapeSize = GetShapeSize(*(tensor.get())); + void *oriWeightData = tensor->data.data(); + if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { + std::vector qDatas(wShapeSize); + auto weightQauntParam = GetTensorQuantParam(tensor); + if (tensor->dataType == TypeId::kNumberTypeFloat || + tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant + auto *weightData = static_cast(oriWeightData); + for (size_t j = 0; j < wShapeSize; j++) { + qDatas[j] = quant::QuantizeData(weightData[j], weightQauntParam.get()); + } + } else { // tflite awareing quant + auto *weightData = static_cast(oriWeightData); + for (size_t j = 0; j < wShapeSize; j++) { + qDatas[j] = (int32_t)weightData[j] - 128; + } + weightQauntParam->zeroPoint -= 128; + tensor->quantParams.clear(); + tensor->quantParams.emplace_back(weightQauntParam.release()); + } + tensor->dataType = TypeId::kNumberTypeInt8; + ::memcpy(tensor->data.data(), qDatas.data(), wShapeSize); + } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { + // quant bias data + auto bShapeSize = GetShapeSize(*(tensor.get())); + std::unique_ptr qDatas(new (std::nothrow) int32_t[bShapeSize]); + if (qDatas == nullptr) { + MS_LOG(ERROR) << "new qDatas failed"; + return RET_ERROR; + } + void *biasData = tensor->data.data(); + auto *rawDatas = static_cast(biasData); + for (size_t i = 0; i < bShapeSize; ++i) { + qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale); + } + tensor->dataType = TypeId::kNumberTypeInt32; + tensor->data.clear(); + tensor->data.resize(bShapeSize * sizeof(int32_t)); + auto ret = + memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed: " << ret; + return RET_ERROR; + } + } + } else { // pertensor + } + } + return RET_OK; +} + +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.h new file mode 100644 index 0000000000..cc734bf000 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_TENSOR_QUANT_PASS_H +#define LITE_TENSOR_QUANT_PASS_H + +#include +#include "tools/converter/optimizer.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +class TensorQuantPass : public GraphPass { + public: + TensorQuantPass() {} + + ~TensorQuantPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TENSOR_QUANT_PASS_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index 7fbc882b0b..472ed3ae11 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -52,7 +52,7 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - attr->srcT = GetTfliteDataType(in_tensor->type); + attr->srcT = kNumberTypeInt8; attr->dstT = GetTfliteDataType(out_tensor->type); op->primitive->value.value = attr.release(); op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc index 28bc97e8cf..9f76c795c5 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc @@ -52,7 +52,7 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u return RET_NULL_PTR; } attr->srcT = GetTfliteDataType(in_tensor->type); - attr->dstT = GetTfliteDataType(out_tensor->type); + attr->dstT = kNumberTypeInt8; op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; op->primitive->value.value = attr.release(); } else { diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index 208d84232a..d3587181cc 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -35,29 +35,10 @@ using std::string; using std::vector; namespace mindspore::lite::quant { -const std::array AwareQuantizer::propagatedOps = { - {schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape, - schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation, - schema::PrimitiveType_DetectionPostProcess}}; - AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType) : FbQuantizer(graph) {} STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; } -STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) { - MS_ASSERT(subGraph != nullptr); - for (const auto &tensor : subGraph->allTensors) { - if (!tensor->quantParams.empty()) { - continue; - } - std::unique_ptr defaultQuantParam(new QuantParamT()); - tensor->quantParams.emplace_back(std::move(defaultQuantParam)); - } - return RET_OK; -} - -STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { return RET_OK; } - STATUS AwareQuantizer::GenerateQuantParam() { auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); @@ -70,13 +51,13 @@ STATUS AwareQuantizer::GenerateQuantParam() { } auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); if (quantParamCalcer == nullptr) { - MS_LOG(INFO) << "Can not find QuantParamCalcer for " << node->name.c_str() - << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; + MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() + << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; node->quantType = static_cast(QuantType_QUANT_NONE); } else { auto status = quantParamCalcer->Calc(graph, *node); if (status != RET_OK) { - MS_LOG(INFO) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); + MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); node->quantType = schema::QuantType_QUANT_NONE; } else { node->quantType = schema::QuantType_AwareTraining; @@ -87,250 +68,65 @@ STATUS AwareQuantizer::GenerateQuantParam() { } STATUS AwareQuantizer::DoQuantize() { - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - auto &node = *iter; - if (!IsContain(GetInt8OpList(), GetCNodeTType(*node))) { + for (auto &tensor : graph->allTensors) { + if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { continue; } - if (node->quantType != schema::QuantType_AwareTraining) { + if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && + tensor->dataType != TypeId::kNumberTypeUInt8) { continue; } - STATUS status; - if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D || - GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D || - GetCNodeTType(*node) == schema::PrimitiveType_DeConv2D || - GetCNodeTType(*node) == schema::PrimitiveType_FullConnection || - GetCNodeTType(*node) == schema::PrimitiveType_MatMul) { - auto inputIndexes = node->inputIndex; - if (inputIndexes.size() < 2) { - MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count"; - return RET_ERROR; - } - // quant weight - auto &weightTensor = graph->allTensors.at(node->inputIndex.at(1)); - if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) { - status = QuantConvWeight(graph, node.get()); - if (status != RET_OK) { - MS_LOG(ERROR) << "QuantConvWeight failed!"; - return RET_ERROR; - } - } - // quant bias - if (inputIndexes.size() == 3) { - auto &biasTensor = graph->allTensors.at(node->inputIndex.at(2)); - if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) { - status = QuantConvBias(graph, node.get()); - if (status != RET_OK) { - MS_LOG(ERROR) << "QuantConvBias failed!"; - return RET_ERROR; + // perlayer + if (tensor->quantParams.size() == 1) { + auto &quantParam = tensor->quantParams.front(); + size_t wShapeSize = GetShapeSize(*(tensor.get())); + void *oriWeightData = tensor->data.data(); + if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { + vector qDatas(wShapeSize); + auto weightQauntParam = GetTensorQuantParam(tensor); + if (tensor->dataType == TypeId::kNumberTypeFloat || + tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant + auto *weightData = static_cast(oriWeightData); + for (size_t j = 0; j < wShapeSize; j++) { + qDatas[j] = QuantizeData(weightData[j], weightQauntParam.get()); } + } else { // tflite awareing quant + auto *weightData = static_cast(oriWeightData); + for (size_t j = 0; j < wShapeSize; j++) { + qDatas[j] = (int32_t)weightData[j] - 128; + } + weightQauntParam->zeroPoint -= 128; + tensor->quantParams.clear(); + tensor->quantParams.emplace_back(weightQauntParam.release()); } - } - } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { - status = QuantDetectionPostProcessConstTensor(graph, node.get()); - if (status != RET_OK) { - MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; - return RET_ERROR; - } - } else if (GetCNodeTType(*node) == schema::PrimitiveType_Add || - GetCNodeTType(*node) == schema::PrimitiveType_Scale || - GetCNodeTType(*node) == schema::PrimitiveType_Mul) { - status = QuantArithmeticConstTensor(graph, node.get()); - if (status != RET_OK) { - MS_LOG(ERROR) << "QuantArithmeticConstTensor failed!"; - return RET_ERROR; - } - } - const auto nodeType = GetCNodeTType(*node); - auto find = std::find(propagatedOps.begin(), propagatedOps.end(), nodeType); - if (find != propagatedOps.end()) { - auto inputTensor = graph->allTensors.at(node->inputIndex[0]).get(); - auto outputTensor = graph->allTensors.at(node->outputIndex[0]).get(); - MS_ASSERT(inputTensor != nullptr); - MS_ASSERT(outputTensor != nullptr); - outputTensor->dataType = inputTensor->dataType; - } - } - return RET_OK; -} -STATUS AwareQuantizer::QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { - MS_ASSERT(graph != nullptr); - MS_ASSERT(node != nullptr); - for (size_t i = 0; i < node->inputIndex.size(); i++) { - auto inTensorIdx = node->inputIndex.at(i); - MS_ASSERT(graph->allTensors.size() > inTensorIdx); - auto &inTensor = graph->allTensors.at(inTensorIdx); - MS_ASSERT(inTensor != nullptr); - if (!inTensor->data.empty()) { - if (inTensor->dataType == TypeId::kNumberTypeInt8) { - continue; - } - if (inTensor->dataType != TypeId::kNumberTypeFloat32 && inTensor->dataType != TypeId::kNumberTypeFloat && - inTensor->dataType != TypeId::kNumberTypeUInt8) { - MS_LOG(ERROR) << node->name.c_str() << "'s weight data is not float or uint8"; - return RET_ERROR; - } - - auto quantParam = GetTensorQuantParam(inTensor); - MS_ASSERT(quantParam != nullptr); - MS_ASSERT(quantParam->inited); - auto constTensorShapeSize = GetShapeSize(*(inTensor.get())); - vector qDatas(constTensorShapeSize); - void *inData = inTensor->data.data(); - if (inTensor->dataType == TypeId::kNumberTypeFloat || - inTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant - auto *weightData = static_cast(inData); - for (size_t j = 0; j < constTensorShapeSize; j++) { - qDatas[j] = QuantizeData(weightData[j], quantParam.get()); + ::memcpy(tensor->data.data(), qDatas.data(), wShapeSize); + } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { + // quant bias data + auto bShapeSize = GetShapeSize(*(tensor.get())); + std::unique_ptr qDatas(new (std::nothrow) int32_t[bShapeSize]); + if (qDatas == nullptr) { + MS_LOG(ERROR) << "new qDatas failed"; + return RET_ERROR; } - } else { // tflite awareing quant - auto *weightData = static_cast(inData); - for (size_t j = 0; j < constTensorShapeSize; j++) { - qDatas[j] = (int32_t)weightData[j] - 128; + void *biasData = tensor->data.data(); + auto *rawDatas = static_cast(biasData); + for (size_t i = 0; i < bShapeSize; ++i) { + qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale); + } + tensor->dataType = TypeId::kNumberTypeInt32; + tensor->data.clear(); + tensor->data.resize(bShapeSize * sizeof(int32_t)); + auto ret = + memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed: " << ret; + return RET_ERROR; } - quantParam->zeroPoint -= 128; - inTensor->quantParams.clear(); - inTensor->quantParams.emplace_back(quantParam.release()); } - - ::memcpy(inTensor->data.data(), qDatas.data(), constTensorShapeSize); - inTensor->dataType = TypeId::kNumberTypeInt8; - } - } - return RET_OK; -} - -STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { - MS_ASSERT(subGraph != nullptr); - MS_ASSERT(node != nullptr); - auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]); - MS_ASSERT(constTensor != nullptr); - const auto *constData = reinterpret_cast(constTensor->data.data()); - - if (!constTensor->data.empty() && - (constTensor->dataType == TypeId::kNumberTypeFloat || constTensor->dataType == TypeId::kNumberTypeFloat32)) { - size_t constTensorShapeSize = GetShapeSize(*constTensor); - std::unique_ptr quantParam = GetTensorQuantParam(constTensor); - if (quantParam == nullptr) { - MS_LOG(ERROR) << "new QuantParamT failed"; - return RET_NULL_PTR; - } - vector qDatas(constTensorShapeSize); - for (size_t j = 0; j < constTensorShapeSize; j++) { - float rawData = constData[j]; - qDatas[j] = QuantizeData(rawData, quantParam.get()); - } - ::memcpy(constTensor->data.data(), qDatas.data(), constTensorShapeSize); - constTensor->dataType = TypeId::kNumberTypeInt8; - } - return RET_OK; -} - -STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) { - MS_ASSERT(graph != nullptr); - MS_ASSERT(node != nullptr); - auto inputIndexes = node->inputIndex; - MS_ASSERT(inputIndexes.size() >= 3); - MS_ASSERT(graph->allTensors.size() > inputIndexes.at(0)); - MS_ASSERT(graph->allTensors.size() > inputIndexes.at(1)); - MS_ASSERT(graph->allTensors.size() > inputIndexes.at(2)); - auto &biasTensor = graph->allTensors.at(inputIndexes.at(2)); - MS_ASSERT(biasTensor != nullptr); - if (biasTensor->dataType == TypeId::kNumberTypeInt32) { - return RET_OK; - } - if (biasTensor->dataType != TypeId::kNumberTypeFloat && biasTensor->dataType != TypeId::kNumberTypeFloat32) { - MS_LOG(ERROR) << "conv " << node->name << "'s bias data is not float"; - return RET_ERROR; - } - auto &inputTensor = graph->allTensors.at(inputIndexes.at(0)); - auto &weightTensor = graph->allTensors.at(inputIndexes.at(1)); - - MS_ASSERT(inputTensor != nullptr); - MS_ASSERT(weightTensor != nullptr); - auto inputScale = inputTensor->quantParams.front()->scale; - auto weightScale = weightTensor->quantParams.front()->scale; - auto scale = inputScale * weightScale; - // set bias quant param - std::unique_ptr biasQuantParam = GetTensorQuantParam(biasTensor); - if (biasQuantParam == nullptr) { - MS_LOG(ERROR) << "new QuantParamT failed"; - return RET_ERROR; - } - biasQuantParam->inited = true; - biasQuantParam->scale = scale; - biasQuantParam->zeroPoint = 0; - biasQuantParam->numBits = 8; - biasQuantParam->narrowRange = false; - biasQuantParam->min = 0.0; - biasQuantParam->max = 0.0; - - // quant bias data - auto bShapeSize = GetShapeSize(*(biasTensor.get())); - std::unique_ptr qDatas(new (std::nothrow) int32_t[bShapeSize]); - if (qDatas == nullptr) { - MS_LOG(ERROR) << "new qDatas failed"; - return RET_ERROR; - } - void *biasData = biasTensor->data.data(); - auto *rawDatas = static_cast(biasData); - for (size_t i = 0; i < bShapeSize; ++i) { - qDatas[i] = (int32_t)std::round(rawDatas[i] / scale); - } - biasTensor->dataType = TypeId::kNumberTypeInt32; - biasTensor->data.clear(); - biasTensor->data.resize(bShapeSize * sizeof(int32_t)); - auto ret = - memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s failed: " << ret; - return RET_ERROR; - } - return RET_OK; -} - -STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { - MS_ASSERT(subGraph != nullptr); - MS_ASSERT(node != nullptr); - MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size()); - auto inputIndexes = node->inputIndex; - MS_ASSERT(inputIndexes.size() >= 2); - MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1)); - auto &weightTensor = subGraph->allTensors.at(inputIndexes.at(1)); - if (weightTensor->dataType == TypeId::kNumberTypeInt8) { - return RET_OK; - } - if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat && - weightTensor->dataType != TypeId::kNumberTypeUInt8) { - MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; - return RET_ERROR; - } - size_t wShapeSize = GetShapeSize(*(weightTensor.get())); - void *oriWeightData = weightTensor->data.data(); - MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr); - vector qDatas(wShapeSize); - // todo support perchannel - auto weightQauntParam = GetTensorQuantParam(weightTensor); - if (weightTensor->dataType == TypeId::kNumberTypeFloat || - weightTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant - auto *weightData = static_cast(oriWeightData); - for (size_t j = 0; j < wShapeSize; j++) { - qDatas[j] = QuantizeData(weightData[j], weightQauntParam.get()); - } - } else { // tflite awareing quant - auto *weightData = static_cast(oriWeightData); - for (size_t j = 0; j < wShapeSize; j++) { - qDatas[j] = (int32_t)weightData[j] - 128; + } else { // pertensor } - weightQauntParam->zeroPoint -= 128; - weightTensor->quantParams.clear(); - weightTensor->quantParams.emplace_back(weightQauntParam.release()); } - - weightTensor->data.resize(wShapeSize * sizeof(uint8_t)); - ::memcpy(weightTensor->data.data(), qDatas.data(), wShapeSize); - weightTensor->dataType = TypeId::kNumberTypeInt8; return RET_OK; } STATUS AwareQuantizer::DetermineNodeQuantType() { diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.h b/mindspore/lite/tools/converter/quantizer/aware_quantizer.h index 29b9eccaa7..e28a220e4c 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.h @@ -39,24 +39,6 @@ class AwareQuantizer : public FbQuantizer { STATUS DetermineNodeQuantType() override; STATUS DoQuantize() override; // override; - - private: - // RemoveFakeQuant - STATUS SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node); - - STATUS GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph); - - STATUS QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node); - - STATUS QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node); - - STATUS QuantConvBias(const schema::MetaGraphT *graph, schema::CNodeT *node); - - STATUS QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node); - - float inputScale = 0.0f; - - static const std::array propagatedOps; }; } // namespace mindspore::lite::quant #endif diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc index 79afd5f5d2..d5435db54d 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc @@ -26,6 +26,9 @@ #include "tools/converter/quantizer/quantize_util.h" namespace mindspore::lite { +static constexpr size_t BIAS_SIZE = 3; +static constexpr size_t BIAS_ADD_SIZE = 2; + STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, QuantParamT *quantParam) { MS_ASSERT(quantParam != nullptr); // int32 weight no need to quant @@ -126,6 +129,36 @@ int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { return RET_OK; } +int ConvCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { + auto status = CommonCalcer::Calc(subGraph, node); + if (status != RET_OK) { + MS_LOG(WARNING) << "Call CommonCalcer::Calc failed: " << status; + return status; + } + if (node.inputIndex.size() == BIAS_SIZE) { + auto &biasTensor = subGraph->allTensors.at(node.inputIndex.at(BIAS_SIZE - 1)); + for (auto &quantParam : biasTensor->quantParams) { + quantParam->dstDtype = TypeId::kNumberTypeInt32; + } + } + return RET_OK; +} + +int BiasAddCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { + auto status = CommonCalcer::Calc(subGraph, node); + if (status != RET_OK) { + MS_LOG(WARNING) << "Call CommonCalcer::Calc failed: " << status; + return status; + } + if (node.inputIndex.size() == BIAS_ADD_SIZE) { + auto &biasTensor = subGraph->allTensors.at(node.inputIndex.at(BIAS_ADD_SIZE - 1)); + for (auto &quantParam : biasTensor->quantParams) { + quantParam->dstDtype = TypeId::kNumberTypeInt32; + } + } + return RET_OK; +} + int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { auto status = QuantParamCalcer::Calc(graph, node); if (status != RET_OK) { @@ -474,10 +507,10 @@ QuantParamCalcRegister::QuantParamCalcRegister() { _registerMap[schema::PrimitiveType_Activation] = std::make_shared(); _registerMap[schema::PrimitiveType_Add] = std::make_shared(); _registerMap[schema::PrimitiveType_Mul] = commonCalcer; - _registerMap[schema::PrimitiveType_Scale] = commonCalcer; - _registerMap[schema::PrimitiveType_Conv2D] = commonCalcer; - _registerMap[schema::PrimitiveType_DeConv2D] = commonCalcer; - _registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer; + _registerMap[schema::PrimitiveType_Scale] = std::make_shared(); + _registerMap[schema::PrimitiveType_Conv2D] = std::make_shared(); + _registerMap[schema::PrimitiveType_DeConv2D] = std::make_shared(); + _registerMap[schema::PrimitiveType_DepthwiseConv2D] = std::make_shared(); _registerMap[schema::PrimitiveType_Pooling] = linearCalcer; _registerMap[schema::PrimitiveType_Resize] = linearCalcer; _registerMap[schema::PrimitiveType_Reshape] = linearCalcer; @@ -487,11 +520,11 @@ QuantParamCalcRegister::QuantParamCalcRegister() { _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer; _registerMap[schema::PrimitiveType_RealDiv] = std::make_shared(); _registerMap[schema::PrimitiveType_Reduce] = commonCalcer; - _registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer; + _registerMap[schema::PrimitiveType_BiasAdd] = std::make_shared(); _registerMap[schema::PrimitiveType_Mean] = linearCalcer; _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; - _registerMap[schema::PrimitiveType_MatMul] = commonCalcer; - _registerMap[schema::PrimitiveType_FullConnection] = commonCalcer; + _registerMap[schema::PrimitiveType_MatMul] = std::make_shared(); + _registerMap[schema::PrimitiveType_FullConnection] = std::make_shared(); _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer; _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer; // detection_postprocess op's quant param will not infer only fetch from preNode or postNode diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.h b/mindspore/lite/tools/converter/quantizer/calc_quant_param.h index 0f8f6331e7..e722c3b3fc 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.h +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.h @@ -46,6 +46,20 @@ class CommonCalcer : public QuantParamCalcer { int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; }; +class ConvCalcer : public CommonCalcer { + public: + ConvCalcer() = default; + ~ConvCalcer() override = default; + int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; +}; + +class BiasAddCalcer : public CommonCalcer { + public: + BiasAddCalcer() = default; + ~BiasAddCalcer() override = default; + int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; +}; + class LinearCalcer : public QuantParamCalcer { public: LinearCalcer() = default; diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 5acf6f9f28..d9d5402284 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -564,11 +564,8 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in } } -STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min, - std::shared_ptr lite_primitive) { - if (!lite_primitive->GetInputQuantParams().empty()) { - MS_LOG(DEBUG) << "input quant params not empty"; // multi-input op: like concat - } +STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, + std::shared_ptr lite_primitive, const size_t &index) { schema::QuantParamT quant_param; quant_param.scale = scale; quant_param.zeroPoint = zeropoint; @@ -577,15 +574,12 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M quant_param.numBits = bit_num; quant_param.narrowRange = false; std::vector quant_params = {quant_param}; - lite_primitive->AddInputQuantParam(quant_params); + lite_primitive->SetInputQuantParam(index, quant_params); return RET_OK; } STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min, std::shared_ptr lite_primitive) { - if (!lite_primitive->GetOutputQuantParams().empty()) { - MS_LOG(DEBUG) << "output quant params not empty"; // multi-output op: like split - } schema::QuantParamT quant_param; quant_param.scale = scale; quant_param.zeroPoint = zeropoint; @@ -593,8 +587,9 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct quant_param.min = max_min->min; quant_param.numBits = bit_num; quant_param.narrowRange = false; + quant_param.inited = true; std::vector quant_params = {quant_param}; - lite_primitive->AddOutputQuantParam(quant_params); + lite_primitive->SetOutputQuantParam(0, quant_params); return RET_OK; } @@ -647,7 +642,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr(bias_default_param); auto active_weight_quant_params = primitive_c->GetInputQuantParams(); - if (active_weight_quant_params.size() != 2) { + if (active_weight_quant_params.size() != 3) { MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size(); return RET_ERROR; } @@ -714,7 +709,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptrSetInputQuantParam(active_weight_quant_params); + primitive_c->SetInputQuantParams(active_weight_quant_params); bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit; quant_params[i].scale = bias_scale_tmp; MS_LOG(DEBUG) << "new filter scale: " << filter_scale; @@ -726,7 +721,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptrAddInputQuantParam(quant_params); + primitive_c->SetInputQuantParam(2, quant_params); auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s failed."; @@ -834,22 +829,21 @@ STATUS PostTrainingQuantizer::QuantNode() { << " PrimitiveC is null"; continue; } - if (!input_cnode_primitive_c->GetOutputQuantParams().empty()) { - for (auto &quant_param : input_cnode_primitive_c->GetOutputQuantParams()) { - primitive_c->AddInputQuantParam(quant_param); - } + if (input_cnode_primitive_c->IsOutputQuantParamsInited()) { + auto quant_param = input_cnode_primitive_c->GetOutputQuantParams().front(); + primitive_c->SetInputQuantParam(i - 1, quant_param); } else { // do input quant double scale = input_scale[cnode]; int32_t zp = input_zero_point[cnode]; - DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c); + DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c, i - 1); } } } else { // do input quant double scale = input_scale[cnode]; int32_t convInputzeropoint = input_zero_point[cnode]; - DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c); + DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c, 0); // do weight quant auto weight = cnode->input(2); bool perchannel = per_channel_; diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index 23c4df7696..1e61ac09e3 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -106,7 +106,8 @@ class PostTrainingQuantizer : public Quantizer { STATUS QuantNode(); - STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); + STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, + std::shared_ptr lite_primitive, const size_t &index); STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitive_c, bool perchannel); diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 922aa10879..a9e0974c7f 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -246,7 +246,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; return RET_ERROR; } - quantParam->inited = true; + quantParam->inited = false; quantParam->min = mMin; quantParam->max = mMax; quantParam->scale = 0.0f; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 2492c7172f..ed1fa0cb85 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -39,6 +39,7 @@ namespace mindspore { namespace lite { namespace quant { static constexpr size_t UINT8_QUANTIZATION = 8; +static constexpr size_t WEIGHT_INDEX = 1; /** * 1. when op's weight size > mWeightSize just skip @@ -225,16 +226,16 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti } variance_dequant = std::sqrt(variance_dequant / one_filter_size); variance_raw = std::sqrt(variance_raw / one_filter_size); - quant_param.var_corr = 1; + quant_param.varCorr = 1; if (variance_raw != 0 && variance_dequant != 0) { auto temp_var_corr = variance_raw / variance_dequant; if (temp_var_corr > 0 && temp_var_corr < 10) { - quant_param.var_corr = temp_var_corr; + quant_param.varCorr = temp_var_corr; } else { MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr; } } - quant_param.mean_corr = average_raw - average_dequant * quant_param.var_corr; + quant_param.meanCorr = average_raw - average_dequant * quant_param.varCorr; } quant_params.emplace_back(quant_param); } @@ -282,7 +283,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti MS_LOG(ERROR) << "quant_params empty"; return RET_ERROR; } - primitive_c->AddInputQuantParam(quant_params); + primitive_c->SetInputQuantParam(WEIGHT_INDEX, quant_params); return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index a649ae263a..ef83e9bc66 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -101,8 +101,6 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { return RET_ERROR; } - std::vector quant_params; - primitive_c->AddInputQuantParam(quant_params); auto status = RET_ERROR; if (type_id == kNumberTypeInt8) { status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); @@ -143,9 +141,9 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { ParameterPtr param_node = nullptr; for (size_t i = 1; i < node->size(); i++) { auto inputNode = node->input(i); - if (inputNode->isa() == true) { + if (inputNode->isa()) { param_node = inputNode->cast(); - if ((param_node != nullptr) && (param_node->has_default() == true)) { + if ((param_node != nullptr) && param_node->has_default()) { param_value = std::static_pointer_cast(param_node->default_param()); if ((param_value == nullptr) || (param_value->tensor_size() == 0) || (param_value->tensor_addr() == nullptr) || @@ -169,8 +167,6 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { return RET_ERROR; } - std::vector quant_params; - primitive_c->AddInputQuantParam(quant_params); auto status = RET_ERROR; if (type_id == kNumberTypeInt8) { status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 0c81e0c780..9f0cad9dfe 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -619,7 +619,7 @@ STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int3 } return RET_OK; } -template +template static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { MS_ASSERT(tensor != nullptr); @@ -628,7 +628,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType MS_LOG(ERROR) << "Dim size invalid"; return RET_ERROR; } - std::unique_ptr buf(new(std::nothrow) T[count]); + std::unique_ptr buf(new (std::nothrow) T[count]); if (buf == nullptr) { MS_LOG(ERROR) << "new buf failed"; return RET_ERROR; @@ -653,18 +653,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); if (type == kCHWK2HWCK) { p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kCHWK2KHWC) { p2Buff = - buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } *p2Buff = *p1Buff; } } } } - } - break; + } break; case kKHWC2HWCK: { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { @@ -677,8 +676,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType } } } - } - break; + } break; case kKCHW2HWCK: case kKCHW2CKHW: case kKCHW2KHWC: @@ -690,24 +688,23 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); if (type == kKCHW2HWCK) { p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kKCHW2KHWC) { p2Buff = - buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } else if (type == kKCHW2CKHW) { p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = - buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); } *p2Buff = *p1Buff; } } } } - } - break; + } break; case kCKHW2HWCK: case kCKHW2KHWC: case kCKHW2HWKC: { @@ -718,21 +715,20 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); if (type == kCKHW2HWCK) { p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kCKHW2KHWC) { p2Buff = - buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } else { p2Buff = - buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); } *p2Buff = *p1Buff; } } } } - } - break; + } break; case kHWCK2KCHW: case kHWCK2CKHW: { for (int h = 0; h < filterH; ++h) { @@ -742,18 +738,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); if (type == kHWCK2KCHW) { p2Buff = - buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } *p2Buff = *p1Buff; } } } } - } - break; + } break; case kHWKC2KCHW: case kHWKC2CKHW: { for (int h = 0; h < filterH; ++h) { @@ -763,18 +758,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); if (type == kHWKC2KCHW) { p2Buff = - buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } *p2Buff = *p1Buff; } } } } - } - break; + } break; case kNHWC2HWCK: case kNHWC2KCHW: case kNHWC2CKHW: { @@ -785,21 +779,20 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); if (type == kNHWC2HWCK) { p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kNHWC2CKHW) { p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = - buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } *p2Buff = *p1Buff; } } } } - } - break; + } break; case kKHWC2CHWK: { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { @@ -812,8 +805,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType } } } - } - break; + } break; default: { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; @@ -828,7 +820,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType return RET_OK; } -template +template static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) { MS_ASSERT(tensor != nullptr); auto oriDims = tensor->tensor_shape(); @@ -882,6 +874,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kKCHW2KHWC); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKCHW2KHWC); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kKCHW2KHWC); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; @@ -894,6 +888,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kCKHW2KHWC); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCKHW2KHWC); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kCKHW2KHWC); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; @@ -906,18 +902,20 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kCHWK2KHWC); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCHWK2KHWC); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kCHWK2KHWC); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; - case schema::Format::Format_KHWC:return RET_OK; - default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " - << EnumNameFormat(dst_format); + case schema::Format::Format_KHWC: + return RET_OK; + default: + MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; } - } - break; + } break; case schema::Format::Format_HWCK: { switch (src_format) { case schema::Format::Format_KCHW: @@ -927,6 +925,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kKCHW2HWCK); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKCHW2HWCK); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kKCHW2HWCK); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; @@ -939,6 +939,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kKHWC2HWCK); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKHWC2HWCK); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kKHWC2HWCK); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; @@ -951,6 +953,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kCKHW2HWCK); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCKHW2HWCK); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kCKHW2HWCK); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; @@ -963,21 +967,24 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kCHWK2HWCK); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCHWK2HWCK); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kCHWK2HWCK); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return lite::RET_ERROR; } break; - case schema::Format::Format_HWCK:return RET_OK; - default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " - << EnumNameFormat(dst_format); + case schema::Format::Format_HWCK: + return RET_OK; + default: + MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; } - } - break; + } break; case schema::Format::Format_KCHW: { switch (src_format) { - case schema::Format::Format_KCHW:return RET_OK; + case schema::Format::Format_KCHW: + return RET_OK; case schema::Format::Format_HWCK: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kHWCK2KCHW); @@ -985,6 +992,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kHWCK2KCHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWCK2KCHW); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kHWCK2KCHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; @@ -997,6 +1006,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kHWKC2KCHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWKC2KCHW); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kHWCK2KCHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; @@ -1009,6 +1020,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kKHWC2KCHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKHWC2KCHW); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kKHWC2KCHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; @@ -1021,6 +1034,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kCKHW2KCHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCKHW2KCHW); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kCKHW2KCHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; @@ -1033,17 +1048,18 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kCHWK2KCHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCHWK2KCHW); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kCKHW2KCHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; - default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " - << EnumNameFormat(dst_format); + default: + MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; } - } - break; + } break; case schema::Format::Format_CKHW: { switch (src_format) { case schema::Format::Format_HWCK: @@ -1053,6 +1069,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kHWCK2CKHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWCK2CKHW); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kHWCK2CKHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; @@ -1065,6 +1083,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kHWKC2CKHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWKC2CKHW); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kHWKC2CKHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; @@ -1077,20 +1097,22 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for status = TransFilterFormat(tensor, kKCHW2CKHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKCHW2CKHW); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kKCHW2CKHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; - case schema::Format::Format_CKHW:return RET_OK; - default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " - << EnumNameFormat(dst_format); + case schema::Format::Format_CKHW: + return RET_OK; + default: + MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; } - } - break; - default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " - << EnumNameFormat(dst_format); + } break; + default: + MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; } if (status != RET_OK) { diff --git a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc index 850d2fc6cd..8771fea49d 100644 --- a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc @@ -155,8 +155,8 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons rmatmul_quant_params.pop_back(); // no bias quantParams rmatmul_quant_params.emplace_back(jointed_quant_params); - matmul_cvalue->SetInputQuantParam(rmatmul_quant_params); - matmul_cvalue->SetOutputQuantParam(fc_prim->GetOutputQuantParams()); + matmul_cvalue->SetInputQuantParams(rmatmul_quant_params); + matmul_cvalue->SetOutputQuantParams(fc_prim->GetOutputQuantParams()); auto matmul_value_node = NewValueNode(std::shared_ptr(matmul_cvalue)); std::vector matmul_inputs = {matmul_value_node, left_matmul_input};