From 8d60d097e091e718bb1934bf7d7fa1d40cd6a6cc Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Thu, 15 Apr 2021 16:48:06 +0800 Subject: [PATCH] [MS][LITE] use set quant params instead of add quant params --- .../cpu/nnacl/infer/arithmetic_infer.c | 16 ++- .../lite/test/models_tflite_awaretraining.cfg | 2 +- .../lite/tools/anf_exporter/anf_exporter.cc | 135 ++++++++---------- .../lite/tools/anf_exporter/anf_exporter.h | 6 +- mindspore/lite/tools/common/graph_util.cc | 36 +++-- mindspore/lite/tools/common/node_util.cc | 15 ++ mindspore/lite/tools/common/node_util.h | 3 + .../lite/tools/converter/converter_context.h | 19 ++- .../graph/dtype_trans_pass.cc | 91 +++++++----- .../graph/tensor_quant_pass.cc | 51 ++----- ...quant_dtype_cast_quant_param_propogator.cc | 34 +++++ .../quant_dtype_cast_quant_param_propogator.h | 27 ++++ .../quant_helper/quant_node_helper.cc | 4 +- .../converter/quantizer/quantize_util.cc | 19 +++ .../tools/converter/quantizer/quantize_util.h | 4 + .../lite/tools/optimizer/common/gllo_utils.cc | 4 + .../optimizer/graph/mindir_adjust_pass.cc | 36 +---- .../optimizer/graph/transpose_strategy.cc | 4 - .../graph/weight_format_transform_pass.cc | 2 + 19 files changed, 305 insertions(+), 203 deletions(-) create mode 100644 mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.cc create mode 100644 mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c index 7a1050ac41..d1ed115fe7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c @@ -17,6 +17,16 @@ #include "nnacl/infer/arithmetic_infer.h" #include "nnacl/infer/infer_register.h" +void SetOutputDtypeFormat(const TensorC *input0, const TensorC *input1, TensorC *output) { + output->format_ = input0->format_; + output->data_type_ = input0->data_type_; + // when input0 is const, it is quanted before insert quant trans op, so use input1 data type instead + if (input0->data_ != NULL || + ((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32))) { + output->data_type_ = input1->data_type_; + } +} + void UpdateInputShape(const int input_shape0_size, const int input_shape1_size, int *ndim, const int *input_shape0, const int *input_shape1, int *in_shape0, int *in_shape1) { if (input_shape0_size < input_shape1_size) { @@ -71,11 +81,7 @@ int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso size_t input_shape0_size = input0->shape_size_; const int *input_shape1 = input1->shape_; size_t input_shape1_size = input1->shape_size_; - output->format_ = input0->format_; - output->data_type_ = input0->data_type_; - if ((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32)) { - output->data_type_ = input1->data_type_; - } + SetOutputDtypeFormat(input0, input1, output); if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; diff --git a/mindspore/lite/test/models_tflite_awaretraining.cfg b/mindspore/lite/test/models_tflite_awaretraining.cfg index 3f3842761c..38584226a4 100644 --- a/mindspore/lite/test/models_tflite_awaretraining.cfg +++ b/mindspore/lite/test/models_tflite_awaretraining.cfg @@ -1,4 +1,4 @@ -video_infer.tflite +video_infer2.tflite mobilenet_v1_0.25_128_quant.tflite mobilenet_v1_0.25_160_quant.tflite mobilenet_v1_0.25_192_quant.tflite diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 1dbe38cd91..bf01fa9146 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -38,6 +38,8 @@ #include "ops/op_utils.h" #include "tools/common/graph_util.h" #include "src/ops/ops_utils.h" +#include "tools/common/node_util.h" +#include "tools/converter/converter_context.h" using mindspore::ops::PrimitiveC; @@ -81,9 +83,9 @@ std::list GetOrderedCNodes(const FuncGraphPtr fg) { } } // namespace -int AnfExporter::SetQuantOutputTensorType(const std::unique_ptr &meta_graph, - const std::shared_ptr &primitive, - const std::unique_ptr &dst_node) { +int AnfExporter::SetPostTrainOutputTensorType(const std::unique_ptr &meta_graph, + const std::shared_ptr &primitive, + const std::unique_ptr &dst_node) { auto first_output_index = dst_node->outputIndex[0]; auto first_tensor_output = meta_graph->allTensors[first_output_index].get(); if (dst_node->quantType == schema::QuantType_PostTraining) { @@ -116,82 +118,63 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me QuantParamsVector output_quant_params; dst_node->quantType = schema::QuantType_QUANT_NONE; auto quant_tensor_info_ptr = primitive->GetAttr("quant_params"); - if (quant_tensor_info_ptr != nullptr) { - auto quant_param_holder = quant_tensor_info_ptr->cast(); - if (quant_param_holder == nullptr) { - MS_LOG(ERROR) << "quant param is invalid."; - return RET_ERROR; - } - input_quant_params = quant_param_holder->get_input_quant_params(); - output_quant_params = quant_param_holder->get_output_quant_params(); - dst_node->quantType = quant_param_holder->quant_type(); + QuantParamHolderPtr quant_param_holder = nullptr; + if (quant_tensor_info_ptr == nullptr || + (quant_param_holder = quant_tensor_info_ptr->cast()) == nullptr) { + quant_param_holder = std::make_shared(dst_node->inputIndex.size(), dst_node->outputIndex.size()); } - // add quant param - if (!input_quant_params.empty()) { - for (size_t i = 0; i < input_quant_params.size(); i++) { - if (i >= dst_node->inputIndex.size()) { - MS_LOG(INFO) << "node: " << dst_node->name << " input has " << input_quant_params.size() - << " quant_params; but only " << dst_node->inputIndex.size() << " input"; - break; - } - auto activate_index = dst_node->inputIndex[i]; - auto tensor_input = meta_graph->allTensors[activate_index].get(); - if (tensor_input->quantParams.empty()) { - for (auto input_quant_param : input_quant_params[i]) { - auto input_quant_param_ptr = std::make_unique(input_quant_param); - MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale - << " zp: " << input_quant_param_ptr->zeroPoint; - input_quant_param_ptr->dstDtype = tensor_input->dataType; - tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); - } + + input_quant_params = quant_param_holder->get_input_quant_params(); + output_quant_params = quant_param_holder->get_output_quant_params(); + dst_node->quantType = quant_param_holder->quant_type(); + + // convert input quant param + for (size_t i = 0; i < dst_node->inputIndex.size(); i++) { + if (i >= input_quant_params.size()) { + MS_LOG(INFO) << "node: " << dst_node->name << " has " << dst_node->inputIndex.size() << ", but only has" + << input_quant_params.size() << " quant params"; + break; + } + auto activate_index = dst_node->inputIndex[i]; + auto tensor_input = meta_graph->allTensors[activate_index].get(); + if (tensor_input->quantParams.empty()) { + for (auto input_quant_param : input_quant_params[i]) { + auto input_quant_param_ptr = std::make_unique(input_quant_param); + MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale + << " zp: " << input_quant_param_ptr->zeroPoint; + input_quant_param_ptr->dstDtype = tensor_input->dataType; + tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); } - if (!tensor_input->quantParams.empty()) { - int bit_num = tensor_input->quantParams.at(0)->numBits; - if (bit_num != 8 && bit_num != 16) { - auto status = DoBitPack(bit_num, tensor_input); - if (status != RET_OK) { - MS_LOG(ERROR) << "do bit pack failed. " << status; - return RET_ERROR; - } + } + + if (!tensor_input->quantParams.empty()) { + int bit_num = tensor_input->quantParams.at(0)->numBits; + if (bit_num != 8 && bit_num != 16) { + auto status = DoBitPack(bit_num, tensor_input); + if (status != RET_OK) { + MS_LOG(ERROR) << "do bit pack failed. " << status; + return RET_ERROR; } } } - } else { - MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty"; } + // output - if (output_quant_params.empty()) { - if (primitive->name() != mindspore::ops::kNameQuantDTypeCast) { - MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty"; - } - } else { - if (dst_node->outputIndex.size() != output_quant_params.size()) { - MS_LOG(INFO) << "node: " << dst_node->name << " output has " << output_quant_params.size() - << " quant_params; but only " << dst_node->outputIndex.size() << " output"; - return RET_ERROR; - } - int output_idx = 0; - for (const auto &output_quant_param : output_quant_params) { - auto output_tensor = meta_graph->allTensors[dst_node->outputIndex[output_idx]].get(); - output_idx++; - for (const auto &channel_quant_param : output_quant_param) { - if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) { - std::unique_ptr output_quant_param_ptr = - std::make_unique(channel_quant_param); - MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale - << " zp: " << output_quant_param_ptr->zeroPoint; - output_quant_param_ptr->dstDtype = output_tensor->dataType; - output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr)); - } + int output_idx = 0; + for (const auto &output_quant_param : output_quant_params) { + auto output_tensor = meta_graph->allTensors[dst_node->outputIndex[output_idx]].get(); + output_idx++; + for (const auto &channel_quant_param : output_quant_param) { + if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) { + std::unique_ptr output_quant_param_ptr = + std::make_unique(channel_quant_param); + MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale + << " zp: " << output_quant_param_ptr->zeroPoint; + output_quant_param_ptr->dstDtype = output_tensor->dataType; + output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr)); } } } - - auto status = SetQuantOutputTensorType(meta_graph, primitive, dst_node); - if (status != RET_OK) { - MS_LOG(ERROR) << "set quant output tensor data type failed."; - return RET_ERROR; - } return RET_OK; } @@ -224,6 +207,7 @@ int AnfExporter::SetGraphInputIndex(const std::unique_ptr &m tensor->format = schema::Format_NHWC; if (!IsContain(subgraph->inputIndices, input)) { if (subgraph_index == kMainGraphIndex) { + TensorDataType::GetInstance()->UpdateGraphInputDType(meta_graphT->inputIndex.size(), tensor->dataType); meta_graphT->inputIndex.push_back(input); } subgraph->inputIndices.push_back(input); @@ -262,6 +246,8 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgrap } for (unsigned int &i : return_node->inputIndex) { if (subgraph_index == kMainGraphIndex) { + auto &tensor = meta_graphT->allTensors.at(i); + TensorDataType::GetInstance()->UpdateGraphOutputDType(meta_graphT->outputIndex.size(), tensor->dataType); meta_graphT->outputIndex.push_back(i); } meta_graphT->subGraph.at(subgraph_index)->outputIndices.push_back(i); @@ -354,6 +340,13 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrnodes.push_back(std::move(node)); meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx_++); } @@ -615,7 +608,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrelements(); - for (size_t i = 0; i < elements.size(); i++) { + for (size_t i = 0; i < lite::GetCNodeOutputsSize(cnode, train_flag_); i++) { auto msTensor = new (std::nothrow) schema::TensorT(); if (msTensor == nullptr) { MS_LOG(ERROR) << "new msTensor failed"; @@ -627,8 +620,6 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrallTensors.size(); meta_graphT->allTensors.emplace_back(msTensor); - if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) - break; } else { if (elements.size() == 1) { node_id_map_[cnode_name] = meta_graphT->allTensors.size(); diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index ed20096e78..6fd2ca4065 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -57,9 +57,9 @@ class AnfExporter { int SetGraphInputIndex(const std::unique_ptr &meta_graphT, const size_t &subgraph_index); int SetGraphoutputIndex(const CNodePtr &cnode, size_t subgraph_index, const std::unique_ptr &meta_graphT, schema::CNodeT *return_node); - static int SetQuantOutputTensorType(const std::unique_ptr &meta_graph, - const std::shared_ptr &primitive, - const std::unique_ptr &dst_node); + static int SetPostTrainOutputTensorType(const std::unique_ptr &meta_graph, + const std::shared_ptr &primitive, + const std::unique_ptr &dst_node); static int ConvertQuantParam(const std::unique_ptr &meta_graph, const std::shared_ptr &primitive, const std::unique_ptr &dst_node); diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index d537c575b1..1225a8f89c 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -495,13 +495,21 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); MS_ASSERT(prim != nullptr); + if (prim->src_t == TypeId::kNumberTypeUInt8) { + if (preTensor->dataType == TypeId::kNumberTypeUInt8) { + toAddTensor->quantParams.front()->zeroPoint -= 128; + } else { + preTensor->quantParams.front()->zeroPoint += 128; + } + } else if (prim->dst_t == TypeId::kNumberTypeUInt8) { + if (preTensor->dataType == TypeId::kNumberTypeInt8) { + toAddTensor->quantParams.front()->zeroPoint += 128; + } else { + preTensor->quantParams.front()->zeroPoint -= 128; + } + } preTensor->dataType = prim->src_t; toAddTensor->dataType = prim->dst_t; - if (prim->src_t == TypeId::kNumberTypeUInt8 && prim->dst_t == TypeId::kNumberTypeInt8) { - preTensor->quantParams.front()->zeroPoint += 128; - } else if (prim->src_t == TypeId::kNumberTypeInt8 && prim->dst_t == TypeId::kNumberTypeUInt8) { - toAddTensor->quantParams.front()->zeroPoint += 128; - } } graphT->allTensors.emplace_back(std::move(toAddTensor)); size_t toAddTensorIdx = graphT->allTensors.size() - 1; @@ -565,13 +573,21 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); MS_ASSERT(prim != nullptr); + if (prim->dst_t == TypeId::kNumberTypeUInt8) { + if (postTensor->dataType == TypeId::kNumberTypeUInt8) { + postTensor->quantParams.front()->zeroPoint -= 128; + } else { + toAddTensor->quantParams.front()->zeroPoint += 128; + } + } else if (prim->src_t == TypeId::kNumberTypeUInt8) { + if (postTensor->dataType == TypeId::kNumberTypeUInt8) { + toAddTensor->quantParams.front()->zeroPoint -= 128; + } else { + postTensor->quantParams.front()->zeroPoint += 128; + } + } postTensor->dataType = prim->src_t; toAddTensor->dataType = prim->dst_t; - if (prim->src_t == TypeId::kNumberTypeInt8 && prim->dst_t == TypeId::kNumberTypeUInt8) { - toAddTensor->quantParams.front()->zeroPoint += 128; - } else if (prim->src_t == TypeId::kNumberTypeUInt8 && prim->dst_t == TypeId::kNumberTypeInt8) { - postTensor->quantParams.front()->zeroPoint += 128; - } } graphT->allTensors.emplace_back(std::move(toAddTensor)); size_t toAddTensorIdx = graphT->allTensors.size() - 1; diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 98daaae2aa..4930d828ec 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -446,5 +446,20 @@ STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { } return RET_OK; } + +size_t GetCNodeOutputsSize(const std::shared_ptr &anf_node, bool train_flag) { + auto cnode = anf_node->cast(); + if (train_flag && + (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))) { + return 1; + } + if (utils::isa(cnode->abstract())) { + auto tuple = std::reinterpret_pointer_cast(cnode->abstract()); + return tuple->elements().size(); + } else { + return 1; + } +} + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index e4bf078bb0..20d004a265 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -26,6 +26,7 @@ #include "src/common/log_adapter.h" #include "include/errorcode.h" #include "securec/include/securec.h" +#include "tools/optimizer/common/gllo_utils.h" namespace mindspore { namespace lite { @@ -401,6 +402,8 @@ static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) } STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat); + +size_t GetCNodeOutputsSize(const std::shared_ptr &anf_node, bool train_flag = false); } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H diff --git a/mindspore/lite/tools/converter/converter_context.h b/mindspore/lite/tools/converter/converter_context.h index 24e13f7c96..9bc101ff54 100644 --- a/mindspore/lite/tools/converter/converter_context.h +++ b/mindspore/lite/tools/converter/converter_context.h @@ -77,18 +77,29 @@ class TensorDataType { static TensorDataType tensor_data_type; return &tensor_data_type; } - void UpdateTensorType(int32_t index, int32_t type) { tensor_data_type_map_[index] = type; } - int32_t GetTensorType(int32_t index) const { - if (tensor_data_type_map_.find(index) == tensor_data_type_map_.end()) { + + void UpdateGraphInputDType(int32_t index, int32_t dtype) { graph_input_data_type_map_[index] = dtype; } + int32_t GetGraphInputDType(int32_t index) const { + if (graph_input_data_type_map_.find(index) == graph_input_data_type_map_.end()) { + return TypeId::kTypeUnknown; + } + return graph_input_data_type_map_.at(index); + } + + void UpdateGraphOutputDType(int32_t index, int32_t dtype) { graph_output_data_type_map_[index] = dtype; } + int32_t GetGraphOutputDType(int32_t index) const { + if (graph_output_data_type_map_.find(index) == graph_output_data_type_map_.end()) { return TypeId::kTypeUnknown; } - return tensor_data_type_map_.at(index); + return graph_output_data_type_map_.at(index); } private: TensorDataType() {} virtual ~TensorDataType() = default; std::map tensor_data_type_map_; + std::map graph_input_data_type_map_; + std::map graph_output_data_type_map_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index 5e4016bcdc..4ad2b670ef 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -23,6 +23,7 @@ #include "tools/converter/converter_context.h" #include "src/common/common.h" #include "src/common/utils.h" +#include "tools/converter/quantizer/quantize_util.h" namespace mindspore { namespace lite { @@ -38,18 +39,17 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { return status; } - status = DoNodeInoutDTypeTrans(graph); + status = DoModelOutputDTypeTrans(graph); if (status != RET_OK) { - MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; + MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; return status; } - status = DoModelOutputDTypeTrans(graph); + status = DoNodeInoutDTypeTrans(graph); if (status != RET_OK) { - MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; + MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; return status; } - return RET_OK; } @@ -61,15 +61,23 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { MS_LOG(ERROR) << "Invalid inputDataType: " << this->input_data_dtype; return RET_ERROR; } - for (auto graph_in_idx : graph_in_idxes) { + for (size_t i = 0; i < graph_in_idxes.size(); i++) { + size_t graph_in_idx = graph_in_idxes.at(i); MS_ASSERT(graph_in_idx < graph->allTensors.size()); auto &tensor = graph->allTensors.at(graph_in_idx); - if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { + if (!quant::TensorQuantParamsInited(*tensor)) { + continue; + } + + if (this->input_data_dtype == TypeId::kTypeUnknown) { + if (tensor->dataType != TensorDataType::GetInstance()->GetGraphInputDType(i)) { + MS_LOG(ERROR) << "Change graph input dtype is not allowed."; + return RET_ERROR; + } continue; } - int32_t tensor_data_type = this->input_data_dtype != TypeId::kTypeUnknown - ? this->input_data_dtype - : TensorDataType::GetInstance()->GetTensorType(graph_in_idx); + + int32_t tensor_data_type = this->input_data_dtype; for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto node_name = (*iter)->name; for (size_t input_indexidx = 0; input_indexidx < (*iter)->inputIndex.size(); input_indexidx++) { @@ -77,7 +85,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { STATUS status = RET_OK; // insert dtype cast node between input tensor and input node - if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) { + if (tensor_data_type != tensor->dataType) { iter = InsertDTypeTransNode(graph, iter, kBefore, input_indexidx, tensor_data_type, tensor->dataType, &status); } @@ -101,15 +109,23 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { return RET_ERROR; } auto &graph_out_idxes = graph->outputIndex; - for (auto graph_out_idx : graph_out_idxes) { + for (size_t i = 0; i < graph_out_idxes.size(); i++) { + size_t graph_out_idx = graph_out_idxes.at(i); MS_ASSERT(graph_out_idx < graph->allTensors.size()); auto &tensor = graph->allTensors.at(graph_out_idx); - if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { + if (!quant::TensorQuantParamsInited(*tensor)) { continue; } - int32_t tensor_data_type = this->output_data_dtype != TypeId::kTypeUnknown - ? this->output_data_dtype - : TensorDataType::GetInstance()->GetTensorType(graph_out_idx); + + if (this->output_data_dtype == TypeId::kTypeUnknown) { + if (tensor->dataType != TensorDataType::GetInstance()->GetGraphOutputDType(i)) { + MS_LOG(ERROR) << "Change graph output dtype is not allowed."; + return RET_ERROR; + } + continue; + } + + int32_t tensor_data_type = this->output_data_dtype; for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto node_name = (*iter)->name; MS_ASSERT(node != nullptr); @@ -117,7 +133,7 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { if ((*iter)->outputIndex.at(outputIndexIdx) == graph_out_idx) { // insert transNode STATUS status = RET_OK; - if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) { + if (tensor_data_type != tensor->dataType) { iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensor_data_type, &status); } @@ -136,27 +152,34 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { STATUS DTypeTransPass::InsetDTypeTransNodeForWrongDtypeQuantOp(schema::MetaGraphT *graph, NodeIter *iter) { auto node_name = (**iter)->name; auto status = RET_OK; - // insert fp32 to int8 before + // insert fp32/uint8 to int8 before for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) { auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i)); - if (pre_tensor->dataType == kNumberTypeFloat32 && !pre_tensor->quantParams.empty() && - pre_tensor->quantParams.front()->inited) { - *iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeFloat32, kNumberTypeInt8, &status); + // insert quant cast op for tensor which should be int8 + if ((pre_tensor->dataType == kNumberTypeFloat32 || pre_tensor->dataType == kNumberTypeUInt8) && + quant::TensorQuantParamsInited(*pre_tensor)) { + if (!pre_tensor->data.empty()) { + MS_LOG(ERROR) << "tensor with float data should be quantized at tensor_quant_pass."; + return RET_ERROR; + } + *iter = InsertDTypeTransNode(graph, *iter, kBefore, i, pre_tensor->dataType, kNumberTypeInt8, &status); if (status != RET_OK) { - MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed"; + MS_LOG(ERROR) << "Insert float32 or uint8 to int8 node after before " << node_name.c_str() << " failed"; return RET_ERROR; } } } - // insert int8 to fp32 after + // insert int8 to fp32/uint8 after for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) { auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i)); - if (post_tensor->dataType == kNumberTypeFloat32 && !post_tensor->quantParams.empty() && - post_tensor->quantParams.front()->inited) { - *iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status); + // insert quant cast op for tensor which should be int8 + // e.g: reshape's shape tensor don't need insert quant op so its quant param isn't inited + if ((post_tensor->dataType == kNumberTypeFloat32 || post_tensor->dataType == kNumberTypeUInt8) && + quant::TensorQuantParamsInited(*post_tensor)) { + *iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, post_tensor->dataType, &status); if (status != RET_OK) { - MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed"; + MS_LOG(ERROR) << "Insert int8 to float32 or uint8 node after " << node_name.c_str() << " failed"; return RET_ERROR; } } @@ -170,8 +193,7 @@ STATUS DTypeTransPass::InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraph // insert int8 to fp32 before for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) { auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i)); - if (pre_tensor->dataType == kNumberTypeInt8 && !pre_tensor->quantParams.empty() && - pre_tensor->quantParams.front()->inited) { + if (pre_tensor->dataType == kNumberTypeInt8 && quant::TensorQuantParamsInited(*pre_tensor)) { *iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status); if (status != RET_OK) { MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed"; @@ -183,8 +205,7 @@ STATUS DTypeTransPass::InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraph // insert fp32 to int8 after for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) { auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i)); - if (post_tensor->dataType == kNumberTypeInt8 && !post_tensor->quantParams.empty() && - post_tensor->quantParams.front()->inited) { + if (post_tensor->dataType == kNumberTypeInt8 && quant::TensorQuantParamsInited(*post_tensor)) { *iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status); if (status != RET_OK) { MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed"; @@ -200,8 +221,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto node_name = (*iter)->name; if ((*iter)->inputIndex.empty()) { - MS_LOG(ERROR) << "Op " << node_name.c_str() << " should have " << kMinInputNum << " input tensor at least"; - return RET_ERROR; + MS_LOG(WARNING) << "Op " << node_name.c_str() << " should have " << kMinInputNum << " input tensor at least"; + continue; } if ((*iter)->primitive->value.type == schema::PrimitiveType_QuantDTypeCast || @@ -270,6 +291,10 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte trans_node->name = "uint8toint8_" + tile_name + std::to_string(id_++); } else if (input_data_type == TypeId::kNumberTypeInt8 && output_data_type == TypeId::kNumberTypeUInt8) { trans_node->name = "int8touint8_" + tile_name + std::to_string(id_++); + } else if (input_data_type == TypeId::kNumberTypeUInt8 && output_data_type == TypeId::kNumberTypeFloat32) { + trans_node->name = "uint8toft32_" + tile_name + std::to_string(id_++); + } else if (input_data_type == TypeId::kNumberTypeFloat32 && output_data_type == TypeId::kNumberTypeUInt8) { + trans_node->name = "ft32touint8_" + tile_name + std::to_string(id_++); } trans_node->primitive->value.value = quant_dtype_cast_param; int insert_num = 0; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc index 7ea6ee35a5..73c70b66d0 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc @@ -25,29 +25,15 @@ namespace mindspore::lite { namespace { -STATUS PreHandleQuantDtypeCast(schema::MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - for (auto &node : graph->nodes) { - if (node == nullptr || node->primitive == nullptr) { - MS_LOG(ERROR) << " node or node->primitive is nullptr"; - return RET_ERROR; - } - if (node->primitive->value.type == PrimitiveType_QuantDTypeCast) { - auto attr = node->primitive->value.AsQuantDTypeCast(); - auto &inputTensor = graph->allTensors.at(node->inputIndex.front()); - inputTensor->dataType = attr->src_t; - auto &outputTensor = graph->allTensors.at(node->outputIndex.front()); - outputTensor->dataType = attr->dst_t; - - if (attr->src_t == TypeId::kNumberTypeUInt8) { - attr->src_t = TypeId::kNumberTypeInt8; - } - if (attr->dst_t == TypeId::kNumberTypeUInt8) { - attr->dst_t = TypeId::kNumberTypeInt8; - } - } +bool TensorNeedQuant(const std::unique_ptr &tensor) { + if (!quant::TensorQuantParamsInited(*tensor)) { + return false; } - return RET_OK; + if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && + tensor->dataType != TypeId::kNumberTypeUInt8 && tensor->dataType != TypeId::kTypeUnknown) { + return false; + } + return !tensor->data.empty(); } STATUS ComputeDataToInt8(const std::unique_ptr &tensor, int32_t index) { @@ -73,7 +59,6 @@ STATUS ComputeDataToInt8(const std::unique_ptr &tensor, int32_t index) weightQauntParam->zeroPoint -= 128; tensor->quantParams.clear(); tensor->quantParams.emplace_back(weightQauntParam.release()); - TensorDataType::GetInstance()->UpdateTensorType(index, TypeId::kNumberTypeUInt8); } tensor->dataType = TypeId::kNumberTypeInt8; if (tensor->data.empty()) { @@ -174,23 +159,15 @@ STATUS ComputeQuantTensorPerChannel(TensorT *tensor, const int &tensor_index, co STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); - auto status = PreHandleQuantDtypeCast(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "pre adjust failed."; - return status; - } int32_t index = 0; + auto status = RET_OK; for (auto &tensor : graph->allTensors) { - if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { - index++; - continue; - } - if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && - tensor->dataType != TypeId::kNumberTypeUInt8 && tensor->dataType != TypeId::kTypeUnknown) { + if (!TensorNeedQuant(tensor)) { index++; continue; } - if (tensor->quantParams.size() != 1) { // perchannel + + if (tensor->quantParams.size() > 1) { // perchannel status = ComputeQuantTensorPerChannel(tensor.get(), index, *graph); if (status != RET_OK) { MS_LOG(ERROR) << "compute tensor to int8 prechannel failed."; @@ -201,8 +178,8 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { } // perlayer auto &quantParam = tensor->quantParams.front(); - if (quantParam->dstDtype == TypeId::kNumberTypeUInt8 || quantParam->dstDtype == TypeId::kNumberTypeFloat32 || - quantParam->dstDtype == TypeId::kNumberTypeFloat) { + if (quantParam->dstDtype == TypeId::kNumberTypeInt8 || quantParam->dstDtype == TypeId::kNumberTypeUInt8 || + quantParam->dstDtype == TypeId::kNumberTypeFloat32 || quantParam->dstDtype == TypeId::kNumberTypeFloat) { status = ComputeDataToInt8(tensor, index); } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { // quant bias data diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.cc new file mode 100644 index 0000000000..5f6660838a --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h" +#include "mindspore/core/utils/log_adapter.h" +#include "mindspore/core/ir/dtype/type_id.h" +namespace mindspore::lite { +static constexpr size_t kInputIndex = 0; +static constexpr size_t kWeightIndex = 1; + +STATUS QuantDtypeCastQuantParamPropogator::PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) { + auto &input_tensor = graph->allTensors.at(node.inputIndex.at(0)); + if (!input_tensor->quantParams.empty() && input_tensor->quantParams.front()->inited) { + input_tensor->quantParams.front()->dstDtype = input_tensor->dataType; + } + auto &output_tensor = graph->allTensors.at(node.outputIndex.at(0)); + if (!output_tensor->quantParams.empty() && output_tensor->quantParams.front()->inited) { + output_tensor->quantParams.front()->dstDtype = output_tensor->dataType; + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h new file mode 100644 index 0000000000..56d051d106 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H + +#include "tools/converter/quantizer/quant_helper/quant_node_helper.h" +namespace mindspore::lite { +class QuantDtypeCastQuantParamPropogator : public QuantParamPropogator { + public: + STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.cc index d70ec8b68a..43cabe1010 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.cc @@ -25,6 +25,7 @@ #include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h" #include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h" #include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h" +#include "tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h" namespace mindspore::lite { void QuantNodeBase::UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node) { @@ -100,6 +101,7 @@ QuantNodeHelper *QuantHelperRegister::GetQuantHelper(schema::PrimitiveType op_ty QuantHelperRegister::QuantHelperRegister() { auto base_propogator = std::make_shared(); auto base_determiner = std::make_shared(); + auto quant_dtype_cast_propogator = std::make_shared(); auto bias_add_propogator = std::make_shared(); auto carry_data_propogator = std::make_shared(); auto carry_data_determiner = std::make_shared(); @@ -127,7 +129,7 @@ QuantHelperRegister::QuantHelperRegister() { register_map_[schema::PrimitiveType_MatMul] = new QuantNodeHelper(conv_propogator, conv_determiner); register_map_[schema::PrimitiveType_QuantDTypeCast] = - new QuantNodeHelper(base_propogator, default_quant_all_determiner); + new QuantNodeHelper(quant_dtype_cast_propogator, default_quant_all_determiner); register_map_[schema::PrimitiveType_DetectionPostProcess] = new QuantNodeHelper(base_propogator, only_need_inputs_determiner); diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 9231decb25..79f6434367 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -248,6 +248,25 @@ QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) { } return quant_params_holder; } +bool QuantParamEqual(const schema::QuantParamT &quant_param1, const schema::QuantParamT &quant_param2) { + return quant_param1.inited == quant_param2.inited && quant_param1.scale == quant_param2.scale && + quant_param1.zeroPoint == quant_param2.zeroPoint && quant_param1.min == quant_param2.min && + quant_param1.max == quant_param2.max && quant_param1.narrowRange == quant_param2.narrowRange && + quant_param1.numBits == quant_param2.numBits && quant_param1.inited == quant_param2.inited && + quant_param1.varCorr == quant_param2.varCorr && quant_param1.meanCorr == quant_param2.meanCorr; +} +bool TensorQuantParamsInited(const schema::TensorT &tensor) { + if (tensor.quantParams.empty()) { + return false; + } + + for (auto &quant_param : tensor.quantParams) { + if (!quant_param->inited) { + return false; + } + } + return true; +} STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits) { diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 0324b0bd46..7bc02ef41d 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -175,6 +175,10 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan }(); } +bool QuantParamEqual(const schema::QuantParamT &quant_param1, const schema::QuantParamT &quant_param2); + +bool TensorQuantParamsInited(const schema::TensorT &tensor); + template STATUS DoPerChannelQuant(const tensor::TensorPtr &weight, const QuantType &quant_type, std::vector *quant_params, const int &quant_max, const int &quant_min, diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index a95dd20ce2..e80e89ef2e 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -28,6 +28,7 @@ #include "tools/common/tensor_util.h" #include "frontend/operator/ops.h" #include "backend/optimizer/common/helper.h" +#include "tools/converter/quant_param_holder.h" using float16 = Eigen::half; @@ -1416,6 +1417,9 @@ CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm_node}); MS_ASSERT(cnode != nullptr); cnode->set_fullname_with_scope(cnode_name); + auto quant_params_holder = std::make_shared(2, 1); + auto trans_insert_prim = GetValueNode(cnode->input(0)); + trans_insert_prim->AddAttr("quant_params", quant_params_holder); return cnode; } diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc index 135035f02c..e2d25c4779 100644 --- a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc @@ -21,20 +21,11 @@ #include "tools/converter/quant_param_holder.h" #include "tools/converter/quantizer/quantize_util.h" #include "src/common/log_adapter.h" +#include "tools/common/node_util.h" namespace mindspore { namespace opt { namespace { -size_t GetCNodeOutputsSize(std::shared_ptr anf_node) { - auto cnode = anf_node->cast(); - if (utils::isa(cnode->abstract())) { - auto tuple = std::reinterpret_pointer_cast(cnode->abstract()); - return tuple->elements().size(); - } else { - return 1; - } -} - int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { auto quant_param_holder = prim->GetAttr("quant_params")->cast(); std::vector quants; @@ -112,27 +103,6 @@ int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t return lite::RET_OK; } -void CheckQuantParams(const PrimitivePtr &prim) { - auto quant_param_holder = prim->GetAttr("quant_params")->cast(); - auto input_quant_params = quant_param_holder->get_input_quant_params(); - bool is_quant = false; - for (size_t i = 0; i < input_quant_params.size(); ++i) { - if (!input_quant_params.at(i).empty() && input_quant_params.at(i).at(0).inited) { - is_quant = true; - break; - } - } - auto output_quant_params = quant_param_holder->get_output_quant_params(); - for (size_t i = 0; i < output_quant_params.size(); ++i) { - if (!output_quant_params.at(i).empty() && output_quant_params.at(i).at(0).inited) { - is_quant = true; - } - } - if (!is_quant) { - prim->EraseAttr("quant_params"); - } -} - int ConvertQuantParam(const PrimitivePtr &prim, const std::vector &inputs) { auto narrow_range = prim->GetAttr("narrow_range"); bool narrow_range_param = false; @@ -170,7 +140,6 @@ int ConvertQuantParam(const PrimitivePtr &prim, const std::vector &i MS_LOG(ERROR) << "compute output quant param failed."; return status; } - CheckQuantParams(prim); return lite::RET_OK; } } // namespace @@ -236,7 +205,8 @@ int MindirAdjustPass::ComputeQuantParams(std::shared_ptr anf_node) { auto inputs = cnode->inputs(); inputs.erase(inputs.begin()); - auto quant_param_holder = std::make_shared(inputs.size(), GetCNodeOutputsSize(anf_node)); + auto quant_param_holder = + std::make_shared(inputs.size(), lite::GetCNodeOutputsSize(anf_node, train_flag_)); primitive->AddAttr("quant_params", quant_param_holder); if (ConvertQuantParam(primitive, inputs) != lite::RET_OK) { diff --git a/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc b/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc index af3bfdd1a9..19be1d80e9 100644 --- a/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc +++ b/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc @@ -25,7 +25,6 @@ #include "ops/fusion/slice_fusion.h" #include "ops/op_utils.h" #include "ops/strided_slice.h" -#include "tools/converter/quant_param_holder.h" namespace mindspore { namespace opt { @@ -93,9 +92,6 @@ AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_gr std::string trans_name = before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post"; auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name); - auto quant_params_holder = std::make_shared(1, 1); - auto trans_insert_prim = GetValueNode(trans_insert_node->input(0)); - trans_insert_prim->AddAttr("quant_params", quant_params_holder); return trans_insert_node; } diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc index 852a51a0d0..2c8f412e16 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc @@ -20,6 +20,7 @@ #include "ops/transpose.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/common/tensor_util.h" +#include "tools/converter/quant_param_holder.h" using mindspore::lite::converter::FmkType_CAFFE; using mindspore::lite::converter::FmkType_MS; @@ -92,6 +93,7 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu } auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm"); auto prim = std::make_shared(); + prim->AddAttr("quant_params", std::make_shared(1, 1)); auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node}); if (!weight_node->has_default()) { MS_LOG(DEBUG) << "Weight parameter should has default parameter.";