Browse Source

!14458 [MS_LITE] fix if op

From: @YeFeng_24
Reviewed-by: @zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
pull/14458/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
d514f891ad
2 changed files with 9 additions and 3 deletions
  1. +7
    -3
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  2. +2
    -0
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.h

+ 7
- 3
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -699,13 +699,15 @@ STATUS TFModelParser::ConvertSubgraph() {

// add while cond body function to while node input
if (opt::CheckPrimitiveType(cnode, prim::kPrimWhile)) {
if (sub_graph_name.find("cond") != std::string::npos) {
if (find(while_cond_branch_name_.begin(), while_cond_branch_name_.end(), sub_graph_name) !=
while_cond_branch_name_.end()) {
while_cond_map[cnode] = sub_func_graph;
} else {
while_body_map[cnode] = sub_func_graph;
}
} else {
if (sub_graph_name.find("true") != std::string::npos) {
if (find(if_then_branch_name_.begin(), if_then_branch_name_.end(), sub_graph_name) !=
if_then_branch_name_.end()) {
if_then_map[cnode] = sub_func_graph;
} else {
if_else_map[cnode] = sub_func_graph;
@@ -914,13 +916,15 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) {
auto cond_name = attr_value.func().name();
function_while_map_[cond_name] = anf_node;
while_cond_branch_name_.push_back(cond_name);
MS_LOG(DEBUG) << "parse cond name:" << cond_name;
}
} else if (op_type == "StatelessIf") {
} else if (op_type == "StatelessIf" || op_type == "If") {
MS_LOG(INFO) << "find if node:" << node_def.name();
tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(node_def, "then_branch", &attr_value)) {
auto then_name = attr_value.func().name();
if_then_branch_name_.push_back(then_name);
function_if_map_[then_name] = anf_node;
MS_LOG(DEBUG) << "parse then name:" << then_name;
}


+ 2
- 0
mindspore/lite/tools/converter/parser/tf/tf_model_parser.h View File

@@ -90,6 +90,8 @@ class TFModelParser : public ModelParser {
std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name
std::map<std::string, AnfNodePtr> function_if_map_; // tf function name->if_node
std::vector<std::pair<CNodePtr, std::vector<std::string>>> nodes_with_null_input_{};
std::vector<std::string> while_cond_branch_name_;
std::vector<std::string> if_then_branch_name_;
};
} // namespace lite
} // namespace mindspore


Loading…
Cancel
Save