diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc index 637aeab950..23e7cddc74 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc @@ -45,6 +45,9 @@ ops::PrimitiveC *OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const padding_mode = mindspore::PaddingMode::REFLECT; } else if (mode == "edge") { padding_mode = mindspore::PaddingMode::SYMMETRIC; + } else { + MS_LOG(ERROR) << "Unsupported pad mode: " << mode; + return nullptr; } prim->set_padding_mode(padding_mode); } diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 993fb06919..7d90dd9afb 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -40,13 +40,14 @@ bool IsTensorListOp(const AnfNodePtr &anf_node) { opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListReserve); } -AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map &anf_node_map) { +AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map &anf_node_map, + int index = 0) { AnfNodePtr ret = nullptr; auto flat_anf_name = TensorFlowUtils::GetFlattenNodeName(name); if (anf_node_map.find(flat_anf_name) != anf_node_map.end()) { ret = anf_node_map.at(flat_anf_name); - } else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { - ret = anf_node_map.at(flat_anf_name + ":0"); + } else if (anf_node_map.find(name + ":" + to_string(index)) != anf_node_map.end()) { + ret = anf_node_map.at(flat_anf_name + ":" + to_string(index)); } return ret; } @@ -901,6 +902,10 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, MS_LOG(ERROR) << "node " << op_type << " parser failed"; return RET_ERROR; } + node_output_num_[node_def.name()] = output_size; + for (int i = 0; i < output_size; i++) { + node_output_num_[node_def.name() + ":" + to_string(i)] = 1; + } auto value_node = NewValueNode(std::shared_ptr(primitiveC)); if (value_node == nullptr) { MS_LOG(ERROR) << "value_node is nullptr"; @@ -915,6 +920,32 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, // control_depends are not processed currently auto anf_node = func_graph_ptr->NewCNode(inputs); anf_node->set_fullname_with_scope(node_def.name()); + status = ProcessControlFlowOp(anf_node, op_type, node_def); + if (status != RET_OK) { + MS_LOG(ERROR) << "ProcessControlFlowOp failed."; + return RET_ERROR; + } + + if (!input_name_not_found.empty()) { + RecordNullInput(anf_node, input_name_not_found); + } + + status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; + return status; + } + + status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed."; + return status; + } + return status; +} + +STATUS TFModelParser::ProcessControlFlowOp(CNodePtr anf_node, const string op_type, + const tensorflow::NodeDef &node_def) { if (op_type == "StatelessWhile" || op_type == "While") { MS_LOG(INFO) << "find while node:" << node_def.name(); tensorflow::AttrValue attr_value; @@ -944,21 +975,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, MS_LOG(DEBUG) << "parse else name:" << else_name; } } - - if (!input_name_not_found.empty()) { - RecordNullInput(anf_node, input_name_not_found); - } - - status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size); - if (status != RET_OK) { - MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; - } - - status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC); - if (status != RET_OK) { - MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed."; - } - return status; + return RET_OK; } STATUS TFModelParser::ConvertQuantParams(const size_t &input_size, const size_t &output_size, @@ -994,13 +1011,15 @@ STATUS TFModelParser::ConvertRootGraphOutputs() { auto it = all_node_inputs.find(pair.first); if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); - auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); - if (anf_node == nullptr) { - MS_LOG(ERROR) << "can't find anf node: " << origin_name; - return RET_ERROR; + for (int i = 0; i < node_output_num_[origin_name]; i++) { + auto anf_node = GetAnfNode(origin_name, anf_root_node_map_, i); + if (anf_node == nullptr) { + MS_LOG(ERROR) << "can't find anf node: " << origin_name; + return RET_ERROR; + } + output_nodes.push_back(anf_node); + graph_output_names_.push_back(anf_node->fullname_with_scope()); } - output_nodes.push_back(anf_node); - graph_output_names_.push_back(anf_node->fullname_with_scope()); } } auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index 512407a2e5..528049dae7 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -60,6 +60,9 @@ class TFModelParser : public ModelParser { STATUS ConvertOps(const tensorflow::NodeDef &node_def, const std::map &tf_node_map, const FuncGraphPtr &func_graph_ptr, std::unordered_map *anf_node_map); + + STATUS ProcessControlFlowOp(CNodePtr anf_node, const string op_type, const tensorflow::NodeDef &node_def); + STATUS ConvertRootGraphOutputs(); STATUS ConvertSubgraph(); @@ -95,6 +98,7 @@ class TFModelParser : public ModelParser { std::vector>> nodes_with_null_input_{}; std::vector while_cond_branch_name_; std::vector if_then_branch_name_; + std::unordered_map node_output_num_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.cc b/mindspore/lite/tools/converter/parser/tf/tf_util.cc index b99a79d3da..b2c83e3264 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_util.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.cc @@ -70,9 +70,13 @@ TypeId TensorFlowUtils::ParseAttrDataType(const tensorflow::NodeDef &node_def, c } bool TensorFlowUtils::DecodeInt64(std::string_view *str_view, uint64_t *value) { - if (str_view == nullptr || value == nullptr) { + if (value == nullptr) { + MS_LOG(ERROR) << "value is nullptr"; + return false; + } + if (str_view == nullptr) { *value = 0; - MS_LOG(ERROR) << "str_view or value is nullptr"; + MS_LOG(ERROR) << "str_view is nullptr"; return false; } auto data = str_view->data(); diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_cond.h b/mindspore/lite/tools/optimizer/graph/functionalize_cond.h index 7bc1b27886..847aec3016 100644 --- a/mindspore/lite/tools/optimizer/graph/functionalize_cond.h +++ b/mindspore/lite/tools/optimizer/graph/functionalize_cond.h @@ -36,7 +36,7 @@ class FunctionalizeCond { public: FunctionalizeCond(FuncGraphPtr fg, CNodePtr merge_node) : fg_(fg), merge_node_(merge_node) {} - ~FunctionalizeCond() = default; + virtual ~FunctionalizeCond() = default; STATUS Process(); diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_while.h b/mindspore/lite/tools/optimizer/graph/functionalize_while.h index 0052883135..75e05ce194 100644 --- a/mindspore/lite/tools/optimizer/graph/functionalize_while.h +++ b/mindspore/lite/tools/optimizer/graph/functionalize_while.h @@ -35,7 +35,7 @@ class FunctionalizeWhile { FunctionalizeWhile(std::vector node_cluster, const CNodePtr &loop_cond_node, FuncGraphPtr fg) : node_cluster_(node_cluster), loop_cond_node_(loop_cond_node), fg_(fg) {} - ~FunctionalizeWhile() = default; + virtual ~FunctionalizeWhile() = default; // while STATUS BuildWhileNode();