From fba6cb5a256ebc0b150701d41d9724149188b3b8 Mon Sep 17 00:00:00 2001 From: wangzhengjun Date: Fri, 30 Oct 2020 11:45:30 +0800 Subject: [PATCH] update input desc --- parser/tensorflow/tensorflow_parser.cc | 37 ++++++++++++++++++++++++++ parser/tensorflow/tensorflow_parser.h | 1 + 2 files changed, 38 insertions(+) diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index a486462..cdf727e 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -1397,6 +1397,7 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro GE_RETURN_IF_ERROR(AddEdges(graph)); GE_RETURN_IF_ERROR(RemoveIsolateNode(graph)); + GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph)); GE_RETURN_IF_ERROR(graph->TopologicalSorting()); if (has_error) { @@ -2196,6 +2197,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, PARSER_TIMESTAMP_START(RemoveIsolateNode); // Delete isolated nodes GE_RETURN_IF_ERROR(RemoveIsolateNode(graph)); + GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph)); PARSER_TIMESTAMP_END(RemoveIsolateNode, "TensorFlowModelParser::RemoveIsolateNode"); PARSER_TIMESTAMP_START(TopologicalSorting); @@ -3714,6 +3716,41 @@ void TensorFlowModelParser::DumpAllNodeContext(const string &phase) { DumpNodeContext(iter.first, iter.second, phase); } } + +Status TensorFlowModelParser::CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph) { + GE_CHECK_NOTNULL(compute_graph); + for (auto &node : compute_graph->GetDirectNode()) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (auto &in_anchor : node->GetAllInDataAnchors()) { + if (!(op_desc->IsOptionalInput(static_cast(in_anchor->GetIdx())))) { + continue; + } + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + auto in_desc = op_desc->MutableInputDesc(static_cast(in_anchor->GetIdx())); + if ((peer_out_anchor != nullptr) && (in_desc == nullptr)) { + // The input is connected to the peer output but TensorDesc is invalid, update TensorDesc to valid. + ge::GeTensorDesc tensor_desc; + auto ret = op_desc->UpdateInputDesc(static_cast(in_anchor->GetIdx()), tensor_desc); + if (ret != ge::GRAPH_SUCCESS) { + GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx()); + return ret; + } + GELOGI("Update input desc to valid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx()); + } else if ((peer_out_anchor == nullptr) && (in_desc != nullptr)) { + // The input is not connected to the peer output but TensorDesc is valid, update TensorDesc to invalid. + ge::GeTensorDesc tensor_desc(ge::GeShape(), FORMAT_RESERVED, DT_UNDEFINED); + auto ret = op_desc->UpdateInputDesc(static_cast(in_anchor->GetIdx()), tensor_desc); + if (ret != ge::GRAPH_SUCCESS) { + GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx()); + return ret; + } + GELOGI("Update input desc to invalid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx()); + } + } + } + return SUCCESS; +} } // namespace ge namespace domi { diff --git a/parser/tensorflow/tensorflow_parser.h b/parser/tensorflow/tensorflow_parser.h index ac313fc..2200a8d 100644 --- a/parser/tensorflow/tensorflow_parser.h +++ b/parser/tensorflow/tensorflow_parser.h @@ -603,6 +603,7 @@ class TensorFlowModelParser : public domi::ModelParser { void DumpAllNodeContext(const string &phase); Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr &op_parser); + Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph); /** * save