From: @cjh9368 Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiangtags/v1.3.0
| @@ -164,7 +164,6 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||
| auto input_quant_param_ptr = std::make_unique<schema::QuantParamT>(input_quant_param); | |||
| MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale | |||
| << " zp: " << input_quant_param_ptr->zeroPoint; | |||
| input_quant_param_ptr->dstDtype = tensor_input->dataType; | |||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); | |||
| } | |||
| } | |||
| @@ -185,7 +184,6 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||
| std::make_unique<schema::QuantParamT>(channel_quant_param); | |||
| MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale | |||
| << " zp: " << output_quant_param_ptr->zeroPoint; | |||
| output_quant_param_ptr->dstDtype = output_tensor->dataType; | |||
| output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr)); | |||
| } | |||
| } | |||
| @@ -222,7 +220,6 @@ int AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &m | |||
| tensor->format = schema::Format_NHWC; | |||
| if (!IsContain(subgraph->inputIndices, input)) { | |||
| if (subgraph_index == kMainGraphIndex) { | |||
| TensorDataType::GetInstance()->UpdateGraphInputDType(meta_graphT->inputIndex.size(), tensor->dataType); | |||
| meta_graphT->inputIndex.push_back(input); | |||
| } | |||
| subgraph->inputIndices.push_back(input); | |||
| @@ -624,40 +621,36 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| } | |||
| auto elements = tuple->elements(); | |||
| for (size_t i = 0; i < lite::GetCNodeOutputsSize(cnode, train_flag_); i++) { | |||
| auto msTensor = new (std::nothrow) schema::TensorT(); | |||
| if (msTensor == nullptr) { | |||
| auto ms_tensor = new (std::nothrow) schema::TensorT(); | |||
| if (ms_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new msTensor failed"; | |||
| return; | |||
| } | |||
| msTensor->nodeType = NodeType_CNode; | |||
| ms_tensor->nodeType = NodeType_CNode; | |||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| if (train_flag_) { | |||
| std::string name = cnode_name + "_o:" + std::to_string(i); | |||
| node_id_map_[name] = meta_graphT->allTensors.size(); | |||
| meta_graphT->allTensors.emplace_back(msTensor); | |||
| meta_graphT->allTensors.emplace_back(ms_tensor); | |||
| } else { | |||
| if (elements.size() == 1) { | |||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | |||
| msTensor->name = cnode_name; | |||
| ms_tensor->name = cnode_name; | |||
| } else { | |||
| std::string name = cnode_name + "_o:" + std::to_string(i); | |||
| node_id_map_[name] = meta_graphT->allTensors.size(); | |||
| msTensor->name = name; | |||
| ms_tensor->name = name; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||
| MS_LOG(ERROR) << "abstract is not AbstractTensor"; | |||
| delete (msTensor); | |||
| delete (ms_tensor); | |||
| return; | |||
| } | |||
| auto type = kNumberTypeFloat32; | |||
| if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| type = typePtr->type_id(); | |||
| } | |||
| msTensor->dataType = type; | |||
| meta_graphT->allTensors.emplace_back(msTensor); | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]); | |||
| auto type_ptr = abstract_tensor->element()->GetTypeTrack(); | |||
| ms_tensor->dataType = type_ptr->type_id(); | |||
| meta_graphT->allTensors.emplace_back(ms_tensor); | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || | |||
| opt::CheckPrimitiveType(cnode, prim::kPrimFusedBatchNorm)) { | |||
| break; | |||
| @@ -709,5 +709,58 @@ std::string BoolVectorToString(const std::vector<bool> &bool_vec) { | |||
| return str; | |||
| } | |||
| TypeId GetAbstractTensorDtype(const abstract::AbstractTensorPtr &tensor) { | |||
| if (tensor == nullptr || tensor->element() == nullptr) { | |||
| MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr"; | |||
| return kTypeUnknown; | |||
| } | |||
| auto type_ptr = tensor->element()->GetTypeTrack(); | |||
| return type_ptr->type_id(); | |||
| } | |||
| TypeId GetParameterDtype(const ParameterPtr ¶m_node) { | |||
| auto abstract_base = param_node->abstract(); | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| auto type_ptr = abstract_tensor->element()->GetTypeTrack(); | |||
| return type_ptr->type_id(); | |||
| } | |||
| STATUS UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr &func_graph) { | |||
| // update graph inputs dtype | |||
| size_t idx = 0; | |||
| for (auto &input : func_graph->get_inputs()) { | |||
| TypeId type = GetParameterDtype(input->cast<ParameterPtr>()); | |||
| TensorDataType::GetInstance()->UpdateGraphInputDType(idx, type); | |||
| idx++; | |||
| } | |||
| // update graph outputs dtype | |||
| auto graph_return = func_graph->get_return(); | |||
| idx = 0; | |||
| for (auto &input : graph_return->inputs()) { | |||
| if (input->isa<CNode>()) { | |||
| if (utils::isa<abstract::AbstractTuple>(input->abstract())) { | |||
| auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(input->abstract()); | |||
| if (tuple == nullptr) { | |||
| MS_LOG(ERROR) << "tuple is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| for (const auto &tuple_item : tuple->elements()) { | |||
| TypeId type = GetAbstractTensorDtype(tuple_item->cast<abstract::AbstractTensorPtr>()); | |||
| TensorDataType::GetInstance()->UpdateGraphOutputDType(idx, type); | |||
| idx++; | |||
| } | |||
| } else if (utils::isa<abstract::AbstractTensor>(input->abstract())) { | |||
| TypeId type = GetAbstractTensorDtype(input->abstract()->cast<abstract::AbstractTensorPtr>()); | |||
| TensorDataType::GetInstance()->UpdateGraphOutputDType(idx, type); | |||
| idx++; | |||
| } else { | |||
| TensorDataType::GetInstance()->UpdateGraphOutputDType(idx, kTypeUnknown); | |||
| idx++; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -32,6 +32,8 @@ | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/common/graph_util.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -104,6 +106,12 @@ std::vector<int> GetTransposePerm(schema::MetaGraphT *graph, const std::unique_p | |||
| std::string BoolVectorToString(const std::vector<bool> &bool_vec); | |||
| TypeId GetAbstractTensorDtype(const abstract::AbstractTensorPtr &tensor); | |||
| TypeId GetParameterDtype(const ParameterPtr ¶m_node); | |||
| STATUS UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr &func_graph); | |||
| template <typename T> | |||
| bool IndexingCompress(const std::set<T> &quant_data_set, const std::map<T, size_t> &unique_value_index_map, | |||
| size_t unique_value_bit, size_t unique_value_cnt, size_t pack_repetition_size_in_byte, | |||
| @@ -25,6 +25,7 @@ | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/anf_transform.h" | |||
| #include "tools/converter/converter_context.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "load_mindir/load_model.h" | |||
| namespace mindspore { | |||
| @@ -62,6 +63,12 @@ class MindsporeImporter : public Converter { | |||
| } | |||
| func_graph->set_attr("graph_name", MakeValue("main_graph")); | |||
| func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS))); | |||
| auto status = UpdateFuncGraphInputsAndOutputsDtype(func_graph); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "update graph inputs and outputs dtype failed."; | |||
| return nullptr; | |||
| } | |||
| return func_graph; | |||
| } | |||
| }; | |||
| @@ -70,7 +70,8 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| } | |||
| if (this->input_data_dtype == TypeId::kTypeUnknown) { | |||
| if (tensor->dataType != TensorDataType::GetInstance()->GetGraphInputDType(i)) { | |||
| auto origin_input_dtype = TensorDataType::GetInstance()->GetGraphInputDType(i); | |||
| if (origin_input_dtype != kTypeUnknown && tensor->dataType != origin_input_dtype) { | |||
| MS_LOG(ERROR) << "Change graph input dtype is not allowed."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -118,7 +119,8 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| } | |||
| if (this->output_data_dtype == TypeId::kTypeUnknown) { | |||
| if (tensor->dataType != TensorDataType::GetInstance()->GetGraphOutputDType(i)) { | |||
| auto origin_output_dtype = TensorDataType::GetInstance()->GetGraphOutputDType(i); | |||
| if (origin_output_dtype != kTypeUnknown && tensor->dataType != origin_output_dtype) { | |||
| MS_LOG(ERROR) << "Change graph output dtype is not allowed."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -66,6 +66,12 @@ int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std: | |||
| } | |||
| res_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||
| res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE))); | |||
| status = UpdateFuncGraphInputsAndOutputsDtype(res_graph_); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "update graph inputs and outputs dtype failed."; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -67,6 +67,12 @@ int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std:: | |||
| } | |||
| res_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||
| res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); | |||
| status = UpdateFuncGraphInputsAndOutputsDtype(res_graph_); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "update graph inputs and outputs dtype failed."; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -549,6 +549,11 @@ int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::str | |||
| return status; | |||
| } | |||
| status = UpdateFuncGraphInputsAndOutputsDtype(res_graph_); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "update graph inputs and outputs dtype failed."; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -24,6 +24,7 @@ | |||
| #include "tools/converter/ops/ops_def.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "ir/func_graph.h" | |||
| #include "tools/common/graph_util.h" | |||
| namespace mindspore::lite { | |||
| std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *model_path) { | |||
| @@ -80,6 +81,12 @@ int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std | |||
| return status; | |||
| } | |||
| res_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||
| status = UpdateFuncGraphInputsAndOutputsDtype(res_graph_); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "update graph inputs and outputs dtype failed."; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -217,6 +224,7 @@ STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tens | |||
| if (!tflite_tensor->quantization->max.empty()) { | |||
| quant_param->max = tflite_tensor->quantization->max[i]; | |||
| } | |||
| quant_param->dstDtype = GetTfliteDataType(tflite_tensor->type); | |||
| quant_param->inited = true; | |||
| quant_param->roundType = round_type; | |||
| quant_param->multiplier = 1; | |||
| @@ -287,7 +295,8 @@ STATUS TfliteModelParser::ConvertGraphInputs() { | |||
| std::vector<int64_t> shape_vector; | |||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type)); | |||
| auto dtype = GetTfliteDataType(tensor->type); | |||
| auto abstract_tensor = CreateTensorAbstract(shape_vector, dtype); | |||
| if (abstract_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||
| return RET_ERROR; | |||
| @@ -310,9 +319,9 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { | |||
| } | |||
| auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); | |||
| make_tuple_inputs.emplace_back(make_tuple_prim); | |||
| for (auto outputNode : tflite_subgraph->outputs) { | |||
| outputNode = outputNode < 0 ? outputNode + tflite_subgraph->tensors.size() : outputNode; | |||
| auto cnode = nodes_.at(outputNode); | |||
| for (auto output_idx : tflite_subgraph->outputs) { | |||
| output_idx = output_idx < 0 ? output_idx + tflite_subgraph->tensors.size() : output_idx; | |||
| auto cnode = nodes_.at(output_idx); | |||
| if (nullptr == cnode) { | |||
| MS_LOG(ERROR) << "Can't find input node."; | |||
| return RET_NOT_FIND_OP; | |||
| @@ -30,15 +30,15 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t | |||
| auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>(); | |||
| std::vector<schema::QuantParamT> quants; | |||
| schema::QuantParamT quant_param; | |||
| auto inputMin = prim->GetAttr("input_minq"); | |||
| auto inputMax = prim->GetAttr("input_maxq"); | |||
| if (inputMin != nullptr && inputMax != nullptr) { | |||
| auto inputMinPtr = inputMin->cast<tensor::TensorPtr>(); | |||
| auto inputMaxPtr = inputMax->cast<tensor::TensorPtr>(); | |||
| auto *minBuf = static_cast<float *>(inputMinPtr->data_c()); | |||
| auto *maxBuf = static_cast<float *>(inputMaxPtr->data_c()); | |||
| quant_param.min = *minBuf; | |||
| quant_param.max = *maxBuf; | |||
| auto input_min = prim->GetAttr("input_minq"); | |||
| auto input_max = prim->GetAttr("input_maxq"); | |||
| if (input_min != nullptr && input_max != nullptr) { | |||
| auto input_min_ptr = input_min->cast<tensor::TensorPtr>(); | |||
| auto input_max_ptr = input_max->cast<tensor::TensorPtr>(); | |||
| auto *min_buf = static_cast<float *>(input_min_ptr->data_c()); | |||
| auto *max_buf = static_cast<float *>(input_max_ptr->data_c()); | |||
| quant_param.min = *min_buf; | |||
| quant_param.max = *max_buf; | |||
| auto ret = | |||
| lite::quant::CalQuantizationParams(&quant_param, quant_param.min, quant_param.max, narrow_range, numbits); | |||
| if (ret != RET_OK) { | |||
| @@ -50,19 +50,19 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t | |||
| } | |||
| quants.clear(); | |||
| auto filterMin = prim->GetAttr("filter_minq"); | |||
| auto filterMax = prim->GetAttr("filter_maxq"); | |||
| if (filterMin != nullptr && filterMax != nullptr) { | |||
| auto filterMinPtr = filterMin->cast<tensor::TensorPtr>(); | |||
| auto filterMaxPtr = filterMax->cast<tensor::TensorPtr>(); | |||
| auto *minBuf = static_cast<float *>(filterMinPtr->data_c()); | |||
| auto *maxBuf = static_cast<float *>(filterMaxPtr->data_c()); | |||
| auto filter_min = prim->GetAttr("filter_minq"); | |||
| auto filter_max = prim->GetAttr("filter_maxq"); | |||
| if (filter_min != nullptr && filter_max != nullptr) { | |||
| auto filter_min_ptr = filter_min->cast<tensor::TensorPtr>(); | |||
| auto filter_max_ptr = filter_max->cast<tensor::TensorPtr>(); | |||
| auto *min_buf = static_cast<float *>(filter_min_ptr->data_c()); | |||
| auto *max_buf = static_cast<float *>(filter_max_ptr->data_c()); | |||
| quant_param.min = FLT_MAX; | |||
| quant_param.max = FLT_MIN; | |||
| for (int i = 0; i < filterMinPtr->ElementsNum(); ++i) { | |||
| for (int i = 0; i < filter_min_ptr->ElementsNum(); ++i) { | |||
| schema::QuantParamT tmp_quant_param; | |||
| tmp_quant_param.min = *minBuf; | |||
| tmp_quant_param.max = *maxBuf; | |||
| tmp_quant_param.min = *min_buf; | |||
| tmp_quant_param.max = *max_buf; | |||
| auto ret = | |||
| lite::quant::CalQuantizationParams(&tmp_quant_param, tmp_quant_param.min, tmp_quant_param.max, true, numbits); | |||
| if (ret != RET_OK) { | |||
| @@ -70,8 +70,8 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t | |||
| return ret; | |||
| } | |||
| quants.emplace_back(tmp_quant_param); | |||
| minBuf++; | |||
| maxBuf++; | |||
| min_buf++; | |||
| max_buf++; | |||
| } | |||
| quant_param_holder->set_input_quant_param(1, quants); | |||
| } | |||