| @@ -21,12 +21,16 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "graph/ascend_string.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/types.h" | |||
| #include "graph/graph.h" | |||
| #include "graph/types.h" | |||
| namespace ge { | |||
| graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph); | |||
| graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||
| const std::map<ge::AscendString, ge::AscendString> &parser_params, ge::Graph &graph); | |||
| } // namespace ge | |||
| #endif // INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ | |||
| @@ -22,12 +22,15 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "graph/ascend_string.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/types.h" | |||
| #include "graph/graph.h" | |||
| #include "graph/types.h" | |||
| namespace ge { | |||
| graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph); | |||
| graphStatus aclgrphParseTensorFlow(const char *model_file, | |||
| const std::map<ge::AscendString, ge::AscendString> &parser_params, ge::Graph &graph); | |||
| } // namespace ge | |||
| #endif // INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ | |||
| @@ -1 +1 @@ | |||
| Subproject commit 9bbce07b846858fa30ef2bd7c662894e20f83ef1 | |||
| Subproject commit 37465b85d30b67a0edcc6ea4acd2f11a9697c7af | |||
| @@ -108,12 +108,67 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||
| return ret; | |||
| } | |||
| GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); | |||
| if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) { | |||
| std::map<AscendString, AscendString> parser_params; | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| return ge::SUCCESS; | |||
| } | |||
| graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||
| const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | |||
| GE_CHECK_NOTNULL(model_file); | |||
| GetParserContext().type = domi::CAFFE; | |||
| std::map<string, string> options; | |||
| options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::CAFFE))); | |||
| // load custom plugin so and proto | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| (void)acl_graph_parse_util.AclParserInitialize(options); | |||
| string output_name; | |||
| if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params before graph failed."); | |||
| return ge::FAILED; | |||
| } | |||
| // Create an empty computegraph | |||
| string graph_name = output_name.empty() ? "tmpGraph" : output_name; | |||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name); | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); | |||
| GE_CHECK_NOTNULL(model_parser); | |||
| // parse caffe model_file and weights_file to GE graph | |||
| ge::graphStatus ret = model_parser->Parse(model_file, graph); | |||
| if (ret != ge::SUCCESS) { | |||
| GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| GELOGI("Parser graph %s success.", graph.GetName().c_str()); | |||
| if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params after graph failed."); | |||
| return ge::FAILED; | |||
| } | |||
| auto weights_parser = domi::WeightsParserFactory::Instance()->CreateWeightsParser(domi::CAFFE); | |||
| ret = weights_parser->Parse(weights_file, graph); | |||
| if (ret != ge::SUCCESS) { | |||
| GELOGE(ret, "Weights parse failed. graph: %s", graph.GetName().c_str()); | |||
| return ret; | |||
| } | |||
| GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||
| return ge::SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -17,14 +17,18 @@ | |||
| #ifndef ACL_GRAPH_PARSE_UTIL_ | |||
| #define ACL_GRAPH_PARSE_UTIL_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <google/protobuf/text_format.h> | |||
| #include <map> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "framework/omg/parser/parser_types.h" | |||
| #include "register/register_error_codes.h" | |||
| #include "graph/ascend_string.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "register/register_error_codes.h" | |||
| namespace ge { | |||
| @@ -37,13 +41,39 @@ class AclGrphParseUtil { | |||
| domi::Status LoadOpsProtoLib(); | |||
| void SaveCustomCaffeProtoPath(); | |||
| domi::Status AclParserInitialize(const std::map<std::string, std::string> &options); | |||
| domi::Status SetDefaultOutputNode(ge::Graph &graph); | |||
| domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | |||
| domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | |||
| std::string &graph_name); | |||
| domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | |||
| domi::Status ParseOutputInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | |||
| private: | |||
| bool parser_initialized = false; | |||
| domi::Status CheckOptions(const std::map<AscendString, AscendString> &parser_params); | |||
| domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||
| void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name); | |||
| domi::Status ParseAclLogLevel(const std::string &log); | |||
| bool CheckAclInputFormat(string &input_format); | |||
| domi::Status ParseAclFormat(std::string &input_format); | |||
| bool ParseInputShape(const std::string &input_shape, std::unordered_map<std::string, vector<int64_t>> &shape_map, | |||
| vector<pair<std::string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input); | |||
| domi::Status ParseAclShape(const std::string &input_shape, bool is_dynamic_input); | |||
| domi::Status ParseAclOutputNodes(const std::string &out_nodes); | |||
| domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16); | |||
| domi::Status ParseAclOpConf(const std::string &op_conf); | |||
| domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes); | |||
| static void AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, const string &fp16_nodes_name, | |||
| uint32_t index, OpDescPtr &op_desc); | |||
| domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | |||
| const string &is_input_adjust_hw_layout); | |||
| domi::Status ParseAclWeightCompressConf(const ComputeGraphPtr &graph, const string &compress_weight_conf); | |||
| domi::Status ParseAclOutputType(const std::string &output_type, | |||
| std::map<std::string, vector<std::string>> &output_node_dt_map); | |||
| domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | |||
| std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||
| domi::Status CheckAclInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input); | |||
| domi::Status CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf); | |||
| }; | |||
| namespace parser { | |||
| @@ -113,13 +113,59 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { | |||
| return ge::FAILED; | |||
| } | |||
| if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) { | |||
| std::map<AscendString, AscendString> parser_params; | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| GELOGI("Parser graph %s success.", graph.GetName().c_str()); | |||
| return ge::SUCCESS; | |||
| } | |||
| graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<AscendString, AscendString> &parser_params, | |||
| ge::Graph &graph) { | |||
| GE_CHECK_NOTNULL(model_file); | |||
| GetParserContext().type = domi::TENSORFLOW; | |||
| std::map<string, string> options; | |||
| options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::TENSORFLOW))); | |||
| // load custom plugin so and proto | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| (void)acl_graph_parse_util.AclParserInitialize(options); | |||
| string output_name; | |||
| if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params before graph failed."); | |||
| return ge::FAILED; | |||
| } | |||
| // Create an empty computegraph | |||
| string graph_name = output_name.empty() ? "tmpGraph" : output_name; | |||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name); | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); | |||
| GE_CHECK_NOTNULL(model_parser); | |||
| // parse tensorflow model_file to GE graph | |||
| ge::graphStatus ret = model_parser->Parse(model_file, graph); | |||
| if (ret != ge::SUCCESS) { | |||
| GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params after graph failed."); | |||
| return ge::FAILED; | |||
| } | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||
| return ge::SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| namespace ge { | |||