| @@ -322,11 +322,6 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) { | |||
| } | |||
| int LiteSession::CompileGraph(Model *model) { | |||
| if (!ModelVerify(*model)) { | |||
| MS_LOG(ERROR) << "wrong model input, please check"; | |||
| return RET_ERROR; | |||
| } | |||
| bool expected = false; | |||
| if (!is_running_.compare_exchange_strong(expected, true)) { | |||
| MS_LOG(ERROR) << "Not support multi-threading"; | |||
| @@ -343,6 +338,11 @@ int LiteSession::CompileGraph(Model *model) { | |||
| is_running_.store(false); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (!ModelVerify(*model)) { | |||
| MS_LOG(ERROR) << "wrong model input, please check"; | |||
| is_running_.store(false); | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = ConvertTensors(model); | |||
| if (ret != RET_OK) { | |||
| @@ -44,7 +44,11 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| MS_ASSERT(nullptr != abstract_tensor); | |||
| parameter->set_abstract(abstract_tensor); | |||
| parameter->set_name("const_" + std::to_string(i) + "_parameter"); | |||
| if (!tensor->name.empty()) { | |||
| parameter->set_name(tensor->name); | |||
| } else { | |||
| parameter->set_name("const-" + std::to_string(i)); | |||
| } | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| MS_ASSERT(nullptr != param_value); | |||
| @@ -74,6 +74,7 @@ class TensorCache { | |||
| } else { | |||
| tensor->nodeType = schema::NodeType_Parameter; | |||
| } | |||
| tensor->name = name; | |||
| tensors.push_back(tensor); | |||
| if (Category == GRAPH_INPUT) { | |||
| @@ -180,24 +180,25 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| } | |||
| } | |||
| // topological sorting | |||
| // tensor name | |||
| { | |||
| Optimizer topologicalOptimizer; | |||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| status = topologicalOptimizer.Run(graphDefT); | |||
| Optimizer nameOptimizer; | |||
| nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| nameOptimizer.AddPass(new (std::nothrow) TensorNamePass()); | |||
| status = nameOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| // tensor name | |||
| // topological sorting | |||
| { | |||
| Optimizer nameOptimizer; | |||
| nameOptimizer.AddPass(new (std::nothrow) TensorNamePass()); | |||
| status = nameOptimizer.Run(graphDefT); | |||
| Optimizer topologicalOptimizer; | |||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| status = topologicalOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -21,54 +21,31 @@ | |||
| namespace mindspore::lite { | |||
| STATUS TensorNamePass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (int i = 0; i < static_cast<int>(graph->inputIndex.size()); i++) { | |||
| auto tensor_id = graph->inputIndex.at(i); | |||
| auto &tensor = graph->allTensors.at(tensor_id); | |||
| tensor->name = "graph_input-" + std::to_string(i); | |||
| if (graph == nullptr) { | |||
| MS_LOG(ERROR) << "graph is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| for (auto &node : graph->nodes) { | |||
| if (node == nullptr || node->primitive == nullptr) { | |||
| MS_LOG(ERROR) << " node or node->primitive is nullptr"; | |||
| return RET_ERROR; | |||
| MS_LOG(ERROR) << "node or node->primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| for (int i = 0; i < static_cast<int>(node->outputIndex.size()); i++) { | |||
| auto tensor_id = node->outputIndex.at(i); | |||
| for (int i = 0; i < static_cast<int>(node->inputIndex.size()); i++) { | |||
| auto tensor_id = node->inputIndex.at(i); | |||
| auto &tensor = graph->allTensors.at(tensor_id); | |||
| if (tensor->name.empty()) { | |||
| tensor->name = node->name + "/output-" + std::to_string(i); | |||
| MS_LOG(WARNING) << "input tensor (id = " << tensor_id << ") name is null"; | |||
| tensor->name = node->name + "/input-" + std::to_string(i); | |||
| } | |||
| } | |||
| auto type = node->primitive->value.type; | |||
| if (type == PrimitiveType_Conv2D || type == PrimitiveType_DeConv2D || type == PrimitiveType_DepthwiseConv2D || | |||
| type == PrimitiveType_DeDepthwiseConv2D || type == PrimitiveType_FullConnection) { | |||
| auto input_size = node->inputIndex.size(); | |||
| if (input_size > 1) { | |||
| auto weight_tensor_id = node->inputIndex.at(1); | |||
| auto &weight_tensor = graph->allTensors.at(weight_tensor_id); | |||
| if (weight_tensor->name.empty()) { | |||
| weight_tensor->name = node->name + "/weight"; | |||
| } | |||
| if (input_size > 2) { | |||
| auto bias_tensor_id = node->inputIndex.at(2); | |||
| auto &bias_tensor = graph->allTensors.at(bias_tensor_id); | |||
| if (bias_tensor->name.empty()) { | |||
| bias_tensor->name = node->name + "/bias"; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| for (int i = 0; i < static_cast<int>(node->inputIndex.size()); i++) { | |||
| auto tensor_id = node->inputIndex.at(i); | |||
| auto &tensor = graph->allTensors.at(tensor_id); | |||
| if (tensor->name.empty()) { | |||
| tensor->name = node->name + "/input-" + std::to_string(i); | |||
| } | |||
| for (int i = 0; i < static_cast<int>(node->outputIndex.size()); i++) { | |||
| auto tensor_id = node->outputIndex.at(i); | |||
| auto &tensor = graph->allTensors.at(tensor_id); | |||
| if (tensor->name.empty()) { | |||
| tensor->name = node->name + "/output-" + std::to_string(i); | |||
| } | |||
| } | |||
| } | |||
| @@ -115,7 +115,8 @@ STATUS TfliteModelParser::ConvertOps() { | |||
| std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))}; | |||
| // parse inputs | |||
| for (auto input_idx : op->inputs) { | |||
| for (int i = 0; i < static_cast<int>(op->inputs.size()); i++) { | |||
| auto input_idx = op->inputs.at(i); | |||
| if (tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED && input_idx == -1) { | |||
| continue; | |||
| } | |||
| @@ -127,9 +128,27 @@ STATUS TfliteModelParser::ConvertOps() { | |||
| op_inputs.emplace_back(nodes_.at(input_idx)); | |||
| continue; | |||
| } | |||
| // const tensor | |||
| std::string tensor_name; | |||
| if (!input_tensor->name.empty()) { | |||
| tensor_name = input_tensor->name; | |||
| } else { | |||
| tensor_name = op_name + "/input-" + std::to_string(op_inputs.size()); | |||
| if (tflite_op_type == tflite::BuiltinOperator_CONV_2D || | |||
| tflite_op_type == tflite::BuiltinOperator_TRANSPOSE_CONV || | |||
| tflite_op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D || | |||
| tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED) { | |||
| if (i == 1) { | |||
| tensor_name = op_name + "/weight"; | |||
| } | |||
| if (i == 2) { | |||
| tensor_name = op_name + "/bias"; | |||
| } | |||
| } | |||
| } | |||
| auto parameter = func_graph_->add_parameter(); | |||
| status = ConvertConstTensor(input_tensor.get(), parameter.get()); | |||
| status = ConvertConstTensor(input_tensor.get(), parameter.get(), tensor_name); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; | |||
| continue; | |||
| @@ -248,11 +267,12 @@ STATUS TfliteModelParser::ConvertGraphInputs() { | |||
| auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| parameter->set_abstract(abstract_tensor); | |||
| parameter->set_name("graph_input_" + std::to_string(tflite_graph_input) + "_parameter"); | |||
| parameter->set_name("graph_input-" + std::to_string(tflite_graph_input)); | |||
| nodes_.insert(std::pair(tflite_graph_input, parameter)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TfliteModelParser::ConvertGraphOutputs() { | |||
| const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | |||
| if (tflite_subgraph->outputs.size() > 1) { | |||
| @@ -312,7 +332,8 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { | |||
| return RET_OK; | |||
| } | |||
| STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter) { | |||
| STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter, | |||
| const std::string &tensor_name) { | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "tensor is null, get const tensor failed."; | |||
| return RET_NULL_PTR; | |||
| @@ -329,7 +350,7 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para | |||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| parameter->set_abstract(abstract_tensor); | |||
| parameter->set_name("const_" + std::to_string(nodes_.size()) + "_parameter"); | |||
| parameter->set_name(tensor_name); | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| MS_ASSERT(param_value != nullptr); | |||
| @@ -42,13 +42,13 @@ class TfliteModelParser : public ModelParser { | |||
| FuncGraphPtr func_graph_; | |||
| char *tflite_model_buf_ = nullptr; | |||
| std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path); | |||
| STATUS ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter); | |||
| STATUS ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter, const std::string &tensor_name); | |||
| STATUS ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode); | |||
| STATUS ConvertOpQuantParams(const tflite::OperatorT *op, lite::PrimitiveC *primitive_c); | |||
| STATUS ConvertOps(); | |||
| STATUS ConvertGraphInputs(); | |||
| STATUS ConvertGraphOutputs(); | |||
| STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params); | |||
| static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params); | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // LITE_TFLITE_MODEL_PARSER_H | |||