| @@ -274,8 +274,8 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||
| MS_LOG(ERROR) << "quant param is invalid."; | |||
| return RET_ERROR; | |||
| } | |||
| input_quant_params = quant_param_holder->input_quant_params(); | |||
| output_quant_params = quant_param_holder->output_quant_params(); | |||
| input_quant_params = quant_param_holder->get_input_quant_params(); | |||
| output_quant_params = quant_param_holder->get_output_quant_params(); | |||
| dst_node->quantType = quant_param_holder->quant_type(); | |||
| } | |||
| // add quant param | |||
| @@ -311,19 +311,8 @@ STATUS CaffeModelParser::ConvertLayerQuantParams(const caffe::LayerParameter &la | |||
| MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(); | |||
| for (auto input_idx : layer.bottom()) { | |||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| quant_params_holder->AddInputQuantParam(notinited_quant_params); | |||
| } | |||
| for (auto input_idx : weight.blobs()) { | |||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| quant_params_holder->AddInputQuantParam(notinited_quant_params); | |||
| } | |||
| for (auto output_idx : layer.top()) { | |||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| quant_params_holder->AddOutputQuantParam(notinited_quant_params); | |||
| } | |||
| auto quant_params_holder = | |||
| std::make_shared<QuantParamHolder>(layer.bottom_size() + weight.blobs_size(), layer.top_size()); | |||
| primitive_c->AddAttr("quant_params", quant_params_holder); | |||
| return RET_OK; | |||
| } | |||
| @@ -535,7 +535,7 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o | |||
| return RET_ERROR; | |||
| } | |||
| // set input tensors | |||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(); | |||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(onnx_node.input_size(), onnx_node.output_size()); | |||
| for (int i = 0; i < onnx_node.input_size(); ++i) { | |||
| const auto &input_name = onnx_node.input(i); | |||
| std::vector<schema::QuantParamT> quant_params; | |||
| @@ -544,7 +544,7 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o | |||
| MS_LOG(ERROR) << "set input tensor quant param failed."; | |||
| return status; | |||
| } | |||
| quant_params_holder->AddInputQuantParam(quant_params); | |||
| quant_params_holder->set_input_quant_param(i, quant_params); | |||
| } | |||
| // set out tensors | |||
| for (int i = 0; i < onnx_node.output_size(); ++i) { | |||
| @@ -555,7 +555,7 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o | |||
| MS_LOG(ERROR) << "set output tensor quant param failed."; | |||
| return status; | |||
| } | |||
| quant_params_holder->AddOutputQuantParam(quant_params); | |||
| quant_params_holder->set_output_quant_param(i, quant_params); | |||
| } | |||
| primitive_c->AddAttr("quant_params", quant_params_holder); | |||
| return RET_OK; | |||
| @@ -953,9 +953,25 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; | |||
| } | |||
| status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed."; | |||
| } | |||
| return status; | |||
| } | |||
| STATUS TFModelParser::ConvertQuantParams(const size_t &input_size, const size_t &output_size, | |||
| ops::PrimitiveC *primitive_c) { | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(input_size, output_size); | |||
| primitive_c->AddAttr("quant_params", quant_params_holder); | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertRootGraphOutputs() { | |||
| // because output of intermediate node in anf graph may also be output tensors, we search output tensors in | |||
| // tf_root_graph_nodes_ but not anf_root_node_map_ | |||
| @@ -29,6 +29,7 @@ | |||
| #include "securec/include/securec.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -75,6 +76,8 @@ class TFModelParser : public ModelParser { | |||
| STATUS ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &first_func_map, | |||
| const std::map<CNodePtr, FuncGraphPtr> &second_func_map); | |||
| STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, ops::PrimitiveC *primitive_c); | |||
| static STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph); | |||
| STATUS RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found); | |||
| @@ -243,7 +243,8 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, ops: | |||
| round_type = 2; | |||
| } | |||
| const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | |||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(); | |||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(op->inputs.size(), op->outputs.size()); | |||
| size_t idx = 0; | |||
| for (auto input_idx : op->inputs) { | |||
| if (input_idx < 0) { | |||
| input_idx += tflite_subgraph->tensors.size(); | |||
| @@ -255,8 +256,10 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, ops: | |||
| MS_LOG(ERROR) << "set input tensor quant param failed."; | |||
| return status; | |||
| } | |||
| quant_params_holder->AddInputQuantParam(quant_params); | |||
| quant_params_holder->set_input_quant_param(idx, quant_params); | |||
| idx++; | |||
| } | |||
| idx = 0; | |||
| for (auto output_idx : op->outputs) { | |||
| if (output_idx < 0) { | |||
| output_idx += tflite_subgraph->tensors.size(); | |||
| @@ -268,7 +271,8 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, ops: | |||
| MS_LOG(ERROR) << "set output tensor quant param failed."; | |||
| return status; | |||
| } | |||
| quant_params_holder->AddOutputQuantParam(quant_params); | |||
| quant_params_holder->set_output_quant_param(idx, quant_params); | |||
| idx++; | |||
| } | |||
| primitive_c->AddAttr("quant_params", quant_params_holder); | |||
| return RET_OK; | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| @@ -27,7 +28,24 @@ namespace lite { | |||
| using QuantParamsVector = std::vector<std::vector<schema::QuantParamT>>; | |||
| class QuantParamHolder : public Value { | |||
| public: | |||
| QuantParamHolder() = default; | |||
| QuantParamHolder(size_t input_size, size_t output_size) { | |||
| input_quant_params_.resize(input_size); | |||
| output_quant_params_.resize(output_size); | |||
| for (size_t i = 0; i < input_size; i++) { | |||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| set_input_quant_param(i, notinited_quant_params); | |||
| } | |||
| for (size_t i = 0; i < output_size; i++) { | |||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| set_output_quant_param(i, notinited_quant_params); | |||
| } | |||
| } | |||
| QuantParamHolder(const QuantParamsVector &input_quant_params, const QuantParamsVector &output_quant_params) { | |||
| input_quant_params_ = input_quant_params; | |||
| output_quant_params_ = output_quant_params; | |||
| } | |||
| ~QuantParamHolder() override = default; | |||
| @@ -36,17 +54,17 @@ class QuantParamHolder : public Value { | |||
| bool operator==(const Value &rhs) const override { // unused | |||
| if (rhs.isa<QuantParamHolder>()) { | |||
| auto other_holder = dynamic_cast<const QuantParamHolder &>(rhs); | |||
| auto input_quant_params_rhs = other_holder.input_quant_params(); | |||
| auto output_quant_params_rhs = other_holder.output_quant_params(); | |||
| if (input_quant_params_rhs.size() != this->input_quant_param_.size() || | |||
| output_quant_params_rhs.size() != this->output_quant_param_.size()) { | |||
| auto input_quant_params_rhs = other_holder.get_input_quant_params(); | |||
| auto output_quant_params_rhs = other_holder.get_output_quant_params(); | |||
| if (input_quant_params_rhs.size() != this->input_quant_params_.size() || | |||
| output_quant_params_rhs.size() != this->output_quant_params_.size()) { | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < input_quant_params_rhs.size(); ++i) { | |||
| if (input_quant_params_rhs.at(i).size() != this->input_quant_param_.at(i).size()) { | |||
| if (input_quant_params_rhs.at(i).size() != this->input_quant_params_.at(i).size()) { | |||
| return false; | |||
| } | |||
| auto *params = reinterpret_cast<const char *>(this->input_quant_param_.at(i).data()); | |||
| auto *params = reinterpret_cast<const char *>(this->input_quant_params_.at(i).data()); | |||
| auto *params_rhs = reinterpret_cast<const char *>(input_quant_params_rhs.at(i).data()); | |||
| for (size_t j = 0; j < input_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) { | |||
| if (params[j] != params_rhs[j]) { | |||
| @@ -55,10 +73,10 @@ class QuantParamHolder : public Value { | |||
| } | |||
| } | |||
| for (size_t i = 0; i < output_quant_params_rhs.size(); ++i) { | |||
| if (output_quant_params_rhs.at(i).size() != this->output_quant_param_.at(i).size()) { | |||
| if (output_quant_params_rhs.at(i).size() != this->output_quant_params_.at(i).size()) { | |||
| return false; | |||
| } | |||
| auto *params = reinterpret_cast<const char *>(this->output_quant_param_.at(i).data()); | |||
| auto *params = reinterpret_cast<const char *>(this->output_quant_params_.at(i).data()); | |||
| auto *params_rhs = reinterpret_cast<const char *>(output_quant_params_rhs.at(i).data()); | |||
| for (size_t j = 0; j < output_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) { | |||
| if (params[j] != params_rhs[j]) { | |||
| @@ -76,58 +94,53 @@ class QuantParamHolder : public Value { | |||
| schema::QuantType quant_type() const { return quant_type_; } | |||
| void set_input_quant_params(const QuantParamsVector &input_quant_param) { | |||
| this->input_quant_param_ = input_quant_param; | |||
| } | |||
| void set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) { | |||
| if (index >= this->input_quant_param_.size()) { | |||
| if (index >= this->input_quant_params_.size()) { | |||
| std::vector<schema::QuantParamT> place_quant(1); | |||
| this->input_quant_param_.insert(this->input_quant_param_.end(), index + 1 - input_quant_param_.size(), | |||
| place_quant); | |||
| this->input_quant_params_.insert(this->input_quant_params_.end(), index + 1 - input_quant_params_.size(), | |||
| place_quant); | |||
| } | |||
| this->input_quant_param_.at(index) = input_quant_param; | |||
| } | |||
| void set_output_quant_params(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) { | |||
| this->output_quant_param_ = output_quant_param; | |||
| this->input_quant_params_.at(index) = input_quant_param; | |||
| } | |||
| void set_output_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) { | |||
| if (index >= this->output_quant_param_.size()) { | |||
| if (index >= this->output_quant_params_.size()) { | |||
| std::vector<schema::QuantParamT> place_quant(1); | |||
| this->output_quant_param_.insert(this->output_quant_param_.end(), index + 1 - output_quant_param_.size(), | |||
| place_quant); | |||
| this->output_quant_params_.insert(this->output_quant_params_.end(), index + 1 - output_quant_params_.size(), | |||
| place_quant); | |||
| } | |||
| this->output_quant_param_.at(index) = output_quant_param; | |||
| this->output_quant_params_.at(index) = output_quant_param; | |||
| } | |||
| void set_enable_huffman_code(bool enable_huffman_code) { enable_huffman_code_ = enable_huffman_code; } | |||
| bool enable_huffman_code() const { return enable_huffman_code_; } | |||
| // deprecated | |||
| void AddInputQuantParam(const std::vector<schema::QuantParamT> &quant_param) { | |||
| this->input_quant_param_.emplace_back(quant_param); | |||
| this->input_quant_params_.emplace_back(quant_param); | |||
| } | |||
| std::vector<std::vector<schema::QuantParamT>> input_quant_params() const { return this->input_quant_param_; } | |||
| // deprecated | |||
| void AddOutputQuantParam(const std::vector<schema::QuantParamT> &quant_param) { | |||
| this->output_quant_param_.emplace_back(quant_param); | |||
| this->output_quant_params_.emplace_back(quant_param); | |||
| } | |||
| std::vector<std::vector<schema::QuantParamT>> output_quant_params() const { return this->output_quant_param_; } | |||
| std::vector<std::vector<schema::QuantParamT>> get_input_quant_params() const { return this->input_quant_params_; } | |||
| std::vector<std::vector<schema::QuantParamT>> get_output_quant_params() const { return this->output_quant_params_; } | |||
| // deprecated | |||
| void ClearInputOutputQuantParam() { | |||
| input_quant_param_.clear(); | |||
| output_quant_param_.clear(); | |||
| input_quant_params_.clear(); | |||
| output_quant_params_.clear(); | |||
| } | |||
| bool IsInputQuantParamsInited() { | |||
| if (this->input_quant_param_.empty()) { | |||
| if (this->input_quant_params_.empty()) { | |||
| return false; | |||
| } | |||
| for (auto &quant_param : this->input_quant_param_) { | |||
| for (auto &quant_param : this->input_quant_params_) { | |||
| if (!quant_param.front().inited) { | |||
| return false; | |||
| } | |||
| @@ -136,10 +149,10 @@ class QuantParamHolder : public Value { | |||
| } | |||
| bool IsOutputQuantParamsInited() { | |||
| if (this->output_quant_param_.empty()) { | |||
| if (this->output_quant_params_.empty()) { | |||
| return false; | |||
| } | |||
| for (auto &quant_param : this->output_quant_param_) { | |||
| for (auto &quant_param : this->output_quant_params_) { | |||
| if (!quant_param.front().inited) { | |||
| return false; | |||
| } | |||
| @@ -149,8 +162,8 @@ class QuantParamHolder : public Value { | |||
| private: | |||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | |||
| QuantParamsVector input_quant_param_; | |||
| QuantParamsVector output_quant_param_; | |||
| QuantParamsVector input_quant_params_; | |||
| QuantParamsVector output_quant_params_; | |||
| bool enable_huffman_code_ = false; | |||
| }; | |||
| using QuantParamHolderPtr = std::shared_ptr<QuantParamHolder>; | |||
| @@ -59,7 +59,7 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, cons | |||
| MS_ASSERT(quant_params != nullptr && quant_datas != nullptr); | |||
| double bias_scale_tmp; | |||
| const constexpr int32_t quanted_bias_abs_limit = 0.5 * INT32_MAX; | |||
| auto active_weight_quant_params = quant_param_holder->input_quant_params(); | |||
| auto weight_quant_params = quant_param_holder->get_input_quant_params().at(1); | |||
| auto shape_size = quant_datas->size(); | |||
| if (bias_scales.size() == shape_size) { | |||
| for (size_t i = 0; i < shape_size; i++) { | |||
| @@ -69,14 +69,14 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, cons | |||
| return RET_ERROR; | |||
| } | |||
| if (std::abs(raw_datas[i] / bias_scale_tmp) >= quanted_bias_abs_limit) { | |||
| MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << active_weight_quant_params[1][i].scale | |||
| MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << weight_quant_params[i].scale | |||
| << " is too small, need to update"; | |||
| // update filter scale and zp | |||
| double activate_scale = input_scales[0]; | |||
| double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit); | |||
| active_weight_quant_params[1][i].scale = filter_scale; | |||
| active_weight_quant_params[1][i].zeroPoint = 0; | |||
| quant_param_holder->set_input_quant_params(active_weight_quant_params); | |||
| weight_quant_params[i].scale = filter_scale; | |||
| weight_quant_params[i].zeroPoint = 0; | |||
| quant_param_holder->set_input_quant_param(1, weight_quant_params); | |||
| bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit; | |||
| quant_params->at(i).scale = bias_scale_tmp; | |||
| MS_LOG(DEBUG) << "new filter scale: " << filter_scale; | |||
| @@ -99,13 +99,13 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, cons | |||
| return RET_ERROR; | |||
| } | |||
| if (std::abs(max_raw_data / bias_scale_tmp) >= quanted_bias_abs_limit) { | |||
| MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << active_weight_quant_params[1][0].scale | |||
| MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << weight_quant_params[0].scale | |||
| << " is too small, need to update"; | |||
| double activate_scale = input_scales[0]; | |||
| double filter_scale = std::abs(max_raw_data) / (activate_scale * quanted_bias_abs_limit); | |||
| active_weight_quant_params[1][0].scale = filter_scale; | |||
| active_weight_quant_params[1][0].zeroPoint = 0; | |||
| quant_param_holder->set_input_quant_params(active_weight_quant_params); | |||
| weight_quant_params[0].scale = filter_scale; | |||
| weight_quant_params[0].zeroPoint = 0; | |||
| quant_param_holder->set_input_quant_param(1, weight_quant_params); | |||
| bias_scale_tmp = max_raw_data / quanted_bias_abs_limit; | |||
| quant_params->front().scale = bias_scale_tmp; | |||
| MS_LOG(DEBUG) << "new filter scale: " << filter_scale; | |||
| @@ -117,7 +117,7 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, cons | |||
| return RET_OK; | |||
| } | |||
| MS_LOG(ERROR) << "unexpected input_scales size: " << input_scales.size() | |||
| << " weight_scales size: " << active_weight_quant_params[1].size(); | |||
| << " weight_scales size: " << weight_quant_params.size(); | |||
| return RET_ERROR; | |||
| } | |||
| } // namespace | |||
| @@ -620,7 +620,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const Primitiv | |||
| MS_ASSERT(bias_parameter != nullptr); | |||
| auto quant_param_holder = GetCNodeQuantHolder(primitive); | |||
| MS_ASSERT(quant_param_holder != nullptr); | |||
| auto active_weight_quant_params = quant_param_holder->input_quant_params(); | |||
| auto active_weight_quant_params = quant_param_holder->get_input_quant_params(); | |||
| if (active_weight_quant_params.size() != 2) { | |||
| MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size(); | |||
| return RET_ERROR; | |||
| @@ -731,7 +731,7 @@ STATUS PostTrainingQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) { | |||
| auto input_primitive_quant_holder = GetCNodeQuantHolder(input_cnode_primitive); | |||
| MS_ASSERT(input_primitive_quant_holder != nullptr); | |||
| if (input_primitive_quant_holder->IsOutputQuantParamsInited()) { | |||
| auto quant_param = input_primitive_quant_holder->output_quant_params().front(); | |||
| auto quant_param = input_primitive_quant_holder->get_output_quant_params().front(); | |||
| primitive_quant_holder->AddInputQuantParam(quant_param); | |||
| } else { | |||
| // do input quant | |||
| @@ -820,14 +820,14 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||
| } | |||
| auto input_primitive_quant_holder = GetCNodeQuantHolder(input_cnode_primitive); | |||
| MS_ASSERT(input_primitive_quant_holder != nullptr); | |||
| if (input_primitive_quant_holder->output_quant_params().size() > index) { | |||
| auto quant_param = input_primitive_quant_holder->output_quant_params()[index]; | |||
| if (input_primitive_quant_holder->get_output_quant_params().size() > index) { | |||
| auto quant_param = input_primitive_quant_holder->get_output_quant_params()[index]; | |||
| primitive_quant_holder->AddInputQuantParam(quant_param); | |||
| primitive_quant_holder->AddOutputQuantParam(quant_param); | |||
| } else { | |||
| MS_LOG(WARNING) << "this TupleGetItem node's input node: " << input_cnode->fullname_with_scope() | |||
| << "'s output quant_params size: " << input_primitive_quant_holder->output_quant_params().size() | |||
| << ", but index: " << index; | |||
| << "'s output quant_params size: " | |||
| << input_primitive_quant_holder->get_output_quant_params().size() << ", but index: " << index; | |||
| } | |||
| primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL); | |||
| continue; | |||
| @@ -1125,7 +1125,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, con | |||
| } | |||
| auto quant_param_holder = GetCNodeQuantHolder(primitive); | |||
| MS_ASSERT(quant_param_holder != nullptr); | |||
| auto input_quant_params = quant_param_holder->input_quant_params(); | |||
| auto input_quant_params = quant_param_holder->get_input_quant_params(); | |||
| if (input_quant_params.size() == 3) { | |||
| // compensate the existed | |||
| auto bias_quant_params = input_quant_params[2]; | |||
| @@ -1191,7 +1191,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, con | |||
| cnode->add_input(parameter); | |||
| DoBiasQuant(parameter, primitive); | |||
| } else { | |||
| MS_LOG(ERROR) << "unexpected input_quant_params size: " << input_quant_params.size(); | |||
| MS_LOG(ERROR) << "unexpected get_input_quant_params size: " << input_quant_params.size(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -25,7 +25,7 @@ namespace mindspore::lite::quant { | |||
| ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) { | |||
| auto prim_c = std::make_shared<ops::QuantDTypeCast>(); | |||
| prim_c->Init(src_type, dst_type); | |||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(); | |||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(0, 0); | |||
| quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL); | |||
| for (auto &quant_param : quant_params) { | |||
| std::vector<schema::QuantParamT> quant_params_in = {quant_param}; | |||
| @@ -82,17 +82,17 @@ STATUS QuantCast::Run(const FuncGraphPtr &graph) { | |||
| ValueNodePtr value_node = nullptr; | |||
| if (curnode_quant_type == schema::QuantType_QUANT_ALL && | |||
| input_cnode_quant_type == schema::QuantType_QUANT_NONE) { | |||
| if (primitive_quant_param_holder->input_quant_params().size() < i) { | |||
| if (primitive_quant_param_holder->get_input_quant_params().size() < i) { | |||
| MS_LOG(ERROR) << "quant param is invalid."; | |||
| return RET_ERROR; | |||
| } | |||
| value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, | |||
| primitive_quant_param_holder->input_quant_params()[i - 1]); | |||
| primitive_quant_param_holder->get_input_quant_params()[i - 1]); | |||
| } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && | |||
| input_cnode_quant_type == schema::QuantType_QUANT_ALL) { | |||
| auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode_primitive_c); | |||
| value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, | |||
| input_primitive_quant_param_holder->output_quant_params().front()); | |||
| input_primitive_quant_param_holder->get_output_quant_params().front()); | |||
| } | |||
| if (value_node == nullptr) { | |||
| MS_LOG(WARNING) << "value_node is null! " | |||
| @@ -186,12 +186,12 @@ QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) { | |||
| QuantParamHolderPtr quant_params_holder = nullptr; | |||
| auto quant_params_valueptr = primitive->GetAttr("quant_params"); | |||
| if (quant_params_valueptr == nullptr) { | |||
| quant_params_holder = std::make_shared<QuantParamHolder>(); | |||
| quant_params_holder = std::make_shared<QuantParamHolder>(0, 0); | |||
| primitive->AddAttr("quant_params", quant_params_holder); | |||
| } else { | |||
| quant_params_holder = quant_params_valueptr->cast<QuantParamHolderPtr>(); | |||
| if (quant_params_holder == nullptr) { | |||
| quant_params_holder = std::make_shared<QuantParamHolder>(); | |||
| quant_params_holder = std::make_shared<QuantParamHolder>(0, 0); | |||
| primitive->AddAttr("quant_params", quant_params_holder); | |||
| } | |||
| } | |||
| @@ -133,7 +133,6 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons | |||
| } | |||
| auto left_matmul_input = left_slice_cnode->input(1); | |||
| auto right_reshape_node = fullconnect_cnode->input(2); | |||
| auto matmul_cvalue = new (std::nothrow) mindspore::ops::MatMul(); | |||
| if (matmul_cvalue == nullptr) { | |||
| MS_LOG(ERROR) << "new MatMul failed"; | |||
| @@ -153,29 +152,29 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons | |||
| MS_LOG(ERROR) << "quant param is invalid."; | |||
| return nullptr; | |||
| } | |||
| auto fc_input_quantParams = fc_input_quantParams_holder->input_quant_params(); | |||
| auto fc_input_quantParams = fc_input_quantParams_holder->get_input_quant_params(); | |||
| if (fc_input_quantParams.size() > 1 && !fc_input_quantParams[1].empty()) { | |||
| jointed_quant_params.push_back(fc_input_quantParams[1][0]); | |||
| } | |||
| } | |||
| auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(); | |||
| auto fc_prim = GetValueNode<PrimitiveCPtr>(fullconnect_cnode->input(0)); | |||
| lite::QuantParamsVector rmatmul_quant_params; | |||
| auto rmatmul_quant_params_valueptr = fc_prim->GetAttr("quant_params"); | |||
| lite::QuantParamsVector output_quant_params; | |||
| if (rmatmul_quant_params_valueptr != nullptr) { | |||
| auto rmatmul_quant_params_holder = rmatmul_quant_params_valueptr->cast<lite::QuantParamHolderPtr>(); | |||
| if (rmatmul_quant_params_holder == nullptr) { | |||
| MS_LOG(ERROR) << "quant param is invalid."; | |||
| return nullptr; | |||
| } | |||
| rmatmul_quant_params = rmatmul_quant_params_holder->input_quant_params(); | |||
| quant_params_holder->set_output_quant_params(rmatmul_quant_params_holder->output_quant_params()); | |||
| rmatmul_quant_params = rmatmul_quant_params_holder->get_input_quant_params(); | |||
| output_quant_params = rmatmul_quant_params_holder->get_output_quant_params(); | |||
| } | |||
| rmatmul_quant_params.pop_back(); | |||
| rmatmul_quant_params.pop_back(); | |||
| // no bias quantParams | |||
| rmatmul_quant_params.emplace_back(jointed_quant_params); | |||
| quant_params_holder->set_input_quant_params(rmatmul_quant_params); | |||
| auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(rmatmul_quant_params, output_quant_params); | |||
| matmul_cvalue->AddAttr("quant_params", quant_params_holder); | |||
| auto matmul_value_node = NewValueNode(std::shared_ptr<ops::PrimitiveC>(matmul_cvalue)); | |||
| std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input}; | |||
| @@ -205,7 +205,7 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector<Tensor *> | |||
| MS_LOG(ERROR) << "quant param is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto input_quant_params = quant_param_holder->input_quant_params(); | |||
| auto input_quant_params = quant_param_holder->get_input_quant_params(); | |||
| for (size_t m = 0; m < input_quant_params.size(); m++) { | |||
| for (auto inputQuantParam : input_quant_params[m]) { | |||
| lite::QuantArg quant_arg{}; | |||
| @@ -216,7 +216,7 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector<Tensor *> | |||
| inputs[m]->AddQuantParam(quant_arg); | |||
| } | |||
| } | |||
| auto output_quant_params = quant_param_holder->output_quant_params(); | |||
| auto output_quant_params = quant_param_holder->get_output_quant_params(); | |||
| for (size_t m = 0; m < output_quant_params.size(); m++) { | |||
| for (auto outputQuantParam : output_quant_params[m]) { | |||
| lite::QuantArg quant_arg{}; | |||
| @@ -25,12 +25,17 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kDoubleNum = 2; | |||
| int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { | |||
| auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); | |||
| if (quant_tensor_info_ptr == nullptr) { | |||
| prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>()); | |||
| size_t GetCNodeOutputsSize(std::shared_ptr<AnfNode> anf_node) { | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) { | |||
| auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract()); | |||
| return tuple->elements().size(); | |||
| } else { | |||
| return 1; | |||
| } | |||
| } | |||
| int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { | |||
| auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>(); | |||
| std::vector<schema::QuantParamT> quants; | |||
| schema::QuantParamT quant_param; | |||
| @@ -50,10 +55,7 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t | |||
| return ret; | |||
| } | |||
| quants.emplace_back(quant_param); | |||
| quant_param_holder->AddInputQuantParam(quants); | |||
| } else { | |||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| quant_param_holder->AddInputQuantParam(notinited_quant_params); | |||
| quant_param_holder->set_input_quant_param(0, quants); | |||
| } | |||
| quants.clear(); | |||
| @@ -78,19 +80,12 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t | |||
| return ret; | |||
| } | |||
| quants.emplace_back(quant_param); | |||
| quant_param_holder->AddInputQuantParam(quants); | |||
| } else { | |||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| quant_param_holder->AddInputQuantParam(notinited_quant_params); | |||
| quant_param_holder->set_input_quant_param(1, quants); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { | |||
| auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); | |||
| if (quant_tensor_info_ptr == nullptr) { | |||
| prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>()); | |||
| } | |||
| auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>(); | |||
| std::vector<schema::QuantParamT> quants; | |||
| schema::QuantParamT quant_param; | |||
| @@ -110,22 +105,14 @@ int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t | |||
| return ret; | |||
| } | |||
| quants.emplace_back(quant_param); | |||
| quant_param_holder->AddOutputQuantParam(quants); | |||
| } else { | |||
| schema::QuantParamT tmpQuantParam; | |||
| quants.emplace_back(tmpQuantParam); | |||
| quant_param_holder->AddOutputQuantParam(quants); | |||
| quant_param_holder->set_output_quant_param(0, quants); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| void CheckQuantParams(const PrimitivePtr &prim) { | |||
| auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); | |||
| if (quant_tensor_info_ptr == nullptr) { | |||
| prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>()); | |||
| } | |||
| auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>(); | |||
| auto input_quant_params = quant_param_holder->input_quant_params(); | |||
| auto input_quant_params = quant_param_holder->get_input_quant_params(); | |||
| bool is_quant = false; | |||
| for (size_t i = 0; i < input_quant_params.size(); ++i) { | |||
| if (!input_quant_params.at(i).empty() && input_quant_params.at(i).at(0).inited) { | |||
| @@ -133,7 +120,7 @@ void CheckQuantParams(const PrimitivePtr &prim) { | |||
| break; | |||
| } | |||
| } | |||
| auto output_quant_params = quant_param_holder->output_quant_params(); | |||
| auto output_quant_params = quant_param_holder->get_output_quant_params(); | |||
| for (size_t i = 0; i < output_quant_params.size(); ++i) { | |||
| if (!output_quant_params.at(i).empty() && output_quant_params.at(i).at(0).inited) { | |||
| is_quant = true; | |||
| @@ -145,8 +132,6 @@ void CheckQuantParams(const PrimitivePtr &prim) { | |||
| } | |||
| int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| auto quant_param_holder = std::make_shared<lite::QuantParamHolder>(); | |||
| prim->AddAttr("quant_params", quant_param_holder); | |||
| auto narrow_range = prim->GetAttr("narrow_range"); | |||
| bool narrow_range_param = false; | |||
| if (narrow_range != nullptr) { | |||
| @@ -248,6 +233,10 @@ int MindirAdjustPass::ComputeQuantParams(std::shared_ptr<AnfNode> anf_node) { | |||
| } | |||
| auto inputs = cnode->inputs(); | |||
| inputs.erase(inputs.begin()); | |||
| auto quant_param_holder = std::make_shared<lite::QuantParamHolder>(inputs.size(), GetCNodeOutputsSize(anf_node)); | |||
| primitive->AddAttr("quant_params", quant_param_holder); | |||
| if (ConvertQuantParam(primitive, inputs) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "compute quant param failed."; | |||
| return lite::RET_ERROR; | |||
| @@ -56,12 +56,16 @@ lite::STATUS ReorderCnodeInputs(CNode *cnode, const std::vector<size_t> &perm) { | |||
| std::vector<AnfNodePtr> new_inputs = {cnode->input(0)}; | |||
| auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| auto input_quant_params = primitive->GetAttr("quant_params"); | |||
| auto input_quant_params_holder = input_quant_params == nullptr | |||
| ? std::make_shared<lite::QuantParamHolder>() | |||
| : input_quant_params->cast<lite::QuantParamHolderPtr>(); | |||
| auto old_quant_params = input_quant_params_holder->input_quant_params(); | |||
| auto new_input_quant_holder = std::make_shared<lite::QuantParamHolder>(); | |||
| if (input_quant_params == nullptr) { | |||
| MS_LOG(ERROR) << "quant params holder is null"; | |||
| return RET_ERROR; | |||
| } | |||
| auto input_quant_params_holder = input_quant_params->cast<lite::QuantParamHolderPtr>(); | |||
| auto old_quant_params = input_quant_params_holder->get_input_quant_params(); | |||
| auto new_input_quant_holder = | |||
| std::make_shared<lite::QuantParamHolder>(perm.size(), input_quant_params_holder->get_output_quant_params().size()); | |||
| // add inputs as perm order | |||
| size_t new_idx = 0; | |||
| for (size_t idx : perm) { | |||
| if (idx > cnode->inputs().size() - 1) { | |||
| MS_LOG(ERROR) << "Idx " << idx << " is larger than inputs size: " << cnode->inputs().size() - 1; | |||
| @@ -69,7 +73,12 @@ lite::STATUS ReorderCnodeInputs(CNode *cnode, const std::vector<size_t> &perm) { | |||
| } | |||
| new_inputs.emplace_back(cnode->input(idx)); | |||
| auto quant_param = idx < old_quant_params.size() ? old_quant_params.at(idx) : std::vector<schema::QuantParamT>(); | |||
| new_input_quant_holder->AddInputQuantParam(quant_param); | |||
| new_input_quant_holder->set_input_quant_param(new_idx, quant_param); | |||
| new_idx++; | |||
| } | |||
| for (size_t i = 0; i < input_quant_params_holder->get_output_quant_params().size(); i++) { | |||
| new_input_quant_holder->set_output_quant_param(i, input_quant_params_holder->get_output_quant_params().at(i)); | |||
| } | |||
| cnode->set_inputs(new_inputs); | |||
| primitive->set_attr("quant_params", new_input_quant_holder); | |||
| @@ -93,9 +93,7 @@ AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_gr | |||
| std::string trans_name = | |||
| before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post"; | |||
| auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name); | |||
| auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(); | |||
| quant_params_holder->AddInputQuantParam(std::vector<schema::QuantParamT>(1)); | |||
| quant_params_holder->AddOutputQuantParam(std::vector<schema::QuantParamT>(1)); | |||
| auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(1, 1); | |||
| auto trans_insert_prim = GetValueNode<PrimitivePtr>(trans_insert_node->input(0)); | |||
| trans_insert_prim->AddAttr("quant_params", quant_params_holder); | |||
| return trans_insert_node; | |||