| @@ -21,12 +21,16 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/ascend_string.h" | |||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| #include "graph/types.h" | |||||
| #include "graph/graph.h" | #include "graph/graph.h" | ||||
| #include "graph/types.h" | |||||
| namespace ge { | namespace ge { | ||||
| graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph); | graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph); | ||||
| graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||||
| const std::map<ge::AscendString, ge::AscendString> &parser_params, ge::Graph &graph); | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ | #endif // INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ | ||||
| @@ -22,12 +22,15 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/ascend_string.h" | |||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| #include "graph/types.h" | |||||
| #include "graph/graph.h" | #include "graph/graph.h" | ||||
| #include "graph/types.h" | |||||
| namespace ge { | namespace ge { | ||||
| graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph); | graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph); | ||||
| graphStatus aclgrphParseTensorFlow(const char *model_file, | |||||
| const std::map<ge::AscendString, ge::AscendString> &parser_params, ge::Graph &graph); | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ | #endif // INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit 9bbce07b846858fa30ef2bd7c662894e20f83ef1 | |||||
| Subproject commit 37465b85d30b67a0edcc6ea4acd2f11a9697c7af | |||||
| @@ -108,12 +108,67 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); | GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); | ||||
| if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) { | |||||
| std::map<AscendString, AscendString> parser_params; | |||||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||||
| GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); | GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); | ||||
| return ge::FAILED; | return ge::FAILED; | ||||
| } | } | ||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||||
| const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | |||||
| GE_CHECK_NOTNULL(model_file); | |||||
| GetParserContext().type = domi::CAFFE; | |||||
| std::map<string, string> options; | |||||
| options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::CAFFE))); | |||||
| // load custom plugin so and proto | |||||
| AclGrphParseUtil acl_graph_parse_util; | |||||
| (void)acl_graph_parse_util.AclParserInitialize(options); | |||||
| string output_name; | |||||
| if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { | |||||
| GELOGE(ge::FAILED, "Parser params before graph failed."); | |||||
| return ge::FAILED; | |||||
| } | |||||
| // Create an empty computegraph | |||||
| string graph_name = output_name.empty() ? "tmpGraph" : output_name; | |||||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); | |||||
| GE_CHECK_NOTNULL(model_parser); | |||||
| // parse caffe model_file and weights_file to GE graph | |||||
| ge::graphStatus ret = model_parser->Parse(model_file, graph); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); | |||||
| return ge::FAILED; | |||||
| } | |||||
| GELOGI("Parser graph %s success.", graph.GetName().c_str()); | |||||
| if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||||
| GELOGE(ge::FAILED, "Parser params after graph failed."); | |||||
| return ge::FAILED; | |||||
| } | |||||
| auto weights_parser = domi::WeightsParserFactory::Instance()->CreateWeightsParser(domi::CAFFE); | |||||
| ret = weights_parser->Parse(weights_file, graph); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(ret, "Weights parse failed. graph: %s", graph.GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); | |||||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||||
| GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||||
| return ge::FAILED; | |||||
| } | |||||
| GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -17,14 +17,18 @@ | |||||
| #ifndef ACL_GRAPH_PARSE_UTIL_ | #ifndef ACL_GRAPH_PARSE_UTIL_ | ||||
| #define ACL_GRAPH_PARSE_UTIL_ | #define ACL_GRAPH_PARSE_UTIL_ | ||||
| #include <map> | |||||
| #include <string> | |||||
| #include <google/protobuf/text_format.h> | #include <google/protobuf/text_format.h> | ||||
| #include <map> | |||||
| #include <sstream> | #include <sstream> | ||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
| #include "register/register_error_codes.h" | |||||
| #include "graph/ascend_string.h" | |||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "register/register_error_codes.h" | |||||
| namespace ge { | namespace ge { | ||||
| @@ -37,13 +41,39 @@ class AclGrphParseUtil { | |||||
| domi::Status LoadOpsProtoLib(); | domi::Status LoadOpsProtoLib(); | ||||
| void SaveCustomCaffeProtoPath(); | void SaveCustomCaffeProtoPath(); | ||||
| domi::Status AclParserInitialize(const std::map<std::string, std::string> &options); | domi::Status AclParserInitialize(const std::map<std::string, std::string> &options); | ||||
| domi::Status SetDefaultOutputNode(ge::Graph &graph); | |||||
| domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | |||||
| domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | |||||
| std::string &graph_name); | |||||
| domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | |||||
| domi::Status ParseOutputInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | |||||
| private: | private: | ||||
| bool parser_initialized = false; | bool parser_initialized = false; | ||||
| domi::Status CheckOptions(const std::map<AscendString, AscendString> &parser_params); | |||||
| domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | ||||
| void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | ||||
| std::vector<std::string> &output_nodes_name); | std::vector<std::string> &output_nodes_name); | ||||
| domi::Status ParseAclLogLevel(const std::string &log); | |||||
| bool CheckAclInputFormat(string &input_format); | |||||
| domi::Status ParseAclFormat(std::string &input_format); | |||||
| bool ParseInputShape(const std::string &input_shape, std::unordered_map<std::string, vector<int64_t>> &shape_map, | |||||
| vector<pair<std::string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input); | |||||
| domi::Status ParseAclShape(const std::string &input_shape, bool is_dynamic_input); | |||||
| domi::Status ParseAclOutputNodes(const std::string &out_nodes); | |||||
| domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16); | |||||
| domi::Status ParseAclOpConf(const std::string &op_conf); | |||||
| domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes); | |||||
| static void AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, const string &fp16_nodes_name, | |||||
| uint32_t index, OpDescPtr &op_desc); | |||||
| domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | |||||
| const string &is_input_adjust_hw_layout); | |||||
| domi::Status ParseAclWeightCompressConf(const ComputeGraphPtr &graph, const string &compress_weight_conf); | |||||
| domi::Status ParseAclOutputType(const std::string &output_type, | |||||
| std::map<std::string, vector<std::string>> &output_node_dt_map); | |||||
| domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | |||||
| std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||||
| domi::Status CheckAclInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input); | |||||
| domi::Status CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf); | |||||
| }; | }; | ||||
| namespace parser { | namespace parser { | ||||
| @@ -113,13 +113,59 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { | |||||
| return ge::FAILED; | return ge::FAILED; | ||||
| } | } | ||||
| if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) { | |||||
| std::map<AscendString, AscendString> parser_params; | |||||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||||
| GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); | GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); | ||||
| return ge::FAILED; | return ge::FAILED; | ||||
| } | } | ||||
| GELOGI("Parser graph %s success.", graph.GetName().c_str()); | GELOGI("Parser graph %s success.", graph.GetName().c_str()); | ||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<AscendString, AscendString> &parser_params, | |||||
| ge::Graph &graph) { | |||||
| GE_CHECK_NOTNULL(model_file); | |||||
| GetParserContext().type = domi::TENSORFLOW; | |||||
| std::map<string, string> options; | |||||
| options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::TENSORFLOW))); | |||||
| // load custom plugin so and proto | |||||
| AclGrphParseUtil acl_graph_parse_util; | |||||
| (void)acl_graph_parse_util.AclParserInitialize(options); | |||||
| string output_name; | |||||
| if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { | |||||
| GELOGE(ge::FAILED, "Parser params before graph failed."); | |||||
| return ge::FAILED; | |||||
| } | |||||
| // Create an empty computegraph | |||||
| string graph_name = output_name.empty() ? "tmpGraph" : output_name; | |||||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); | |||||
| GE_CHECK_NOTNULL(model_parser); | |||||
| // parse tensorflow model_file to GE graph | |||||
| ge::graphStatus ret = model_parser->Parse(model_file, graph); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); | |||||
| return ge::FAILED; | |||||
| } | |||||
| if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||||
| GELOGE(ge::FAILED, "Parser params after graph failed."); | |||||
| return ge::FAILED; | |||||
| } | |||||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||||
| GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||||
| return ge::FAILED; | |||||
| } | |||||
| GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| namespace ge { | namespace ge { | ||||