Browse Source

Merge remote-tracking branch 'refs/remotes/origin/ge_dev' into ge_dev

pull/687/head
13291271729 3 years ago
parent
commit
b1cf2e0407
54 changed files with 643 additions and 513 deletions
  1. +1
    -0
      inc/external/parser/tensorflow_parser.h
  2. +1
    -1
      metadef
  3. +4
    -4
      parser/caffe/caffe_parser.cc
  4. +2
    -2
      parser/caffe/caffe_parser.h
  5. +24
    -26
      parser/common/acl_graph_parser_util.cc
  6. +6
    -27
      parser/common/acl_graph_parser_util.h
  7. +1
    -1
      parser/common/auto_mapping_subgraph_io_index_func.h
  8. +1
    -1
      parser/common/convert/pb2json.cc
  9. +399
    -399
      parser/common/parser_types.cc
  10. +4
    -4
      parser/onnx/onnx_parser.cc
  11. +2
    -2
      parser/tensorflow/tensorflow_parser.cc
  12. +4
    -0
      tests/st/testcase/origin_models/conv2d_depthwise_pb_gen.py
  13. +4
    -0
      tests/st/testcase/origin_models/onnx_clip_v9.py
  14. +5
    -1
      tests/st/testcase/origin_models/onnx_if_const_intput_gen.py
  15. +4
    -0
      tests/st/testcase/origin_models/tensor_array_pb_gen.py
  16. +5
    -1
      tests/st/testcase/origin_models/test_VarIsInitializedOp_pb_gen.py
  17. +5
    -1
      tests/st/testcase/origin_models/test_avgpool3dgrad_pb_gen.py
  18. +4
    -0
      tests/st/testcase/origin_models/test_blocklstm_pb.gen.py
  19. +4
    -0
      tests/st/testcase/origin_models/test_constant_pb_gen.py
  20. +5
    -1
      tests/st/testcase/origin_models/test_conv2d_pb_gen.py
  21. +5
    -1
      tests/st/testcase/origin_models/test_enter_pb_gen.py
  22. +5
    -1
      tests/st/testcase/origin_models/test_fill_pb_gen.py
  23. +5
    -1
      tests/st/testcase/origin_models/test_identity_pb_gen.py
  24. +4
    -0
      tests/st/testcase/origin_models/test_merge_pb_gen.py
  25. +5
    -1
      tests/st/testcase/origin_models/test_no_op_pb_gen.py
  26. +5
    -1
      tests/st/testcase/origin_models/test_reshape_pb_gen.py
  27. +5
    -1
      tests/st/testcase/origin_models/test_sequeeze_pb_gen.py
  28. +5
    -1
      tests/st/testcase/origin_models/test_shape_n_pb_gen.py
  29. +5
    -1
      tests/st/testcase/origin_models/test_switch_pb_gen.py
  30. +5
    -1
      tests/st/testcase/origin_models/test_variableV2_pb_gen.py
  31. +2
    -2
      tests/st/testcase/test_caffe_parser.cc
  32. +7
    -7
      tests/st/testcase/test_tensorflow_parser.cc
  33. +4
    -4
      tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc
  34. +1
    -1
      tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc
  35. +4
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/onnx_clip_v9.py
  36. +4
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/conv2d_depthwise_pb_gen.py
  37. +4
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/onnx_clip_v9.py
  38. +4
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/tensor_array_pb_gen.py
  39. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_VarIsInitializedOp_pb_gen.py
  40. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_avgpool3dgrad_pb_gen.py
  41. +4
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_blocklstm_pb.gen.py
  42. +4
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_constant_pb_gen.py
  43. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_conv2d_pb_gen.py
  44. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_enter_pb_gen.py
  45. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_fill_pb_gen.py
  46. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_identity_pb_gen.py
  47. +4
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_merge_pb_gen.py
  48. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_no_op_pb_gen.py
  49. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_reshape_pb_gen.py
  50. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_sequeeze_pb_gen.py
  51. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_shape_n_pb_gen.py
  52. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_switch_pb_gen.py
  53. +5
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_variableV2_pb_gen.py
  54. +7
    -7
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 1
- 0
inc/external/parser/tensorflow_parser.h View File

@@ -35,6 +35,7 @@
#include <memory>
#include <string>
#include <vector>
#include <map>
#include "graph/ascend_string.h"
#include "graph/ge_error_codes.h"
#include "graph/graph.h"


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit f1af97e1c9ce9164901d4e719d3acaa1b8597d14
Subproject commit ecdc591e4ffb87609be93e2f630c82098586ebc2

+ 4
- 4
parser/caffe/caffe_parser.cc View File

@@ -86,7 +86,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::CAFFE)));

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
domi::Status status = acl_graph_parse_util.AclParserInitialize(options);
if (status != domi::SUCCESS) {
REPORT_CALL_ERROR("E19999", "AclParserInitialize failed, ret:%d.", status);
@@ -144,7 +144,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::CAFFE)));

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
domi::Status status = acl_graph_parse_util.AclParserInitialize(options);
if (status != domi::SUCCESS) {
REPORT_CALL_ERROR("E19999", "AclParserInitialize failed, ret:%d.", status);
@@ -429,7 +429,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons
}

Status CaffeModelParser::CustomProtoParse(const char *model_path, const string &custom_proto,
const string &caffe_proto, vector<ge::Operator> &operators) {
const string &caffe_proto, vector<ge::Operator> &operators) const {
(void)caffe_proto;
string custom_proto_path = ge::parser::RealPath(custom_proto.c_str());
if (custom_proto_path.empty()) {
@@ -1904,7 +1904,7 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto
}

Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message &message,
google::protobuf::Message *layer) {
google::protobuf::Message *layer) const {
const google::protobuf::Reflection *layer_reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc;


+ 2
- 2
parser/caffe/caffe_parser.h View File

@@ -168,7 +168,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
* @return FAILED parse failed
*/
Status CustomProtoParse(const char *model_path, const string &custom_proto, const string &caffe_proto,
std::vector<ge::Operator> &operators);
std::vector<ge::Operator> &operators) const;

/*
* @ingroup domi_omg
@@ -396,7 +396,7 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser {
Status CheckLayersSize(const google::protobuf::Message &message) const;

Status ConvertLayerProto(const google::protobuf::Message &message,
google::protobuf::Message *layer);
google::protobuf::Message *layer) const;

Status ParseLayerField(const google::protobuf::Reflection &reflection,
const google::protobuf::Message &message,


+ 24
- 26
parser/common/acl_graph_parser_util.cc View File

@@ -149,7 +149,7 @@ static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) {
return domi::SUCCESS;
}

domi::Status AclGrphParseUtil::LoadOpsProtoLib() {
domi::Status AclGraphParseUtil::LoadOpsProtoLib() {
string opsproto_path;
ge::Status ret = ge::TBEPluginLoader::GetOpsProtoPath(opsproto_path);
if (ret != ge::SUCCESS) {
@@ -170,7 +170,7 @@ domi::Status AclGrphParseUtil::LoadOpsProtoLib() {
return SUCCESS;
}

void AclGrphParseUtil::SaveCustomCaffeProtoPath() {
void AclGraphParseUtil::SaveCustomCaffeProtoPath() {
GELOGD("Enter save custom caffe proto path.");
std::string path_base = GetSoPath();
path_base = path_base.substr(0, path_base.rfind('/'));
@@ -192,7 +192,7 @@ void AclGrphParseUtil::SaveCustomCaffeProtoPath() {

// Initialize PARSER, load custom op plugin
// options will be used later for parser decoupling
domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, std::string> &options) {
domi::Status AclGraphParseUtil::AclParserInitialize(const std::map<std::string, std::string> &options) {
GELOGT(TRACE_INIT, "AclParserInitialize start");
// check init status
if (parser_initialized) {
@@ -240,7 +240,7 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s
return SUCCESS;
}

void AclGrphParseUtil::SetDefaultFormat() {
void AclGraphParseUtil::SetDefaultFormat() {
if (ge::GetParserContext().type == domi::TENSORFLOW) {
ge::GetParserContext().format = domi::DOMI_TENSOR_NHWC;
} else {
@@ -248,7 +248,7 @@ void AclGrphParseUtil::SetDefaultFormat() {
}
}

domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) const {
domi::Status AclGraphParseUtil::ParseAclOutputNodes(const string &out_nodes) const {
try {
ge::GetParserContext().out_nodes_map.clear();
ge::GetParserContext().user_out_nodes.clear();
@@ -323,7 +323,7 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) cons
return SUCCESS;
}

domi::Status AclGrphParseUtil::ParseAclOutputFp16NodesFormat(const string &is_output_fp16) const {
domi::Status AclGraphParseUtil::ParseAclOutputFp16NodesFormat(const string &is_output_fp16) const {
if (is_output_fp16.empty()) {
return SUCCESS;
}
@@ -347,7 +347,7 @@ domi::Status AclGrphParseUtil::ParseAclOutputFp16NodesFormat(const string &is_ou
return SUCCESS;
}

domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) const {
domi::Status AclGraphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) const {
ge::GetParserContext().enable_scope_fusion_passes.clear();
if (enable_scope_fusion_passes.empty()) {
return SUCCESS;
@@ -356,8 +356,8 @@ domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fu
return SUCCESS;
}

void AclGrphParseUtil::AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec,
const string &fp16_nodes_name, size_t index, OpDescPtr &op_desc) {
void AclGraphParseUtil::AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec,
const string &fp16_nodes_name, size_t index, OpDescPtr &op_desc) {
if (AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) {
if ((index < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[index] == "true")) {
GELOGI("This node [%s] should be set NC1HWC0", fp16_nodes_name.c_str());
@@ -368,8 +368,8 @@ void AclGrphParseUtil::AddAttrsForInputNodes(const vector<string> &adjust_fp16_f
}
}

domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes,
const string &is_input_adjust_hw_layout) const {
domi::Status AclGraphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes,
const string &is_input_adjust_hw_layout) const {
GE_CHECK_NOTNULL(graph);
vector<string> adjust_fp16_format_vec;
if (!is_input_adjust_hw_layout.empty()) {
@@ -411,7 +411,7 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra
return SUCCESS;
}

domi::Status AclGrphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph,
domi::Status AclGraphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph,
const std::string &input_data_names) const {
std::vector<std::string> input_names = StringUtils::Split(input_data_names, ',');
std::unordered_map<std::string, size_t> name_to_index;
@@ -446,8 +446,8 @@ domi::Status AclGrphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGrap
return SUCCESS;
}

void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name) const {
void AclGraphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name) const {
output_nodes_name.clear();
auto &out_tensor_names = ge::GetParserContext().out_tensor_names;
if (out_tensor_names.empty()) {
@@ -478,8 +478,8 @@ void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr,
}
}

domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const {
domi::Status AclGraphParseUtil::GetOutputLeaf(NodePtr node,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const {
ge::OpDescPtr tmpDescPtr = node->GetOpDesc();
if (tmpDescPtr == nullptr) {
REPORT_INNER_ERROR("E19999", "param node has no opdesc.");
@@ -508,7 +508,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node,
return SUCCESS;
}

domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
domi::Status AclGraphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const {
std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes;
if (!default_out_nodes.empty()) {
@@ -531,8 +531,8 @@ domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_gr
return domi::SUCCESS;
}

domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph,
const std::map<AscendString, AscendString> &parser_params) {
domi::Status AclGraphParseUtil::SetOutputNodeInfo(ge::Graph &graph,
const std::map<AscendString, AscendString> &parser_params) const {
(void)parser_params;
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
@@ -588,7 +588,7 @@ domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph,
return domi::SUCCESS;
}

domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) const {
domi::Status AclGraphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) const {
for (auto &ele : parser_params) {
const char *key_ascend = ele.first.GetString();
if (key_ascend == nullptr) {
@@ -609,8 +609,8 @@ domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendS
return SUCCESS;
}

domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
string &graph_name) {
domi::Status AclGraphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
string &graph_name) const {
GELOGI("Parse graph user options start.");
ge::GetParserContext().input_nodes_format_map.clear();
ge::GetParserContext().output_formats.clear();
@@ -663,8 +663,8 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin
return SUCCESS;
}

domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,
const std::map<AscendString, AscendString> &parser_params) const {
domi::Status AclGraphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,
const std::map<AscendString, AscendString> &parser_params) const {
// support paragrams: input_fp16_nodes, is_input_adjust_hw_layout,
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
@@ -938,12 +938,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const cha
return ret;
}

///
/// @brief get the Original Type of FrameworkOp
/// @param [in] node
/// @param [out] type
/// @return Status
///
Status GetOriginalType(const ge::NodePtr &node, string &type) {
GE_CHECK_NOTNULL(node);
type = node->GetType();


+ 6
- 27
parser/common/acl_graph_parser_util.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -34,16 +34,16 @@ namespace ge {

using google::protobuf::Message;

class AclGrphParseUtil {
class AclGraphParseUtil {
public:
AclGrphParseUtil() {}
virtual ~AclGrphParseUtil() {}
AclGraphParseUtil() {}
virtual ~AclGraphParseUtil() {}
static domi::Status LoadOpsProtoLib();
static void SaveCustomCaffeProtoPath();
domi::Status AclParserInitialize(const std::map<std::string, std::string> &options);
domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params);
domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params) const;
domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
std::string &graph_name);
std::string &graph_name) const;
domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString,
AscendString> &parser_params) const;

@@ -67,31 +67,23 @@ class AclGrphParseUtil {
};

namespace parser {
///
/// @ingroup: domi_common
/// @brief: get length of file
/// @param [in] input_file: path of file
/// @return long: File length. If the file length fails to be obtained, the value -1 is returned.
///
extern long GetFileLength(const std::string &input_file);

///
/// @ingroup domi_common
/// @brief Absolute path for obtaining files.
/// @param [in] path of input file
/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned
///
std::string RealPath(const char *path);

///
/// @ingroup domi_common
/// @brief Obtains the absolute time (timestamp) of the current system.
/// @return Timestamp, in microseconds (US)
///
///
uint64_t GetCurrentTimestamp();

///
/// @ingroup domi_common
/// @brief Reads all data from a binary file.
/// @param [in] file_name path of file
@@ -99,20 +91,16 @@ uint64_t GetCurrentTimestamp();
/// @param [out] length Output memory size
/// @return false fail
/// @return true success
///
bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length);

///
/// @ingroup domi_common
/// @brief proto file in bianary format
/// @param [in] file path of proto file
/// @param [out] proto memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromBinaryFile(const char *file, Message *proto);

///
/// @ingroup domi_common
/// @brief Reads the proto structure from an array.
/// @param [in] data proto data to be read
@@ -120,42 +108,33 @@ bool ReadProtoFromBinaryFile(const char *file, Message *proto);
/// @param [out] proto Memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromArray(const void *data, int size, Message *proto);

///
/// @ingroup domi_proto
/// @brief Reads the proto file in the text format.
/// @param [in] file path of proto file
/// @param [out] message Memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromText(const char *file, google::protobuf::Message *message);

bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message);

///
/// @brief get the Original Type of FrameworkOp
/// @param [in] node
/// @param [out] type
/// @return Status
///
domi::Status GetOriginalType(const ge::NodePtr &node, string &type);

///
/// @ingroup domi_common
/// @brief Check whether the file path meets the whitelist verification requirements.
/// @param [in] filePath file path
/// @param [out] result
///
bool ValidateStr(const std::string &filePath, const std::string &mode);

///
/// @ingroup domi_common
/// @brief Obtains the current time string.
/// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555
///
std::string CurrentTimeInStr();

template <typename T, typename... Args>


+ 1
- 1
parser/common/auto_mapping_subgraph_io_index_func.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2019~2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.


+ 1
- 1
parser/common/convert/pb2json.cc View File

@@ -178,7 +178,7 @@ void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFie
switch (field->type()) {
case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i);
if (0UL != tmp_message.ByteSizeLong()) {
if (tmp_message.ByteSizeLong() != 0UL) {
Message2Json(tmp_message, black_fields, tmp_json, enum2str, depth + 1);
}
} break;


+ 399
- 399
parser/common/parser_types.cc View File

@@ -18,426 +18,426 @@

namespace ge {
namespace parser {
const char *DATA = "Data";
const char *AIPPDATA = "AippData";
const char *CONVOLUTION = "Convolution";
const char *CORRELATION = "Correlation";
const char *CORRELATIONV2 = "Correlation_V2";
const char *DECONVOLUTION = "Deconvolution";
const char *POOLING = "Pooling";
const char *ELTWISE = "Eltwise";
const char *RELU = "ReLU";
const char *RELU6 = "ReLU6";
const char *SIGMOID = "Sigmoid";
const char *ABSVAL = "AbsVal";
const char *TANH = "TanH";
const char *PRELU = "PReLU";
const char *BATCHNORM = "BatchNorm";
const char *FUSIONBATCHNORM = "FusionBatchNorm";
const char *SCALE = "Scale";
const char *FULL_CONNECTION = "FullConnection";
const char *SOFTMAX = "Softmax";
const char *PLUS = "Plus";
const char *ACTIVATION = "Activation";
const char *FLATTEN = "Flatten";
const char *ADD = "Add";
const char *SUB = "Sub";
const char *MUL = "Mul";
const char *MATMUL = "MatMul";
const char *RSQRT = "Rsqrt";
const char *BIASADD = "BiasAdd";
const char *RESHAPE = "Reshape";
const char *REFORMAT = "ReFormat";
const char *DEPCONVOLUTION = "ConvolutionDepthwise";
const char *DROPOUT = "Dropout";
const char *DROPOUTGENMASK = "DropOutGenMask";
const char *DROPOUTDOMASK = "DropOutDoMask";
const char *CONCAT = "Concat";
const char *ROIPOOLING = "ROIPooling";
const char *PROPOSAL = "Proposal";
const char *FSRDETECTIONOUTPUT = "FSRDetectionOutput";
const char *DETECTIONPOSTPROCESS = "Detectpostprocess";
const char *LRN = "LRN";
const char *TRANSDATA = "TransData";
const char *PERMUTE = "Permute";
const char *SSDNORMALIZE = "SSDNormalize";
const char *SSDPRIORBOX = "SSDPriorBox";
const char *NETOUTPUT = "NetOutput";
const char *SSDDETECTIONOUTPUT = "SSDDetectionOutput";
const char *REFINEDETDETECTIONOUTPUT = "RefinedetDetectionOutput";
const char *CHANNELAXPY = "ChannelAxpy";
const char *PSROIPOOLING = "PSROIPooling";
const char *POWER = "Power";
const char *POW = "Pow";
const char *ROIALIGN = "ROIAlign";
const char *PYTHON = "Python";
const char *FREESPACEEXTRACT = "FreespaceExtract";
const char *SPATIALTF = "SpatialTransform";
const char *SHAPE = "Shape";
const char *SHAPEN = "ShapeN";
const char *ARGMAX = "ArgMax";
const char *GATHERND = "GatherNd";
const char *GATHER = "Gather";
const char *REALDIV = "RealDiv";
const char *PACK = "Pack";
const char *SLICE = "Slice";
const char *SLICED = "SliceD";
const char *FLOORDIV = "FloorDiv";
const char *SQUEEZE = "Squeeze";
const char *UNSQUEEZE = "Unsqueeze";
const char *STRIDEDSLICE = "StridedSlice";
const char *RANGE = "Range";
const char *RPNPROPOSALS = "RpnProposals";
const char *DECODEBBOX = "DecodeBbox";
const char *PAD = "Pad";
const char *PADV2 = "PadV2";
const char *MIRRORPAD = "MirrorPad";
const char *TILE = "Tile";
const char *SIZE = "Size";
const char *CLIPBOXES = "ClipBoxes";
const char *FASTRCNNPREDICTIONS = "FastrcnnPredictions";
const char *SPLIT = "Split";
const char *SPLITV = "SplitV";
const char *EXPANDDIMS = "ExpandDims";
const char *EMPTY = "Empty";
const char *MEAN = "Mean";
const char *GREATER = "Greater";
const char *SWITCH = "Switch";
const char *SWITCHN = "SwitchN";
const char *MERGE = "Merge";
const char *SYMBOLICGRADIENT = "SymbolicGradient";
const char *REMOTECALL = "RemoteCall";
const char *_IF = "_If";
const char *STATELESSIF = "StatelessIf";
const char *IF = "If";
const char *CASE = "Case";
const char *_WHILE = "_While";
const char *WHILE = "While";
const char *STATELESSWHILE = "StatelessWhile";
const char *FOR = "For";
const char *PARTITIONEDCALL = "PartitionedCall";
const char *STATEFULPARTITIONEDCALL = "StatefulPartitionedCall";
const char *FAKEPARAM = "FakeParam";
const char *TRANSPOSE = "Transpose";
const char *TRANSPOSED = "TransposeD";
const char *CAST = "Cast";
const char *REGION = "Region";
const char *YOLO = "Yolo";
const char *YOLODETECTIONOUTPUT = "YoloDetectionOutput";
const char *FILL = "Fill";
const char *REVERSE = "Reverse";
const char *UNPACK = "Unpack";
const char *YOLO2REORG = "Yolo2Reorg";
const char *REDUCESUM = "ReduceSum";
const char *SUM = "Sum";
const char *CONSTANT = "Const";
const char *FILECONSTANT = "FileConstant";
const char *RESIZEBILINEAR = "ResizeBilinear";
const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad";
const char *MAXIMUM = "Maximum";
const char *FRAMEWORKOP = "FrameworkOp";
const char *ARG = "_Arg";
const char *FUSEDBATCHNORMGRAD = "FusedBatchNormGrad";
const char *LSTM = "LSTM";
const char *HIGHWAY = "HighWay";
const char *RNN = "RNN";
const char *ATTENTIONDECODER = "AttentionDecoder";
const char *LOGICAL_NOT = "LogicalNot";
const char *LOGICAL_AND = "LogicalAnd";
const char *LOGICAL_OR = "LogicalOr";
const char *EQUAL = "Equal";
const char *NOTEQUAL = "NotEqual";
const char *INTERP = "Interp";
const char *SHUFFLECHANNEL = "ShuffleChannel";
const char *AIPP = "Aipp";
const char *MULTISHAPE = "MultiShape";
const char *RECIPROCAL = "Reciprocal";
const char *SELU = "Selu";
const char *ELU = "Elu";
const char *ACOSH = "Acosh";
const char *ASINH = "Asinh";
const char *MINIMUM = "Minimum";
const char *CLIP = "Clip";
const char *L2NORMALIZE = "L2Normalize";
const char *CROPANDRESIZE = "CropAndResize";
const char *UNUSEDCONST = "UnusedConst";
const char *SPARSETODENSE = "SparseToDense";
const char *NONMAXSUPPRESSION = "NonMaxSuppression";
const char *TOPKV2 = "TopKV2";
const char *INVERTPERMUTATION = "InvertPermutation";
const char *MULTINOMIAL = "Multinomial";
const char *REVERSESEQUENCE = "ReverseSequence";
const char *REDUCEPROD = "ReduceProd";
const char *REDUCEMAX = "ReduceMax";
const char *REDUCEMIN = "ReduceMin";
const char *EXTRACTIMAGEPATCHES = "ExtractImagePatches";
const char *SQRT = "Sqrt";
const char *REDUCEALL = "ReduceAll";
const char *RESIZENEARESTNEIGHBOR = "ResizeNearestNeighbor";
const char *SPACETOBATCHND = "SpaceToBatchND";
const char *BATCHTOSPACEND = "BatchToSpaceND";
const char *ASSERT = "Assert";
const char *GREATEREQUAL = "GreaterEqual";
const char *FLOOR = "Floor";
const char *RANDOMUNIFORM = "RandomUniform";
const char *BATCHMATMUL = "BatchMatMul";
const char *SPACETODEPTH = "SpaceToDepth";
const char *DEPTHTOSPACE = "DepthToSpace";
const char *RINT = "Rint";
const char *ATAN = "Atan";
const char *ATAN2 = "Atan2";
const char *ATANH = "Atanh";
const char *ACOS = "Acos";
const char *ASIN = "Asin";
const char *NEG = "Neg";
const char *LOG = "Log";
const char *TAN = "Tan";
const char *ROUND = "Round";
const char *UPSAMPLE = "Upsample";
const char *FLOORMOD = "FloorMod";
const char *LESS = "Less";
const char *LESSEQUAL = "LessEqual";
const char *ONEHOT = "OneHot";
const char *REFSWITCH = "RefSwitch";
const char *REFMERGE = "RefMerge";
const char *ENTER = "Enter";
const char *REFENTER = "RefEnter";
const char *LOOPCOND = "LoopCond";
const char *NEXTITERATION = "NextIteration";
const char *REFNEXTITERATION = "RefNextIteration";
const char *EXIT = "Exit";
const char *REFEXIT = "RefExit";
const char *CONTROLTRIGGER = "ControlTrigger";
const char *ZEROSLIKE = "ZerosLike";
const char *EXP = "Exp";
const char *WHERE = "Where";
const char *FAKEQUANTWITHMINMAXVARS = "FakeQuantWithMinMaxVars";
const char *SOFTPLUS = "Softplus";
const char *SOFTSIGN = "Softsign";
const char *COSH = "Cosh";
const char *SINH = "Sinh";
const char *SQUAREDDIFFERENCE = "SquaredDifference";
const char *REQUIREDSPACETOBATCHPADDINGS = "RequiredSpaceToBatchPaddings"; // for retinanet scope fusion
const char *SSDPOSTPROCESSOR = "SSDPostProcessor";
const char *RETINANETBOXES = "RetinanetBoxes";
const char *RETINAMULTIANCHORS = "RetinaMultiAnchor";
const char *RETINANETCLIPPEDBOXES = "RetinanetClippedBoxes";
const char *RETINANETFILTEREDDETECTIONS = "RetinanetFilteredDetections";
const char *RETINANETPOSTPROCESSOR = "RetinanetPostProcessor";
const char *RETINANETANCHORS = "RetinanetAnchors";
const char *FASTERRCNNMAP = "FasterRCNNMap";
const char *FASTERRCNNMAP1 = "FasterRCNNMap1";
const char *FASTERRCNNSECONDSTAGEPOSTPROCESSOR = "FasterRCNNSecondStagePostprocessor";
const char *FASTERRCNNROIINTERPOOLING = "FasterRCNNROIInterPooling";
const char *FASTERRCNNFIRSTSTAGEPOSTPROCESSOR = "FasterRCNNFirstStagePostprocessor";
const char *FASTERRCNNGRIDANCHORGENERATOR = "FasterRCNNGridAnchorGenerator";
const char *ROIINTERPOOLING = "ROIInterPooling";
const char *FASTERRCNNCLIPTOWINDOW = "FasterRCNNClipToWindow";
const char *EMBEDLOOKUP = "EmbedLookup";
const char *HASHLOOKUP = "HashLookup";
const char *LSH_PROJ = "LshProject";
const char *SVDF = "SVDF";
const char *SSDANCHORGENERATOR = "SSDAnchorGenerator";
const char *IDENTITY = "Identity";
const char *IDENTITYN = "IdentityN";
const char *PLACEHOLDERWITHDEFAULT = "PlaceholderWithDefault";
const char *SELECT = "Select";
const char *GETSPAN = "GetSpan";
const char *STOPGRADIENT = "StopGradient";
const char *PREVENTGRADIENT = "PreventGradient";
const char *GUARANTEECONST = "GuaranteeConst";
const char *BROADCASTGRADIENTARGS = "BroadcastGradientArgs";
const char *BROADCASTARGS = "BroadcastArgs";
const char *CONFUSIONMATRIX = "ConfusionMatrix";
const char *RANK = "Rank";
const char *PLACEHOLDER = "PlaceHolder";
const char *END = "End";
const char *BASICLSTMCELL = "BasicLSTMCell";
const char *GETNEXT = "GetNext";
const char *INITDATA = "InitData";
const char *REFIDENTITY = "RefIdentity";
const char *BITCAST = "Bitcast";
const char_t * const DATA = "Data";
const char_t * const AIPPDATA = "AippData";
const char_t * const CONVOLUTION = "Convolution";
const char_t * const CORRELATION = "Correlation";
const char_t * const CORRELATIONV2 = "Correlation_V2";
const char_t * const DECONVOLUTION = "Deconvolution";
const char_t * const POOLING = "Pooling";
const char_t * const ELTWISE = "Eltwise";
const char_t * const RELU = "ReLU";
const char_t * const RELU6 = "ReLU6";
const char_t * const SIGMOID = "Sigmoid";
const char_t * const ABSVAL = "AbsVal";
const char_t * const TANH = "TanH";
const char_t * const PRELU = "PReLU";
const char_t * const BATCHNORM = "BatchNorm";
const char_t * const FUSIONBATCHNORM = "FusionBatchNorm";
const char_t * const SCALE = "Scale";
const char_t * const FULL_CONNECTION = "FullConnection";
const char_t * const SOFTMAX = "Softmax";
const char_t * const PLUS = "Plus";
const char_t * const ACTIVATION = "Activation";
const char_t * const FLATTEN = "Flatten";
const char_t * const ADD = "Add";
const char_t * const SUB = "Sub";
const char_t * const MUL = "Mul";
const char_t * const MATMUL = "MatMul";
const char_t * const RSQRT = "Rsqrt";
const char_t * const BIASADD = "BiasAdd";
const char_t * const RESHAPE = "Reshape";
const char_t * const REFORMAT = "ReFormat";
const char_t * const DEPCONVOLUTION = "ConvolutionDepthwise";
const char_t * const DROPOUT = "Dropout";
const char_t * const DROPOUTGENMASK = "DropOutGenMask";
const char_t * const DROPOUTDOMASK = "DropOutDoMask";
const char_t * const CONCAT = "Concat";
const char_t * const ROIPOOLING = "ROIPooling";
const char_t * const PROPOSAL = "Proposal";
const char_t * const FSRDETECTIONOUTPUT = "FSRDetectionOutput";
const char_t * const DETECTIONPOSTPROCESS = "Detectpostprocess";
const char_t * const LRN = "LRN";
const char_t * const TRANSDATA = "TransData";
const char_t * const PERMUTE = "Permute";
const char_t * const SSDNORMALIZE = "SSDNormalize";
const char_t * const SSDPRIORBOX = "SSDPriorBox";
const char_t * const NETOUTPUT = "NetOutput";
const char_t * const SSDDETECTIONOUTPUT = "SSDDetectionOutput";
const char_t * const REFINEDETDETECTIONOUTPUT = "RefinedetDetectionOutput";
const char_t * const CHANNELAXPY = "ChannelAxpy";
const char_t * const PSROIPOOLING = "PSROIPooling";
const char_t * const POWER = "Power";
const char_t * const POW = "Pow";
const char_t * const ROIALIGN = "ROIAlign";
const char_t * const PYTHON = "Python";
const char_t * const FREESPACEEXTRACT = "FreespaceExtract";
const char_t * const SPATIALTF = "SpatialTransform";
const char_t * const SHAPE = "Shape";
const char_t * const SHAPEN = "ShapeN";
const char_t * const ARGMAX = "ArgMax";
const char_t * const GATHERND = "GatherNd";
const char_t * const GATHER = "Gather";
const char_t * const REALDIV = "RealDiv";
const char_t * const PACK = "Pack";
const char_t * const SLICE = "Slice";
const char_t * const SLICED = "SliceD";
const char_t * const FLOORDIV = "FloorDiv";
const char_t * const SQUEEZE = "Squeeze";
const char_t * const UNSQUEEZE = "Unsqueeze";
const char_t * const STRIDEDSLICE = "StridedSlice";
const char_t * const RANGE = "Range";
const char_t * const RPNPROPOSALS = "RpnProposals";
const char_t * const DECODEBBOX = "DecodeBbox";
const char_t * const PAD = "Pad";
const char_t * const PADV2 = "PadV2";
const char_t * const MIRRORPAD = "MirrorPad";
const char_t * const TILE = "Tile";
const char_t * const SIZE = "Size";
const char_t * const CLIPBOXES = "ClipBoxes";
const char_t * const FASTRCNNPREDICTIONS = "FastrcnnPredictions";
const char_t * const SPLIT = "Split";
const char_t * const SPLITV = "SplitV";
const char_t * const EXPANDDIMS = "ExpandDims";
const char_t * const EMPTY = "Empty";
const char_t * const MEAN = "Mean";
const char_t * const GREATER = "Greater";
const char_t * const SWITCH = "Switch";
const char_t * const SWITCHN = "SwitchN";
const char_t * const MERGE = "Merge";
const char_t * const SYMBOLICGRADIENT = "SymbolicGradient";
const char_t * const REMOTECALL = "RemoteCall";
const char_t * const _IF = "_If";
const char_t * const STATELESSIF = "StatelessIf";
const char_t * const IF = "If";
const char_t * const CASE = "Case";
const char_t * const _WHILE = "_While";
const char_t * const WHILE = "While";
const char_t * const STATELESSWHILE = "StatelessWhile";
const char_t * const FOR = "For";
const char_t * const PARTITIONEDCALL = "PartitionedCall";
const char_t * const STATEFULPARTITIONEDCALL = "StatefulPartitionedCall";
const char_t * const FAKEPARAM = "FakeParam";
const char_t * const TRANSPOSE = "Transpose";
const char_t * const TRANSPOSED = "TransposeD";
const char_t * const CAST = "Cast";
const char_t * const REGION = "Region";
const char_t * const YOLO = "Yolo";
const char_t * const YOLODETECTIONOUTPUT = "YoloDetectionOutput";
const char_t * const FILL = "Fill";
const char_t * const REVERSE = "Reverse";
const char_t * const UNPACK = "Unpack";
const char_t * const YOLO2REORG = "Yolo2Reorg";
const char_t * const REDUCESUM = "ReduceSum";
const char_t * const SUM = "Sum";
const char_t * const CONSTANT = "Const";
const char_t * const FILECONSTANT = "FileConstant";
const char_t * const RESIZEBILINEAR = "ResizeBilinear";
const char_t * const RESIZEBILINEARGRAD = "ResizeBilinearGrad";
const char_t * const MAXIMUM = "Maximum";
const char_t * const FRAMEWORKOP = "FrameworkOp";
const char_t * const ARG = "_Arg";
const char_t * const FUSEDBATCHNORMGRAD = "FusedBatchNormGrad";
const char_t * const LSTM = "LSTM";
const char_t * const HIGHWAY = "HighWay";
const char_t * const RNN = "RNN";
const char_t * const ATTENTIONDECODER = "AttentionDecoder";
const char_t * const LOGICAL_NOT = "LogicalNot";
const char_t * const LOGICAL_AND = "LogicalAnd";
const char_t * const LOGICAL_OR = "LogicalOr";
const char_t * const EQUAL = "Equal";
const char_t * const NOTEQUAL = "NotEqual";
const char_t * const INTERP = "Interp";
const char_t * const SHUFFLECHANNEL = "ShuffleChannel";
const char_t * const AIPP = "Aipp";
const char_t * const MULTISHAPE = "MultiShape";
const char_t * const RECIPROCAL = "Reciprocal";
const char_t * const SELU = "Selu";
const char_t * const ELU = "Elu";
const char_t * const ACOSH = "Acosh";
const char_t * const ASINH = "Asinh";
const char_t * const MINIMUM = "Minimum";
const char_t * const CLIP = "Clip";
const char_t * const L2NORMALIZE = "L2Normalize";
const char_t * const CROPANDRESIZE = "CropAndResize";
const char_t * const UNUSEDCONST = "UnusedConst";
const char_t * const SPARSETODENSE = "SparseToDense";
const char_t * const NONMAXSUPPRESSION = "NonMaxSuppression";
const char_t * const TOPKV2 = "TopKV2";
const char_t * const INVERTPERMUTATION = "InvertPermutation";
const char_t * const MULTINOMIAL = "Multinomial";
const char_t * const REVERSESEQUENCE = "ReverseSequence";
const char_t * const REDUCEPROD = "ReduceProd";
const char_t * const REDUCEMAX = "ReduceMax";
const char_t * const REDUCEMIN = "ReduceMin";
const char_t * const EXTRACTIMAGEPATCHES = "ExtractImagePatches";
const char_t * const SQRT = "Sqrt";
const char_t * const REDUCEALL = "ReduceAll";
const char_t * const RESIZENEARESTNEIGHBOR = "ResizeNearestNeighbor";
const char_t * const SPACETOBATCHND = "SpaceToBatchND";
const char_t * const BATCHTOSPACEND = "BatchToSpaceND";
const char_t * const ASSERT = "Assert";
const char_t * const GREATEREQUAL = "GreaterEqual";
const char_t * const FLOOR = "Floor";
const char_t * const RANDOMUNIFORM = "RandomUniform";
const char_t * const BATCHMATMUL = "BatchMatMul";
const char_t * const SPACETODEPTH = "SpaceToDepth";
const char_t * const DEPTHTOSPACE = "DepthToSpace";
const char_t * const RINT = "Rint";
const char_t * const ATAN = "Atan";
const char_t * const ATAN2 = "Atan2";
const char_t * const ATANH = "Atanh";
const char_t * const ACOS = "Acos";
const char_t * const ASIN = "Asin";
const char_t * const NEG = "Neg";
const char_t * const LOG = "Log";
const char_t * const TAN = "Tan";
const char_t * const ROUND = "Round";
const char_t * const UPSAMPLE = "Upsample";
const char_t * const FLOORMOD = "FloorMod";
const char_t * const LESS = "Less";
const char_t * const LESSEQUAL = "LessEqual";
const char_t * const ONEHOT = "OneHot";
const char_t * const REFSWITCH = "RefSwitch";
const char_t * const REFMERGE = "RefMerge";
const char_t * const ENTER = "Enter";
const char_t * const REFENTER = "RefEnter";
const char_t * const LOOPCOND = "LoopCond";
const char_t * const NEXTITERATION = "NextIteration";
const char_t * const REFNEXTITERATION = "RefNextIteration";
const char_t * const EXIT = "Exit";
const char_t * const REFEXIT = "RefExit";
const char_t * const CONTROLTRIGGER = "ControlTrigger";
const char_t * const ZEROSLIKE = "ZerosLike";
const char_t * const EXP = "Exp";
const char_t * const WHERE = "Where";
const char_t * const FAKEQUANTWITHMINMAXVARS = "FakeQuantWithMinMaxVars";
const char_t * const SOFTPLUS = "Softplus";
const char_t * const SOFTSIGN = "Softsign";
const char_t * const COSH = "Cosh";
const char_t * const SINH = "Sinh";
const char_t * const SQUAREDDIFFERENCE = "SquaredDifference";
const char_t * const REQUIREDSPACETOBATCHPADDINGS = "RequiredSpaceToBatchPaddings"; // for retinanet scope fusion
const char_t * const SSDPOSTPROCESSOR = "SSDPostProcessor";
const char_t * const RETINANETBOXES = "RetinanetBoxes";
const char_t * const RETINAMULTIANCHORS = "RetinaMultiAnchor";
const char_t * const RETINANETCLIPPEDBOXES = "RetinanetClippedBoxes";
const char_t * const RETINANETFILTEREDDETECTIONS = "RetinanetFilteredDetections";
const char_t * const RETINANETPOSTPROCESSOR = "RetinanetPostProcessor";
const char_t * const RETINANETANCHORS = "RetinanetAnchors";
const char_t * const FASTERRCNNMAP = "FasterRCNNMap";
const char_t * const FASTERRCNNMAP1 = "FasterRCNNMap1";
const char_t * const FASTERRCNNSECONDSTAGEPOSTPROCESSOR = "FasterRCNNSecondStagePostprocessor";
const char_t * const FASTERRCNNROIINTERPOOLING = "FasterRCNNROIInterPooling";
const char_t * const FASTERRCNNFIRSTSTAGEPOSTPROCESSOR = "FasterRCNNFirstStagePostprocessor";
const char_t * const FASTERRCNNGRIDANCHORGENERATOR = "FasterRCNNGridAnchorGenerator";
const char_t * const ROIINTERPOOLING = "ROIInterPooling";
const char_t * const FASTERRCNNCLIPTOWINDOW = "FasterRCNNClipToWindow";
const char_t * const EMBEDLOOKUP = "EmbedLookup";
const char_t * const HASHLOOKUP = "HashLookup";
const char_t * const LSH_PROJ = "LshProject";
const char_t * const SVDF = "SVDF";
const char_t * const SSDANCHORGENERATOR = "SSDAnchorGenerator";
const char_t * const IDENTITY = "Identity";
const char_t * const IDENTITYN = "IdentityN";
const char_t * const PLACEHOLDERWITHDEFAULT = "PlaceholderWithDefault";
const char_t * const SELECT = "Select";
const char_t * const GETSPAN = "GetSpan";
const char_t * const STOPGRADIENT = "StopGradient";
const char_t * const PREVENTGRADIENT = "PreventGradient";
const char_t * const GUARANTEECONST = "GuaranteeConst";
const char_t * const BROADCASTGRADIENTARGS = "BroadcastGradientArgs";
const char_t * const BROADCASTARGS = "BroadcastArgs";
const char_t * const CONFUSIONMATRIX = "ConfusionMatrix";
const char_t * const RANK = "Rank";
const char_t * const PLACEHOLDER = "PlaceHolder";
const char_t * const END = "End";
const char_t * const BASICLSTMCELL = "BasicLSTMCell";
const char_t * const GETNEXT = "GetNext";
const char_t * const INITDATA = "InitData";
const char_t * const REFIDENTITY = "RefIdentity";
const char_t * const BITCAST = "Bitcast";

/***************Ann special operator*************************/
const char *ANN_MEAN = "AnnMean";
const char *ANN_CONVOLUTION = "AnnConvolution";
const char *ANN_DEPCONVOLUTION = "AnnDepthConv";
const char *ANN_FULLCONNECTION = "AnnFullConnection";
const char *ANN_NETOUTPUT = "AnnNetOutput";
const char *ANN_DATA = "AnnData";
const char *ANN_RESHAPE = "AnnReshape";
const char *ANN_ADD = "AnnAdd";
const char *ANN_MUL = "AnnMul";
const char *ANN_SUB = "AnnSub";
const char *ANN_DIV = "AnnDiv";
const char *ANN_DEQUANTIZE = "AnnDequant";
const char *ANN_QUANTIZE = "AnnQuant";
const char *ANN_PAD = "AnnPad";
const char *ANN_RESIZE_BILINEAR = "AnnResizeBilinear";
const char_t * const ANN_MEAN = "AnnMean";
const char_t * const ANN_CONVOLUTION = "AnnConvolution";
const char_t * const ANN_DEPCONVOLUTION = "AnnDepthConv";
const char_t * const ANN_FULLCONNECTION = "AnnFullConnection";
const char_t * const ANN_NETOUTPUT = "AnnNetOutput";
const char_t * const ANN_DATA = "AnnData";
const char_t * const ANN_RESHAPE = "AnnReshape";
const char_t * const ANN_ADD = "AnnAdd";
const char_t * const ANN_MUL = "AnnMul";
const char_t * const ANN_SUB = "AnnSub";
const char_t * const ANN_DIV = "AnnDiv";
const char_t * const ANN_DEQUANTIZE = "AnnDequant";
const char_t * const ANN_QUANTIZE = "AnnQuant";
const char_t * const ANN_PAD = "AnnPad";
const char_t * const ANN_RESIZE_BILINEAR = "AnnResizeBilinear";

/***************************************************/
/******************Training operator*************************/
const char *GATHERV2 = "GatherV2";
const char *CONVGRADFILTER = "Conv2DBackpropFilter";
const char *CONV2D = "Conv2D";
const char *CONV2DBACKPROPINPUT = "Conv2DBackpropInput";
const char *FUSEDBATCHNORM = "FusedBatchNorm";
const char *BIASADDGRAD = "BiasAddGrad";
const char *ACTIVATIONGRAD = "ReluGrad";
const char *MAXPOOLWITHARGMAX = "MaxPoolWithArgmax";
const char *MAXPOOLGRADWITHARGMAX = "MaxPoolGradWithArgmax";
const char *SPARSESOFTMAXCROSSENTROPYWITHLOGITS = "SparseSoftmaxCrossEntropyWithLogits";
const char *SNAPSHOT = "Snapshot";
const char *VAR = "Var";
const char *MEANGRAD = "MeanGrad";
const char *TRANSLATE = "Translate";
const char *ADDN = "AddN";
const char *L2LOSS = "L2Loss";
const char *MULTIPLY = "Multiply";
const char *HUBERLOSSGRAD = "HuberLossGrad";
const char *HUBERLOSS = "HuberLoss";
const char *NEGATIVE = "Negative";
const char *SSDCAST = "SSDCast";
const char *SPARSESOFTMAXCROSSENTROPY = "SsdSparseSoftmaxCrossEntropy";
const char *SPARSESOFTMAXCROSSENTROPYGRAD = "SsdSparseSoftmaxCrossEntropyGrad";
const char *SSDSQUEEZEFUSION = "SsdSqueezeFusion";
const char *CONCATFOUR2FIVE = "ConcatFour2Five";
const char *CONCATFIVE2FOUR = "ConcatFive2Four";
const char *SSDREALDIVTILEMUL = "SSDRealdivTileMul";
const char *SSDSUMMULREALDIVMEAN = "SSDSumMulRealdivMean";
const char_t * const GATHERV2 = "GatherV2";
const char_t * const CONVGRADFILTER = "Conv2DBackpropFilter";
const char_t * const CONV2D = "Conv2D";
const char_t * const CONV2DBACKPROPINPUT = "Conv2DBackpropInput";
const char_t * const FUSEDBATCHNORM = "FusedBatchNorm";
const char_t * const BIASADDGRAD = "BiasAddGrad";
const char_t * const ACTIVATIONGRAD = "ReluGrad";
const char_t * const MAXPOOLWITHARGMAX = "MaxPoolWithArgmax";
const char_t * const MAXPOOLGRADWITHARGMAX = "MaxPoolGradWithArgmax";
const char_t * const SPARSESOFTMAXCROSSENTROPYWITHLOGITS = "SparseSoftmaxCrossEntropyWithLogits";
const char_t * const SNAPSHOT = "Snapshot";
const char_t * const VAR = "Var";
const char_t * const MEANGRAD = "MeanGrad";
const char_t * const TRANSLATE = "Translate";
const char_t * const ADDN = "AddN";
const char_t * const L2LOSS = "L2Loss";
const char_t * const MULTIPLY = "Multiply";
const char_t * const HUBERLOSSGRAD = "HuberLossGrad";
const char_t * const HUBERLOSS = "HuberLoss";
const char_t * const NEGATIVE = "Negative";
const char_t * const SSDCAST = "SSDCast";
const char_t * const SPARSESOFTMAXCROSSENTROPY = "SsdSparseSoftmaxCrossEntropy";
const char_t * const SPARSESOFTMAXCROSSENTROPYGRAD = "SsdSparseSoftmaxCrossEntropyGrad";
const char_t * const SSDSQUEEZEFUSION = "SsdSqueezeFusion";
const char_t * const CONCATFOUR2FIVE = "ConcatFour2Five";
const char_t * const CONCATFIVE2FOUR = "ConcatFive2Four";
const char_t * const SSDREALDIVTILEMUL = "SSDRealdivTileMul";
const char_t * const SSDSUMMULREALDIVMEAN = "SSDSumMulRealdivMean";

const char *VARIABLEV2 = "VariableV2";
const char *VARHANDLEOP = "VarHandleOp";
const char *TEMPORARYVARIABLE = "TemporaryVariable";
const char *DESTROYTEMPORARYVARIABLE = "DestroyTemporaryVariable";
const char *VARIABLE = "Variable";
const char *ASSIGN = "Assign";
const char *ASSIGNVARIABLEOP = "AssignVariableOp";
const char *ASSIGNADD = "AssignAdd";
const char *ASSIGNADDVARIABLEOP = "AssignAddVariableOp";
const char *ASSIGNSUB = "AssignSub";
const char *ASSIGNSUBVARIABLEOP = "AssignSubVariableOp";
const char *APPLYMOMENTUM = "ApplyMomentum";
const char *RESOURCEAPPLYMOMENTUM = "ResourceApplyMomentum";
const char *SGD = "SGD";
const char *NOOP = "NoOp";
const char *READVARIABLEOP = "ReadVariableOp";
const char *PARALLELCONCATSTART = "_ParallelConcatStart";
const char *CONSTANTOP = "Constant";
const char *DEPTHWISECONV2DBACKPROPFILTER = "DepthwiseConv2dNativeBackpropFilter";
const char *DEPTHWISECONV2DBACKPORPINPUT = "DepthwiseConv2dNativeBackpropInput";
const char *DEPTHWISECONV2DFORWARDNATIVE = "DepthwiseConv2dNative";
const char *DROPOUTGRAD = "DropOutGrad";
const char *APPLYRMSPROPMIXEDPRECISION = "apply_rms_prop_mixed_precision";
const char *APPLYRMSPROP = "ApplyRMSProp";
const char *RELU6GRAD = "Relu6Grad";
const char *AVGPOOLGRAD = "AvgPoolGrad";
const char *CONCATV2 = "ConcatV2";
const char *CONCATOFFSET = "ConcatOffset";
const char *LAYERNORMGRAD = "LayerNormGrad";
const char *LAYERNORM = "LayerNorm";
const char *LARS = "Lars";
const char *DYNAMICSTITCH = "DynamicStitch";
const char_t * const VARIABLEV2 = "VariableV2";
const char_t * const VARHANDLEOP = "VarHandleOp";
const char_t * const TEMPORARYVARIABLE = "TemporaryVariable";
const char_t * const DESTROYTEMPORARYVARIABLE = "DestroyTemporaryVariable";
const char_t * const VARIABLE = "Variable";
const char_t * const ASSIGN = "Assign";
const char_t * const ASSIGNVARIABLEOP = "AssignVariableOp";
const char_t * const ASSIGNADD = "AssignAdd";
const char_t * const ASSIGNADDVARIABLEOP = "AssignAddVariableOp";
const char_t * const ASSIGNSUB = "AssignSub";
const char_t * const ASSIGNSUBVARIABLEOP = "AssignSubVariableOp";
const char_t * const APPLYMOMENTUM = "ApplyMomentum";
const char_t * const RESOURCEAPPLYMOMENTUM = "ResourceApplyMomentum";
const char_t * const SGD = "SGD";
const char_t * const NOOP = "NoOp";
const char_t * const READVARIABLEOP = "ReadVariableOp";
const char_t * const PARALLELCONCATSTART = "_ParallelConcatStart";
const char_t * const CONSTANTOP = "Constant";
const char_t * const DEPTHWISECONV2DBACKPROPFILTER = "DepthwiseConv2dNativeBackpropFilter";
const char_t * const DEPTHWISECONV2DBACKPORPINPUT = "DepthwiseConv2dNativeBackpropInput";
const char_t * const DEPTHWISECONV2DFORWARDNATIVE = "DepthwiseConv2dNative";
const char_t * const DROPOUTGRAD = "DropOutGrad";
const char_t * const APPLYRMSPROPMIXEDPRECISION = "apply_rms_prop_mixed_precision";
const char_t * const APPLYRMSPROP = "ApplyRMSProp";
const char_t * const RELU6GRAD = "Relu6Grad";
const char_t * const AVGPOOLGRAD = "AvgPoolGrad";
const char_t * const CONCATV2 = "ConcatV2";
const char_t * const CONCATOFFSET = "ConcatOffset";
const char_t * const LAYERNORMGRAD = "LayerNormGrad";
const char_t * const LAYERNORM = "LayerNorm";
const char_t * const LARS = "Lars";
const char_t * const DYNAMICSTITCH = "DynamicStitch";

/***************************************************/
const char *SQUARE = "Square";
const char *HCOMBROADCAST = "HcomBroadcast";
const char *HCOMALLGATHER = "HcomAllGather";
const char *HCOMALLREDUCE = "HcomAllReduce";
const char *HCOMREDUCESCATTER = "HcomReduceScatter";
const char *HCOMSEND = "HcomSend";
const char *HCOMRECEIVE = "HcomReceive";
const char *HCOMREMOTEREAD = "HcomRemoteRead";
const char *HCOMREMOTEREFREAD = "HcomRemoteRefRead";
const char *HCOMREMOTEWRITE = "HcomRemoteWrite";
const char *HCOMREMOTESCATTERWRITE = "HcomRemoteScatterWrite";
const char_t * const SQUARE = "Square";
const char_t * const HCOMBROADCAST = "HcomBroadcast";
const char_t * const HCOMALLGATHER = "HcomAllGather";
const char_t * const HCOMALLREDUCE = "HcomAllReduce";
const char_t * const HCOMREDUCESCATTER = "HcomReduceScatter";
const char_t * const HCOMSEND = "HcomSend";
const char_t * const HCOMRECEIVE = "HcomReceive";
const char_t * const HCOMREMOTEREAD = "HcomRemoteRead";
const char_t * const HCOMREMOTEREFREAD = "HcomRemoteRefRead";
const char_t * const HCOMREMOTEWRITE = "HcomRemoteWrite";
const char_t * const HCOMREMOTESCATTERWRITE = "HcomRemoteScatterWrite";

const char *VARASSIGN = "VarAssign";
const char *VARISINITIALIZEDOP = "VarIsInitializedOp";
const char *LogTimeStamp = "LogTimeStamp";
const char *ISVARIABLEINITIALIZED = "IsVariableInitialized";
const char *STREAMSWITCH = "StreamSwitch";
const char *STREAMSWITCHN = "StreamSwitchN";
const char *STREAMACTIVE = "StreamActive";
const char *MEMCPYASYNC = "MemcpyAsync";
const char *MEMCPYADDRASYNC = "MemcpyAddrAsync";
const char *STREAMMERGE = "StreamMerge";
const char *ENDGRAPH = "EndGraph";
const char *SEND = "Send";
const char *RECV = "Recv";
const char *ENDOFSEQUENCE = "EndOfSequence";
const char_t * const VARASSIGN = "VarAssign";
const char_t * const VARISINITIALIZEDOP = "VarIsInitializedOp";
const char_t * const LogTimeStamp = "LogTimeStamp";
const char_t * const ISVARIABLEINITIALIZED = "IsVariableInitialized";
const char_t * const STREAMSWITCH = "StreamSwitch";
const char_t * const STREAMSWITCHN = "StreamSwitchN";
const char_t * const STREAMACTIVE = "StreamActive";
const char_t * const MEMCPYASYNC = "MemcpyAsync";
const char_t * const MEMCPYADDRASYNC = "MemcpyAddrAsync";
const char_t * const STREAMMERGE = "StreamMerge";
const char_t * const ENDGRAPH = "EndGraph";
const char_t * const SEND = "Send";
const char_t * const RECV = "Recv";
const char_t * const ENDOFSEQUENCE = "EndOfSequence";

const char *LABELSET = "LabelSet";
const char *LABELGOTO = "LabelGoto";
const char *LABELGOTOEX = "LabelGotoEx";
const char *LABELSWITCH = "LabelSwitch";
const char *LABELSWITCHBYINDEX = "LabelSwitchByIndex";
const char_t * const LABELSET = "LabelSet";
const char_t * const LABELGOTO = "LabelGoto";
const char_t * const LABELGOTOEX = "LabelGotoEx";
const char_t * const LABELSWITCH = "LabelSwitch";
const char_t * const LABELSWITCHBYINDEX = "LabelSwitchByIndex";

const char *ATOMICADDRCLEAN = "AtomicAddrClean";
const char_t * const ATOMICADDRCLEAN = "AtomicAddrClean";

const char *ABS_GRAD = "AbsGrad";
const char *ACCUMULATE_N_V2 = "AccumulateNV2";
const char *ACOS_GRAD = "AcosGrad";
const char *ACOSH_GRAD = "AcoshGrad";
const char *ANY = "Any";
const char *APPROXIMATE_EQUAL = "ApproximateEqual";
const char *ASIN_GRAD = "AsinGrad";
const char *ASINH_GRAD = "AsinhGrad";
const char *ATAN_GRAD = "AtanGrad";
const char *BROADCAST_TO = "BroadcastTo";
const char *ELU_GRAD = "EluGrad";
const char *ADD_V2 = "AddV2";
const char *DATAFORMATDIMMAP = "DataFormatDimMap";
const char *DATAFORMATVECPERMUTE = "DataFormatVecPermute";
const char *BESSELI0E = "BesselI0e";
const char *BESSELI1E = "BesselI1e";
const char *APPLYADADELTA = "ApplyAdadelta";
const char *APPLYADAGRAD = "ApplyAdagrad";
const char *APPLYADAGRADDA = "ApplyAdagradDA";
const char *APPLYADAM = "ApplyAdam";
const char *APPLYADAMAX = "ApplyAdaMax";
const char *APPLYADDSIGN = "ApplyAddSign";
const char *APPLYCENTEREDRMSPROP = "ApplyCenteredRMSProp";
const char *APPLYFTRL = "ApplyFtrl";
const char *APPLYFTRLV2 = "ApplyFtrlV2";
const char *APPLYGRADIENTDESCENT = "ApplyGradientDescent";
const char *APPLYPOWERSIGN = "ApplyPowerSign";
const char *APPLYPROXIMALADAGRAD = "ApplyProximalAdagrad";
const char *APPLYPROXIMALGRADIENTDESCENT = "ApplyProximalGradientDescent";
const char *DEQUANTIZE = "Dequantize";
const char_t * const ABS_GRAD = "AbsGrad";
const char_t * const ACCUMULATE_N_V2 = "AccumulateNV2";
const char_t * const ACOS_GRAD = "AcosGrad";
const char_t * const ACOSH_GRAD = "AcoshGrad";
const char_t * const ANY = "Any";
const char_t * const APPROXIMATE_EQUAL = "ApproximateEqual";
const char_t * const ASIN_GRAD = "AsinGrad";
const char_t * const ASINH_GRAD = "AsinhGrad";
const char_t * const ATAN_GRAD = "AtanGrad";
const char_t * const BROADCAST_TO = "BroadcastTo";
const char_t * const ELU_GRAD = "EluGrad";
const char_t * const ADD_V2 = "AddV2";
const char_t * const DATAFORMATDIMMAP = "DataFormatDimMap";
const char_t * const DATAFORMATVECPERMUTE = "DataFormatVecPermute";
const char_t * const BESSELI0E = "BesselI0e";
const char_t * const BESSELI1E = "BesselI1e";
const char_t * const APPLYADADELTA = "ApplyAdadelta";
const char_t * const APPLYADAGRAD = "ApplyAdagrad";
const char_t * const APPLYADAGRADDA = "ApplyAdagradDA";
const char_t * const APPLYADAM = "ApplyAdam";
const char_t * const APPLYADAMAX = "ApplyAdaMax";
const char_t * const APPLYADDSIGN = "ApplyAddSign";
const char_t * const APPLYCENTEREDRMSPROP = "ApplyCenteredRMSProp";
const char_t * const APPLYFTRL = "ApplyFtrl";
const char_t * const APPLYFTRLV2 = "ApplyFtrlV2";
const char_t * const APPLYGRADIENTDESCENT = "ApplyGradientDescent";
const char_t * const APPLYPOWERSIGN = "ApplyPowerSign";
const char_t * const APPLYPROXIMALADAGRAD = "ApplyProximalAdagrad";
const char_t * const APPLYPROXIMALGRADIENTDESCENT = "ApplyProximalGradientDescent";
const char_t * const DEQUANTIZE = "Dequantize";

const char *FOCAL_LOSS = "FocalLoss";
const char *FOCAL_LOSS_GRAD = "FocalLossGrad";
const char *SMOOTHL1_LOSS = "SmoothL1Loss";
const char *SMOOTHL1_LOSS_grad = "SmoothL1LossGrad";
const char *REDUCEMEAN = "ReduceMean";
const char *CONCAT_V2 = "ConcatV2";
const char *ONEHOT_V2 = "OneHotV2";
const char *SLICE_V2 = "SliceV2";
const char *TILE_V2 = "TileV2";
const char *SUM_V2 = "SumV2";
const char_t * const FOCAL_LOSS = "FocalLoss";
const char_t * const FOCAL_LOSS_GRAD = "FocalLossGrad";
const char_t * const SMOOTHL1_LOSS = "SmoothL1Loss";
const char_t * const SMOOTHL1_LOSS_grad = "SmoothL1LossGrad";
const char_t * const REDUCEMEAN = "ReduceMean";
const char_t * const CONCAT_V2 = "ConcatV2";
const char_t * const ONEHOT_V2 = "OneHotV2";
const char_t * const SLICE_V2 = "SliceV2";
const char_t * const TILE_V2 = "TileV2";
const char_t * const SUM_V2 = "SumV2";
// Common type when the operator has the same name
const char *DETECTIONOUTPUT = "DetectionOutput";
const char_t * const DETECTIONOUTPUT = "DetectionOutput";
// Custom operator
const char *CUSTOMOP = "CustomOp";
const char *CUSTOMOP_NCHW = "CustomOpNchw";
const char *CUSTOMOP_NHWC = "CustomOpNhwc";
const char *CUSTOMOP_NC1HWC0 = "CustomOpNc1hwc0";
const char_t * const CUSTOMOP = "CustomOp";
const char_t * const CUSTOMOP_NCHW = "CustomOpNchw";
const char_t * const CUSTOMOP_NHWC = "CustomOpNhwc";
const char_t * const CUSTOMOP_NC1HWC0 = "CustomOpNc1hwc0";

// Depthwise 4d_2_6d,6d_2_4d
const char *DEPTHWISEWEIGHT4D26D = "depthwise_weight_4d_2_6d";
const char *DEPTHWISEWEIGHT6D24D = "depthwise_weight_6d_2_4d";
const char_t * const DEPTHWISEWEIGHT4D26D = "depthwise_weight_4d_2_6d";
const char_t * const DEPTHWISEWEIGHT6D24D = "depthwise_weight_6d_2_4d";

const char *SQRTGRAD = "SqrtGrad";
const char *SIGMOIDGRAD = "SigmoidGrad";
const char_t * const SQRTGRAD = "SqrtGrad";
const char_t * const SIGMOIDGRAD = "SigmoidGrad";

const char *TRANSSHAPE = "TransShape";
const char_t * const TRANSSHAPE = "TransShape";

// Horovod operator
const char *HVDCALLBACKALLREDUCE = "HorovodAllreduce";
const char *HVDCALLBACKALLGATHER = "HorovodAllgather";
const char *HVDCALLBACKBROADCAST = "HorovodBroadcast";
const char *HVDWAIT = "HorovodWait";
const char_t * const HVDCALLBACKALLREDUCE = "HorovodAllreduce";
const char_t * const HVDCALLBACKALLGATHER = "HorovodAllgather";
const char_t * const HVDCALLBACKBROADCAST = "HorovodBroadcast";
const char_t * const HVDWAIT = "HorovodWait";

///
/// @brief Magic number of model file


+ 4
- 4
parser/onnx/onnx_parser.cc View File

@@ -52,7 +52,7 @@ const char *kLocation = "location";
}

namespace ge {
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util,
graphStatus PrepareBeforeParse(AclGraphParseUtil &acl_graph_parse_util,
const std::map<AscendString, AscendString> &parser_params,
ge::Graph &graph, std::shared_ptr<domi::ModelParser> &model_parser) {
GetParserContext().type = domi::ONNX;
@@ -82,7 +82,7 @@ graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util,
return ge::SUCCESS;
}

graphStatus HandleAfterParse(AclGrphParseUtil &acl_graph_parse_util,
graphStatus HandleAfterParse(AclGraphParseUtil &acl_graph_parse_util,
const std::map<AscendString, AscendString> &parser_params,
ge::Graph &graph) {
if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) {
@@ -104,7 +104,7 @@ graphStatus aclgrphParseONNX(const char *model_file,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
GE_CHECK_NOTNULL(model_file);
// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
std::shared_ptr<domi::ModelParser> model_parser;

if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) {
@@ -136,7 +136,7 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
GE_CHECK_NOTNULL(buffer);
// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
std::shared_ptr<domi::ModelParser> model_parser;

if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) {


+ 2
- 2
parser/tensorflow/tensorflow_parser.cc View File

@@ -94,7 +94,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) {
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::TENSORFLOW)));

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
if (acl_graph_parse_util.AclParserInitialize(options) != domi::SUCCESS) {
GELOGE(GRAPH_FAILED, "Parser Initialize failed.");
return GRAPH_FAILED;
@@ -142,7 +142,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<Ascend
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::TENSORFLOW)));

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
domi::Status status = acl_graph_parse_util.AclParserInitialize(options);
if (status != domi::SUCCESS) {
GELOGE(GRAPH_FAILED, "Parser Initialize failed.");


+ 4
- 0
tests/st/testcase/origin_models/conv2d_depthwise_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import os



+ 4
- 0
tests/st/testcase/origin_models/onnx_clip_v9.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto


+ 5
- 1
tests/st/testcase/origin_models/onnx_if_const_intput_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import os
import numpy as np
import onnx
@@ -42,4 +46,4 @@ def gen_onnx():
print(model_def)

if __name__ == "__main__":
gen_onnx()
gen_onnx()

+ 4
- 0
tests/st/testcase/origin_models/tensor_array_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util


+ 5
- 1
tests/st/testcase/origin_models/test_VarIsInitializedOp_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np

@@ -11,4 +15,4 @@ def generate_VarIsInitializedOp_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_VarIsInitializedOp.pb", as_text=False)

if __name__=='__main__':
generate_VarIsInitializedOp_pb()
generate_VarIsInitializedOp_pb()

+ 5
- 1
tests/st/testcase/origin_models/test_avgpool3dgrad_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import os
import numpy as np
@@ -38,4 +42,4 @@ def generate_case_2():
tf.io.write_graph(sess.graph, logdir="./", name="avgpool3dgrad.pb.txt", as_text=False)

if __name__=='__main__':
generate_case_2()
generate_case_2()

+ 4
- 0
tests/st/testcase/origin_models/test_blocklstm_pb.gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import os



+ 4
- 0
tests/st/testcase/origin_models/test_constant_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np
from tensorflow.python.framework import graph_util


+ 5
- 1
tests/st/testcase/origin_models/test_conv2d_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
@@ -23,4 +27,4 @@ def generate_add_pb():

if __name__=='__main__':
generate_conv2d_pb()
generate_add_pb()
generate_add_pb()

+ 5
- 1
tests/st/testcase/origin_models/test_enter_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

@@ -10,4 +14,4 @@ def generate_enter_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_enter.pb", as_text=False)

if __name__=='__main__':
generate_enter_pb()
generate_enter_pb()

+ 5
- 1
tests/st/testcase/origin_models/test_fill_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np

@@ -9,4 +13,4 @@ def generate_fill_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_fill.pb", as_text=False)

if __name__ == "__main__":
generate_fill_pb()
generate_fill_pb()

+ 5
- 1
tests/st/testcase/origin_models/test_identity_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf

def generate_identity_pb():
@@ -10,4 +14,4 @@ def generate_identity_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_identity.pb", as_text=False)

if __name__=='__main__':
generate_identity_pb()
generate_identity_pb()

+ 4
- 0
tests/st/testcase/origin_models/test_merge_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np


+ 5
- 1
tests/st/testcase/origin_models/test_no_op_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np

@@ -11,4 +15,4 @@ def generate_no_op_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_no_op.pb", as_text=False)

if __name__=='__main__':
generate_no_op_pb()
generate_no_op_pb()

+ 5
- 1
tests/st/testcase/origin_models/test_reshape_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf

def generate_reshape_pb():
@@ -7,4 +11,4 @@ def generate_reshape_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_reshape.pb", as_text=False)

if __name__ == "__main__":
generate_reshape_pb()
generate_reshape_pb()

+ 5
- 1
tests/st/testcase/origin_models/test_sequeeze_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np

@@ -10,4 +14,4 @@ def generate_sequeeze_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_sequeeze.pb", as_text=False)

if __name__ == "__main__":
generate_sequeeze_pb()
generate_sequeeze_pb()

+ 5
- 1
tests/st/testcase/origin_models/test_shape_n_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np

@@ -8,4 +12,4 @@ def generate_shape_n_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_shape_n.pb", as_text=False)

if __name__ == "__main__":
generate_shape_n_pb()
generate_shape_n_pb()

+ 5
- 1
tests/st/testcase/origin_models/test_switch_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

@@ -10,4 +14,4 @@ def generate_switch_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_switch.pb", as_text=False)

if __name__=='__main__':
generate_switch_pb()
generate_switch_pb()

+ 5
- 1
tests/st/testcase/origin_models/test_variableV2_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf

def generate_VariableV2_pb():
@@ -10,4 +14,4 @@ def generate_VariableV2_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_VariableV2.pb", as_text=False)

if __name__=='__main__':
generate_VariableV2_pb()
generate_VariableV2_pb()

+ 2
- 2
tests/st/testcase/test_caffe_parser.cc View File

@@ -191,7 +191,7 @@ TEST_F(STestCaffeParser, caffe_parser_user_output_with_default) {
ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
auto ret = model_parser->Parse(model_file.c_str(), graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
std::map<AscendString, AscendString> parser_params;
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params);
ASSERT_EQ(status, SUCCESS);
@@ -483,7 +483,7 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_CreateCustomOperator_test)
TEST_F(STestCaffeParser, CaffeWeightsParser_ParseOutputNodeTopInfo_test)
{
CaffeModelParser model_parser;
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;

domi::caffe::NetParameter net;
domi::caffe::LayerParameter *lay0 = net.add_layer();


+ 7
- 7
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -1104,7 +1104,7 @@ TEST_F(STestTensorflowParser, parser_tensorflow_model) {

// parser tensorflow model out_node_size is equal to index
string graph_name;
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
std::map<AscendString, AscendString> out_nodes_with_node_and_index = {
{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:1")}};
ParerSTestsUtils::ClearParserInnerCtx();
@@ -1356,7 +1356,7 @@ TEST_F(STestTensorflowParser, tensorflow_parserAllGraph_failed)

TEST_F(STestTensorflowParser, test_parse_acl_output_nodes)
{
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
string graph_name;
// case 1: Normal with 'node and index'
ParerSTestsUtils::ClearParserInnerCtx();
@@ -1523,7 +1523,7 @@ TEST_F(STestTensorflowParser, parse_AddFmkNode)
std::string modelFile = caseDir + "/origin_models/tf_add.pb";
ge::Graph graph;
string graph_name;
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
std::map<ge::AscendString, ge::AscendString> parser_options = {{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}};
ParerSTestsUtils::ClearParserInnerCtx();
Status ret = acl_graph_parse_util.ParseParamsBeforeGraph(parser_options, graph_name);
@@ -3781,9 +3781,9 @@ TEST_F(STestTensorflowParser, tensorflow_ReadBytesFromBinaryFile_test)
EXPECT_EQ(realPath, "");
}

TEST_F(STestTensorflowParser, tensorflow_AclGrphParseUtil_ParseAclInputFp16Nodes_test)
TEST_F(STestTensorflowParser, tensorflow_AclGraphParseUtil_ParseAclInputFp16Nodes_test)
{
AclGrphParseUtil parserUtil;
AclGraphParseUtil parserUtil;
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
std::string input_fp16_nodes = "Add";
std::string is_input_adjust_hw_layout = "is_input_adjust_hw_layout";
@@ -4010,7 +4010,7 @@ TEST_F(STestTensorflowParser, tensorflow_FP16_parser_test)

TEST_F(STestTensorflowParser, tensorflow_AclParserInitialize_test)
{
AclGrphParseUtil parseUtil;
AclGraphParseUtil parseUtil;
std::map<std::string, std::string> options;
Status ret = parseUtil.AclParserInitialize(options);
EXPECT_EQ(ret, FAILED);
@@ -4022,7 +4022,7 @@ TEST_F(STestTensorflowParser, tensorflow_AclParserInitialize_test)

TEST_F(STestTensorflowParser, tensorflow_GetOutputLeaf_test)
{
AclGrphParseUtil parseUtil;
AclGraphParseUtil parseUtil;
ge::ComputeGraphPtr compute_graph = build_graph(true);
ge::NodePtr output_nodes_info = compute_graph->FindNode("Relu3");
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{output_nodes_info,0}};


+ 4
- 4
tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc View File

@@ -189,7 +189,7 @@ TEST_F(UtestCaffeParser, caffe_parser_user_output_with_name_and_index) {
ge::GetParserContext().user_out_nodes.push_back({"abs", 0});
auto ret = model_parser->Parse(model_file.c_str(), graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
std::map<AscendString, AscendString> parser_params;
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params);
ASSERT_EQ(status, SUCCESS);
@@ -216,7 +216,7 @@ TEST_F(UtestCaffeParser, caffe_parser_user_output_with_top_name) {
ge::GetParserContext().user_out_tensors.push_back("abs_out");
auto ret = model_parser->Parse(model_file.c_str(), graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
std::map<AscendString, AscendString> parser_params;
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params);
ASSERT_EQ(status, SUCCESS);
@@ -241,7 +241,7 @@ TEST_F(UtestCaffeParser, caffe_parser_user_output_with_default) {
ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
auto ret = model_parser->Parse(model_file.c_str(), graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
std::map<AscendString, AscendString> parser_params;
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params);
ASSERT_EQ(status, SUCCESS);
@@ -543,7 +543,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_CreateCustomOperator_test)
TEST_F(UtestCaffeParser, CaffeWeightsParser_ParseOutputNodeTopInfo_test)
{
CaffeModelParser model_parser;
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;

domi::caffe::NetParameter net;
domi::caffe::LayerParameter *lay0 = net.add_layer();


+ 1
- 1
tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc View File

@@ -53,7 +53,7 @@ class UtestAclGraphParser : public testing::Test {
};

TEST_F(UtestAclGraphParser, test_parse_acl_output_nodes) {
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
string graph_name;
// case 1: Normal with 'node and index'
ParerUTestsUtils::ClearParserInnerCtx();


+ 4
- 0
tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/onnx_clip_v9.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto


+ 4
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/conv2d_depthwise_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import os



+ 4
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/onnx_clip_v9.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto


+ 4
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/tensor_array_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util


+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_VarIsInitializedOp_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np

@@ -11,4 +15,4 @@ def generate_VarIsInitializedOp_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_VarIsInitializedOp.pb", as_text=False)

if __name__=='__main__':
generate_VarIsInitializedOp_pb()
generate_VarIsInitializedOp_pb()

+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_avgpool3dgrad_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import os
import numpy as np
@@ -38,4 +42,4 @@ def generate_case_2():
tf.io.write_graph(sess.graph, logdir="./", name="avgpool3dgrad.pb.txt", as_text=False)

if __name__=='__main__':
generate_case_2()
generate_case_2()

+ 4
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_blocklstm_pb.gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import os



+ 4
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_constant_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np
from tensorflow.python.framework import graph_util


+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_conv2d_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
@@ -23,4 +27,4 @@ def generate_add_pb():

if __name__=='__main__':
generate_conv2d_pb()
generate_add_pb()
generate_add_pb()

+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_enter_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

@@ -10,4 +14,4 @@ def generate_enter_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_enter.pb", as_text=False)

if __name__=='__main__':
generate_enter_pb()
generate_enter_pb()

+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_fill_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np

@@ -9,4 +13,4 @@ def generate_fill_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_fill.pb", as_text=False)

if __name__ == "__main__":
generate_fill_pb()
generate_fill_pb()

+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_identity_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf

def generate_identity_pb():
@@ -10,4 +14,4 @@ def generate_identity_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_identity.pb", as_text=False)

if __name__=='__main__':
generate_identity_pb()
generate_identity_pb()

+ 4
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_merge_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np


+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_no_op_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np

@@ -11,4 +15,4 @@ def generate_no_op_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_no_op.pb", as_text=False)

if __name__=='__main__':
generate_no_op_pb()
generate_no_op_pb()

+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_reshape_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf

def generate_reshape_pb():
@@ -7,4 +11,4 @@ def generate_reshape_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_reshape.pb", as_text=False)

if __name__ == "__main__":
generate_reshape_pb()
generate_reshape_pb()

+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_sequeeze_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np

@@ -10,4 +14,4 @@ def generate_sequeeze_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_sequeeze.pb", as_text=False)

if __name__ == "__main__":
generate_sequeeze_pb()
generate_sequeeze_pb()

+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_shape_n_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
import numpy as np

@@ -8,4 +12,4 @@ def generate_shape_n_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_shape_n.pb", as_text=False)

if __name__ == "__main__":
generate_shape_n_pb()
generate_shape_n_pb()

+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_switch_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

@@ -10,4 +14,4 @@ def generate_switch_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_switch.pb", as_text=False)

if __name__=='__main__':
generate_switch_pb()
generate_switch_pb()

+ 5
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/test_variableV2_pb_gen.py View File

@@ -1,3 +1,7 @@
#!/usr/bin/env python3
# -*- coding utf-8 -*-
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved.

import tensorflow as tf

def generate_VariableV2_pb():
@@ -10,4 +14,4 @@ def generate_VariableV2_pb():
tf.io.write_graph(sess.graph, logdir="./", name="test_VariableV2.pb", as_text=False)

if __name__=='__main__':
generate_VariableV2_pb()
generate_VariableV2_pb()

+ 7
- 7
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -1106,7 +1106,7 @@ TEST_F(UtestTensorflowParser, parser_tensorflow_model) {

// parser tensorflow model out_node_size is equal to index
string graph_name;
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
std::map<AscendString, AscendString> out_nodes_with_node_and_index = {
{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:1")}};
ParerUTestsUtils::ClearParserInnerCtx();
@@ -1452,7 +1452,7 @@ TEST_F(UtestTensorflowParser, tensorflow_parserAllGraph_failed)

TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes)
{
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
string graph_name;
// case 1: Normal with 'node and index'
ParerUTestsUtils::ClearParserInnerCtx();
@@ -1621,7 +1621,7 @@ TEST_F(UtestTensorflowParser, parse_AddFmkNode)
std::string modelFile = caseDir + "/tensorflow_model/tf_add.pb";
ge::Graph graph;
string graph_name;
AclGrphParseUtil acl_graph_parse_util;
AclGraphParseUtil acl_graph_parse_util;
std::map<ge::AscendString, ge::AscendString> parser_options = {{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}};
ParerUTestsUtils::ClearParserInnerCtx();
Status ret = acl_graph_parse_util.ParseParamsBeforeGraph(parser_options, graph_name);
@@ -3885,9 +3885,9 @@ TEST_F(UtestTensorflowParser, tensorflow_ReadBytesFromBinaryFile_test)
EXPECT_EQ(realPath, "");
}

TEST_F(UtestTensorflowParser, tensorflow_AclGrphParseUtil_ParseAclInputFp16Nodes_test)
TEST_F(UtestTensorflowParser, tensorflow_AclGraphParseUtil_ParseAclInputFp16Nodes_test)
{
AclGrphParseUtil parserUtil;
AclGraphParseUtil parserUtil;
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
std::string input_fp16_nodes = "Add";
std::string is_input_adjust_hw_layout = "is_input_adjust_hw_layout";
@@ -4094,7 +4094,7 @@ TEST_F(UtestTensorflowParser, tensorflow_FP16_parser_test)

TEST_F(UtestTensorflowParser, tensorflow_AclParserInitialize_test)
{
AclGrphParseUtil parseUtil;
AclGraphParseUtil parseUtil;
std::map<std::string, std::string> options;
Status ret = parseUtil.AclParserInitialize(options);
EXPECT_EQ(ret, FAILED);
@@ -4106,7 +4106,7 @@ TEST_F(UtestTensorflowParser, tensorflow_AclParserInitialize_test)

TEST_F(UtestTensorflowParser, tensorflow_GetOutputLeaf_test)
{
AclGrphParseUtil parseUtil;
AclGraphParseUtil parseUtil;
ge::ComputeGraphPtr compute_graph = build_graph(true);
ge::NodePtr output_nodes_info = compute_graph->FindNode("Relu3");
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{output_nodes_info,0}};


Loading…
Cancel
Save