Browse Source

Merge branch 'development' of https://gitee.com/ascend/parser into development

pull/63/head
l00444296 5 years ago
parent
commit
fd6da93289
5 changed files with 45 additions and 3 deletions
  1. +1
    -1
      metadef
  2. +1
    -1
      parser/CMakeLists.txt
  3. +5
    -1
      parser/caffe/caffe_parser.cc
  4. +37
    -0
      parser/tensorflow/tensorflow_parser.cc
  5. +1
    -0
      parser/tensorflow/tensorflow_parser.h

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit bf012e65ec60b40d18936677547c94bbd1f89323
Subproject commit d08620f627a7fb8087271c2dea0677f8a00b96d4

+ 1
- 1
parser/CMakeLists.txt View File

@@ -131,7 +131,7 @@ target_compile_options(fmk_parser_stub PRIVATE
)

target_compile_definitions(fmk_parser_stub PRIVATE
$<$<STREQUAL:${PRODUCT_SIDE},host>:FMK_SUPPORT_DUMP>
$<$<OR:$<STREQUAL:${PRODUCT_SIDE},host>,$<STREQUAL:${ENABLE_OPEN_SRC},True>>:FMK_SUPPORT_DUMP>
PROTOBUF_INLINE_NOT_IN_HEADERS=0
REUSE_MEMORY=1
FMK_HOST_INFER


+ 5
- 1
parser/caffe/caffe_parser.cc View File

@@ -1599,6 +1599,7 @@ void CaffeModelParser::SaveOrigionLayerTops(domi::caffe::LayerParameter &layer)
Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &layer) {
string name = layer.name();
if (node_map.find(name) == node_map.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E11034", {"opname"}, {name});
GELOGE(FAILED, "Node can not be found by layer name: %s", name.c_str());
return FAILED;
}
@@ -1608,6 +1609,8 @@ Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &la

if (node->GetType() == ge::parser::DATA) {
if (layer.top_size() != 1) {
ErrorManager::GetInstance().ATCReportErrMessage("E11035", {"opname", "size"},
{name, std::to_string(layer.top_size())});
GELOGE(FAILED, "Data layer[%s] top size must be 1, real size: %d", name.c_str(), layer.top_size());
return FAILED;
}
@@ -1615,7 +1618,8 @@ Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &la
string top_name = layer.top(0);
auto data_top_names = ge::GetParserContext().data_top_names;
if (find(data_top_names.begin(), data_top_names.end(), top_name) != data_top_names.end()) {
GELOGE(FAILED, "Different data can not have same top name: %s.", top_name.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E11036", {"topname"}, {top_name});
GELOGE(FAILED, "Different data node can not have same top name: %s.", top_name.c_str());
return FAILED;
}
ge::GetParserContext().data_top_names.push_back(top_name);


+ 37
- 0
parser/tensorflow/tensorflow_parser.cc View File

@@ -1444,6 +1444,7 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro

GE_RETURN_IF_ERROR(AddEdges(graph));
GE_RETURN_IF_ERROR(RemoveIsolateNode(graph));
GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph));
GE_RETURN_IF_ERROR(graph->TopologicalSorting());

if (has_error) {
@@ -2243,6 +2244,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,
PARSER_TIMESTAMP_START(RemoveIsolateNode);
// Delete isolated nodes
GE_RETURN_IF_ERROR(RemoveIsolateNode(graph));
GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph));

PARSER_TIMESTAMP_END(RemoveIsolateNode, "TensorFlowModelParser::RemoveIsolateNode");
PARSER_TIMESTAMP_START(TopologicalSorting);
@@ -3761,6 +3763,41 @@ void TensorFlowModelParser::DumpAllNodeContext(const string &phase) {
DumpNodeContext(iter.first, iter.second, phase);
}
}

Status TensorFlowModelParser::CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph) {
GE_CHECK_NOTNULL(compute_graph);
for (auto &node : compute_graph->GetDirectNode()) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (auto &in_anchor : node->GetAllInDataAnchors()) {
if (!(op_desc->IsOptionalInput(static_cast<uint32_t>(in_anchor->GetIdx())))) {
continue;
}
auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
auto in_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()));
if ((peer_out_anchor != nullptr) && (in_desc == nullptr)) {
// The input is connected to the peer output but TensorDesc is invalid, update TensorDesc to valid.
ge::GeTensorDesc tensor_desc;
auto ret = op_desc->UpdateInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()), tensor_desc);
if (ret != ge::GRAPH_SUCCESS) {
GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
return ret;
}
GELOGI("Update input desc to valid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
} else if ((peer_out_anchor == nullptr) && (in_desc != nullptr)) {
// The input is not connected to the peer output but TensorDesc is valid, update TensorDesc to invalid.
ge::GeTensorDesc tensor_desc(ge::GeShape(), FORMAT_RESERVED, DT_UNDEFINED);
auto ret = op_desc->UpdateInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()), tensor_desc);
if (ret != ge::GRAPH_SUCCESS) {
GELOGE(ret, "Failed to update input desc, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
return ret;
}
GELOGI("Update input desc to invalid, node:%s, index:%d.", node->GetName().c_str(), in_anchor->GetIdx());
}
}
}
return SUCCESS;
}
} // namespace ge

namespace domi {


+ 1
- 0
parser/tensorflow/tensorflow_parser.h View File

@@ -603,6 +603,7 @@ class TensorFlowModelParser : public domi::ModelParser {
void DumpAllNodeContext(const string &phase);

Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr<OpParser> &op_parser);
Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph);

/**
* save <node_name, node_def>


Loading…
Cancel
Save