|
|
|
@@ -1397,6 +1397,7 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro |
|
|
|
|
|
|
|
GE_RETURN_IF_ERROR(AddEdges(graph)); |
|
|
|
GE_RETURN_IF_ERROR(RemoveIsolateNode(graph)); |
|
|
|
GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph)); |
|
|
|
GE_RETURN_IF_ERROR(graph->TopologicalSorting()); |
|
|
|
|
|
|
|
if (has_error) { |
|
|
|
@@ -2196,6 +2197,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, |
|
|
|
PARSER_TIMESTAMP_START(RemoveIsolateNode); |
|
|
|
// Delete isolated nodes |
|
|
|
GE_RETURN_IF_ERROR(RemoveIsolateNode(graph)); |
|
|
|
GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph)); |
|
|
|
|
|
|
|
PARSER_TIMESTAMP_END(RemoveIsolateNode, "TensorFlowModelParser::RemoveIsolateNode"); |
|
|
|
PARSER_TIMESTAMP_START(TopologicalSorting); |
|
|
|
@@ -3714,6 +3716,41 @@ void TensorFlowModelParser::DumpAllNodeContext(const string &phase) { |
|
|
|
DumpNodeContext(iter.first, iter.second, phase); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Status TensorFlowModelParser::CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph) { |
|
|
|
GE_CHECK_NOTNULL(compute_graph); |
|
|
|
for (auto &node : compute_graph->GetDirectNode()) { |
|
|
|
auto op_desc = node->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
for (auto &in_anchor : node->GetAllInDataAnchors()) { |
|
|
|
if (!(op_desc->IsOptionalInput(static_cast<uint32_t>(in_anchor->GetIdx())))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); |
|
|
|
auto in_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(in_anchor->GetIdx())); |
|
|
|
if ((peer_out_anchor != nullptr) && (in_desc == nullptr)) { |
|
|
|
// The input is connected to the peer output but TensorDesc is invalid, update TensorDesc to valid. |
|
|
|
ge::GeTensorDesc tensor_desc; |
|
|
|
auto ret = op_desc->UpdateInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()), tensor_desc); |
|
|
|
if (ret != ge::GRAPH_SUCCESS) { |
|
|
|
GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx()); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
GELOGI("Update input desc to valid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx()); |
|
|
|
} else if ((peer_out_anchor == nullptr) && (in_desc != nullptr)) { |
|
|
|
// The input is not connected to the peer output but TensorDesc is valid, update TensorDesc to invalid. |
|
|
|
ge::GeTensorDesc tensor_desc(ge::GeShape(), FORMAT_RESERVED, DT_UNDEFINED); |
|
|
|
auto ret = op_desc->UpdateInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()), tensor_desc); |
|
|
|
if (ret != ge::GRAPH_SUCCESS) { |
|
|
|
GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx()); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
GELOGI("Update input desc to invalid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} // namespace ge |
|
|
|
|
|
|
|
namespace domi { |
|
|
|
|