diff --git a/atc/CMakeLists.txt b/atc/CMakeLists.txt index 4fa9f8b..acbea71 100644 --- a/atc/CMakeLists.txt +++ b/atc/CMakeLists.txt @@ -54,7 +54,6 @@ target_include_directories(atc_atc.bin PRIVATE ${METADEF_DIR}/third_party/graphengine/inc/ ${METADEF_DIR}/third_party/graphengine/inc/external ${METADEF_DIR}/third_party/graphengine/inc/framework - ${METADEF_DIR}/third_party/graphengine/ge ${METADEF_DIR}/third_party/fwkacllib/inc ${PARSER_DIR}/third_party/graphengine/inc ) @@ -123,7 +122,6 @@ target_include_directories(fwk_atc.bin PRIVATE ${METADEF_DIR}/third_party/graphengine/inc/ ${METADEF_DIR}/third_party/graphengine/inc/external ${METADEF_DIR}/third_party/graphengine/inc/framework - ${METADEF_DIR}/third_party/graphengine/ge ${METADEF_DIR}/third_party/fwkacllib/inc ${PARSER_DIR}/third_party/graphengine/inc/ diff --git a/atc/atc_ir_common.cc b/atc/atc_ir_common.cc index 0114fb7..4e6bcfa 100644 --- a/atc/atc_ir_common.cc +++ b/atc/atc_ir_common.cc @@ -165,6 +165,7 @@ bool CheckDynamicImageSizeShape(const vector &shape, const string &data return false; } } + bool CheckDynamicImagesizeInputShapeValid(map> shape_map, const std::string input_format, std::string &dynamic_image_size) { if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(input_format.c_str())) { diff --git a/atc/main.cc b/atc/main.cc index 620b18f..4847337 100644 --- a/atc/main.cc +++ b/atc/main.cc @@ -63,19 +63,18 @@ using std::shared_ptr; using std::string; using std::vector; +namespace { static bool is_dynamic_input = false; - const char *const kModeSupport = "only support 0(model to framework model), " "1(framework model to json), 3(only pre-check), " "5(pbtxt to json), 6(display model info)"; const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow) 5(Onnx)"; - -static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; -static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; -static const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model"; - +const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; +const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; +const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model"; // limit available mem size 2G const long kMinAvailableMem = 2097152; // 2 * 1024 * 1024 +} // namespace DEFINE_string(model, "", "The model file."); DEFINE_string(output, "", "The output file path&name."); @@ -523,12 +522,12 @@ class GFlagUtils { static bool CheckEncryptModeValid(const int encrypt_mode) { #if !defined(__ANDROID__) && !defined(ANDROID) if (encrypt_mode != 0 && encrypt_mode != -1) { - GELOGE(ge::FAILED, "encrypt mode must be 0 or -1"); + GELOGE(ge::FAILED, "encrypt mode must be 0 or -1"); return false; } #else if (encrypt_mode != -1) { - GELOGE(ge::FAILED, "encrypt mode must be -1"); + GELOGE(ge::FAILED, "encrypt mode must be -1"); return false; } #endif @@ -543,14 +542,14 @@ class GFlagUtils { ErrorManager::GetInstance().ATCReportErrMessage( "E10007", {"parameter", "support"}, {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)"}); - GELOGE(ge::FAILED, "Input parameter[--framework] is mandatory and it's value must be: " - "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)."); + GELOGE(ge::FAILED, "Input parameter[--framework] is mandatory and it's value must be: " + "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)."); return domi::PARAM_INVALID; } if ((framework == (int32_t)domi::CAFFE) && (weight_file == "")) { ErrorManager::GetInstance().ATCReportErrMessage("E10008", {"parameter"}, {"weight"}); - GELOGE(ge::FAILED, "Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!"); + GELOGE(ge::FAILED, "Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!"); return domi::PARAM_INVALID; } @@ -584,7 +583,7 @@ class GFlagUtils { // Failure if no filename follows the path if (slashPosition == static_cast(fileName.size() - 1)) { ErrorManager::GetInstance().ATCReportErrMessage("E10022", {"parameter", "filename"}, {"output", fileName}); - GELOGE(ge::FAILED, "Input parameter[--output]'s path[%s] not include file name", fileName.c_str()); + GELOGE(ge::FAILED, "Input parameter[--output]'s path[%s] not include file name", fileName.c_str()); return false; } @@ -628,7 +627,7 @@ static bool CheckInputFormat() { ErrorManager::GetInstance().ATCReportErrMessage( "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport}); GELOGE(ge::FAILED, - "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kCaffeFormatSupport); + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kCaffeFormatSupport); return false; } else if ((FLAGS_framework == static_cast(domi::TENSORFLOW))) { // tf if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) { @@ -638,7 +637,7 @@ static bool CheckInputFormat() { ErrorManager::GetInstance().ATCReportErrMessage( "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport}); GELOGE(ge::FAILED, - "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kTFFormatSupport); + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kTFFormatSupport); return false; } else if (FLAGS_framework == static_cast(domi::ONNX)) { if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) { @@ -648,7 +647,7 @@ static bool CheckInputFormat() { ErrorManager::GetInstance().ATCReportErrMessage( "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport}); GELOGE(ge::FAILED, - "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kONNXFormatSupport); + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kONNXFormatSupport); return false; } return true; @@ -804,7 +803,7 @@ Status CreateInputsForInference(const ge::Graph &graph, vector &in GE_CHECK_NOTNULL(input_node); ge::OpDescPtr op = input_node->GetOpDesc(); GE_CHECK_NOTNULL(op); - if (op->GetType() == "Data") { + if (op->GetType() == ge::DATA) { GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size()); ge::GeTensorDesc tensor = op->GetInputDesc(0); string data_op_name = op->GetName(); @@ -845,7 +844,7 @@ domi::Status GenerateInfershapeJson() { std::map options; ge::Status geRet = ge_generator.Initialize(options, domi::GetContext()); if (geRet != ge::SUCCESS) { - GELOGE(ge::FAILED, "GeGenerator initialize failed!"); + GELOGE(ge::FAILED, "GeGenerator initialize failed!"); return domi::FAILED; } @@ -855,19 +854,19 @@ domi::Status GenerateInfershapeJson() { ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework, "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false); if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED"); + GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED"); (void)ge_generator.Finalize(); return domi::FAILED; } geRet = ge_generator.GenerateInfershapeGraph(graph); if (geRet != ge::SUCCESS) { - GELOGE(ge::FAILED, "ATC GenerateInfershapeJson failed"); + GELOGE(ge::FAILED, "ATC GenerateInfershapeJson failed"); (void)ge_generator.Finalize(); return domi::FAILED; } if (DumpInfershapeJson(graph, FLAGS_json.c_str()) != SUCCESS) { - GELOGE(ge::FAILED, "ATC DumpInfershapeJson failed"); + GELOGE(ge::FAILED, "ATC DumpInfershapeJson failed"); (void)ge_generator.Finalize(); return domi::FAILED; } @@ -935,12 +934,12 @@ domi::Status GenerateModel(std::map &options, std::string output ge::Status geRet = ge::SUCCESS; geRet = ge::GEInit::Initialize(options); if (geRet != ge::SUCCESS) { - GELOGE(ge::FAILED, "GE initialize failed!"); + GELOGE(ge::FAILED, "GE initialize failed!"); return domi::FAILED; } geRet = ge_generator.Initialize(options, domi::GetContext()); if (geRet != ge::SUCCESS) { - GELOGE(ge::FAILED, "GeGenerator initialize failed!"); + GELOGE(ge::FAILED, "GeGenerator initialize failed!"); (void)ge::GEInit::Finalize(); return domi::FAILED; } @@ -953,8 +952,8 @@ domi::Status GenerateModel(std::map &options, std::string output auto ret1 = load_model.LoadFromFile(FLAGS_model); if (ret1 != ge::GRAPH_SUCCESS) { ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model}); - GELOGE(ge::FAILED, "Load model from %s failed, please check model file or " - "input parameter[--framework] is correct", FLAGS_model.c_str()); + GELOGE(ge::FAILED, "Load model from %s failed, please check model file or " + "input parameter[--framework] is correct", FLAGS_model.c_str()); (void)ge_generator.Finalize(); (void)ge::GEInit::Finalize(); return domi::FAILED; @@ -995,21 +994,21 @@ domi::Status GenerateModel(std::map &options, std::string output (void)ge_generator.Finalize(); (void)ge::GEInit::Finalize(); if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "ATC precheck fail."); + GELOGE(ge::FAILED, "ATC precheck fail."); return domi::FAILED; } return domi::SUCCESS; } if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED"); - GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case + GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED"); + GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case (void)ge_generator.Finalize(); (void)ge::GEInit::Finalize(); return domi::FAILED; } if (ge::SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) { - GELOGE(ge::FAILED, "Set output node info fail."); + GELOGE(ge::FAILED, "Set output node info fail."); (void)ge_generator.Finalize(); (void)ge::GEInit::Finalize(); return domi::FAILED; @@ -1024,8 +1023,8 @@ domi::Status GenerateModel(std::map &options, std::string output geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); if (geRet != ge::SUCCESS) { - GELOGE(ge::FAILED, "GE GenerateOfflineModel execute failed"); - GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case + GELOGE(ge::FAILED, "GE GenerateOfflineModel execute failed"); + GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case // checking error log) (void)ge_generator.Finalize(); (void)ge::GEInit::Finalize(); @@ -1061,7 +1060,7 @@ static void SetEnvForSingleOp(std::map &options) { domi::Status GenerateSingleOp(const std::string& json_file_path) { if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) { - GELOGE(ge::FAILED, "output path %s is not valid!", FLAGS_output.c_str()); + GELOGE(ge::FAILED, "output path %s is not valid!", FLAGS_output.c_str()); return domi::FAILED; } // check optypelist_for_implmode and op_select_implmode @@ -1075,21 +1074,21 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) { auto ret = ge::GEInit::Initialize(options); if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "GE initialize failed!"); + GELOGE(ge::FAILED, "GE initialize failed!"); return domi::FAILED; } ge::GeGenerator generator; ret = generator.Initialize(options, domi::GetContext()); if (ret != SUCCESS) { - GELOGE(ge::FAILED, "GeGenerator initialize failed!"); + GELOGE(ge::FAILED, "GeGenerator initialize failed!"); (void)ge::GEInit::Finalize(); return domi::FAILED; } vector build_params; if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) { - GELOGE(ge::FAILED, "parse single op json file failed"); + GELOGE(ge::FAILED, "parse single op json file failed"); (void)generator.Finalize(); (void)ge::GEInit::Finalize(); return domi::FAILED; @@ -1104,7 +1103,7 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) { output_path += param.file_name; ret = generator.BuildSingleOpModel(param.op_desc, param.inputs, param.outputs, output_path); if (ret != SUCCESS) { - GELOGE(ge::FAILED, "Compile op failed. ge ret = %u, op index = %d", ret, index); + GELOGE(ge::FAILED, "Compile op failed. ge ret = %u, op index = %d", ret, index); ret = domi::FAILED; break; } @@ -1320,10 +1319,11 @@ int init(int argc, char* argv[]) { std::string path_base = ge::GEInit::GetPath(); ret = ErrorManager::GetInstance().Init(path_base); if (ret != 0) { - GELOGE(ge::FAILED, "ErrorManager init fail !"); + GELOGE(ge::FAILED, "ErrorManager init fail !"); return ret; } + ErrorManager::GetInstance().GenWorkStreamIdDefault(); return 0; } diff --git a/atc/parse_graph.h b/atc/parse_graph.h index 5e04c73..ad59d34 100644 --- a/atc/parse_graph.h +++ b/atc/parse_graph.h @@ -28,7 +28,6 @@ #include "graph/compute_graph.h" #include "graph/graph.h" #include "graph/model.h" -//#include "runtime/kernel.h" using domi::Status; using std::pair; @@ -42,8 +41,8 @@ namespace ge { * @brief init omg context * @return void */ -GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, - bool is_dynamic_input); +GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format, + const string &net_format, bool is_dynamic_input); /** * @ingroup domi_omg @@ -60,9 +59,10 @@ GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const st * @param [in] atc_params multiply atc params * @return Status result code */ -GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map &atc_params, const char *model_file, - const char *weights_file, domi::FrameworkType type, const char *op_conf = nullptr, - const char *target = nullptr, RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false); +GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map &atc_params, + const char *model_file, const char *weights_file, domi::FrameworkType type, + const char *op_conf = nullptr, const char *target = nullptr, + RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false); /** * @ingroup domi_omg @@ -84,7 +84,8 @@ GE_FUNC_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const char * @param [key] encrypted key * @return Status result code */ -GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, const char *json_file); +GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, + const char *json_file); GE_FUNC_VISIBILITY void GetGroupName(ge::proto::ModelDef &model); @@ -92,17 +93,19 @@ GE_FUNC_VISIBILITY void FindParserSo(const string &path, vector &fileLis GE_FUNC_VISIBILITY Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); -GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); +GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, + const std::string &output_format); -GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node, std::vector> &output_nodes_info); +GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node, + std::vector> &output_nodes_info); GE_FUNC_VISIBILITY void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, - std::vector &output_nodes_name); + std::vector &output_nodes_name); GE_FUNC_VISIBILITY void UpdateOmgCtxWithParserCtx(); GE_FUNC_VISIBILITY void UpdateParserCtxWithOmgCtx(); -GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def); +GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def, uint32_t modeldef_size); } // namespace ge #endif // PARSE_GRAPH_H_ diff --git a/third_party/graphengine/inc/external/ge/ge_ir_build.h b/third_party/graphengine/inc/external/ge/ge_ir_build.h new file mode 100644 index 0000000..04e059a --- /dev/null +++ b/third_party/graphengine/inc/external/ge/ge_ir_build.h @@ -0,0 +1,159 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd + +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef INC_EXTERNAL_GE_IR_BUILD_H_ +#define INC_EXTERNAL_GE_IR_BUILD_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY _declspec(dllexport) +#else +#define GE_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_VISIBILITY +#endif +#endif + +#include +#include +#include +#include "graph/graph.h" +#include "graph/ge_error_codes.h" + +namespace { +const int IR_MAJOR_VERSION = 1; +const int IR_MINOR_VERSION = 0; +const int IR_PATCH_VERSION = 0; +} // namespace + +namespace ge { + +struct ModelBufferData { + std::shared_ptr data = nullptr; + uint64_t length; +}; + +enum aclgrphAttrType { ATTR_TYPE_KEEP_DTYPE = 0, ATTR_TYPE_WEIGHT_COMPRESS }; + +/** + * @ingroup AscendCL + * @brief build model.Notice the model is stored in buffer + * + * @param global_options[IN] global init params for build + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map &)) +GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map global_options); + +GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map &global_options); + +/** + * @ingroup AscendCL + * @brief build model.Notice the model is stored in buffer + * + */ +GE_FUNC_VISIBILITY void aclgrphBuildFinalize(); + +/** + * @ingroup AscendCL + * @brief build model.Notice the model is stored in buffer + * + * @param graph[IN] the graph ready to build + * @param options[IN] options used for build + * @param model[OUT] builded model + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &, + const std::map &, + ModelBufferData &)) +GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, + const std::map &build_options, + ModelBufferData &model); + +GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, + const std::map &build_options, + ModelBufferData &model); + +/** + * @ingroup AscendCL + * @brief save model buffer to file + * + * @param output_file[IN] the file path to be saved + * @param model[IN] model buffer data + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char *, const ModelBufferData &)) +GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model); + +GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model); + +/** + * @ingroup AscendCL + * @brief query IR interface version + * + * @param major_version[OUT] IR interface major version + * @param minor_version[OUT] IR interface minor version + * @param patch_version[OUT] IR interface patch version + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version); + +/** + * @ingroup AscendCL + * @brief dump graph + * + * @param graph[IN] the graph ready to build + * @param file[IN] file path + * @param file[IN] file path string len + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len); + +/** + * @ingroup AscendCL + * @brief create single op graph + * + * @param op_type[IN] the op_type + * @param inputs[IN] the inputdesc + * @param outputs[IN] the outputdesc + * @param graph[OUT] the graph + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector &inputs, + const std::vector &outputs, Graph &graph); + +/** + * @name aclgrphSetOpAttr + * @brief set attribute for operators in the configuration file + * @param graph [IN/OUT] compute graph + * @param attr_type [In] attribute type + * @param cfg_path [IN] the config file path + * @return graphStatus + */ +GE_FUNC_VISIBILITY graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path); + +}; // namespace ge +#endif // INC_EXTERNAL_GE_IR_BUILD_H_ diff --git a/third_party/graphengine/inc/framework/generator/ge_generator.h b/third_party/graphengine/inc/framework/generator/ge_generator.h new file mode 100644 index 0000000..1b3cade --- /dev/null +++ b/third_party/graphengine/inc/framework/generator/ge_generator.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_ +#define INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_ + +#include +#include +#include +#include +#include "common/ge_inner_error_codes.h" +#include "graph/ge_tensor.h" +#include "graph/graph.h" +#include "graph/op_desc.h" +#include "graph/detail/attributes_holder.h" +#include "omg/omg_inner_types.h" + +namespace ge { +class GE_FUNC_VISIBILITY GeGenerator { + public: + GeGenerator() = default; + + ~GeGenerator() { (void)Finalize(); } + + GeGenerator(const GeGenerator &) = delete; + + GeGenerator &operator=(const GeGenerator &) = delete; + + Status Initialize(const std::map &options, OmgContext &context); + + Status Finalize(); + + Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix, + const std::vector &inputs = std::vector()); + + Status GenerateInfershapeGraph(const Graph &graph); + + Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector &inputs, + const std::vector &outputs, const std::string &model_file_name); + private: + class Impl; + + std::shared_ptr impl_; +}; +} // namespace ge + +#endif // INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_ diff --git a/third_party/graphengine/inc/framework/omg/ge_init.h b/third_party/graphengine/inc/framework/omg/ge_init.h new file mode 100644 index 0000000..b0cd0d6 --- /dev/null +++ b/third_party/graphengine/inc/framework/omg/ge_init.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_GE_INIT_H_ +#define INC_FRAMEWORK_OMG_GE_INIT_H_ +#include +#include +#include "common/ge_inner_error_codes.h" + +using std::string; +using std::map; + +namespace ge { +class GE_FUNC_VISIBILITY GEInit { + public: + // GE Environment Initialize, return Status: SUCCESS,FAILED + static Status Initialize(const map &options); + + static string GetPath(); + + // GE Environment Finalize, return Status: SUCCESS,FAILED + static Status Finalize(); +}; +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_GE_INIT_H_ diff --git a/third_party/graphengine/inc/framework/omg/model_tool.h b/third_party/graphengine/inc/framework/omg/model_tool.h new file mode 100644 index 0000000..93c4e68 --- /dev/null +++ b/third_party/graphengine/inc/framework/omg/model_tool.h @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_MODEL_TOOL_H_ +#define INC_FRAMEWORK_OMG_MODEL_TOOL_H_ + +#include +#include + +#include "framework/common/debug/ge_log.h" +#include "proto/ge_ir.pb.h" + +namespace ge { +class GE_FUNC_VISIBILITY ModelTool { + public: + static Status GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef &model_def, uint32_t &modeldef_size); + + static Status GetModelInfoFromPbtxt(const char *model_file, ge::proto::ModelDef &model_def); +}; +} // namespace ge +#endif // INC_FRAMEWORK_OMG_MODEL_TOOL_H_