Browse Source

onnx output name

pull/351/head
wangzhengjun 4 years ago
parent
commit
d2187fe086
18 changed files with 612 additions and 76 deletions
  1. +12
    -12
      parser/caffe/caffe_parser.cc
  2. +41
    -19
      parser/common/acl_graph_parser_util.cc
  3. +2
    -2
      parser/common/acl_graph_parser_util.h
  4. +31
    -5
      parser/common/parser_utils.cc
  5. +11
    -3
      parser/common/parser_utils.h
  6. +72
    -18
      parser/onnx/onnx_parser.cc
  7. +7
    -2
      parser/onnx/onnx_parser.h
  8. +16
    -2
      parser/tensorflow/tensorflow_parser.cc
  9. +2
    -0
      parser/tensorflow/tensorflow_parser.h
  10. +2
    -2
      tests/depends/mmpa/src/mmpa_stub.cc
  11. +4
    -0
      tests/ut/parser/CMakeLists.txt
  12. +44
    -0
      tests/ut/parser/parser_ut_utils.cc
  13. +29
    -0
      tests/ut/parser/parser_ut_utils.h
  14. +14
    -0
      tests/ut/parser/testcase/caffe_parser_testcase/caffe_model/caffe_abs.pbtxt
  15. +158
    -0
      tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc
  16. +97
    -0
      tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc
  17. +66
    -9
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc
  18. +4
    -2
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 12
- 12
parser/caffe/caffe_parser.cc View File

@@ -601,17 +601,17 @@ void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_ind
}

Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) {
if (ge::GetParserContext().user_out_nodes_top_vec.empty()) {
if (ge::GetParserContext().user_out_tensors.empty()) {
return SUCCESS;
}

ge::GetParserContext().out_nodes_map.clear();
ge::GetParserContext().user_out_nodes.clear();
int32_t layer_count = proto_message.layer_size();
const std::vector<string> &user_out_nodes_top_vec =
ge::GetParserContext().user_out_nodes_top_vec;
const std::vector<string> &user_out_tensors =
ge::GetParserContext().user_out_tensors;

for (const auto &top_name : user_out_nodes_top_vec) {
for (const auto &top_name : user_out_tensors) {
bool find_node_falg = false;
string layer_name;
int32_t top_index = 0;
@@ -1082,7 +1082,7 @@ Status CaffeModelParser::AddUserOutNodesTop() {
string top_name = layer_iter->second[out_pair.second];
auto top_node_iter = node_map.find(out_pair.first);
if (top_node_iter != node_map.end()) {
ge::GetParserContext().out_top_names.push_back(top_name);
ge::GetParserContext().out_tensor_names.push_back(top_name);
GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str());
}
++index;
@@ -1129,7 +1129,7 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes
auto top_node_iter = node_map.find(layer.name());
GELOGI("output in top_blob: %s", layer.name().c_str());
if (top_node_iter != node_map.end()) {
ge::GetParserContext().out_top_names.push_back(top_origin);
ge::GetParserContext().out_tensor_names.push_back(top_origin);
ge::GetParserContext().default_out_nodes.push_back(std::make_pair(layer.name(), (int32_t)i));
GELOGI("The top of out node [%s] is [%s]", layer.name().c_str(), top_origin.c_str());
}
@@ -1389,13 +1389,13 @@ Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &la
}

string top_name = layer.top(0);
auto data_top_names = ge::GetParserContext().data_top_names;
if (find(data_top_names.begin(), data_top_names.end(), top_name) != data_top_names.end()) {
auto data_tensor_names = ge::GetParserContext().data_tensor_names;
if (find(data_tensor_names.begin(), data_tensor_names.end(), top_name) != data_tensor_names.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E11036", {"topname"}, {top_name});
GELOGE(FAILED, "[Check][Node]Different data node can not have same top name: %s.", top_name.c_str());
return FAILED;
}
ge::GetParserContext().data_top_names.push_back(top_name);
ge::GetParserContext().data_tensor_names.push_back(top_name);
}

return SUCCESS;
@@ -1464,18 +1464,18 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap

int32_t layer_count = proto_message.layer_size();

if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) {
if (!ge::GetParserContext().user_out_tensors.empty()) {
GELOGW("The out_put info has top_name items.");
GE_RETURN_WITH_LOG_IF_ERROR(ParseOutputNodeTopInfo(proto_message),
"[Parse][OutputNodeTopInfo] failed.");
ge::GetParserContext().user_out_nodes_top_vec.clear();
ge::GetParserContext().user_out_tensors.clear();
}

std::map<std::string, std::string> inplace_blob_name_remapping;
// Map of operator name and occurrence times
std::map<std::string, int32_t> layer_name_map;

GetParserContext().data_top_names.clear();
GetParserContext().data_tensor_names.clear();
// <layername,paramnames>
std::map<std::string, std::vector<std::string>> layer_params_map;
// same param name set <paramnames,layernames>


+ 41
- 19
parser/common/acl_graph_parser_util.cc View File

@@ -52,6 +52,13 @@ const int kMaxFileSizeLimit = INT_MAX;
const int kMaxBuffSize = 256;
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M
const uint32_t kSetOutputWithNodeAndIndex = 0x1;
const uint32_t kSetOutputWithTensorName = 0x2;
const uint32_t kSetOutputModeMixed = 0x3;
const std::unordered_set<domi::FrameworkType> kSupportTensorAsOutput = {
domi::CAFFE,
domi::ONNX
};

static string GetSoPath() {
Dl_info dl_info;
@@ -263,14 +270,19 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) {
if (!out_nodes.empty()) {
ge::GetParserContext().out_nodes_map.clear();
ge::GetParserContext().user_out_nodes.clear();
ge::GetParserContext().user_out_nodes_top_vec.clear();
ge::GetParserContext().user_out_tensors.clear();
uint32_t set_output_mode = 0;

vector<string> nodes_v = StringUtils::Split(out_nodes, ';');
for (const string &node : nodes_v) {
vector<string> key_value_v = StringUtils::Split(node, ':');
if (key_value_v.size() != 2) { // The size must be 2.
if (key_value_v.size() == 1 && ge::GetParserContext().type == domi::CAFFE) {
ge::GetParserContext().user_out_nodes_top_vec.push_back(node);
if (key_value_v.size() == 1 && kSupportTensorAsOutput.count(ge::GetParserContext().type) > 0) {
set_output_mode |= kSetOutputWithTensorName;
if (set_output_mode == kSetOutputModeMixed) {
break;
}
ge::GetParserContext().user_out_tensors.push_back(node);
continue;
}
ErrorManager::GetInstance().ATCReportErrMessage(
@@ -281,12 +293,9 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) {
node.c_str());
return PARAM_INVALID;
}
if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"out_nodes", out_nodes, "is not all index or top_name"});
GELOGE(PARAM_INVALID, "[Check][Param] This out_nodes str must be all index or top_name, "
"while the actual input is %s", out_nodes.c_str());
return PARAM_INVALID;
set_output_mode |= kSetOutputWithNodeAndIndex;
if (set_output_mode == kSetOutputModeMixed) {
break;
}
// stoi: The method may throw an exception: invalid_argument/out_of_range
if (!CheckDigitStr(key_value_v[1])) {
@@ -309,6 +318,13 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) {
}
ge::GetParserContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index));
}
if (set_output_mode == kSetOutputModeMixed) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--out_nodes", out_nodes, "is not all index or top_name"});
GELOGE(PARAM_INVALID, "[Parse][Param]This out_nodes str must be all index or tensor_name, "
"while the actual input is %s", out_nodes.c_str());
return PARAM_INVALID;
}
}
} catch (std::invalid_argument &) {
GELOGE(PARAM_INVALID, "[Check][Param] Invalid of out_nodes: %s ", out_nodes.c_str());
@@ -410,10 +426,11 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra
return SUCCESS;
}

void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name) {
void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name) {
output_nodes_name.clear();
if (ge::GetParserContext().out_top_names.empty()) {
auto &out_tensor_names = ge::GetParserContext().out_tensor_names;
if (out_tensor_names.empty()) {
// tf process, no top name.
for (const auto output_node_info : output_nodes_info) {
std::string node_name = output_node_info.first->GetName();
@@ -422,13 +439,18 @@ void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::Node
}
return;
}
// caffe process, need add top name after node_name:index

// Need add top name after node_name:index
for (size_t i = 0; i < output_nodes_info.size(); ++i) {
std::string node_name = output_nodes_info[i].first->GetName();
auto node = output_nodes_info[i].first;
int32_t index = output_nodes_info[i].second;
if (i < ge::GetParserContext().out_top_names.size()) {
output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" +
ge::GetParserContext().out_top_names[i]);
std::string node_name = node->GetName();
if (i < out_tensor_names.size()) {
auto output_desc = node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(index));
(void)AttrUtils::SetStr(output_desc, ATTR_NAME_ORIGIN_OUTPUT_TENSOR_NAME, out_tensor_names[i]);
std::string output_name = node->GetName() + ":" + std::to_string(index) + ":" + out_tensor_names[i];
output_nodes_name.push_back(output_name);
GELOGD("Output[%zu] name[%s]", i, output_name.c_str());
} else {
GELOGW("Get top name of node [%s] fail.", node_name.c_str());
output_nodes_name.push_back(node_name + ":" + std::to_string(index));
@@ -469,7 +491,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node,
domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes;
if (ge::GetParserContext().type == domi::CAFFE && !default_out_nodes.empty()) {
if (!default_out_nodes.empty()) {
for (uint32_t i = 0; i < default_out_nodes.size(); ++i) {
ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first);
if (out_node == nullptr) {
@@ -543,7 +565,7 @@ domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph,
return domi::FAILED;
}
}
GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name);
CreateOutputNodesInfo(output_nodes_info, output_nodes_name);
compute_graph->SetGraphOutNodesInfo(output_nodes_info);
ge::GetParserContext().net_out_nodes = output_nodes_name;
GELOGI("Set graph %s output node success.", graph.GetName().c_str());


+ 2
- 2
parser/common/acl_graph_parser_util.h View File

@@ -50,8 +50,8 @@ class AclGrphParseUtil {
bool parser_initialized = false;
domi::Status CheckOptions(const std::map<AscendString, AscendString> &parser_params);
domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);
void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);
void CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);
void SetDefaultFormat();
domi::Status ParseAclOutputNodes(const std::string &out_nodes);
domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16);


+ 31
- 5
parser/common/parser_utils.cc View File

@@ -71,7 +71,7 @@ Status HandleNewOp(const NodePtr &node,
}
}

Status ParserUtils::ExpandOneToManyGraph(Graph &graph) {
Status ParserUtils::ExpandOneToManyGraph(Graph &graph, OutputMapping &output_mapping) {
GELOGD("Begin run ParserUtils::ExpandOneToManyGraph.");
for (const auto &gn : graph.GetDirectNode()) {
NodePtr n = NodeAdapter::GNode2Node(gn);
@@ -95,7 +95,7 @@ Status ParserUtils::ExpandOneToManyGraph(Graph &graph) {
GELOGE(FAILED, "[Invoke][ParseOpToGraphFunc]Get one to many graph failed for op:%s.", op.GetName().c_str());
return FAILED;
}
ret = ExpandNodeToSubgraph(subgraph, n, graph);
ret = ExpandNodeToSubgraph(subgraph, n, graph, output_mapping);
if (ret != SUCCESS) {
GELOGE(FAILED, "[Invoke][ExpandNodeToSubgraph]Expand one to many graph failed for op:%s.", op.GetName().c_str());
return FAILED;
@@ -105,7 +105,8 @@ Status ParserUtils::ExpandOneToManyGraph(Graph &graph) {
return SUCCESS;
}

Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph) {
Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph,
OutputMapping &output_mapping) {
ComputeGraphPtr sub_compute_graph = GraphUtils::GetComputeGraph(subgraph);
GE_CHECK_NOTNULL(sub_compute_graph);
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
@@ -135,7 +136,7 @@ Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &n

// handle output context.
std::vector<std::pair<NodePtr, int32_t>> out_node_index = sub_compute_graph->GetGraphOutNodesInfo();
ret = HandleOutputContext(node, out_node_index);
ret = HandleOutputContext(node, out_node_index, output_mapping);
if (ret != SUCCESS) {
GELOGE(FAILED, "[Run][HandleOutputContext] failed, node:%s.", node->GetName().c_str());
return FAILED;
@@ -235,7 +236,8 @@ Status ParserUtils::HandleInputContext(const NodePtr &node,
}

Status ParserUtils::HandleOutputContext(const NodePtr &node,
const std::vector<std::pair<NodePtr, int32_t>> &out_node_index) {
const std::vector<std::pair<NodePtr, int32_t>> &out_node_index,
OutputMapping &output_mapping) {
GE_CHECK_NOTNULL(node);
GELOGD("The size of out node is %zu", out_node_index.size());
for (size_t index = 0; index < out_node_index.size(); index++) {
@@ -247,6 +249,8 @@ Status ParserUtils::HandleOutputContext(const NodePtr &node,
NodePtr out_node = out_node_index[index].first;
int32_t out_index = out_node_index[index].second;
GELOGD("Begin to handle output node:%s[%d] with index:%zu", out_node->GetName().c_str(), out_index, index);
std::string key = GenOutputKey({node->GetName(), index});
output_mapping[key] = std::make_pair(out_node->GetName(), out_index);
auto src_out_anchor = out_node->GetOutDataAnchor(out_index); // get out node's out anchor.
GE_CHECK_NOTNULL(src_out_anchor);
for (const auto &dest_in_anchor : node_out_anchor->GetPeerInDataAnchors()) {
@@ -273,4 +277,26 @@ Status ParserUtils::HandleOutputContext(const NodePtr &node,
}
return SUCCESS;
}

string ParserUtils::GenOutputKey(const OutputNodeInfo &node_info) {
return node_info.first + ":" + std::to_string(node_info.second);
}

void ParserUtils::UpdateOutputNodeInfo(const OutputMapping &final_output_nodes, OutputNodeInfo &output_node_info) {
std::string key = ParserUtils::GenOutputKey(output_node_info);
auto iter = final_output_nodes.find(key);
if (iter != final_output_nodes.end()) {
output_node_info = iter->second;
GELOGD("Update output node info, origin[%s], now[%s].",
key.c_str(), ParserUtils::GenOutputKey(output_node_info).c_str());
}
}

void ParserUtils::UpdateOutputCtx(const OutputMapping &final_output_nodes, OutputMapping &tensor_to_nodes) {
for (auto &tensor_to_node : tensor_to_nodes) {
std::string tensor_name = tensor_to_node.first;
auto &output_node_info = tensor_to_node.second;
UpdateOutputNodeInfo(final_output_nodes, output_node_info);
}
}
} // namespace ge

+ 11
- 3
parser/common/parser_utils.h View File

@@ -17,6 +17,7 @@
#ifndef PARSER_COMMON_PARSER_UTILS_H_
#define PARSER_COMMON_PARSER_UTILS_H_

#include <unordered_map>
#include "graph/graph.h"
#include "graph/node.h"
#include "external/ge/ge_api_error_codes.h"
@@ -24,15 +25,22 @@
namespace ge {
class ParserUtils {
public:
static Status ExpandOneToManyGraph(Graph &graph);
using OutputNodeInfo = std::pair<std::string, int32_t>;
using OutputMapping = std::unordered_map<std::string, OutputNodeInfo>;
static Status ExpandOneToManyGraph(Graph &graph, OutputMapping &output_mapping);
static string GenOutputKey(const OutputNodeInfo &node_info);
static void UpdateOutputNodeInfo(const OutputMapping &final_output_nodes, OutputNodeInfo &output_node_info);
static void UpdateOutputCtx(const OutputMapping &final_output_nodes, OutputMapping &tensor_to_nodes);

private:
static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph);
static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph,
OutputMapping &output_mapping);
static Status HandleInputContext(const NodePtr &node,
const std::vector<NodePtr> &input_nodes,
const ComputeGraphPtr &compute_graph);
static Status HandleOutputContext(const NodePtr &node,
const std::vector<std::pair<NodePtr, int32_t>> &out_node_index);
const std::vector<std::pair<NodePtr, int32_t>> &out_node_index,
OutputMapping &output_mapping);
};
} // namespace ge
#endif // PARSER_COMMON_PARSER_UTILS_H_

+ 72
- 18
parser/onnx/onnx_parser.cc View File

@@ -360,7 +360,7 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph,
return SUCCESS;
}

Status OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) {
void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) {
int index = 0;
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i);
@@ -369,8 +369,6 @@ Status OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) {
node->set_name(node_name);
}
}

return SUCCESS;
}

Status OnnxModelParser::ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type) {
@@ -676,7 +674,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve
return SUCCESS;
}

Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops) {
Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops,
ParserUtils::OutputMapping &out_tensor_to_nodes) {
for (auto output_name : output_node_names_) {
auto itr = outputs_map_.find(output_name);
if (itr == outputs_map_.end()) {
@@ -696,6 +695,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vec
}
int index = node_name_index.second;
output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)});
out_tensor_to_nodes[output_name] = std::make_pair(node_name, index);
GELOGI("out node index %d, node:%s", index, node_name.c_str());
}
}
@@ -870,7 +870,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP

GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX),
"Run ProtoType Pass Failed");
// 2. Get all inializer.
// 1. Get all inializer.
std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor;
for (int i = 0; i < onnx_graph.initializer_size(); i++) {
ge::onnx::TensorProto initializer_tensor = onnx_graph.initializer(i);
@@ -880,7 +880,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP
}
}

// 3. Parse Input from graph.
// 2. Parse Input from graph.
GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size());

Status ret = ParseInput(initializer_name_tensor, is_subgraph, onnx_graph);
@@ -890,13 +890,14 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP
}
GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size());

// 4. Parse Constant from graph.
// 3. Parse Constant from graph.
ret = ParseInitializer(onnx_graph, initializer_name_tensor);
if (ret != SUCCESS) {
GELOGE(ret, "[Parse][Initializer] for onnx failed.");
return ret;
}

// 4. Get all output name form origin graph
ret = ParseOutput(onnx_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Parse][Output] Parse output for onnx failed.");
@@ -904,11 +905,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP
}

// 5. Update node name for node do not has name.
ret = UpdateAllNodeName(onnx_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Update][Name] of all node for onnx failed.");
return ret;
}
UpdateAllNodeName(onnx_graph);

// 6 Precheck.
ret = Prechecker(onnx_graph);
@@ -950,19 +947,32 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP
}
graph.SetInputs(input_ops);

// 10. Get output info and set outpus for subgraph
std::vector<std::pair<Operator, std::vector<size_t>>> output_ops;
ParserUtils::OutputMapping out_tensor_to_nodes;
ret = GetGraphOutputs(output_ops, out_tensor_to_nodes);
if (ret != SUCCESS) {
GELOGE(ret, "[Get][Outputs] failed.");
return ret;
}
// root graph needn't set outputs.
if(is_subgraph) {
std::vector<std::pair<Operator, std::vector<size_t>>> output_ops;
ret = GetGraphOutputs(output_ops);
graph.SetOutputs(output_ops);
}

// 11. Expand node to graph if need
ParserUtils::OutputMapping final_output_nodes;
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph, final_output_nodes));

// 12. Set outputs info in ParserContext for root graph
if (!is_subgraph) {
ret = SetOutputsInfo(final_output_nodes, out_tensor_to_nodes);
if (ret != SUCCESS) {
GELOGE(ret, "[Get][Outputs] failed.");
GELOGE(ret, "[Set][OutputsInfo] Graph:%s.", graph.GetName().c_str());
return ret;
}
graph.SetOutputs(output_ops);
}

GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph));

GELOGI("Onnx model parser success.");
return SUCCESS;
}
@@ -1048,6 +1058,50 @@ void OnnxModelParser::UpdateDataFormat(ge::Graph &graph) {
return;
}

Status OnnxModelParser::SetOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes,
const ParserUtils::OutputMapping &tensor_to_nodes) {
auto &user_specified_nodes = ge::GetParserContext().user_out_nodes;
if (!user_specified_nodes.empty()) {
GELOGI("User specified the output nodes with node_name and index.");
for (auto &output_node_info : user_specified_nodes) {
ParserUtils::UpdateOutputNodeInfo(final_output_nodes, output_node_info);
}
return SUCCESS;
}

auto final_tensor_to_nodes = tensor_to_nodes;
ParserUtils::UpdateOutputCtx(final_output_nodes, final_tensor_to_nodes);
auto &user_specified_tensors = ge::GetParserContext().user_out_tensors;
auto &output_tensor_names = ge::GetParserContext().out_tensor_names;
output_tensor_names.clear();
if (!user_specified_tensors.empty()) {
for (auto &tensor_name : user_specified_tensors) {
auto iter = final_tensor_to_nodes.find(tensor_name);
if (iter != final_tensor_to_nodes.end()) {
user_specified_nodes.emplace_back(iter->second);
output_tensor_names.emplace_back(tensor_name);
GELOGI("[UserSpecified]Add network output node[%s], index[%d], tensor name[%s].",
iter->second.first.c_str(), iter->second.second, tensor_name.c_str());
} else {
REPORT_INNER_ERROR("E19999", "User specified tensor[%s] is not output of graph.", tensor_name.c_str());
GELOGE(FAILED, "[Set][OutputsInfo]User specified tensor[%s] is not output of graph.", tensor_name.c_str());
return FAILED;
}
}
return SUCCESS;
}

// for default output
auto &default_out_nodes = ge::GetParserContext().default_out_nodes;
for (auto &tensor_name : output_node_names_) {
auto &output_node_info = final_tensor_to_nodes[tensor_name];
default_out_nodes.emplace_back(output_node_info);
output_tensor_names.emplace_back(tensor_name);
GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].",
output_node_info.first.c_str(), output_node_info.second, tensor_name.c_str());
}
return SUCCESS;
}
} // namespace domi

namespace domi {


+ 7
- 2
parser/onnx/onnx_parser.h View File

@@ -38,6 +38,7 @@
#include "omg/parser/model_parser.h"
#include "omg/parser/op_parser.h"
#include "omg/parser/weights_parser.h"
#include "common/parser_utils.h"
#include "proto/onnx/ge_onnx.pb.h"

namespace ge {
@@ -80,7 +81,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {
Status ParseInitializer(ge::onnx::GraphProto &onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor);

Status UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph);
void UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph);

Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type);

@@ -94,7 +95,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {

Status GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops);

Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &outputs);
Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &outputs,
ParserUtils::OutputMapping &out_tensor_to_nodes);

Status Prechecker(ge::onnx::GraphProto &onnx_graph);
@@ -115,6 +117,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {
Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph);

Status SetOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes,
const ParserUtils::OutputMapping &tensor_to_nodes);

std::map<std::string, std::string> ori_to_om_type_;

std::map<std::string, int64_t> domain_verseion_;


+ 16
- 2
parser/tensorflow/tensorflow_parser.cc View File

@@ -1494,7 +1494,9 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro

GE_RETURN_IF_ERROR(AddEdges(graph));
Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph));
ParserUtils::OutputMapping final_output_nodes;
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph, final_output_nodes));
GE_RETURN_IF_ERROR(UpdateOutputsInfo(final_output_nodes));
GE_RETURN_IF_ERROR(RemoveIsolateNode(graph));
GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph));
GE_RETURN_IF_ERROR(graph->TopologicalSorting());
@@ -2304,7 +2306,9 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,
ret = AddEdges(graph);

Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph));
ParserUtils::OutputMapping final_output_nodes;
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph, final_output_nodes));
GE_RETURN_IF_ERROR(UpdateOutputsInfo(final_output_nodes));

DeleteFuisonNodeDef();
GE_CHK_STATUS_EXEC(ret, return ret, "AddEdges failed");
@@ -4020,6 +4024,16 @@ Status TensorFlowModelParser::CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compu
}
return SUCCESS;
}

Status TensorFlowModelParser::UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes) {
auto &user_specified_nodes = ge::GetParserContext().user_out_nodes;
if (!user_specified_nodes.empty()) {
for (auto &output_node_info : user_specified_nodes) {
ParserUtils::UpdateOutputNodeInfo(final_output_nodes, output_node_info);
}
}
return SUCCESS;
}
} // namespace ge

namespace domi {


+ 2
- 0
parser/tensorflow/tensorflow_parser.h View File

@@ -44,6 +44,7 @@
#include "proto/tensorflow/graph_library.pb.h"
#include "external/register/scope/scope_fusion_pass_register.h"
#include "scope/scope_pass_manager.h"
#include "common/parser_utils.h"

using ge::ScopePassManager;
using domi::tensorflow::GraphDef;
@@ -647,6 +648,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {

Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr<OpParser> &op_parser);
Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph);
static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes);

/**
* save <node_name, node_def>


+ 2
- 2
tests/depends/mmpa/src/mmpa_stub.cc View File

@@ -275,10 +275,10 @@ INT32 mmGetPid()
}

INT32 mmDup2(INT32 oldFd, INT32 newFd) {
return -1;
return 0;
}

INT32 mmDup(INT32 fd) {
return -1;
return 0;
}


+ 4
- 0
tests/ut/parser/CMakeLists.txt View File

@@ -296,6 +296,7 @@ include_directories(${PARSER_DIR})
include_directories(${PARSER_DIR}/inc)
include_directories(${PARSER_DIR}/parser)
include_directories(${PARSER_DIR}/parser/onnx)
include_directories(${PARSER_DIR}/tests)
include_directories(${PARSER_DIR}/metadef/inc)
include_directories(${PARSER_DIR}/metadef/inc/external)
include_directories(${PARSER_DIR}/metadef/inc/register)
@@ -306,7 +307,10 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework)


set(PARSER_UT_FILES
"parser_ut_utils.cc"
"testcase/common/acl_graph_parser_unittest.cc"
"testcase/onnx_parser_testcase/onnx_parser_unittest.cc"
"testcase/caffe_parser_testcase/caffe_parser_unittest.cc"
"testcase/onnx_parser_testcase/message2operator_unittest.cc"
"testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc"
"testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc"


+ 44
- 0
tests/ut/parser/parser_ut_utils.cc View File

@@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ut/parser/parser_ut_utils.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
void ParerUTestsUtils::ClearParserInnerCtx() {
ge::GetParserContext().input_nodes_format_map.clear();
ge::GetParserContext().output_formats.clear();
ge::GetParserContext().user_input_dims.clear();
ge::GetParserContext().input_dims.clear();
ge::GetParserContext().op_conf_map.clear();
ge::GetParserContext().user_out_nodes.clear();
ge::GetParserContext().default_out_nodes.clear();
ge::GetParserContext().out_nodes_map.clear();
ge::GetParserContext().user_out_tensors.clear();
ge::GetParserContext().net_out_nodes.clear();
ge::GetParserContext().out_tensor_names.clear();
ge::GetParserContext().data_tensor_names.clear();
ge::GetParserContext().is_dynamic_input = false;
ge::GetParserContext().train_flag = false;
ge::GetParserContext().format = domi::DOMI_TENSOR_ND;
ge::GetParserContext().type = domi::FRAMEWORK_RESERVED;
ge::GetParserContext().run_mode = GEN_OM_MODEL;
ge::GetParserContext().custom_proto_path = "";
ge::GetParserContext().caffe_proto_path = "";
ge::GetParserContext().enable_scope_fusion_passes = "";
GELOGI("Clear parser inner context successfully.");
}
} // namespace ge

+ 29
- 0
tests/ut/parser/parser_ut_utils.h View File

@@ -0,0 +1,29 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_PARSER_TESTS_UT_PARSER_H_
#define GE_PARSER_TESTS_UT_PARSER_H_

#include "framework/omg/parser/parser_inner_ctx.h"

namespace ge {
class ParerUTestsUtils {
public:
static void ClearParserInnerCtx();
};
} // namespace ge

#endif // GE_PARSER_TESTS_UT_PARSER_H_

+ 14
- 0
tests/ut/parser/testcase/caffe_parser_testcase/caffe_model/caffe_abs.pbtxt View File

@@ -0,0 +1,14 @@
name: "TestAbs"
layer {
name: "data"
type: "Input"
top: "data"
input_param { shape: { dim: 64 dim: 1 dim: 28 dim: 28 } }
}
layer {
name: "abs"
type: "AbsVal"
bottom: "data"
top: "abs_out"
}

+ 158
- 0
tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc View File

@@ -0,0 +1,158 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <iostream>
#include "parser/common/op_parser_factory.h"
#include "graph/operator_reg.h"
#include "external/graph/types.h"
#include "register/op_registry.h"
#include "parser/common/register_tbe.h"
#include "framework/omg/parser/model_parser.h"
#include "framework/omg/parser/parser_factory.h"
#include "external/parser/caffe_parser.h"
#include "ut/parser/parser_ut_utils.h"
#include "external/ge/ge_api_types.h"

namespace ge {
class UtestCaffeParser : public testing::Test {
protected:
void SetUp() {
ParerUTestsUtils::ClearParserInnerCtx();
RegisterCustomOp();
}

void TearDown() {}

public:
void RegisterCustomOp();
};

static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) {
return SUCCESS;
}
void UtestCaffeParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Data")
.FrameworkType(domi::CAFFE)
.OriginOpType("Input")
.ParseParamsFn(ParseParams);

REGISTER_CUSTOM_OP("Abs")
.FrameworkType(domi::CAFFE)
.OriginOpType("AbsVal")
.ParseParamsFn(ParseParams);

std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();
}

namespace {
REG_OP(Data)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(index, Int, 0)
.OP_END_FACTORY_REG(Data)

REG_OP(Abs)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64}))
.OP_END_FACTORY_REG(Abs)
}

TEST_F(UtestCaffeParser, caffe_parser_user_output_with_name_and_index) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/caffe_model/caffe_abs.pbtxt";
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE);
ASSERT_NE(model_parser, nullptr);
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
ASSERT_NE(compute_graph, nullptr);
ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);

ge::GetParserContext().user_out_nodes.push_back({"abs", 0});
auto ret = model_parser->Parse(model_file.c_str(), graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
AclGrphParseUtil acl_graph_parse_util;
std::map<AscendString, AscendString> parser_params;
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params);
ASSERT_EQ(status, SUCCESS);

auto output_nodes_info = compute_graph->GetGraphOutNodesInfo();
ASSERT_EQ(output_nodes_info.size(), 1);
EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs");
EXPECT_EQ((output_nodes_info.at(0).second), 0);
auto &net_out_name = ge::GetParserContext().net_out_nodes;
ASSERT_EQ(net_out_name.size(), 1);
EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out");
}

TEST_F(UtestCaffeParser, caffe_parser_user_output_with_top_name) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/caffe_model/caffe_abs.pbtxt";
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE);
ASSERT_NE(model_parser, nullptr);
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
ASSERT_NE(compute_graph, nullptr);
ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);

ge::GetParserContext().user_out_tensors.push_back("abs_out");
auto ret = model_parser->Parse(model_file.c_str(), graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
AclGrphParseUtil acl_graph_parse_util;
std::map<AscendString, AscendString> parser_params;
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params);
ASSERT_EQ(status, SUCCESS);

auto output_nodes_info = compute_graph->GetGraphOutNodesInfo();
ASSERT_EQ(output_nodes_info.size(), 1);
EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs");
EXPECT_EQ((output_nodes_info.at(0).second), 0);
auto &net_out_name = ge::GetParserContext().net_out_nodes;
ASSERT_EQ(net_out_name.size(), 1);
EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out");
}

TEST_F(UtestCaffeParser, caffe_parser_user_output_with_default) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/caffe_model/caffe_abs.pbtxt";
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE);
ASSERT_NE(model_parser, nullptr);
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
ASSERT_NE(compute_graph, nullptr);
ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
auto ret = model_parser->Parse(model_file.c_str(), graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
AclGrphParseUtil acl_graph_parse_util;
std::map<AscendString, AscendString> parser_params;
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params);
ASSERT_EQ(status, SUCCESS);

auto output_nodes_info = compute_graph->GetGraphOutNodesInfo();
ASSERT_EQ(output_nodes_info.size(), 1);
EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs");
EXPECT_EQ((output_nodes_info.at(0).second), 0);
auto &net_out_name = ge::GetParserContext().net_out_nodes;
ASSERT_EQ(net_out_name.size(), 1);
EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out");
}

} // namespace ge

+ 97
- 0
tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc View File

@@ -0,0 +1,97 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <iostream>
#include "parser/common/op_parser_factory.h"
#include "graph/operator_reg.h"
#include "external/graph/types.h"
#include "register/op_registry.h"
#include "parser/common/register_tbe.h"
#include "external/parser/onnx_parser.h"
#include "ut/parser/parser_ut_utils.h"
#include "external/ge/ge_api_types.h"

namespace ge {
class UtestAclGraphParser : public testing::Test {
protected:
void SetUp() {

}
void TearDown() {}
};

TEST_F(UtestAclGraphParser, test_parse_acl_output_nodes) {
AclGrphParseUtil acl_graph_parse_util;
string graph_name;
// case 1: Normal with 'node and index'
ParerUTestsUtils::ClearParserInnerCtx();
GetParserContext().type = domi::ONNX;
std::map<AscendString, AscendString> out_nodes_with_node_and_index = {
{AscendString(ge::ir_option::OUT_NODES), AscendString("Out1:0;Out2:1")}};
ParerUTestsUtils::ClearParserInnerCtx();
auto ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_with_node_and_index, graph_name);
ASSERT_EQ(ret, SUCCESS);
EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 2);
EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 2);
EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 0);

// case 2: Normal with 'tensor name'
ParerUTestsUtils::ClearParserInnerCtx();
GetParserContext().type = domi::ONNX;
std::map<AscendString, AscendString> out_nodes_with_tensor_name = {
{AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out_tensor_2")}};
ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_with_tensor_name, graph_name);
ASSERT_EQ(ret, SUCCESS);
EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0);
EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0);
EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 2);

// case 3: Failed with 'node and index' before 'tensor name'
ParerUTestsUtils::ClearParserInnerCtx();
GetParserContext().type = domi::ONNX;
std::map<AscendString, AscendString> out_nodes_mode_mixex_pre = {
{AscendString(ge::ir_option::OUT_NODES), AscendString("Out1:0;Out2:1;Out_tensor_1;Out_tensor_2")}};
ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_pre, graph_name);
ASSERT_EQ(ret, PARAM_INVALID);
EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 2);
EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 2);
EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 0);

// case 4: Failed with 'node and index' inserted in 'tensor name'
ParerUTestsUtils::ClearParserInnerCtx();
GetParserContext().type = domi::ONNX;
std::map<AscendString, AscendString> out_nodes_mode_mixex_mid = {
{AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out1:0;Out2:1;Out_tensor_2")}};
ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_mid, graph_name);
ASSERT_EQ(ret, PARAM_INVALID);
EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0);
EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0);
EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 1);

// case 5: Failed with 'node and index' after 'tensor name'
ParerUTestsUtils::ClearParserInnerCtx();
GetParserContext().type = domi::ONNX;
std::map<AscendString, AscendString> out_nodes_mode_mixex_post = {
{AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out_tensor_2;Out1:0;Out2:1")}};
ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_post, graph_name);
ASSERT_EQ(ret, PARAM_INVALID);
EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0);
EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0);
EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 2);

}
} // namespace ge

+ 66
- 9
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -22,12 +22,16 @@
#include "register/op_registry.h"
#include "parser/common/register_tbe.h"
#include "external/parser/onnx_parser.h"

#include "ut/parser/parser_ut_utils.h"
#include "external/ge/ge_api_types.h"

namespace ge {
class UtestOnnxParser : public testing::Test {
protected:
void SetUp() {}
void SetUp() {
ParerUTestsUtils::ClearParserInnerCtx();
RegisterCustomOp();
}

void TearDown() {}

@@ -152,28 +156,81 @@ REG_OP(Identity)
.OP_END_FACTORY_REG(Identity)
}

TEST_F(UtestOnnxParser, onnx_parser_success) {
RegisterCustomOp();
TEST_F(UtestOnnxParser, onnx_parser_if_node) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/onnx_model/if.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}

TEST_F(UtestOnnxParser, onnx_parser_user_output_with_name_and_index) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/onnx_model/conv2d.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
parser_params.insert({AscendString(ge::ir_option::OUT_NODES), AscendString("Conv_0:0")});
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, domi::SUCCESS);
ASSERT_EQ(ret, GRAPH_SUCCESS);
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
auto output_nodes_info = compute_graph->GetGraphOutNodesInfo();
ASSERT_EQ(output_nodes_info.size(), 1);
EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0");
EXPECT_EQ((output_nodes_info.at(0).second), 0);
auto &net_out_name = ge::GetParserContext().net_out_nodes;
ASSERT_EQ(net_out_name.size(), 1);
EXPECT_EQ(net_out_name.at(0), "Conv_0:0");
}

TEST_F(UtestOnnxParser, onnx_parser_if_node) {
RegisterCustomOp();
TEST_F(UtestOnnxParser, onnx_parser_user_output_with_tensor) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/onnx_model/conv2d.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
parser_params.insert({AscendString(ge::ir_option::OUT_NODES), AscendString("y")});
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
auto output_nodes_info = compute_graph->GetGraphOutNodesInfo();
ASSERT_EQ(output_nodes_info.size(), 1);
EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0");
EXPECT_EQ((output_nodes_info.at(0).second), 0);
auto &net_out_name = ge::GetParserContext().net_out_nodes;
ASSERT_EQ(net_out_name.size(), 1);
EXPECT_EQ(net_out_name.at(0), "Conv_0:0:y");
}

TEST_F(UtestOnnxParser, onnx_parser_user_output_with_default) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/onnx_model/if.onnx";
std::string model_file = case_dir + "/onnx_model/conv2d.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
auto output_nodes_info = compute_graph->GetGraphOutNodesInfo();
ASSERT_EQ(output_nodes_info.size(), 1);
EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0");
EXPECT_EQ((output_nodes_info.at(0).second), 0);
auto &net_out_name = ge::GetParserContext().net_out_nodes;
ASSERT_EQ(net_out_name.size(), 1);
EXPECT_EQ(net_out_name.at(0), "Conv_0:0:y");
}

TEST_F(UtestOnnxParser, onnx_parser_user_output_with_tensor_failed) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/onnx_model/conv2d.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
parser_params.insert({AscendString(ge::ir_option::OUT_NODES), AscendString("not_exist_output")});
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, domi::SUCCESS);
EXPECT_EQ(ret, FAILED);
}

} // namespace ge

+ 4
- 2
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -24,12 +24,14 @@
#include "register/op_registry.h"
#include "parser/common/register_tbe.h"
#include "external/parser/tensorflow_parser.h"
#include "ut/parser/parser_ut_utils.h"

namespace ge {
class UtestTensorflowParser : public testing::Test {
protected:
void SetUp() {}
void SetUp() {
ParerUTestsUtils::ClearParserInnerCtx();
}

void TearDown() {}



Loading…
Cancel
Save