diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 7ba679991b..2dfc3f037f 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -274,8 +274,8 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me MS_LOG(ERROR) << "quant param is invalid."; return RET_ERROR; } - input_quant_params = quant_param_holder->input_quant_params(); - output_quant_params = quant_param_holder->output_quant_params(); + input_quant_params = quant_param_holder->get_input_quant_params(); + output_quant_params = quant_param_holder->get_output_quant_params(); dst_node->quantType = quant_param_holder->quant_type(); } // add quant param diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index fcff9784cd..40a97eb5a4 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -311,19 +311,8 @@ STATUS CaffeModelParser::ConvertLayerQuantParams(const caffe::LayerParameter &la MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; return RET_NULL_PTR; } - auto quant_params_holder = std::make_shared(); - for (auto input_idx : layer.bottom()) { - std::vector notinited_quant_params(1); - quant_params_holder->AddInputQuantParam(notinited_quant_params); - } - for (auto input_idx : weight.blobs()) { - std::vector notinited_quant_params(1); - quant_params_holder->AddInputQuantParam(notinited_quant_params); - } - for (auto output_idx : layer.top()) { - std::vector notinited_quant_params(1); - quant_params_holder->AddOutputQuantParam(notinited_quant_params); - } + auto quant_params_holder = + std::make_shared(layer.bottom_size() + weight.blobs_size(), layer.top_size()); primitive_c->AddAttr("quant_params", quant_params_holder); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 8566481d86..0d35d215f0 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -535,7 +535,7 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o return RET_ERROR; } // set input tensors - auto quant_params_holder = std::make_shared(); + auto quant_params_holder = std::make_shared(onnx_node.input_size(), onnx_node.output_size()); for (int i = 0; i < onnx_node.input_size(); ++i) { const auto &input_name = onnx_node.input(i); std::vector quant_params; @@ -544,7 +544,7 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o MS_LOG(ERROR) << "set input tensor quant param failed."; return status; } - quant_params_holder->AddInputQuantParam(quant_params); + quant_params_holder->set_input_quant_param(i, quant_params); } // set out tensors for (int i = 0; i < onnx_node.output_size(); ++i) { @@ -555,7 +555,7 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o MS_LOG(ERROR) << "set output tensor quant param failed."; return status; } - quant_params_holder->AddOutputQuantParam(quant_params); + quant_params_holder->set_output_quant_param(i, quant_params); } primitive_c->AddAttr("quant_params", quant_params_holder); return RET_OK; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index a2949cd8ab..993fb06919 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -953,9 +953,25 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, if (status != RET_OK) { MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; } + + status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed."; + } return status; } +STATUS TFModelParser::ConvertQuantParams(const size_t &input_size, const size_t &output_size, + ops::PrimitiveC *primitive_c) { + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; + return RET_NULL_PTR; + } + auto quant_params_holder = std::make_shared(input_size, output_size); + primitive_c->AddAttr("quant_params", quant_params_holder); + return RET_OK; +} + STATUS TFModelParser::ConvertRootGraphOutputs() { // because output of intermediate node in anf graph may also be output tensors, we search output tensors in // tf_root_graph_nodes_ but not anf_root_node_map_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index 173158c4e0..512407a2e5 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -29,6 +29,7 @@ #include "securec/include/securec.h" #include "tools/common/tensor_util.h" #include "tools/converter/model_parser.h" +#include "ops/primitive_c.h" namespace mindspore { namespace lite { @@ -75,6 +76,8 @@ class TFModelParser : public ModelParser { STATUS ControlFlowNodePostProcess(const std::map &first_func_map, const std::map &second_func_map); + STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, ops::PrimitiveC *primitive_c); + static STATUS MakeAnfGraphOutputs(std::vector *output_nodes, const FuncGraphPtr &anf_graph); STATUS RecordNullInput(const CNodePtr &node, const std::vector &input_name_not_found); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 9ee55e6824..b456bfcbd6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -243,7 +243,8 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, ops: round_type = 2; } const auto &tflite_subgraph = tflite_model_->subgraphs.front(); - auto quant_params_holder = std::make_shared(); + auto quant_params_holder = std::make_shared(op->inputs.size(), op->outputs.size()); + size_t idx = 0; for (auto input_idx : op->inputs) { if (input_idx < 0) { input_idx += tflite_subgraph->tensors.size(); @@ -255,8 +256,10 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, ops: MS_LOG(ERROR) << "set input tensor quant param failed."; return status; } - quant_params_holder->AddInputQuantParam(quant_params); + quant_params_holder->set_input_quant_param(idx, quant_params); + idx++; } + idx = 0; for (auto output_idx : op->outputs) { if (output_idx < 0) { output_idx += tflite_subgraph->tensors.size(); @@ -268,7 +271,8 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, ops: MS_LOG(ERROR) << "set output tensor quant param failed."; return status; } - quant_params_holder->AddOutputQuantParam(quant_params); + quant_params_holder->set_output_quant_param(idx, quant_params); + idx++; } primitive_c->AddAttr("quant_params", quant_params_holder); return RET_OK; diff --git a/mindspore/lite/tools/converter/quant_param_holder.h b/mindspore/lite/tools/converter/quant_param_holder.h index c368d9eed3..fc48985214 100644 --- a/mindspore/lite/tools/converter/quant_param_holder.h +++ b/mindspore/lite/tools/converter/quant_param_holder.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H +#include #include #include #include "ir/anf.h" @@ -27,7 +28,24 @@ namespace lite { using QuantParamsVector = std::vector>; class QuantParamHolder : public Value { public: - QuantParamHolder() = default; + QuantParamHolder(size_t input_size, size_t output_size) { + input_quant_params_.resize(input_size); + output_quant_params_.resize(output_size); + for (size_t i = 0; i < input_size; i++) { + std::vector notinited_quant_params(1); + set_input_quant_param(i, notinited_quant_params); + } + + for (size_t i = 0; i < output_size; i++) { + std::vector notinited_quant_params(1); + set_output_quant_param(i, notinited_quant_params); + } + } + + QuantParamHolder(const QuantParamsVector &input_quant_params, const QuantParamsVector &output_quant_params) { + input_quant_params_ = input_quant_params; + output_quant_params_ = output_quant_params; + } ~QuantParamHolder() override = default; @@ -36,17 +54,17 @@ class QuantParamHolder : public Value { bool operator==(const Value &rhs) const override { // unused if (rhs.isa()) { auto other_holder = dynamic_cast(rhs); - auto input_quant_params_rhs = other_holder.input_quant_params(); - auto output_quant_params_rhs = other_holder.output_quant_params(); - if (input_quant_params_rhs.size() != this->input_quant_param_.size() || - output_quant_params_rhs.size() != this->output_quant_param_.size()) { + auto input_quant_params_rhs = other_holder.get_input_quant_params(); + auto output_quant_params_rhs = other_holder.get_output_quant_params(); + if (input_quant_params_rhs.size() != this->input_quant_params_.size() || + output_quant_params_rhs.size() != this->output_quant_params_.size()) { return false; } for (size_t i = 0; i < input_quant_params_rhs.size(); ++i) { - if (input_quant_params_rhs.at(i).size() != this->input_quant_param_.at(i).size()) { + if (input_quant_params_rhs.at(i).size() != this->input_quant_params_.at(i).size()) { return false; } - auto *params = reinterpret_cast(this->input_quant_param_.at(i).data()); + auto *params = reinterpret_cast(this->input_quant_params_.at(i).data()); auto *params_rhs = reinterpret_cast(input_quant_params_rhs.at(i).data()); for (size_t j = 0; j < input_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) { if (params[j] != params_rhs[j]) { @@ -55,10 +73,10 @@ class QuantParamHolder : public Value { } } for (size_t i = 0; i < output_quant_params_rhs.size(); ++i) { - if (output_quant_params_rhs.at(i).size() != this->output_quant_param_.at(i).size()) { + if (output_quant_params_rhs.at(i).size() != this->output_quant_params_.at(i).size()) { return false; } - auto *params = reinterpret_cast(this->output_quant_param_.at(i).data()); + auto *params = reinterpret_cast(this->output_quant_params_.at(i).data()); auto *params_rhs = reinterpret_cast(output_quant_params_rhs.at(i).data()); for (size_t j = 0; j < output_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) { if (params[j] != params_rhs[j]) { @@ -76,58 +94,53 @@ class QuantParamHolder : public Value { schema::QuantType quant_type() const { return quant_type_; } - void set_input_quant_params(const QuantParamsVector &input_quant_param) { - this->input_quant_param_ = input_quant_param; - } - void set_input_quant_param(const size_t &index, const std::vector &input_quant_param) { - if (index >= this->input_quant_param_.size()) { + if (index >= this->input_quant_params_.size()) { std::vector place_quant(1); - this->input_quant_param_.insert(this->input_quant_param_.end(), index + 1 - input_quant_param_.size(), - place_quant); + this->input_quant_params_.insert(this->input_quant_params_.end(), index + 1 - input_quant_params_.size(), + place_quant); } - this->input_quant_param_.at(index) = input_quant_param; - } - - void set_output_quant_params(const std::vector> &output_quant_param) { - this->output_quant_param_ = output_quant_param; + this->input_quant_params_.at(index) = input_quant_param; } void set_output_quant_param(const size_t &index, const std::vector &output_quant_param) { - if (index >= this->output_quant_param_.size()) { + if (index >= this->output_quant_params_.size()) { std::vector place_quant(1); - this->output_quant_param_.insert(this->output_quant_param_.end(), index + 1 - output_quant_param_.size(), - place_quant); + this->output_quant_params_.insert(this->output_quant_params_.end(), index + 1 - output_quant_params_.size(), + place_quant); } - this->output_quant_param_.at(index) = output_quant_param; + this->output_quant_params_.at(index) = output_quant_param; } void set_enable_huffman_code(bool enable_huffman_code) { enable_huffman_code_ = enable_huffman_code; } bool enable_huffman_code() const { return enable_huffman_code_; } + // deprecated void AddInputQuantParam(const std::vector &quant_param) { - this->input_quant_param_.emplace_back(quant_param); + this->input_quant_params_.emplace_back(quant_param); } - std::vector> input_quant_params() const { return this->input_quant_param_; } - + // deprecated void AddOutputQuantParam(const std::vector &quant_param) { - this->output_quant_param_.emplace_back(quant_param); + this->output_quant_params_.emplace_back(quant_param); } - std::vector> output_quant_params() const { return this->output_quant_param_; } + std::vector> get_input_quant_params() const { return this->input_quant_params_; } + + std::vector> get_output_quant_params() const { return this->output_quant_params_; } + // deprecated void ClearInputOutputQuantParam() { - input_quant_param_.clear(); - output_quant_param_.clear(); + input_quant_params_.clear(); + output_quant_params_.clear(); } bool IsInputQuantParamsInited() { - if (this->input_quant_param_.empty()) { + if (this->input_quant_params_.empty()) { return false; } - for (auto &quant_param : this->input_quant_param_) { + for (auto &quant_param : this->input_quant_params_) { if (!quant_param.front().inited) { return false; } @@ -136,10 +149,10 @@ class QuantParamHolder : public Value { } bool IsOutputQuantParamsInited() { - if (this->output_quant_param_.empty()) { + if (this->output_quant_params_.empty()) { return false; } - for (auto &quant_param : this->output_quant_param_) { + for (auto &quant_param : this->output_quant_params_) { if (!quant_param.front().inited) { return false; } @@ -149,8 +162,8 @@ class QuantParamHolder : public Value { private: schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; - QuantParamsVector input_quant_param_; - QuantParamsVector output_quant_param_; + QuantParamsVector input_quant_params_; + QuantParamsVector output_quant_params_; bool enable_huffman_code_ = false; }; using QuantParamHolderPtr = std::shared_ptr; diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 38be7ba2ac..e177e42fbb 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -59,7 +59,7 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector &bias_scales, cons MS_ASSERT(quant_params != nullptr && quant_datas != nullptr); double bias_scale_tmp; const constexpr int32_t quanted_bias_abs_limit = 0.5 * INT32_MAX; - auto active_weight_quant_params = quant_param_holder->input_quant_params(); + auto weight_quant_params = quant_param_holder->get_input_quant_params().at(1); auto shape_size = quant_datas->size(); if (bias_scales.size() == shape_size) { for (size_t i = 0; i < shape_size; i++) { @@ -69,14 +69,14 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector &bias_scales, cons return RET_ERROR; } if (std::abs(raw_datas[i] / bias_scale_tmp) >= quanted_bias_abs_limit) { - MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << active_weight_quant_params[1][i].scale + MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << weight_quant_params[i].scale << " is too small, need to update"; // update filter scale and zp double activate_scale = input_scales[0]; double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit); - active_weight_quant_params[1][i].scale = filter_scale; - active_weight_quant_params[1][i].zeroPoint = 0; - quant_param_holder->set_input_quant_params(active_weight_quant_params); + weight_quant_params[i].scale = filter_scale; + weight_quant_params[i].zeroPoint = 0; + quant_param_holder->set_input_quant_param(1, weight_quant_params); bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit; quant_params->at(i).scale = bias_scale_tmp; MS_LOG(DEBUG) << "new filter scale: " << filter_scale; @@ -99,13 +99,13 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector &bias_scales, cons return RET_ERROR; } if (std::abs(max_raw_data / bias_scale_tmp) >= quanted_bias_abs_limit) { - MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << active_weight_quant_params[1][0].scale + MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << weight_quant_params[0].scale << " is too small, need to update"; double activate_scale = input_scales[0]; double filter_scale = std::abs(max_raw_data) / (activate_scale * quanted_bias_abs_limit); - active_weight_quant_params[1][0].scale = filter_scale; - active_weight_quant_params[1][0].zeroPoint = 0; - quant_param_holder->set_input_quant_params(active_weight_quant_params); + weight_quant_params[0].scale = filter_scale; + weight_quant_params[0].zeroPoint = 0; + quant_param_holder->set_input_quant_param(1, weight_quant_params); bias_scale_tmp = max_raw_data / quanted_bias_abs_limit; quant_params->front().scale = bias_scale_tmp; MS_LOG(DEBUG) << "new filter scale: " << filter_scale; @@ -117,7 +117,7 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector &bias_scales, cons return RET_OK; } MS_LOG(ERROR) << "unexpected input_scales size: " << input_scales.size() - << " weight_scales size: " << active_weight_quant_params[1].size(); + << " weight_scales size: " << weight_quant_params.size(); return RET_ERROR; } } // namespace @@ -620,7 +620,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const Primitiv MS_ASSERT(bias_parameter != nullptr); auto quant_param_holder = GetCNodeQuantHolder(primitive); MS_ASSERT(quant_param_holder != nullptr); - auto active_weight_quant_params = quant_param_holder->input_quant_params(); + auto active_weight_quant_params = quant_param_holder->get_input_quant_params(); if (active_weight_quant_params.size() != 2) { MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size(); return RET_ERROR; @@ -731,7 +731,7 @@ STATUS PostTrainingQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) { auto input_primitive_quant_holder = GetCNodeQuantHolder(input_cnode_primitive); MS_ASSERT(input_primitive_quant_holder != nullptr); if (input_primitive_quant_holder->IsOutputQuantParamsInited()) { - auto quant_param = input_primitive_quant_holder->output_quant_params().front(); + auto quant_param = input_primitive_quant_holder->get_output_quant_params().front(); primitive_quant_holder->AddInputQuantParam(quant_param); } else { // do input quant @@ -820,14 +820,14 @@ STATUS PostTrainingQuantizer::QuantNode() { } auto input_primitive_quant_holder = GetCNodeQuantHolder(input_cnode_primitive); MS_ASSERT(input_primitive_quant_holder != nullptr); - if (input_primitive_quant_holder->output_quant_params().size() > index) { - auto quant_param = input_primitive_quant_holder->output_quant_params()[index]; + if (input_primitive_quant_holder->get_output_quant_params().size() > index) { + auto quant_param = input_primitive_quant_holder->get_output_quant_params()[index]; primitive_quant_holder->AddInputQuantParam(quant_param); primitive_quant_holder->AddOutputQuantParam(quant_param); } else { MS_LOG(WARNING) << "this TupleGetItem node's input node: " << input_cnode->fullname_with_scope() - << "'s output quant_params size: " << input_primitive_quant_holder->output_quant_params().size() - << ", but index: " << index; + << "'s output quant_params size: " + << input_primitive_quant_holder->get_output_quant_params().size() << ", but index: " << index; } primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL); continue; @@ -1125,7 +1125,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, con } auto quant_param_holder = GetCNodeQuantHolder(primitive); MS_ASSERT(quant_param_holder != nullptr); - auto input_quant_params = quant_param_holder->input_quant_params(); + auto input_quant_params = quant_param_holder->get_input_quant_params(); if (input_quant_params.size() == 3) { // compensate the existed auto bias_quant_params = input_quant_params[2]; @@ -1191,7 +1191,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, con cnode->add_input(parameter); DoBiasQuant(parameter, primitive); } else { - MS_LOG(ERROR) << "unexpected input_quant_params size: " << input_quant_params.size(); + MS_LOG(ERROR) << "unexpected get_input_quant_params size: " << input_quant_params.size(); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.cc b/mindspore/lite/tools/converter/quantizer/quant_cast.cc index 0570946b1f..6e05fdcced 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_cast.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.cc @@ -25,7 +25,7 @@ namespace mindspore::lite::quant { ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector &quant_params) { auto prim_c = std::make_shared(); prim_c->Init(src_type, dst_type); - auto quant_params_holder = std::make_shared(); + auto quant_params_holder = std::make_shared(0, 0); quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL); for (auto &quant_param : quant_params) { std::vector quant_params_in = {quant_param}; @@ -82,17 +82,17 @@ STATUS QuantCast::Run(const FuncGraphPtr &graph) { ValueNodePtr value_node = nullptr; if (curnode_quant_type == schema::QuantType_QUANT_ALL && input_cnode_quant_type == schema::QuantType_QUANT_NONE) { - if (primitive_quant_param_holder->input_quant_params().size() < i) { + if (primitive_quant_param_holder->get_input_quant_params().size() < i) { MS_LOG(ERROR) << "quant param is invalid."; return RET_ERROR; } value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, - primitive_quant_param_holder->input_quant_params()[i - 1]); + primitive_quant_param_holder->get_input_quant_params()[i - 1]); } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && input_cnode_quant_type == schema::QuantType_QUANT_ALL) { auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode_primitive_c); value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, - input_primitive_quant_param_holder->output_quant_params().front()); + input_primitive_quant_param_holder->get_output_quant_params().front()); } if (value_node == nullptr) { MS_LOG(WARNING) << "value_node is null! " diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 682b07d52e..bed6a79a06 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -186,12 +186,12 @@ QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) { QuantParamHolderPtr quant_params_holder = nullptr; auto quant_params_valueptr = primitive->GetAttr("quant_params"); if (quant_params_valueptr == nullptr) { - quant_params_holder = std::make_shared(); + quant_params_holder = std::make_shared(0, 0); primitive->AddAttr("quant_params", quant_params_holder); } else { quant_params_holder = quant_params_valueptr->cast(); if (quant_params_holder == nullptr) { - quant_params_holder = std::make_shared(); + quant_params_holder = std::make_shared(0, 0); primitive->AddAttr("quant_params", quant_params_holder); } } diff --git a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc index f87efb8f79..e75fd0ec36 100644 --- a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc @@ -133,7 +133,6 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons } auto left_matmul_input = left_slice_cnode->input(1); auto right_reshape_node = fullconnect_cnode->input(2); - auto matmul_cvalue = new (std::nothrow) mindspore::ops::MatMul(); if (matmul_cvalue == nullptr) { MS_LOG(ERROR) << "new MatMul failed"; @@ -153,29 +152,29 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons MS_LOG(ERROR) << "quant param is invalid."; return nullptr; } - auto fc_input_quantParams = fc_input_quantParams_holder->input_quant_params(); + auto fc_input_quantParams = fc_input_quantParams_holder->get_input_quant_params(); if (fc_input_quantParams.size() > 1 && !fc_input_quantParams[1].empty()) { jointed_quant_params.push_back(fc_input_quantParams[1][0]); } } - auto quant_params_holder = std::make_shared(); auto fc_prim = GetValueNode(fullconnect_cnode->input(0)); lite::QuantParamsVector rmatmul_quant_params; auto rmatmul_quant_params_valueptr = fc_prim->GetAttr("quant_params"); + lite::QuantParamsVector output_quant_params; if (rmatmul_quant_params_valueptr != nullptr) { auto rmatmul_quant_params_holder = rmatmul_quant_params_valueptr->cast(); if (rmatmul_quant_params_holder == nullptr) { MS_LOG(ERROR) << "quant param is invalid."; return nullptr; } - rmatmul_quant_params = rmatmul_quant_params_holder->input_quant_params(); - quant_params_holder->set_output_quant_params(rmatmul_quant_params_holder->output_quant_params()); + rmatmul_quant_params = rmatmul_quant_params_holder->get_input_quant_params(); + output_quant_params = rmatmul_quant_params_holder->get_output_quant_params(); } rmatmul_quant_params.pop_back(); rmatmul_quant_params.pop_back(); // no bias quantParams rmatmul_quant_params.emplace_back(jointed_quant_params); - quant_params_holder->set_input_quant_params(rmatmul_quant_params); + auto quant_params_holder = std::make_shared(rmatmul_quant_params, output_quant_params); matmul_cvalue->AddAttr("quant_params", quant_params_holder); auto matmul_value_node = NewValueNode(std::shared_ptr(matmul_cvalue)); std::vector matmul_inputs = {matmul_value_node, left_matmul_input}; diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index cc1d6d577c..36c6afd7fa 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -205,7 +205,7 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector MS_LOG(ERROR) << "quant param is invalid."; return lite::RET_ERROR; } - auto input_quant_params = quant_param_holder->input_quant_params(); + auto input_quant_params = quant_param_holder->get_input_quant_params(); for (size_t m = 0; m < input_quant_params.size(); m++) { for (auto inputQuantParam : input_quant_params[m]) { lite::QuantArg quant_arg{}; @@ -216,7 +216,7 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector inputs[m]->AddQuantParam(quant_arg); } } - auto output_quant_params = quant_param_holder->output_quant_params(); + auto output_quant_params = quant_param_holder->get_output_quant_params(); for (size_t m = 0; m < output_quant_params.size(); m++) { for (auto outputQuantParam : output_quant_params[m]) { lite::QuantArg quant_arg{}; diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc index 98eb45e385..a4dacbe378 100644 --- a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc @@ -25,12 +25,17 @@ namespace mindspore { namespace opt { namespace { -constexpr size_t kDoubleNum = 2; -int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { - auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); - if (quant_tensor_info_ptr == nullptr) { - prim->AddAttr("quant_params", std::make_shared()); +size_t GetCNodeOutputsSize(std::shared_ptr anf_node) { + auto cnode = anf_node->cast(); + if (utils::isa(cnode->abstract())) { + auto tuple = std::reinterpret_pointer_cast(cnode->abstract()); + return tuple->elements().size(); + } else { + return 1; } +} + +int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { auto quant_param_holder = prim->GetAttr("quant_params")->cast(); std::vector quants; schema::QuantParamT quant_param; @@ -50,10 +55,7 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t return ret; } quants.emplace_back(quant_param); - quant_param_holder->AddInputQuantParam(quants); - } else { - std::vector notinited_quant_params(1); - quant_param_holder->AddInputQuantParam(notinited_quant_params); + quant_param_holder->set_input_quant_param(0, quants); } quants.clear(); @@ -78,19 +80,12 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t return ret; } quants.emplace_back(quant_param); - quant_param_holder->AddInputQuantParam(quants); - } else { - std::vector notinited_quant_params(1); - quant_param_holder->AddInputQuantParam(notinited_quant_params); + quant_param_holder->set_input_quant_param(1, quants); } return lite::RET_OK; } int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { - auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); - if (quant_tensor_info_ptr == nullptr) { - prim->AddAttr("quant_params", std::make_shared()); - } auto quant_param_holder = prim->GetAttr("quant_params")->cast(); std::vector quants; schema::QuantParamT quant_param; @@ -110,22 +105,14 @@ int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t return ret; } quants.emplace_back(quant_param); - quant_param_holder->AddOutputQuantParam(quants); - } else { - schema::QuantParamT tmpQuantParam; - quants.emplace_back(tmpQuantParam); - quant_param_holder->AddOutputQuantParam(quants); + quant_param_holder->set_output_quant_param(0, quants); } return lite::RET_OK; } void CheckQuantParams(const PrimitivePtr &prim) { - auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); - if (quant_tensor_info_ptr == nullptr) { - prim->AddAttr("quant_params", std::make_shared()); - } auto quant_param_holder = prim->GetAttr("quant_params")->cast(); - auto input_quant_params = quant_param_holder->input_quant_params(); + auto input_quant_params = quant_param_holder->get_input_quant_params(); bool is_quant = false; for (size_t i = 0; i < input_quant_params.size(); ++i) { if (!input_quant_params.at(i).empty() && input_quant_params.at(i).at(0).inited) { @@ -133,7 +120,7 @@ void CheckQuantParams(const PrimitivePtr &prim) { break; } } - auto output_quant_params = quant_param_holder->output_quant_params(); + auto output_quant_params = quant_param_holder->get_output_quant_params(); for (size_t i = 0; i < output_quant_params.size(); ++i) { if (!output_quant_params.at(i).empty() && output_quant_params.at(i).at(0).inited) { is_quant = true; @@ -145,8 +132,6 @@ void CheckQuantParams(const PrimitivePtr &prim) { } int ConvertQuantParam(const PrimitivePtr &prim, const std::vector &inputs) { - auto quant_param_holder = std::make_shared(); - prim->AddAttr("quant_params", quant_param_holder); auto narrow_range = prim->GetAttr("narrow_range"); bool narrow_range_param = false; if (narrow_range != nullptr) { @@ -248,6 +233,10 @@ int MindirAdjustPass::ComputeQuantParams(std::shared_ptr anf_node) { } auto inputs = cnode->inputs(); inputs.erase(inputs.begin()); + + auto quant_param_holder = std::make_shared(inputs.size(), GetCNodeOutputsSize(anf_node)); + primitive->AddAttr("quant_params", quant_param_holder); + if (ConvertQuantParam(primitive, inputs) != lite::RET_OK) { MS_LOG(ERROR) << "compute quant param failed."; return lite::RET_ERROR; diff --git a/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.cc index c92f3dfb11..82b5e6100b 100644 --- a/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.cc @@ -56,12 +56,16 @@ lite::STATUS ReorderCnodeInputs(CNode *cnode, const std::vector &perm) { std::vector new_inputs = {cnode->input(0)}; auto primitive = GetValueNode(cnode->input(0)); auto input_quant_params = primitive->GetAttr("quant_params"); - auto input_quant_params_holder = input_quant_params == nullptr - ? std::make_shared() - : input_quant_params->cast(); - auto old_quant_params = input_quant_params_holder->input_quant_params(); - auto new_input_quant_holder = std::make_shared(); + if (input_quant_params == nullptr) { + MS_LOG(ERROR) << "quant params holder is null"; + return RET_ERROR; + } + auto input_quant_params_holder = input_quant_params->cast(); + auto old_quant_params = input_quant_params_holder->get_input_quant_params(); + auto new_input_quant_holder = + std::make_shared(perm.size(), input_quant_params_holder->get_output_quant_params().size()); // add inputs as perm order + size_t new_idx = 0; for (size_t idx : perm) { if (idx > cnode->inputs().size() - 1) { MS_LOG(ERROR) << "Idx " << idx << " is larger than inputs size: " << cnode->inputs().size() - 1; @@ -69,7 +73,12 @@ lite::STATUS ReorderCnodeInputs(CNode *cnode, const std::vector &perm) { } new_inputs.emplace_back(cnode->input(idx)); auto quant_param = idx < old_quant_params.size() ? old_quant_params.at(idx) : std::vector(); - new_input_quant_holder->AddInputQuantParam(quant_param); + new_input_quant_holder->set_input_quant_param(new_idx, quant_param); + new_idx++; + } + + for (size_t i = 0; i < input_quant_params_holder->get_output_quant_params().size(); i++) { + new_input_quant_holder->set_output_quant_param(i, input_quant_params_holder->get_output_quant_params().at(i)); } cnode->set_inputs(new_inputs); primitive->set_attr("quant_params", new_input_quant_holder); diff --git a/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc b/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc index ece4d6ccae..3a134d4f55 100644 --- a/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc +++ b/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc @@ -93,9 +93,7 @@ AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_gr std::string trans_name = before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post"; auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name); - auto quant_params_holder = std::make_shared(); - quant_params_holder->AddInputQuantParam(std::vector(1)); - quant_params_holder->AddOutputQuantParam(std::vector(1)); + auto quant_params_holder = std::make_shared(1, 1); auto trans_insert_prim = GetValueNode(trans_insert_node->input(0)); trans_insert_prim->AddAttr("quant_params", quant_params_holder); return trans_insert_node;