diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 001d6df..568d8a1 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -2084,8 +2084,8 @@ Status TensorFlowModelParser::UpdateNormalOpContext(shared_ptr & Status TensorFlowModelParser::NormalizeAllNodeOpContext() { for (auto iter = op_node_context_map_.begin(); iter != op_node_context_map_.end();) { OpNodeContext &context = iter->second; - NormalizeInputOrOutputMap(context.input_map); - NormalizeInputOrOutputMap(context.output_map); + NormalizeInputOrOutputMap(iter->first, context.input_map); + NormalizeInputOrOutputMap(iter->first, context.output_map); if ((context.input_map.size() == 0) && (context.output_map.size() == 0)) { GELOGD("[Update op context] node: %s will be removed at the back.", iter->first.c_str()); @@ -2097,7 +2097,7 @@ Status TensorFlowModelParser::NormalizeAllNodeOpContext() { return SUCCESS; } -Status TensorFlowModelParser::NormalizeInputOrOutputMap( +Status TensorFlowModelParser::NormalizeInputOrOutputMap(const string &node_name, std::map>> &context_map) { if (context_map.size() == 0) { return SUCCESS; @@ -2109,7 +2109,9 @@ Status TensorFlowModelParser::NormalizeInputOrOutputMap( std::set compare_set; for (auto &pair : pairs) { - if ((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) { + if (((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) && + ((fusion_op_children_.find(node_name) != fusion_op_children_.end()) || + (fusion_op_children_.find(iter->first) != fusion_op_children_.end()))) { // The edge will be cut off at the back, ignoring continue; } diff --git a/parser/tensorflow/tensorflow_parser.h b/parser/tensorflow/tensorflow_parser.h index 5ecf9e6..23bf750 100644 --- a/parser/tensorflow/tensorflow_parser.h +++ b/parser/tensorflow/tensorflow_parser.h @@ -371,7 +371,8 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { * @brief Normalized I / O relationship: according to context map, de duplicate and de outliers */ - Status NormalizeInputOrOutputMap(std::map>> &context_map); + Status NormalizeInputOrOutputMap(const string &node_name, + std::map>> &context_map); /** * @ingroup domi_omg