Browse Source

!44 update input desc

Merge pull request !44 from wangzhengjun/update_input
pull/44/MERGE
王涛 Gitee 5 years ago
parent
commit
66aba68e3d
2 changed files with 38 additions and 0 deletions
  1. +37
    -0
      parser/tensorflow/tensorflow_parser.cc
  2. +1
    -0
      parser/tensorflow/tensorflow_parser.h

+ 37
- 0
parser/tensorflow/tensorflow_parser.cc View File

@@ -1397,6 +1397,7 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro

GE_RETURN_IF_ERROR(AddEdges(graph));
GE_RETURN_IF_ERROR(RemoveIsolateNode(graph));
GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph));
GE_RETURN_IF_ERROR(graph->TopologicalSorting());

if (has_error) {
@@ -2196,6 +2197,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,
PARSER_TIMESTAMP_START(RemoveIsolateNode);
// Delete isolated nodes
GE_RETURN_IF_ERROR(RemoveIsolateNode(graph));
GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph));

PARSER_TIMESTAMP_END(RemoveIsolateNode, "TensorFlowModelParser::RemoveIsolateNode");
PARSER_TIMESTAMP_START(TopologicalSorting);
@@ -3714,6 +3716,41 @@ void TensorFlowModelParser::DumpAllNodeContext(const string &phase) {
DumpNodeContext(iter.first, iter.second, phase);
}
}

Status TensorFlowModelParser::CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph) {
GE_CHECK_NOTNULL(compute_graph);
for (auto &node : compute_graph->GetDirectNode()) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (auto &in_anchor : node->GetAllInDataAnchors()) {
if (!(op_desc->IsOptionalInput(static_cast<uint32_t>(in_anchor->GetIdx())))) {
continue;
}
auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
auto in_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()));
if ((peer_out_anchor != nullptr) && (in_desc == nullptr)) {
// The input is connected to the peer output but TensorDesc is invalid, update TensorDesc to valid.
ge::GeTensorDesc tensor_desc;
auto ret = op_desc->UpdateInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()), tensor_desc);
if (ret != ge::GRAPH_SUCCESS) {
GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
return ret;
}
GELOGI("Update input desc to valid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
} else if ((peer_out_anchor == nullptr) && (in_desc != nullptr)) {
// The input is not connected to the peer output but TensorDesc is valid, update TensorDesc to invalid.
ge::GeTensorDesc tensor_desc(ge::GeShape(), FORMAT_RESERVED, DT_UNDEFINED);
auto ret = op_desc->UpdateInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()), tensor_desc);
if (ret != ge::GRAPH_SUCCESS) {
GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
return ret;
}
GELOGI("Update input desc to invalid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
}
}
}
return SUCCESS;
}
} // namespace ge

namespace domi {


+ 1
- 0
parser/tensorflow/tensorflow_parser.h View File

@@ -603,6 +603,7 @@ class TensorFlowModelParser : public domi::ModelParser {
void DumpAllNodeContext(const string &phase);

Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr<OpParser> &op_parser);
Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph);

/**
* save <node_name, node_def>


Loading…
Cancel
Save