Browse Source

bugfix for NormalizeInputOrOutputMap

pull/346/head
y00500818 4 years ago
parent
commit
c6e6d050c7
2 changed files with 8 additions and 5 deletions
  1. +6
    -4
      parser/tensorflow/tensorflow_parser.cc
  2. +2
    -1
      parser/tensorflow/tensorflow_parser.h

+ 6
- 4
parser/tensorflow/tensorflow_parser.cc View File

@@ -2084,8 +2084,8 @@ Status TensorFlowModelParser::UpdateNormalOpContext(shared_ptr<ge::ScopeGraph> &
Status TensorFlowModelParser::NormalizeAllNodeOpContext() {
for (auto iter = op_node_context_map_.begin(); iter != op_node_context_map_.end();) {
OpNodeContext &context = iter->second;
NormalizeInputOrOutputMap(context.input_map);
NormalizeInputOrOutputMap(context.output_map);
NormalizeInputOrOutputMap(iter->first, context.input_map);
NormalizeInputOrOutputMap(iter->first, context.output_map);

if ((context.input_map.size() == 0) && (context.output_map.size() == 0)) {
GELOGD("[Update op context] node: %s will be removed at the back.", iter->first.c_str());
@@ -2097,7 +2097,7 @@ Status TensorFlowModelParser::NormalizeAllNodeOpContext() {
return SUCCESS;
}

Status TensorFlowModelParser::NormalizeInputOrOutputMap(
Status TensorFlowModelParser::NormalizeInputOrOutputMap(const string &node_name,
std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &context_map) {
if (context_map.size() == 0) {
return SUCCESS;
@@ -2109,7 +2109,9 @@ Status TensorFlowModelParser::NormalizeInputOrOutputMap(
std::set<std::string> compare_set;

for (auto &pair : pairs) {
if ((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) {
if (((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) &&
((fusion_op_children_.find(node_name) != fusion_op_children_.end()) ||
(fusion_op_children_.find(iter->first) != fusion_op_children_.end()))) {
// The edge will be cut off at the back, ignoring
continue;
}


+ 2
- 1
parser/tensorflow/tensorflow_parser.h View File

@@ -371,7 +371,8 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
* @brief Normalized I / O relationship: according to context map, de duplicate and de outliers

*/
Status NormalizeInputOrOutputMap(std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &context_map);
Status NormalizeInputOrOutputMap(const string &node_name,
std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &context_map);

/**
* @ingroup domi_omg


Loading…
Cancel
Save