From 28052ad188dd857a3cedf990a6a9158b4ad2ff7c Mon Sep 17 00:00:00 2001 From: wangzhe Date: Tue, 15 Dec 2020 10:14:46 +0800 Subject: [PATCH] revert changes --- .../converter/parser/tf/tf_model_parser.cc | 97 +++++++++++-------- .../converter/parser/tf/tf_model_parser.h | 14 +-- 2 files changed, 66 insertions(+), 45 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 0bde8184b5..ce17c46491 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -24,10 +24,16 @@ #include "tools/common/graph_util.h" #include "tools/common/protobuf_utils.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/optimizer/common/gllo_utils.h" namespace mindspore { namespace lite { namespace { +static const std::vector tensorListOutputOpList = { + schema::PrimitiveType_TensorListFromTensor, + schema::PrimitiveType_TensorListSetItem, + schema::PrimitiveType_TensorListReserve, +}; AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map &anf_node_map) { AnfNodePtr ret = nullptr; @@ -216,7 +222,6 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value param_value->set_tensor_type(type); param_value->set_format(schema::Format::Format_NHWC); parameter->set_default_param(param_value); - parameter->set_name("const_" + std::to_string(anf_root_node_map.size()) + "_parameter"); return RET_OK; } @@ -248,8 +253,7 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa return status; } } else { - parameter->set_name("placeholder_" + std::to_string(anf_root_node_map.size())); - graph_input_names.emplace_back(parameter->name()); // only root graph need set graph input names + graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names } auto abstract_tensor = std::make_shared(type_ptr, shape_vector); @@ -257,8 +261,10 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa MS_LOG(ERROR) << "abstract_tensor is nullptr"; return RET_ERROR; } + parameter->set_name(node.name()); parameter->set_abstract(abstract_tensor); + (*anf_node_map)[node.name()] = parameter; (*anf_node_map)[node.name() + ":0"] = parameter; return RET_OK; } @@ -294,43 +300,48 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - tf_root_graph = std::make_unique(); - if (tf_root_graph == nullptr) { - MS_LOG(ERROR) << "tf_root_graph is nullptr"; + tf_root_graph_ = std::make_unique(); + if (tf_root_graph_ == nullptr) { + MS_LOG(ERROR) << "tf_root_graph_ is nullptr"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } - status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph.get()); + status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get()); if (status != RET_OK) { MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } - anf_root_graph = std::make_shared(); - if (anf_root_graph == nullptr) { + anf_root_graph_ = std::make_shared(); + if (anf_root_graph_ == nullptr) { MS_LOG(ERROR) << "funGraphPtr is nullptr"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } - for (int i = 0; i < tf_root_graph->node_size(); i++) { - auto &node_def = tf_root_graph->node(i); - tf_root_graph_nodes[node_def.name()] = &node_def; + for (int i = 0; i < tf_root_graph_->node_size(); i++) { + auto &node_def = tf_root_graph_->node(i); + tf_root_graph_nodes_[node_def.name()] = &node_def; } - status = ConvertGraphInputsAndConsts(tf_root_graph_nodes, anf_root_graph, &anf_root_node_map); + status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); if (status != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - for (int i = 0; i < tf_root_graph->node_size(); i++) { - auto &node_def = tf_root_graph->node(i); - if (ConvertOps(node_def, tf_root_graph_nodes, anf_root_graph, &anf_root_node_map) != RET_OK) { - MS_LOG(ERROR) << "Convert ops failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + bool success_flag = true; + for (int i = 0; i < tf_root_graph_->node_size(); i++) { + auto &node_def = tf_root_graph_->node(i); + status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); + if (status != RET_OK) { + success_flag = false; } } + if (!success_flag) { + MS_LOG(ERROR) << "Convert ops failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } status = ConvertRootGraphOutputs(); if (status != RET_OK) { MS_LOG(ERROR) << "Convert graph outputs failed."; @@ -345,10 +356,10 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin return nullptr; } - return anf_root_graph; + return anf_root_graph_; } STATUS TFModelParser::ConvertSubgraph() { - auto graph_def_liarary = tf_root_graph->library(); + auto graph_def_liarary = tf_root_graph_->library(); auto subgraph_size = graph_def_liarary.function_size(); std::map while_cond_map; std::map while_body_map; @@ -359,11 +370,11 @@ STATUS TFModelParser::ConvertSubgraph() { auto input_arg_size = tf_sub_signature.input_arg_size(); auto &sub_graph_name = tf_sub_signature.name(); - if (!function_while_map.count(sub_graph_name)) { + if (!function_while_map_.count(sub_graph_name)) { MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name; return RET_ERROR; } - auto while_cnode = function_while_map[sub_graph_name]->cast(); + auto while_cnode = function_while_map_[sub_graph_name]->cast(); if (while_cnode == nullptr || static_cast(while_cnode->inputs().size()) != input_arg_size + 1) { MS_LOG(ERROR) << "while cnode not equal input arg size"; return RET_ERROR; @@ -441,7 +452,7 @@ STATUS TFModelParser::ConvertSubgraph() { } // hardcode subgraph inputs name for (size_t j = 0; j < sub_graph_inputs.size(); j++) { - sub_graph_inputs[j]->set_name("graph_input_" + std::to_string(j) + "parameter"); + sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter"); } MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; } @@ -458,9 +469,9 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map roots = {anf_root_graph}; + std::vector roots = {anf_root_graph_}; auto root_func_manager = std::make_shared(roots); - anf_root_graph->set_manager(root_func_manager); + anf_root_graph_->set_manager(root_func_manager); for (auto &kv : while_cond_map) { auto while_node = kv.first; auto &cond_sub_graph = kv.second; @@ -513,10 +524,20 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C MS_ASSERT(op != nullptr); MS_ASSERT(anf_node != nullptr); MS_ASSERT(anf_graph != nullptr); - if (output_size == 1) { + if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node)) && output_size != 1) { + MS_LOG(ERROR) << "tensorlist output op output_size !=1"; + return RET_ERROR; + } + if (output_size == 0) { + return RET_OK; + } else if (output_size == 1) { + auto type = kFloat32; std::vector shape_vector; - anf_node->set_abstract(std::make_shared(kFloat32, shape_vector)); - anf_node_map->insert(std::pair(op.name() + ":0", anf_node)); + if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { + type = TypeIdToType(kObjectTypeTensorType); + } + anf_node->set_abstract(std::make_shared(type, shape_vector)); + anf_node_map->insert(std::pair(op.name(), anf_node)); } else { AbstractBasePtrList abstractList; for (int output_idx = 0; output_idx < output_size; output_idx++) { @@ -585,12 +606,12 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, tensorflow::AttrValue attr_value; if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) { auto body_name = attr_value.func().name(); - function_while_map[body_name] = anf_node; + function_while_map_[body_name] = anf_node; MS_LOG(DEBUG) << "parse body name:" << body_name; } if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) { auto cond_name = attr_value.func().name(); - function_while_map[cond_name] = anf_node; + function_while_map_[cond_name] = anf_node; MS_LOG(DEBUG) << "parse cond name:" << cond_name; } } @@ -606,31 +627,31 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, STATUS TFModelParser::ConvertRootGraphOutputs() { // because output of intermediate node in anf graph may also be output tensors, we search output tensors in - // tf_root_graph_nodes but not anf_root_node_map + // tf_root_graph_nodes_ but not anf_root_node_map_ std::set all_node_inputs; std::vector output_nodes; - for (auto &pair : tf_root_graph_nodes) { + for (auto &pair : tf_root_graph_nodes_) { for (int i = 0; i < pair.second->input_size(); ++i) { all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i))); } } - for (auto &pair : tf_root_graph_nodes) { + for (auto &pair : tf_root_graph_nodes_) { if (pair.second->op() == "Assert") { continue; } auto it = all_node_inputs.find(pair.first); if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity - auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes); - auto anf_node = GetAnfNode(origin_name, anf_root_node_map); + auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); + auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); if (anf_node == nullptr) { MS_LOG(ERROR) << "can't find anf node"; return RET_ERROR; } output_nodes.push_back(anf_node); - graph_output_names.push_back(anf_node->fullname_with_scope()); + graph_output_names_.push_back(anf_node->fullname_with_scope()); } } - auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph); + auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); if (status != RET_OK) { MS_LOG(ERROR) << "make anf graph outputs node error"; return status; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index a4dbd4230d..d112967bd5 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -71,13 +71,13 @@ class TFModelParser : public ModelParser { STATUS MakeAnfGraphOutputs(std::vector *output_nodes, const FuncGraphPtr &anf_graph); - FuncGraphPtr anf_root_graph; - std::unique_ptr tf_root_graph; // tf root graph def - std::map tf_root_graph_nodes; // tf root graph node map - std::unordered_map anf_root_node_map; - std::vector graph_input_names; - std::vector graph_output_names; - std::map function_while_map; // tf function name->while_node_name + FuncGraphPtr anf_root_graph_; + std::unique_ptr tf_root_graph_; // tf root graph def + std::map tf_root_graph_nodes_; // tf root graph node map + std::unordered_map anf_root_node_map_; + std::vector graph_input_names_; + std::vector graph_output_names_; + std::map function_while_map_; // tf function name->while_node_name }; } // namespace lite } // namespace mindspore