Browse Source

!10155 resolve some bug in tf_model_parser

From: @wangzhe128
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
977b52cdba
2 changed files with 66 additions and 45 deletions
  1. +59
    -38
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  2. +7
    -7
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.h

+ 59
- 38
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -24,10 +24,16 @@
#include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/optimizer/common/gllo_utils.h"

namespace mindspore {
namespace lite {
namespace {
static const std::vector<schema::PrimitiveType> tensorListOutputOpList = {
schema::PrimitiveType_TensorListFromTensor,
schema::PrimitiveType_TensorListSetItem,
schema::PrimitiveType_TensorListReserve,
};

AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) {
AnfNodePtr ret = nullptr;
@@ -216,7 +222,6 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value
param_value->set_tensor_type(type);
param_value->set_format(schema::Format::Format_NHWC);
parameter->set_default_param(param_value);
parameter->set_name("const_" + std::to_string(anf_root_node_map.size()) + "_parameter");
return RET_OK;
}

@@ -248,8 +253,7 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
return status;
}
} else {
parameter->set_name("placeholder_" + std::to_string(anf_root_node_map.size()));
graph_input_names.emplace_back(parameter->name()); // only root graph need set graph input names
graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names
}

auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
@@ -257,8 +261,10 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
MS_LOG(ERROR) << "abstract_tensor is nullptr";
return RET_ERROR;
}
parameter->set_name(node.name());
parameter->set_abstract(abstract_tensor);

(*anf_node_map)[node.name()] = parameter;
(*anf_node_map)[node.name() + ":0"] = parameter;
return RET_OK;
}
@@ -294,43 +300,48 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
tf_root_graph = std::make_unique<tensorflow::GraphDef>();
if (tf_root_graph == nullptr) {
MS_LOG(ERROR) << "tf_root_graph is nullptr";
tf_root_graph_ = std::make_unique<tensorflow::GraphDef>();
if (tf_root_graph_ == nullptr) {
MS_LOG(ERROR) << "tf_root_graph_ is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph.get());
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
anf_root_graph = std::make_shared<FuncGraph>();
if (anf_root_graph == nullptr) {
anf_root_graph_ = std::make_shared<FuncGraph>();
if (anf_root_graph_ == nullptr) {
MS_LOG(ERROR) << "funGraphPtr is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}

for (int i = 0; i < tf_root_graph->node_size(); i++) {
auto &node_def = tf_root_graph->node(i);
tf_root_graph_nodes[node_def.name()] = &node_def;
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
tf_root_graph_nodes_[node_def.name()] = &node_def;
}

status = ConvertGraphInputsAndConsts(tf_root_graph_nodes, anf_root_graph, &anf_root_node_map);
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
for (int i = 0; i < tf_root_graph->node_size(); i++) {
auto &node_def = tf_root_graph->node(i);
if (ConvertOps(node_def, tf_root_graph_nodes, anf_root_graph, &anf_root_node_map) != RET_OK) {
MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
bool success_flag = true;
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
if (status != RET_OK) {
success_flag = false;
}
}
if (!success_flag) {
MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = ConvertRootGraphOutputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed.";
@@ -345,10 +356,10 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
return nullptr;
}

return anf_root_graph;
return anf_root_graph_;
}
STATUS TFModelParser::ConvertSubgraph() {
auto graph_def_liarary = tf_root_graph->library();
auto graph_def_liarary = tf_root_graph_->library();
auto subgraph_size = graph_def_liarary.function_size();
std::map<CNodePtr, FuncGraphPtr> while_cond_map;
std::map<CNodePtr, FuncGraphPtr> while_body_map;
@@ -359,11 +370,11 @@ STATUS TFModelParser::ConvertSubgraph() {
auto input_arg_size = tf_sub_signature.input_arg_size();

auto &sub_graph_name = tf_sub_signature.name();
if (!function_while_map.count(sub_graph_name)) {
if (!function_while_map_.count(sub_graph_name)) {
MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name;
return RET_ERROR;
}
auto while_cnode = function_while_map[sub_graph_name]->cast<CNodePtr>();
auto while_cnode = function_while_map_[sub_graph_name]->cast<CNodePtr>();
if (while_cnode == nullptr || static_cast<int>(while_cnode->inputs().size()) != input_arg_size + 1) {
MS_LOG(ERROR) << "while cnode not equal input arg size";
return RET_ERROR;
@@ -441,7 +452,7 @@ STATUS TFModelParser::ConvertSubgraph() {
}
// hardcode subgraph inputs name
for (size_t j = 0; j < sub_graph_inputs.size(); j++) {
sub_graph_inputs[j]->set_name("graph_input_" + std::to_string(j) + "parameter");
sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter");
}
MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name;
}
@@ -458,9 +469,9 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr
MS_LOG(ERROR) << "while cond body size error";
return RET_ERROR;
}
std::vector<FuncGraphPtr> roots = {anf_root_graph};
std::vector<FuncGraphPtr> roots = {anf_root_graph_};
auto root_func_manager = std::make_shared<FuncGraphManager>(roots);
anf_root_graph->set_manager(root_func_manager);
anf_root_graph_->set_manager(root_func_manager);
for (auto &kv : while_cond_map) {
auto while_node = kv.first;
auto &cond_sub_graph = kv.second;
@@ -513,10 +524,20 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
MS_ASSERT(op != nullptr);
MS_ASSERT(anf_node != nullptr);
MS_ASSERT(anf_graph != nullptr);
if (output_size == 1) {
if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node)) && output_size != 1) {
MS_LOG(ERROR) << "tensorlist output op output_size !=1";
return RET_ERROR;
}
if (output_size == 0) {
return RET_OK;
} else if (output_size == 1) {
auto type = kFloat32;
std::vector<int64_t> shape_vector;
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
anf_node_map->insert(std::pair(op.name() + ":0", anf_node));
if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) {
type = TypeIdToType(kObjectTypeTensorType);
}
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vector));
anf_node_map->insert(std::pair(op.name(), anf_node));
} else {
AbstractBasePtrList abstractList;
for (int output_idx = 0; output_idx < output_size; output_idx++) {
@@ -585,12 +606,12 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) {
auto body_name = attr_value.func().name();
function_while_map[body_name] = anf_node;
function_while_map_[body_name] = anf_node;
MS_LOG(DEBUG) << "parse body name:" << body_name;
}
if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) {
auto cond_name = attr_value.func().name();
function_while_map[cond_name] = anf_node;
function_while_map_[cond_name] = anf_node;
MS_LOG(DEBUG) << "parse cond name:" << cond_name;
}
}
@@ -606,31 +627,31 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,

STATUS TFModelParser::ConvertRootGraphOutputs() {
// because output of intermediate node in anf graph may also be output tensors, we search output tensors in
// tf_root_graph_nodes but not anf_root_node_map
// tf_root_graph_nodes_ but not anf_root_node_map_
std::set<std::string> all_node_inputs;
std::vector<AnfNodePtr> output_nodes;
for (auto &pair : tf_root_graph_nodes) {
for (auto &pair : tf_root_graph_nodes_) {
for (int i = 0; i < pair.second->input_size(); ++i) {
all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i)));
}
}
for (auto &pair : tf_root_graph_nodes) {
for (auto &pair : tf_root_graph_nodes_) {
if (pair.second->op() == "Assert") {
continue;
}
auto it = all_node_inputs.find(pair.first);
if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes);
auto anf_node = GetAnfNode(origin_name, anf_root_node_map);
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_);
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_);
if (anf_node == nullptr) {
MS_LOG(ERROR) << "can't find anf node";
return RET_ERROR;
}
output_nodes.push_back(anf_node);
graph_output_names.push_back(anf_node->fullname_with_scope());
graph_output_names_.push_back(anf_node->fullname_with_scope());
}
}
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph);
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_);
if (status != RET_OK) {
MS_LOG(ERROR) << "make anf graph outputs node error";
return status;


+ 7
- 7
mindspore/lite/tools/converter/parser/tf/tf_model_parser.h View File

@@ -71,13 +71,13 @@ class TFModelParser : public ModelParser {

STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph);

FuncGraphPtr anf_root_graph;
std::unique_ptr<tensorflow::GraphDef> tf_root_graph; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes; // tf root graph node map
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map;
std::vector<std::string> graph_input_names;
std::vector<std::string> graph_output_names;
std::map<std::string, AnfNodePtr> function_while_map; // tf function name->while_node_name
FuncGraphPtr anf_root_graph_;
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_;
std::vector<std::string> graph_input_names_;
std::vector<std::string> graph_output_names_;
std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name
};
} // namespace lite
} // namespace mindspore


Loading…
Cancel
Save