| @@ -1,4 +1,4 @@ | |||
| [submodule "metadef"] | |||
| path = metadef | |||
| url = https://gitee.com/ascend/metadef.git | |||
| branch = development | |||
| branch = r1.2.0 | |||
| @@ -33,11 +33,11 @@ if (ENABLE_OPEN_SRC) | |||
| message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | |||
| endif() | |||
| set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | |||
| find_module(slog libslog.so ${GE_LIB_PATH}) | |||
| find_module(slog libalog.so ${GE_LIB_PATH}) | |||
| find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | |||
| find_module(error_manager liberror_manager.so ${GE_LIB_PATH}) | |||
| elseif(ENABLE_GE_COV OR ENABLE_GE_UT) | |||
| message(STATUS "Runing on llt mode, no need to depend other component") | |||
| message(STATUS "Running on llt mode, no need to depend other component.") | |||
| else() | |||
| if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||
| set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | |||
| @@ -47,7 +47,7 @@ if (ENABLE_OPEN_SRC) | |||
| set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | |||
| find_module(slog libslog.so ${ASCEND_ATC_DIR}) | |||
| find_module(slog libalog.so ${ASCEND_ATC_DIR}) | |||
| find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | |||
| find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | |||
| endif() | |||
| @@ -2,6 +2,9 @@ approvers: | |||
| - ji_chen | |||
| - wqtshg | |||
| - ljl0711 | |||
| - startzgf168 | |||
| - lbisdaddy | |||
| - andylhy | |||
| reviewers: | |||
| - xchu42 | |||
| - sheng-nan | |||
| @@ -7,7 +7,6 @@ function(find_module module name path) | |||
| if (TARGET ${module}) | |||
| return() | |||
| endif() | |||
| add_library(${module} INTERFACE) | |||
| find_library(${module}_LIBRARY_DIR NAMES ${name} NAMES_PER_DIR PATHS ${path} | |||
| PATH_SUFFIXES lib | |||
| ) | |||
| @@ -16,5 +15,9 @@ function(find_module module name path) | |||
| if ("${${module}_LIBRARY_DIR}" STREQUAL "${module}_LIBRARY_DIR-NOTFOUND") | |||
| message(FATAL_ERROR "${name} not found in ${path}") | |||
| endif() | |||
| target_link_libraries(${module} INTERFACE ${${module}_LIBRARY_DIR}) | |||
| add_library(${module} SHARED IMPORTED) | |||
| set_target_properties(${module} PROPERTIES | |||
| IMPORTED_LOCATION ${${module}_LIBRARY_DIR} | |||
| ) | |||
| endfunction() | |||
| @@ -16,6 +16,7 @@ target_compile_definitions(intf_pub INTERFACE | |||
| $<$<CONFIG:Debug>:CFG_BUILD_DEBUG> | |||
| WIN64=1 | |||
| LINUX=0 | |||
| LOG_CPP | |||
| ) | |||
| target_link_options(intf_pub INTERFACE | |||
| -Wl,-z,relro | |||
| @@ -17,17 +17,23 @@ | |||
| #ifndef INC_EXTERNAL_PARSER_ONNX_PARSER_H_ | |||
| #define INC_EXTERNAL_PARSER_ONNX_PARSER_H_ | |||
| #include <memory> | |||
| #include <map> | |||
| #include "graph/ascend_string.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/graph.h" | |||
| #include "graph/types.h" | |||
| #include "graph/graph.h" | |||
| namespace ge { | |||
| graphStatus aclgrphParseONNX(const char *model_file, | |||
| const std::map<ge::AscendString, ge::AscendString> &parser_params, ge::Graph &graph); | |||
| const std::map<ge::AscendString, | |||
| ge::AscendString> &parser_params, | |||
| ge::Graph &graph); | |||
| graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, | |||
| const std::map<ge::AscendString, ge::AscendString> &parser_params, | |||
| graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t buffer_size, | |||
| const std::map<ge::AscendString, | |||
| ge::AscendString> &parser_params, | |||
| ge::Graph &graph); | |||
| } // namespace ge | |||
| @@ -1 +1 @@ | |||
| Subproject commit c14d2be38171eed63416e71178774103faf1f5cd | |||
| Subproject commit af156f825aa53a24bd30ae4065e3ea356cf555ef | |||
| @@ -193,6 +193,7 @@ const int kMaxParseDepth = 5; | |||
| const int32_t kMinLineWorldSize = 3; | |||
| const int32_t kMaxIdentifier = 536870911; // 2^29 - 1 | |||
| const int32_t kBase = 10; | |||
| const uint32_t kInteval = 2; | |||
| const char *const kPython = "Python"; | |||
| const char *const kProposalLayer = "ProposalLayer"; | |||
| const char *const kDetectionOutput = "DetectionOutput"; | |||
| @@ -793,13 +794,22 @@ Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection * | |||
| CASE_FIELD_TYPE_REPEATED(STRING, String, string); | |||
| #undef CASE_FIELD_TYPE_REPEATED | |||
| case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { | |||
| for (int i = 0; i < field_size; ++i) { | |||
| const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); | |||
| if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { | |||
| GELOGE(FAILED, "ParseOperatorAttrs of field: %s failed.", field->name().c_str()); | |||
| return FAILED; | |||
| } | |||
| nlohmann::json message_json; | |||
| Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(), | |||
| message_json[field->name()], false); | |||
| std::string repeated_message_str; | |||
| try { | |||
| repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); | |||
| } catch (std::exception &e) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()}); | |||
| GELOGE(FAILED, "Failed to convert JSON to string, reason: %s.", e.what()); | |||
| return FAILED; | |||
| } catch (...) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19008"); | |||
| GELOGE(FAILED, "Failed to convert JSON to string."); | |||
| return FAILED; | |||
| } | |||
| (void)ops.SetAttr(field->name(), repeated_message_str); | |||
| break; | |||
| } | |||
| default: { | |||
| @@ -56,20 +56,17 @@ class CaffeModelParser : public domi::ModelParser { | |||
| /** | |||
| * @ingroup domi_omg | |||
| * @brief Parse the relevant data from memory and save it to graph | |||
| * @param [in] memory buffer of model file | |||
| * @param [in] buffer size | |||
| * @brief Parse the relevant data from the memory and save it to graph | |||
| * @param [in] file Path of model file | |||
| * @param [in|out] graph graph for saving model information | |||
| * @return SUCCESS parse successfully | |||
| * @return FAILED parse failed | |||
| */ | |||
| Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | |||
| #ifndef ONLY_COMPILE_OPEN_SRC | |||
| Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override { | |||
| return domi::SUCCESS; | |||
| return domi::SUCCESS; | |||
| } | |||
| #endif | |||
| /** | |||
| * @ingroup domi_omg | |||
| @@ -30,7 +30,6 @@ enum DataType | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| DT_VARIANT = 26; // variant type | |||
| } | |||
| message AttrDef | |||
| @@ -406,53 +406,6 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra | |||
| return SUCCESS; | |||
| } | |||
| domi::Status AclGrphParseUtil::ParseAclWeightCompressConf(const ComputeGraphPtr &graph, | |||
| const string &compress_weight_conf) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| if (compress_weight_conf.empty()) { | |||
| return SUCCESS; | |||
| } | |||
| std::string real_path = ge::parser::RealPath(compress_weight_conf.c_str()); | |||
| if (real_path.empty()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||
| {"compress_weight_conf", compress_weight_conf}); | |||
| GELOGE(PARAM_INVALID, "Can not get real path for %s.", compress_weight_conf.c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| std::ifstream ifs(real_path); | |||
| if (!ifs.is_open()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||
| {"compress_weight_conf", compress_weight_conf}); | |||
| GELOGE(FAILED, "Open file %s failed", compress_weight_conf.c_str()); | |||
| return FAILED; | |||
| } | |||
| std::string compress_nodes; | |||
| ifs >> compress_nodes; | |||
| ifs.close(); | |||
| if (compress_nodes.empty()) { | |||
| GELOGW("Compress weight of nodes info is empty"); | |||
| return SUCCESS; | |||
| } | |||
| GELOGI("Compress weight of nodes: %s", compress_nodes.c_str()); | |||
| vector<string> compress_node_vec = StringUtils::Split(compress_nodes, ';'); | |||
| for (size_t i = 0; i < compress_node_vec.size(); ++i) { | |||
| ge::NodePtr node = graph->FindNode(compress_node_vec[i]); | |||
| if (node == nullptr) { | |||
| GELOGW("Node %s is not in graph", compress_node_vec[i].c_str()); | |||
| continue; | |||
| } | |||
| auto op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) { | |||
| GELOGE(domi::FAILED, "Node %s SetBool failed.", compress_node_vec[i].c_str()); | |||
| return domi::FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name) { | |||
| output_nodes_name.clear(); | |||
| @@ -641,7 +594,7 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin | |||
| domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | |||
| const std::map<AscendString, AscendString> &parser_params) { | |||
| // support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, compress_weight_conf, | |||
| // support paragrams: input_fp16_nodes, is_input_adjust_hw_layout | |||
| ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| @@ -654,11 +607,6 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | |||
| ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS, | |||
| return PARAM_INVALID, "Parse input_fp16_nodes failed"); | |||
| string compress_weight_conf; | |||
| GetAclParams(parser_params, ge::ir_option::COMPRESS_WEIGHT_CONF, compress_weight_conf); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclWeightCompressConf(compute_graph, compress_weight_conf) != SUCCESS, | |||
| return PARAM_INVALID, "Parse compress_weight_conf failed"); | |||
| return SUCCESS; | |||
| } | |||
| @@ -784,7 +732,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co | |||
| google::protobuf::io::CodedInputStream coded_stream(&istream); | |||
| bool ret = ReadProtoFromCodedInputStream(coded_stream, proto); | |||
| fs.close(); | |||
| if (!ret) { | |||
| @@ -60,7 +60,6 @@ class AclGrphParseUtil { | |||
| uint32_t index, OpDescPtr &op_desc); | |||
| domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | |||
| const string &is_input_adjust_hw_layout); | |||
| domi::Status ParseAclWeightCompressConf(const ComputeGraphPtr &graph, const string &compress_weight_conf); | |||
| domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | |||
| std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||
| }; | |||
| @@ -47,11 +47,11 @@ class Pb2Json { | |||
| static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | |||
| bool enum2str = false); | |||
| protected: | |||
| static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||
| const ProtobufReflection *reflection, const std::set<std::string> &black_fields, | |||
| Json &json, bool enum2str); | |||
| protected: | |||
| static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, | |||
| bool enum2str, Json &json); | |||
| @@ -16,7 +16,6 @@ | |||
| #include "framework/omg/parser/parser_api.h" | |||
| #include "common/debug/log.h" | |||
| #include "tbe_plugin_loader.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "parser/common/register_tbe.h" | |||
| @@ -41,7 +40,7 @@ Status ParserInitialize(const std::map<std::string, std::string> &options) { | |||
| std::string fmk_type = std::to_string(domi::TENSORFLOW); | |||
| auto it = options.find(ge::FRAMEWORK_TYPE); | |||
| if (it != options.end()) { | |||
| fmk_type = it->second; | |||
| fmk_type = it->second; | |||
| } | |||
| std::vector<OpRegistrationData> registrationDatas = domi::OpRegistry::Instance()->registrationDatas; | |||
| GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); | |||
| @@ -28,6 +28,29 @@ | |||
| #include "register/op_registry.h" | |||
| namespace ge { | |||
| namespace { | |||
| Status HandleNewOp(const NodePtr &node, const ComputeGraphPtr &compute_graph, const NodePtr &new_node) { | |||
| GE_CHECK_NOTNULL(node); | |||
| GE_CHECK_NOTNULL(new_node); | |||
| if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { | |||
| GELOGE(FAILED, "Set owner graph for node:%s failed.", new_node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| auto op_desc = new_node->GetOpDesc(); | |||
| static std::atomic_long new_node_index(0); | |||
| auto new_name = "PartitionedCall_" + new_node->GetName() + "_" + to_string(new_node_index++); | |||
| op_desc->SetName(new_name); | |||
| bool ret = ge::AttrUtils::SetListStr(op_desc, | |||
| ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, | |||
| std::move(std::vector<std::string>{node->GetName()})); | |||
| if (!ret) { | |||
| GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), op_desc->GetName().c_str()); | |||
| } | |||
| GELOGD("Handle new op[%s] for node[%s] success.", new_node->GetName().c_str(), node->GetName().c_str()); | |||
| return SUCCESS; | |||
| } | |||
| } | |||
| Status ParserUtils::ExpandOneToManyGraph(Graph &graph) { | |||
| GELOGD("Begin run ParserUtils::ExpandOneToManyGraph."); | |||
| for (const auto &gn : graph.GetDirectNode()) { | |||
| @@ -68,17 +91,14 @@ Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &n | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| // add subgraph node to graph. | |||
| std::unordered_map<std::string, NodePtr> all_new_nodes; | |||
| std::vector<NodePtr> input_nodes; | |||
| for (const auto &n : sub_compute_graph->GetDirectNode()) { | |||
| auto new_node = compute_graph->AddNode(n); | |||
| GE_CHECK_NOTNULL(new_node); | |||
| all_new_nodes[new_node->GetName()] = new_node; | |||
| if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { | |||
| GELOGE(FAILED, "Set owner graph for node:%s failed.", new_node->GetName().c_str()); | |||
| if (HandleNewOp(node, compute_graph, new_node) != SUCCESS) { | |||
| GELOGE(FAILED, "Handle new op[%s] for node[%s] failed.", new_node->GetName().c_str(), node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| if (new_node->GetType() == ge::parser::DATA) { | |||
| input_nodes.emplace_back(new_node); | |||
| } | |||
| @@ -30,7 +30,6 @@ enum DataType | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| DT_VARIANT = 26; // variant type | |||
| } | |||
| message AttrDef | |||
| @@ -115,7 +115,6 @@ target_include_directories(fmk_onnx_parser_stub PRIVATE | |||
| ${PARSER_DIR}/parser | |||
| ${PARSER_DIR}/../inc | |||
| ${METADEF_DIR}/inc | |||
| ${METADEF_DIR}/inc/graph | |||
| ${METADEF_DIR}/inc/external | |||
| ${METADEF_DIR}/inc/external/graph | |||
| ) | |||
| @@ -52,7 +52,7 @@ LOCAL_SHARED_LIBRARIES := \ | |||
| libregister \ | |||
| liberror_manager \ | |||
| LOCAL_STATIC_LIBRARIES += libmmpa | |||
| LOCAL_STATIC_LIBRARIES += libmmpa | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| @@ -62,7 +62,6 @@ include $(BUILD_HOST_SHARED_LIBRARY) | |||
| include $(CLEAR_VARS) | |||
| LOCAL_C_INCLUDES := \ | |||
| $(TOPDIR)inc \ | |||
| $(TOPDIR)metadef/inc \ | |||
| $(TOPDIR)parser/inc \ | |||
| $(TOPDIR)inc/external \ | |||
| @@ -88,4 +87,3 @@ LOCAL_SHARED_LIBRARIES := | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||
| @@ -19,8 +19,8 @@ | |||
| #include <iostream> | |||
| #include "common/convert/pb2json.h" | |||
| #include "common/util.h" | |||
| #include "common/ge_types.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "common/ge_types.h" | |||
| #include "external/graph/operator_factory.h" | |||
| #include "external/register/register_error_codes.h" | |||
| #include "external/parser/onnx_parser.h" | |||
| @@ -39,17 +39,18 @@ | |||
| #include "register/op_registry.h" | |||
| namespace ge { | |||
| graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | |||
| const std::map<AscendString, AscendString> &parser_params, | |||
| ge::Graph &graph, std::shared_ptr<domi::ModelParser> &model_parser) { | |||
| graphStatus aclgrphParseONNX(const char *model_file, | |||
| const std::map<AscendString, | |||
| AscendString> &parser_params, | |||
| ge::Graph &graph) { | |||
| GE_CHECK_NOTNULL(model_file); | |||
| GetParserContext().type = domi::ONNX; | |||
| std::map<string, string> options; | |||
| options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); | |||
| if (acl_graph_parse_util.AclParserInitialize(options) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Acl parser initialize failed."); | |||
| return ge::FAILED; | |||
| } | |||
| // load custom plugin so and proto | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| (void)acl_graph_parse_util.AclParserInitialize(options); | |||
| string output_name; | |||
| if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { | |||
| @@ -62,40 +63,9 @@ graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||
| model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); | |||
| GE_CHECK_NOTNULL(model_parser); | |||
| return ge::SUCCESS; | |||
| } | |||
| graphStatus HandleAfterParse(AclGrphParseUtil &acl_graph_parse_util, | |||
| const std::map<AscendString, AscendString> &parser_params, | |||
| ge::Graph &graph) { | |||
| if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params after graph failed."); | |||
| return ge::FAILED; | |||
| } | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| return ge::SUCCESS; | |||
| } | |||
| graphStatus aclgrphParseONNX(const char *model_file, | |||
| const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | |||
| #ifndef ONLY_COMPILE_OPEN_SRC | |||
| GE_CHECK_NOTNULL(model_file); | |||
| // load custom plugin so and proto | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| std::shared_ptr<domi::ModelParser> model_parser; | |||
| if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Prepare before parse failed."); | |||
| return ge::FAILED; | |||
| } | |||
| GE_CHECK_NOTNULL(model_parser); | |||
| // parse caffe model_file to GE graph | |||
| ge::graphStatus ret = model_parser->Parse(model_file, graph); | |||
| if (ret != ge::SUCCESS) { | |||
| @@ -104,44 +74,65 @@ graphStatus aclgrphParseONNX(const char *model_file, | |||
| } | |||
| GELOGI("Parser graph %s success.", graph.GetName().c_str()); | |||
| if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Handle after parse failed."); | |||
| if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params after graph failed."); | |||
| return ge::FAILED; | |||
| } | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||
| #endif | |||
| return ge::SUCCESS; | |||
| } | |||
| graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, | |||
| const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | |||
| #ifndef ONLY_COMPILE_OPEN_SRC | |||
| graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t buffer_size, | |||
| const std::map<AscendString, | |||
| AscendString> &parser_params, | |||
| ge::Graph &graph) { | |||
| GE_CHECK_NOTNULL(buffer); | |||
| GetParserContext().type = domi::ONNX; | |||
| std::map<string, string> options; | |||
| options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); | |||
| // load custom plugin so and proto | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| std::shared_ptr<domi::ModelParser> model_parser; | |||
| (void)acl_graph_parse_util.AclParserInitialize(options); | |||
| if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Prepare before parse failed."); | |||
| string output_name; | |||
| if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params before graph failed."); | |||
| return ge::FAILED; | |||
| } | |||
| // Create an empty computegraph | |||
| string graph_name = output_name.empty() ? "tmpGraph" : output_name; | |||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name); | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| // parse caffe model_file to GE graph | |||
| ge::graphStatus ret = model_parser->ParseFromMemory(buffer, (uint32_t)size, graph); | |||
| graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); | |||
| GE_CHECK_NOTNULL(model_parser); | |||
| // parse caffe model_file and weights_file to GE graph | |||
| ge::graphStatus ret = model_parser->ParseFromMemory(buffer, (uint32_t)buffer_size, graph); | |||
| if (ret != ge::SUCCESS) { | |||
| GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| GELOGI("Parser graph %s success.", graph.GetName().c_str()); | |||
| if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Handle after parse failed."); | |||
| if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params after graph failed."); | |||
| return ge::FAILED; | |||
| } | |||
| GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||
| #endif | |||
| return ge::SUCCESS; | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||
| return ge::SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -159,7 +150,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, | |||
| GELOGE(FAILED, "Onnx graph has zero input"); | |||
| return FAILED; | |||
| } | |||
| // get input value info map | |||
| std::map<std::string, ge::onnx::TensorProto> input_name_tensor; | |||
| for (int i = 0; i < onnx_graph.input_size(); i++) { | |||
| @@ -173,7 +163,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, | |||
| initializer_name_tensor.erase(initializer_iter); | |||
| continue; | |||
| } | |||
| ge::onnx::TensorProto tensor_tmp; | |||
| if (value_info.has_type()) { | |||
| const ge::onnx::TypeProto type = value_info.type(); | |||
| @@ -194,7 +183,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, | |||
| } | |||
| input_name_tensor[value_info.name()] = tensor_tmp; | |||
| } | |||
| // Construct node for input | |||
| int64_t index = 0; | |||
| for (auto it : input_name_tensor) { | |||
| @@ -350,9 +338,11 @@ Status OnnxModelParser::SetOperatorInputs() { | |||
| for (auto in_iter = inputs_map_.begin(); in_iter != inputs_map_.end(); in_iter++) { | |||
| auto out_iter = outputs_map_.find(in_iter->first); | |||
| if (out_iter == outputs_map_.end()) { | |||
| GELOGE(INTERNAL_ERROR, "Unknown input: %s:%d in node: %s", in_iter->first.c_str(), in_iter->second[0].second, | |||
| GELOGW("Unknown input: %s:%d for node: %s, which maybe option input.", | |||
| in_iter->first.c_str(), | |||
| in_iter->second[0].second, | |||
| in_iter->second[0].first.c_str()); | |||
| return INTERNAL_ERROR; | |||
| continue; | |||
| } | |||
| std::vector<std::pair<std::string, int>> &input_node_indexs = in_iter->second; | |||
| @@ -511,11 +501,10 @@ Status OnnxModelParser::GetGraphInputs(std::vector<ge::Operator> &input_ops) { | |||
| input_ops.emplace_back(in_op->second); | |||
| GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) { | |||
| Status OnnxModelParser::GetModelFromfile(const char *file, ge::onnx::ModelProto &onnx_model) { | |||
| GE_CHECK_NOTNULL(file); | |||
| GELOGI("File path is %s.", file); | |||
| @@ -529,20 +518,18 @@ Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto | |||
| return SUCCESS; | |||
| } | |||
| #ifndef ONLY_COMPILE_OPEN_SRC | |||
| Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) { | |||
| GE_CHECK_NOTNULL(data); | |||
| // 1. Get graph from onnx model file. | |||
| if (!ge::parser::ReadProtoFromArray(data, size, &onnx_model)) { | |||
| // 1. Get graph from memory. | |||
| if (!ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &onnx_model)) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E19021", {"reason"}, {"Read onnx model from memory failed."}); | |||
| GELOGE(PARAM_INVALID, "Read onnx model from memory failed."); | |||
| "E19021", {"reason"}, {"Read onnx model file failed."}); | |||
| GELOGE(PARAM_INVALID, "Read onnx model file failed."); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| #endif | |||
| Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) { | |||
| if (!onnx_model.has_graph()) { | |||
| @@ -551,13 +538,11 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||
| return FAILED; | |||
| } | |||
| ge::onnx::GraphProto onnx_graph = onnx_model.graph(); | |||
| auto opset_import = onnx_model.opset_import(); | |||
| for (auto it : opset_import) { | |||
| domain_verseion_[it.domain()] = it.version(); | |||
| GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version()); | |||
| } | |||
| // 2. Get all inializer. | |||
| std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; | |||
| for (int i = 0; i < onnx_graph.initializer_size(); i++) { | |||
| @@ -567,7 +552,6 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||
| GELOGI("Initializer name: %s .", initializer_tensor.name().c_str()); | |||
| } | |||
| } | |||
| // 3. Parse Input from graph. | |||
| GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); | |||
| Status ret = ParseInput(onnx_graph, initializer_name_tensor); | |||
| @@ -576,21 +560,18 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||
| return ret; | |||
| } | |||
| GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); | |||
| // 4. Parse Constant from graph. | |||
| ret = ParseInitializer(onnx_graph, initializer_name_tensor); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Parse initializer for onnx failed."); | |||
| return ret; | |||
| } | |||
| // 5. Update node name for node do not has name. | |||
| ret = UpdateAllNodeName(onnx_graph); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Update all node name for onnx failed."); | |||
| return ret; | |||
| } | |||
| // 6 Precheck. | |||
| ret = Prechecker(onnx_graph); | |||
| bool is_precheck_failed = (ret != SUCCESS) || (ge::PreChecker::Instance().HasError()); | |||
| @@ -624,7 +605,6 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||
| // 9. Construct graph. | |||
| std::vector<ge::Operator> input_ops; | |||
| ret = GetGraphInputs(input_ops); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Get graph inputs failed."); | |||
| @@ -642,35 +622,33 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||
| Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { | |||
| ge::onnx::ModelProto onnx_model; | |||
| Status ret = GetModelFromFile(file, onnx_model); | |||
| Status ret = GetModelFromfile(file, onnx_model); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(FAILED, "get model from file failed."); | |||
| return FAILED; | |||
| GELOGE(ret, "Get model from file failed."); | |||
| return ret; | |||
| } | |||
| ret = ModelParseToGraph(onnx_model, graph); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(FAILED, "parse model failed."); | |||
| return FAILED; | |||
| GELOGE(ret, "Parse model failed."); | |||
| return ret; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| #ifndef ONLY_COMPILE_OPEN_SRC | |||
| Status OnnxModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) { | |||
| ge::onnx::ModelProto onnx_model; | |||
| Status ret = GetModelFromMemory(data, size, onnx_model); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(FAILED, "get model from file failed."); | |||
| return FAILED; | |||
| GELOGE(ret, "Get model from memory failed."); | |||
| return ret; | |||
| } | |||
| ret = ModelParseToGraph(onnx_model, graph); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(FAILED, "parse model failed."); | |||
| return FAILED; | |||
| GELOGE(ret, "Parse model failed."); | |||
| return ret; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| #endif | |||
| Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) { | |||
| if (model_file == nullptr) { | |||
| @@ -700,4 +678,4 @@ ge::DataType OnnxModelParser::ConvertToGeDataType(const uint32_t type) { | |||
| namespace domi { | |||
| REGISTER_MODEL_PARSER_CREATOR(ONNX, ge::OnnxModelParser); | |||
| REGISTER_WEIGHTS_PARSER_CREATOR(ONNX, ge::OnnxWeightsParser); | |||
| } | |||
| } | |||
| @@ -38,11 +38,11 @@ class OnnxModelParser : public domi::ModelParser { | |||
| ge::DataType ConvertToGeDataType(const uint32_t type) override; | |||
| Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override { return domi::SUCCESS; } | |||
| #ifndef ONLY_COMPILE_OPEN_SRC | |||
| Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override; | |||
| #endif | |||
| Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override { | |||
| return domi::SUCCESS; | |||
| } | |||
| Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override { | |||
| return domi::SUCCESS; | |||
| @@ -81,12 +81,10 @@ class OnnxModelParser : public domi::ModelParser { | |||
| Status GetGraphInputs(std::vector<ge::Operator> &input_ops); | |||
| Status Prechecker(ge::onnx::GraphProto &onnx_graph); | |||
| Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model); | |||
| #ifndef ONLY_COMPILE_OPEN_SRC | |||
| Status GetModelFromfile(const char *file, ge::onnx::ModelProto &onnx_model); | |||
| Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model); | |||
| #endif | |||
| Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph); | |||
| @@ -30,7 +30,6 @@ enum DataType | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| DT_VARIANT = 26; // variant type | |||
| } | |||
| message AttrDef | |||
| @@ -30,7 +30,6 @@ enum DataType | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| DT_VARIANT = 26; // variant type | |||
| } | |||
| message AttrDef | |||
| @@ -721,23 +721,15 @@ Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) { | |||
| GELOGD("Start add contorl edge: from %s to %s.", src->GetName().c_str(), dest->GetName().c_str()); | |||
| ge::InControlAnchorPtr in_archor_ptr = dest->GetInControlAnchor(); | |||
| GE_CHECK_NOTNULL(in_archor_ptr); | |||
| GE_IF_BOOL_EXEC(nodedef_map_[src_op_name]->op() != TENSORFLOWF_NODE_OP_SWITCH, | |||
| ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor(); | |||
| GE_CHECK_NOTNULL(out_archor_ptr); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E12014", {"opname1", "opname2"}, | |||
| {src->GetName(), dest->GetName()}); | |||
| return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), | |||
| dest->GetName().c_str());); | |||
| GE_IF_BOOL_EXEC(nodedef_map_[src_op_name]->op() == TENSORFLOWF_NODE_OP_SWITCH, | |||
| ge::OutDataAnchorPtr out_data_archor_ptr = src->GetOutDataAnchor(outputpair.first); | |||
| GE_CHECK_NOTNULL(out_data_archor_ptr); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| ge::GraphUtils::AddEdge(out_data_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E12014", {"opname1", "opname2"}, | |||
| {src->GetName(), dest->GetName()}); | |||
| return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), | |||
| dest->GetName().c_str());); | |||
| ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor(); | |||
| GE_CHECK_NOTNULL(out_archor_ptr); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E12014", {"opname1", "opname2"}, | |||
| {src->GetName(), dest->GetName()}); | |||
| return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), | |||
| dest->GetName().c_str() | |||
| ); | |||
| } | |||
| } | |||
| dest_input_map.erase(input_iter); | |||
| @@ -3221,7 +3213,7 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||
| } | |||
| // 2.4 remove the input const nodes | |||
| Status ret = RemoveInputs(current_node, unused_inputs); | |||
| Status ret = RemoveInputs(graph_def, current_node, unused_inputs, all_nodedef_map); | |||
| if (ret != SUCCESS) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E12006", {"opname"}, {current_op_name}); | |||
| GELOGE(INTERNAL_ERROR, "Op[%s] remove input failed.", current_op_name.c_str()); | |||
| @@ -3232,6 +3224,34 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||
| return SUCCESS; | |||
| } | |||
| Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def, | |||
| domi::tensorflow::NodeDef *node_def, | |||
| const map<string, NodeDef *> &all_node_map, | |||
| const vector<string> &removed_inputs_vec) { | |||
| GE_CHECK_NOTNULL(graph_def); | |||
| GE_CHECK_NOTNULL(node_def); | |||
| for (const auto &remove_input : removed_inputs_vec) { | |||
| string input_node_name = NodeNameFromInput(remove_input); | |||
| auto it = all_node_map.find(input_node_name); | |||
| if (it == all_node_map.end()) { | |||
| GELOGE(FAILED, "Can not find node name:%s in all node map.", input_node_name.c_str()); | |||
| return FAILED; | |||
| } | |||
| NodeDef *input_node_def = it->second; | |||
| if (input_node_def->op() == SWITCH || input_node_def->op() == REFSWITCH) { | |||
| NodeDef *identity_node_def = graph_def->add_node(); | |||
| GE_CHECK_NOTNULL(identity_node_def); | |||
| input_node_name = input_node_name + "identity"; | |||
| identity_node_def->set_name(input_node_name); | |||
| identity_node_def->set_op(IDENTITY); | |||
| identity_node_def->add_input(remove_input); | |||
| } | |||
| string control_input = "^" + input_node_name; | |||
| node_def->add_input(control_input); | |||
| GELOGD("Add control input:%s for node:%s", control_input.c_str(), node_def->name().c_str()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| /** | |||
| * @ingroup domi_omg | |||
| * @brief Delete input from nodedef | |||
| @@ -3241,7 +3261,10 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||
| * @return false remove failed | |||
| * | |||
| */ | |||
| Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, const set<uint32_t> &remove_index_set) { | |||
| Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, | |||
| domi::tensorflow::NodeDef *node_def, | |||
| const set<uint32_t> &remove_index_set, | |||
| const map<string, NodeDef *> &all_node_map) { | |||
| GE_CHECK_NOTNULL(node_def); | |||
| if (remove_index_set.empty()) { | |||
| GELOGI("The size of remove_index_set is zero."); | |||
| @@ -3258,6 +3281,7 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, | |||
| RemoveInputAttr(node_def, remove_inputs_map); | |||
| int index = 0; | |||
| vector<string> removed_inputs_vec; | |||
| auto *inputs = node_def->mutable_input(); | |||
| for (auto input_it = inputs->begin(); input_it != inputs->end(); ++index) { | |||
| // 1.decide whether to remove the input | |||
| @@ -3269,6 +3293,7 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, | |||
| std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) { | |||
| GELOGD("Remove input:%s, index:%d", remove_input_name.c_str(), index); | |||
| flag = true; | |||
| removed_inputs_vec.emplace_back(remove_input_name); | |||
| break; | |||
| } | |||
| } | |||
| @@ -3281,6 +3306,11 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, | |||
| } | |||
| } | |||
| Status ret = AddControlEdgeAfterRemoveInputs(graph_def, node_def, all_node_map, removed_inputs_vec); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(FAILED, "Add control edges for node:%s failed.", node_def->name().c_str()); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -86,26 +86,22 @@ class TensorFlowModelParser : public domi::ModelParser { | |||
| * @param [in|out] graph save model information after parsing | |||
| * @return SUCCESS parse successfully | |||
| * @return FAILED parse failed | |||
| */ | |||
| Status Parse(const char *file, ge::Graph &graph) override; | |||
| /** | |||
| * @ingroup domi_omg | |||
| * @brief Parse the relevant data from memory and save it to graph | |||
| * @param [in] memory buffer of model file | |||
| * @param [in] buffer size | |||
| * @param [in|out] graph graph for saving model information | |||
| * @brief Parse the relevant data from the memory and save it to graph | |||
| * @param [in] file Path of the model file | |||
| * @param [in|out] graph save model information after parsing | |||
| * @return SUCCESS parse successfully | |||
| * @return FAILED parse failed | |||
| */ | |||
| Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | |||
| #ifndef ONLY_COMPILE_OPEN_SRC | |||
| Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override { | |||
| return domi::SUCCESS; | |||
| } | |||
| #endif | |||
| /** | |||
| * @ingroup domi_omg | |||
| @@ -541,7 +537,15 @@ class TensorFlowModelParser : public domi::ModelParser { | |||
| * @return false remove failed | |||
| * | |||
| */ | |||
| Status RemoveInputs(domi::tensorflow::NodeDef *node_def, const set<uint32_t> &remove_index_set); | |||
| Status RemoveInputs(domi::tensorflow::GraphDef *graph_def, | |||
| domi::tensorflow::NodeDef *node_def, | |||
| const set<uint32_t> &remove_index_set, | |||
| const map<string, NodeDef *> &all_node_map); | |||
| Status AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def, | |||
| domi::tensorflow::NodeDef *node_def, | |||
| const map<string, NodeDef *> &all_node_map, | |||
| const vector<string> &removed_inputs_vec); | |||
| void RemoveInputAttr(domi::tensorflow::NodeDef *node_def, const map<string, vector<int>> &remove_inputs_map); | |||