|
|
@@ -699,13 +699,15 @@ STATUS TFModelParser::ConvertSubgraph() { |
|
|
|
|
|
|
|
|
// add while cond body function to while node input |
|
|
// add while cond body function to while node input |
|
|
if (opt::CheckPrimitiveType(cnode, prim::kPrimWhile)) { |
|
|
if (opt::CheckPrimitiveType(cnode, prim::kPrimWhile)) { |
|
|
if (sub_graph_name.find("cond") != std::string::npos) { |
|
|
|
|
|
|
|
|
if (find(while_cond_branch_name_.begin(), while_cond_branch_name_.end(), sub_graph_name) != |
|
|
|
|
|
while_cond_branch_name_.end()) { |
|
|
while_cond_map[cnode] = sub_func_graph; |
|
|
while_cond_map[cnode] = sub_func_graph; |
|
|
} else { |
|
|
} else { |
|
|
while_body_map[cnode] = sub_func_graph; |
|
|
while_body_map[cnode] = sub_func_graph; |
|
|
} |
|
|
} |
|
|
} else { |
|
|
} else { |
|
|
if (sub_graph_name.find("true") != std::string::npos) { |
|
|
|
|
|
|
|
|
if (find(if_then_branch_name_.begin(), if_then_branch_name_.end(), sub_graph_name) != |
|
|
|
|
|
if_then_branch_name_.end()) { |
|
|
if_then_map[cnode] = sub_func_graph; |
|
|
if_then_map[cnode] = sub_func_graph; |
|
|
} else { |
|
|
} else { |
|
|
if_else_map[cnode] = sub_func_graph; |
|
|
if_else_map[cnode] = sub_func_graph; |
|
|
@@ -914,13 +916,15 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, |
|
|
if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) { |
|
|
if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) { |
|
|
auto cond_name = attr_value.func().name(); |
|
|
auto cond_name = attr_value.func().name(); |
|
|
function_while_map_[cond_name] = anf_node; |
|
|
function_while_map_[cond_name] = anf_node; |
|
|
|
|
|
while_cond_branch_name_.push_back(cond_name); |
|
|
MS_LOG(DEBUG) << "parse cond name:" << cond_name; |
|
|
MS_LOG(DEBUG) << "parse cond name:" << cond_name; |
|
|
} |
|
|
} |
|
|
} else if (op_type == "StatelessIf") { |
|
|
|
|
|
|
|
|
} else if (op_type == "StatelessIf" || op_type == "If") { |
|
|
MS_LOG(INFO) << "find if node:" << node_def.name(); |
|
|
MS_LOG(INFO) << "find if node:" << node_def.name(); |
|
|
tensorflow::AttrValue attr_value; |
|
|
tensorflow::AttrValue attr_value; |
|
|
if (TensorFlowUtils::FindAttrValue(node_def, "then_branch", &attr_value)) { |
|
|
if (TensorFlowUtils::FindAttrValue(node_def, "then_branch", &attr_value)) { |
|
|
auto then_name = attr_value.func().name(); |
|
|
auto then_name = attr_value.func().name(); |
|
|
|
|
|
if_then_branch_name_.push_back(then_name); |
|
|
function_if_map_[then_name] = anf_node; |
|
|
function_if_map_[then_name] = anf_node; |
|
|
MS_LOG(DEBUG) << "parse then name:" << then_name; |
|
|
MS_LOG(DEBUG) << "parse then name:" << then_name; |
|
|
} |
|
|
} |
|
|
|