浏览代码

!10155 resolve some bug in tf_model_parser

From: @wangzhe128
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
tags/v1.1.0
mindspore-ci-bot Gitee 5 年前
父节点
当前提交
977b52cdba
共有 2 个文件被更改,包括 66 次插入45 次删除
  1. +59
    -38
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  2. +7
    -7
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.h

+ 59
- 38
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc 查看文件

@@ -24,10 +24,16 @@
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h" #include "tools/common/protobuf_utils.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/optimizer/common/gllo_utils.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace { namespace {
static const std::vector<schema::PrimitiveType> tensorListOutputOpList = {
schema::PrimitiveType_TensorListFromTensor,
schema::PrimitiveType_TensorListSetItem,
schema::PrimitiveType_TensorListReserve,
};


AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) { AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) {
AnfNodePtr ret = nullptr; AnfNodePtr ret = nullptr;
@@ -216,7 +222,6 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value
param_value->set_tensor_type(type); param_value->set_tensor_type(type);
param_value->set_format(schema::Format::Format_NHWC); param_value->set_format(schema::Format::Format_NHWC);
parameter->set_default_param(param_value); parameter->set_default_param(param_value);
parameter->set_name("const_" + std::to_string(anf_root_node_map.size()) + "_parameter");
return RET_OK; return RET_OK;
} }


@@ -248,8 +253,7 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
return status; return status;
} }
} else { } else {
parameter->set_name("placeholder_" + std::to_string(anf_root_node_map.size()));
graph_input_names.emplace_back(parameter->name()); // only root graph need set graph input names
graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names
} }


auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
@@ -257,8 +261,10 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
MS_LOG(ERROR) << "abstract_tensor is nullptr"; MS_LOG(ERROR) << "abstract_tensor is nullptr";
return RET_ERROR; return RET_ERROR;
} }
parameter->set_name(node.name());
parameter->set_abstract(abstract_tensor); parameter->set_abstract(abstract_tensor);


(*anf_node_map)[node.name()] = parameter;
(*anf_node_map)[node.name() + ":0"] = parameter; (*anf_node_map)[node.name() + ":0"] = parameter;
return RET_OK; return RET_OK;
} }
@@ -294,43 +300,48 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return nullptr;
} }
tf_root_graph = std::make_unique<tensorflow::GraphDef>();
if (tf_root_graph == nullptr) {
MS_LOG(ERROR) << "tf_root_graph is nullptr";
tf_root_graph_ = std::make_unique<tensorflow::GraphDef>();
if (tf_root_graph_ == nullptr) {
MS_LOG(ERROR) << "tf_root_graph_ is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr; return nullptr;
} }
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph.get());
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get());
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr; return nullptr;
} }
anf_root_graph = std::make_shared<FuncGraph>();
if (anf_root_graph == nullptr) {
anf_root_graph_ = std::make_shared<FuncGraph>();
if (anf_root_graph_ == nullptr) {
MS_LOG(ERROR) << "funGraphPtr is nullptr"; MS_LOG(ERROR) << "funGraphPtr is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr; return nullptr;
} }


for (int i = 0; i < tf_root_graph->node_size(); i++) {
auto &node_def = tf_root_graph->node(i);
tf_root_graph_nodes[node_def.name()] = &node_def;
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
tf_root_graph_nodes_[node_def.name()] = &node_def;
} }


status = ConvertGraphInputsAndConsts(tf_root_graph_nodes, anf_root_graph, &anf_root_node_map);
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
if (status != RET_OK) { if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return nullptr;
} }
for (int i = 0; i < tf_root_graph->node_size(); i++) {
auto &node_def = tf_root_graph->node(i);
if (ConvertOps(node_def, tf_root_graph_nodes, anf_root_graph, &anf_root_node_map) != RET_OK) {
MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
bool success_flag = true;
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
if (status != RET_OK) {
success_flag = false;
} }
} }
if (!success_flag) {
MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = ConvertRootGraphOutputs(); status = ConvertRootGraphOutputs();
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed."; MS_LOG(ERROR) << "Convert graph outputs failed.";
@@ -345,10 +356,10 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
return nullptr; return nullptr;
} }


return anf_root_graph;
return anf_root_graph_;
} }
STATUS TFModelParser::ConvertSubgraph() { STATUS TFModelParser::ConvertSubgraph() {
auto graph_def_liarary = tf_root_graph->library();
auto graph_def_liarary = tf_root_graph_->library();
auto subgraph_size = graph_def_liarary.function_size(); auto subgraph_size = graph_def_liarary.function_size();
std::map<CNodePtr, FuncGraphPtr> while_cond_map; std::map<CNodePtr, FuncGraphPtr> while_cond_map;
std::map<CNodePtr, FuncGraphPtr> while_body_map; std::map<CNodePtr, FuncGraphPtr> while_body_map;
@@ -359,11 +370,11 @@ STATUS TFModelParser::ConvertSubgraph() {
auto input_arg_size = tf_sub_signature.input_arg_size(); auto input_arg_size = tf_sub_signature.input_arg_size();


auto &sub_graph_name = tf_sub_signature.name(); auto &sub_graph_name = tf_sub_signature.name();
if (!function_while_map.count(sub_graph_name)) {
if (!function_while_map_.count(sub_graph_name)) {
MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name; MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name;
return RET_ERROR; return RET_ERROR;
} }
auto while_cnode = function_while_map[sub_graph_name]->cast<CNodePtr>();
auto while_cnode = function_while_map_[sub_graph_name]->cast<CNodePtr>();
if (while_cnode == nullptr || static_cast<int>(while_cnode->inputs().size()) != input_arg_size + 1) { if (while_cnode == nullptr || static_cast<int>(while_cnode->inputs().size()) != input_arg_size + 1) {
MS_LOG(ERROR) << "while cnode not equal input arg size"; MS_LOG(ERROR) << "while cnode not equal input arg size";
return RET_ERROR; return RET_ERROR;
@@ -441,7 +452,7 @@ STATUS TFModelParser::ConvertSubgraph() {
} }
// hardcode subgraph inputs name // hardcode subgraph inputs name
for (size_t j = 0; j < sub_graph_inputs.size(); j++) { for (size_t j = 0; j < sub_graph_inputs.size(); j++) {
sub_graph_inputs[j]->set_name("graph_input_" + std::to_string(j) + "parameter");
sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter");
} }
MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name;
} }
@@ -458,9 +469,9 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr
MS_LOG(ERROR) << "while cond body size error"; MS_LOG(ERROR) << "while cond body size error";
return RET_ERROR; return RET_ERROR;
} }
std::vector<FuncGraphPtr> roots = {anf_root_graph};
std::vector<FuncGraphPtr> roots = {anf_root_graph_};
auto root_func_manager = std::make_shared<FuncGraphManager>(roots); auto root_func_manager = std::make_shared<FuncGraphManager>(roots);
anf_root_graph->set_manager(root_func_manager);
anf_root_graph_->set_manager(root_func_manager);
for (auto &kv : while_cond_map) { for (auto &kv : while_cond_map) {
auto while_node = kv.first; auto while_node = kv.first;
auto &cond_sub_graph = kv.second; auto &cond_sub_graph = kv.second;
@@ -513,10 +524,20 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
MS_ASSERT(op != nullptr); MS_ASSERT(op != nullptr);
MS_ASSERT(anf_node != nullptr); MS_ASSERT(anf_node != nullptr);
MS_ASSERT(anf_graph != nullptr); MS_ASSERT(anf_graph != nullptr);
if (output_size == 1) {
if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node)) && output_size != 1) {
MS_LOG(ERROR) << "tensorlist output op output_size !=1";
return RET_ERROR;
}
if (output_size == 0) {
return RET_OK;
} else if (output_size == 1) {
auto type = kFloat32;
std::vector<int64_t> shape_vector; std::vector<int64_t> shape_vector;
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
anf_node_map->insert(std::pair(op.name() + ":0", anf_node));
if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) {
type = TypeIdToType(kObjectTypeTensorType);
}
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vector));
anf_node_map->insert(std::pair(op.name(), anf_node));
} else { } else {
AbstractBasePtrList abstractList; AbstractBasePtrList abstractList;
for (int output_idx = 0; output_idx < output_size; output_idx++) { for (int output_idx = 0; output_idx < output_size; output_idx++) {
@@ -585,12 +606,12 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
tensorflow::AttrValue attr_value; tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) { if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) {
auto body_name = attr_value.func().name(); auto body_name = attr_value.func().name();
function_while_map[body_name] = anf_node;
function_while_map_[body_name] = anf_node;
MS_LOG(DEBUG) << "parse body name:" << body_name; MS_LOG(DEBUG) << "parse body name:" << body_name;
} }
if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) { if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) {
auto cond_name = attr_value.func().name(); auto cond_name = attr_value.func().name();
function_while_map[cond_name] = anf_node;
function_while_map_[cond_name] = anf_node;
MS_LOG(DEBUG) << "parse cond name:" << cond_name; MS_LOG(DEBUG) << "parse cond name:" << cond_name;
} }
} }
@@ -606,31 +627,31 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,


STATUS TFModelParser::ConvertRootGraphOutputs() { STATUS TFModelParser::ConvertRootGraphOutputs() {
// because output of intermediate node in anf graph may also be output tensors, we search output tensors in // because output of intermediate node in anf graph may also be output tensors, we search output tensors in
// tf_root_graph_nodes but not anf_root_node_map
// tf_root_graph_nodes_ but not anf_root_node_map_
std::set<std::string> all_node_inputs; std::set<std::string> all_node_inputs;
std::vector<AnfNodePtr> output_nodes; std::vector<AnfNodePtr> output_nodes;
for (auto &pair : tf_root_graph_nodes) {
for (auto &pair : tf_root_graph_nodes_) {
for (int i = 0; i < pair.second->input_size(); ++i) { for (int i = 0; i < pair.second->input_size(); ++i) {
all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i))); all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i)));
} }
} }
for (auto &pair : tf_root_graph_nodes) {
for (auto &pair : tf_root_graph_nodes_) {
if (pair.second->op() == "Assert") { if (pair.second->op() == "Assert") {
continue; continue;
} }
auto it = all_node_inputs.find(pair.first); auto it = all_node_inputs.find(pair.first);
if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes);
auto anf_node = GetAnfNode(origin_name, anf_root_node_map);
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_);
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_);
if (anf_node == nullptr) { if (anf_node == nullptr) {
MS_LOG(ERROR) << "can't find anf node"; MS_LOG(ERROR) << "can't find anf node";
return RET_ERROR; return RET_ERROR;
} }
output_nodes.push_back(anf_node); output_nodes.push_back(anf_node);
graph_output_names.push_back(anf_node->fullname_with_scope());
graph_output_names_.push_back(anf_node->fullname_with_scope());
} }
} }
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph);
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "make anf graph outputs node error"; MS_LOG(ERROR) << "make anf graph outputs node error";
return status; return status;


+ 7
- 7
mindspore/lite/tools/converter/parser/tf/tf_model_parser.h 查看文件

@@ -71,13 +71,13 @@ class TFModelParser : public ModelParser {


STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph); STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph);


FuncGraphPtr anf_root_graph;
std::unique_ptr<tensorflow::GraphDef> tf_root_graph; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes; // tf root graph node map
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map;
std::vector<std::string> graph_input_names;
std::vector<std::string> graph_output_names;
std::map<std::string, AnfNodePtr> function_while_map; // tf function name->while_node_name
FuncGraphPtr anf_root_graph_;
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_;
std::vector<std::string> graph_input_names_;
std::vector<std::string> graph_output_names_;
std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


正在加载...
取消
保存