Browse Source

fix

pull/250/head
wjm 4 years ago
parent
commit
6277ec6c10
8 changed files with 343 additions and 49 deletions
  1. +0
    -2
      atc/CMakeLists.txt
  2. +1
    -0
      atc/atc_ir_common.cc
  3. +36
    -36
      atc/main.cc
  4. +14
    -11
      atc/parse_graph.h
  5. +159
    -0
      third_party/graphengine/inc/external/ge/ge_ir_build.h
  6. +60
    -0
      third_party/graphengine/inc/framework/generator/ge_generator.h
  7. +39
    -0
      third_party/graphengine/inc/framework/omg/ge_init.h
  8. +34
    -0
      third_party/graphengine/inc/framework/omg/model_tool.h

+ 0
- 2
atc/CMakeLists.txt View File

@@ -54,7 +54,6 @@ target_include_directories(atc_atc.bin PRIVATE
${METADEF_DIR}/third_party/graphengine/inc/
${METADEF_DIR}/third_party/graphengine/inc/external
${METADEF_DIR}/third_party/graphengine/inc/framework
${METADEF_DIR}/third_party/graphengine/ge
${METADEF_DIR}/third_party/fwkacllib/inc
${PARSER_DIR}/third_party/graphengine/inc
)
@@ -123,7 +122,6 @@ target_include_directories(fwk_atc.bin PRIVATE
${METADEF_DIR}/third_party/graphengine/inc/
${METADEF_DIR}/third_party/graphengine/inc/external
${METADEF_DIR}/third_party/graphengine/inc/framework
${METADEF_DIR}/third_party/graphengine/ge
${METADEF_DIR}/third_party/fwkacllib/inc
${PARSER_DIR}/third_party/graphengine/inc/



+ 1
- 0
atc/atc_ir_common.cc View File

@@ -165,6 +165,7 @@ bool CheckDynamicImageSizeShape(const vector<int64_t> &shape, const string &data
return false;
}
}

bool CheckDynamicImagesizeInputShapeValid(map<string, vector<int64_t>> shape_map,
const std::string input_format, std::string &dynamic_image_size) {
if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(input_format.c_str())) {


+ 36
- 36
atc/main.cc View File

@@ -63,19 +63,18 @@ using std::shared_ptr;
using std::string;
using std::vector;

namespace {
static bool is_dynamic_input = false;

const char *const kModeSupport = "only support 0(model to framework model), "
"1(framework model to json), 3(only pre-check), "
"5(pbtxt to json), 6(display model info)";
const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow) 5(Onnx)";

static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model";
static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model";
static const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model";

const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model";
const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model";
const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model";
// limit available mem size 2G
const long kMinAvailableMem = 2097152; // 2 * 1024 * 1024
} // namespace

DEFINE_string(model, "", "The model file.");
DEFINE_string(output, "", "The output file path&name.");
@@ -523,12 +522,12 @@ class GFlagUtils {
static bool CheckEncryptModeValid(const int encrypt_mode) {
#if !defined(__ANDROID__) && !defined(ANDROID)
if (encrypt_mode != 0 && encrypt_mode != -1) {
GELOGE(ge::FAILED, "encrypt mode must be 0 or -1");
GELOGE(ge::FAILED, "encrypt mode must be 0 or -1");
return false;
}
#else
if (encrypt_mode != -1) {
GELOGE(ge::FAILED, "encrypt mode must be -1");
GELOGE(ge::FAILED, "encrypt mode must be -1");
return false;
}
#endif
@@ -543,14 +542,14 @@ class GFlagUtils {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10007", {"parameter", "support"},
{"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)"});
GELOGE(ge::FAILED, "Input parameter[--framework] is mandatory and it's value must be: "
"0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx).");
GELOGE(ge::FAILED, "Input parameter[--framework] is mandatory and it's value must be: "
"0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx).");
return domi::PARAM_INVALID;
}

if ((framework == (int32_t)domi::CAFFE) && (weight_file == "")) {
ErrorManager::GetInstance().ATCReportErrMessage("E10008", {"parameter"}, {"weight"});
GELOGE(ge::FAILED, "Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!");
GELOGE(ge::FAILED, "Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!");
return domi::PARAM_INVALID;
}

@@ -584,7 +583,7 @@ class GFlagUtils {
// Failure if no filename follows the path
if (slashPosition == static_cast<int>(fileName.size() - 1)) {
ErrorManager::GetInstance().ATCReportErrMessage("E10022", {"parameter", "filename"}, {"output", fileName});
GELOGE(ge::FAILED, "Input parameter[--output]'s path[%s] not include file name", fileName.c_str());
GELOGE(ge::FAILED, "Input parameter[--output]'s path[%s] not include file name", fileName.c_str());
return false;
}

@@ -628,7 +627,7 @@ static bool CheckInputFormat() {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport});
GELOGE(ge::FAILED,
"Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kCaffeFormatSupport);
"Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kCaffeFormatSupport);
return false;
} else if ((FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW))) { // tf
if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) {
@@ -638,7 +637,7 @@ static bool CheckInputFormat() {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport});
GELOGE(ge::FAILED,
"Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kTFFormatSupport);
"Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kTFFormatSupport);
return false;
} else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) {
if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) {
@@ -648,7 +647,7 @@ static bool CheckInputFormat() {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport});
GELOGE(ge::FAILED,
"Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kONNXFormatSupport);
"Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kONNXFormatSupport);
return false;
}
return true;
@@ -804,7 +803,7 @@ Status CreateInputsForInference(const ge::Graph &graph, vector<ge::GeTensor> &in
GE_CHECK_NOTNULL(input_node);
ge::OpDescPtr op = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op);
if (op->GetType() == "Data") {
if (op->GetType() == ge::DATA) {
GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size());
ge::GeTensorDesc tensor = op->GetInputDesc(0);
string data_op_name = op->GetName();
@@ -845,7 +844,7 @@ domi::Status GenerateInfershapeJson() {
std::map<string, string> options;
ge::Status geRet = ge_generator.Initialize(options, domi::GetContext());
if (geRet != ge::SUCCESS) {
GELOGE(ge::FAILED, "GeGenerator initialize failed!");
GELOGE(ge::FAILED, "GeGenerator initialize failed!");
return domi::FAILED;
}

@@ -855,19 +854,19 @@ domi::Status GenerateInfershapeJson() {
ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework,
"", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false);
if (ret != ge::SUCCESS) {
GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED");
GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED");
(void)ge_generator.Finalize();
return domi::FAILED;
}

geRet = ge_generator.GenerateInfershapeGraph(graph);
if (geRet != ge::SUCCESS) {
GELOGE(ge::FAILED, "ATC GenerateInfershapeJson failed");
GELOGE(ge::FAILED, "ATC GenerateInfershapeJson failed");
(void)ge_generator.Finalize();
return domi::FAILED;
}
if (DumpInfershapeJson(graph, FLAGS_json.c_str()) != SUCCESS) {
GELOGE(ge::FAILED, "ATC DumpInfershapeJson failed");
GELOGE(ge::FAILED, "ATC DumpInfershapeJson failed");
(void)ge_generator.Finalize();
return domi::FAILED;
}
@@ -935,12 +934,12 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output
ge::Status geRet = ge::SUCCESS;
geRet = ge::GEInit::Initialize(options);
if (geRet != ge::SUCCESS) {
GELOGE(ge::FAILED, "GE initialize failed!");
GELOGE(ge::FAILED, "GE initialize failed!");
return domi::FAILED;
}
geRet = ge_generator.Initialize(options, domi::GetContext());
if (geRet != ge::SUCCESS) {
GELOGE(ge::FAILED, "GeGenerator initialize failed!");
GELOGE(ge::FAILED, "GeGenerator initialize failed!");
(void)ge::GEInit::Finalize();
return domi::FAILED;
}
@@ -953,8 +952,8 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output
auto ret1 = load_model.LoadFromFile(FLAGS_model);
if (ret1 != ge::GRAPH_SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model});
GELOGE(ge::FAILED, "Load model from %s failed, please check model file or "
"input parameter[--framework] is correct", FLAGS_model.c_str());
GELOGE(ge::FAILED, "Load model from %s failed, please check model file or "
"input parameter[--framework] is correct", FLAGS_model.c_str());
(void)ge_generator.Finalize();
(void)ge::GEInit::Finalize();
return domi::FAILED;
@@ -995,21 +994,21 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output
(void)ge_generator.Finalize();
(void)ge::GEInit::Finalize();
if (ret != ge::SUCCESS) {
GELOGE(ge::FAILED, "ATC precheck fail.");
GELOGE(ge::FAILED, "ATC precheck fail.");
return domi::FAILED;
}
return domi::SUCCESS;
}

if (ret != ge::SUCCESS) {
GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED");
GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case
GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED");
GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case
(void)ge_generator.Finalize();
(void)ge::GEInit::Finalize();
return domi::FAILED;
}
if (ge::SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) {
GELOGE(ge::FAILED, "Set output node info fail.");
GELOGE(ge::FAILED, "Set output node info fail.");
(void)ge_generator.Finalize();
(void)ge::GEInit::Finalize();
return domi::FAILED;
@@ -1024,8 +1023,8 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output

geRet = ge_generator.GenerateOfflineModel(graph, output, inputs);
if (geRet != ge::SUCCESS) {
GELOGE(ge::FAILED, "GE GenerateOfflineModel execute failed");
GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case
GELOGE(ge::FAILED, "GE GenerateOfflineModel execute failed");
GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case
// checking error log)
(void)ge_generator.Finalize();
(void)ge::GEInit::Finalize();
@@ -1061,7 +1060,7 @@ static void SetEnvForSingleOp(std::map<string, string> &options) {

domi::Status GenerateSingleOp(const std::string& json_file_path) {
if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) {
GELOGE(ge::FAILED, "output path %s is not valid!", FLAGS_output.c_str());
GELOGE(ge::FAILED, "output path %s is not valid!", FLAGS_output.c_str());
return domi::FAILED;
}
// check optypelist_for_implmode and op_select_implmode
@@ -1075,21 +1074,21 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) {

auto ret = ge::GEInit::Initialize(options);
if (ret != ge::SUCCESS) {
GELOGE(ge::FAILED, "GE initialize failed!");
GELOGE(ge::FAILED, "GE initialize failed!");
return domi::FAILED;
}

ge::GeGenerator generator;
ret = generator.Initialize(options, domi::GetContext());
if (ret != SUCCESS) {
GELOGE(ge::FAILED, "GeGenerator initialize failed!");
GELOGE(ge::FAILED, "GeGenerator initialize failed!");
(void)ge::GEInit::Finalize();
return domi::FAILED;
}

vector<ge::SingleOpBuildParam> build_params;
if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "parse single op json file failed");
GELOGE(ge::FAILED, "parse single op json file failed");
(void)generator.Finalize();
(void)ge::GEInit::Finalize();
return domi::FAILED;
@@ -1104,7 +1103,7 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) {
output_path += param.file_name;
ret = generator.BuildSingleOpModel(param.op_desc, param.inputs, param.outputs, output_path);
if (ret != SUCCESS) {
GELOGE(ge::FAILED, "Compile op failed. ge ret = %u, op index = %d", ret, index);
GELOGE(ge::FAILED, "Compile op failed. ge ret = %u, op index = %d", ret, index);
ret = domi::FAILED;
break;
}
@@ -1320,10 +1319,11 @@ int init(int argc, char* argv[]) {
std::string path_base = ge::GEInit::GetPath();
ret = ErrorManager::GetInstance().Init(path_base);
if (ret != 0) {
GELOGE(ge::FAILED, "ErrorManager init fail !");
GELOGE(ge::FAILED, "ErrorManager init fail !");
return ret;
}

ErrorManager::GetInstance().GenWorkStreamIdDefault();
return 0;
}



+ 14
- 11
atc/parse_graph.h View File

@@ -28,7 +28,6 @@
#include "graph/compute_graph.h"
#include "graph/graph.h"
#include "graph/model.h"
//#include "runtime/kernel.h"

using domi::Status;
using std::pair;
@@ -42,8 +41,8 @@ namespace ge {
* @brief init omg context
* @return void
*/
GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format,
bool is_dynamic_input);
GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format,
const string &net_format, bool is_dynamic_input);

/**
* @ingroup domi_omg
@@ -60,9 +59,10 @@ GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const st
* @param [in] atc_params multiply atc params
* @return Status result code
*/
GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params, const char *model_file,
const char *weights_file, domi::FrameworkType type, const char *op_conf = nullptr,
const char *target = nullptr, RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false);
GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params,
const char *model_file, const char *weights_file, domi::FrameworkType type,
const char *op_conf = nullptr, const char *target = nullptr,
RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false);

/**
* @ingroup domi_omg
@@ -84,7 +84,8 @@ GE_FUNC_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const char
* @param [key] encrypted key
* @return Status result code
*/
GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, const char *json_file);
GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file,
const char *json_file);

GE_FUNC_VISIBILITY void GetGroupName(ge::proto::ModelDef &model);

@@ -92,17 +93,19 @@ GE_FUNC_VISIBILITY void FindParserSo(const string &path, vector<string> &fileLis

GE_FUNC_VISIBILITY Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file);

GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format);
GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type,
const std::string &output_format);

GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);
GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);

GE_FUNC_VISIBILITY void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);
std::vector<std::string> &output_nodes_name);

GE_FUNC_VISIBILITY void UpdateOmgCtxWithParserCtx();

GE_FUNC_VISIBILITY void UpdateParserCtxWithOmgCtx();

GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def);
GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def, uint32_t modeldef_size);
} // namespace ge
#endif // PARSE_GRAPH_H_

+ 159
- 0
third_party/graphengine/inc/external/ge/ge_ir_build.h View File

@@ -0,0 +1,159 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GE_IR_BUILD_H_
#define INC_EXTERNAL_GE_IR_BUILD_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY _declspec(dllexport)
#else
#define GE_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_VISIBILITY
#endif
#endif

#include <string>
#include <map>
#include <memory>
#include "graph/graph.h"
#include "graph/ge_error_codes.h"

namespace {
const int IR_MAJOR_VERSION = 1;
const int IR_MINOR_VERSION = 0;
const int IR_PATCH_VERSION = 0;
} // namespace

namespace ge {

struct ModelBufferData {
std::shared_ptr<uint8_t> data = nullptr;
uint64_t length;
};

enum aclgrphAttrType { ATTR_TYPE_KEEP_DTYPE = 0, ATTR_TYPE_WEIGHT_COMPRESS };

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
* @param global_options[IN] global init params for build
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &))
GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options);

GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &global_options);

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
*/
GE_FUNC_VISIBILITY void aclgrphBuildFinalize();

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
* @param graph[IN] the graph ready to build
* @param options[IN] options used for build
* @param model[OUT] builded model
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &,
const std::map<AscendString, AscendString> &,
ModelBufferData &))
GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph,
const std::map<std::string, std::string> &build_options,
ModelBufferData &model);

GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph,
const std::map<AscendString, AscendString> &build_options,
ModelBufferData &model);

/**
* @ingroup AscendCL
* @brief save model buffer to file
*
* @param output_file[IN] the file path to be saved
* @param model[IN] model buffer data
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char *, const ModelBufferData &))
GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model);

GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model);

/**
* @ingroup AscendCL
* @brief query IR interface version
*
* @param major_version[OUT] IR interface major version
* @param minor_version[OUT] IR interface minor version
* @param patch_version[OUT] IR interface patch version
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
GE_FUNC_VISIBILITY graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version);

/**
* @ingroup AscendCL
* @brief dump graph
*
* @param graph[IN] the graph ready to build
* @param file[IN] file path
* @param file[IN] file path string len
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
GE_FUNC_VISIBILITY graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len);

/**
* @ingroup AscendCL
* @brief create single op graph
*
* @param op_type[IN] the op_type
* @param inputs[IN] the inputdesc
* @param outputs[IN] the outputdesc
* @param graph[OUT] the graph
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector<TensorDesc> &inputs,
const std::vector<TensorDesc> &outputs, Graph &graph);

/**
* @name aclgrphSetOpAttr
* @brief set attribute for operators in the configuration file
* @param graph [IN/OUT] compute graph
* @param attr_type [In] attribute type
* @param cfg_path [IN] the config file path
* @return graphStatus
*/
GE_FUNC_VISIBILITY graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path);

}; // namespace ge
#endif // INC_EXTERNAL_GE_IR_BUILD_H_

+ 60
- 0
third_party/graphengine/inc/framework/generator/ge_generator.h View File

@@ -0,0 +1,60 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_
#define INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_

#include <map>
#include <memory>
#include <string>
#include <vector>
#include "common/ge_inner_error_codes.h"
#include "graph/ge_tensor.h"
#include "graph/graph.h"
#include "graph/op_desc.h"
#include "graph/detail/attributes_holder.h"
#include "omg/omg_inner_types.h"

namespace ge {
class GE_FUNC_VISIBILITY GeGenerator {
public:
GeGenerator() = default;

~GeGenerator() { (void)Finalize(); }

GeGenerator(const GeGenerator &) = delete;

GeGenerator &operator=(const GeGenerator &) = delete;

Status Initialize(const std::map<std::string, std::string> &options, OmgContext &context);

Status Finalize();

Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix,
const std::vector<GeTensor> &inputs = std::vector<GeTensor>());

Status GenerateInfershapeGraph(const Graph &graph);

Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector<GeTensor> &inputs,
const std::vector<GeTensor> &outputs, const std::string &model_file_name);
private:
class Impl;

std::shared_ptr<Impl> impl_;
};
} // namespace ge

#endif // INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_

+ 39
- 0
third_party/graphengine/inc/framework/omg/ge_init.h View File

@@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_OMG_GE_INIT_H_
#define INC_FRAMEWORK_OMG_GE_INIT_H_
#include <map>
#include <string>
#include "common/ge_inner_error_codes.h"

using std::string;
using std::map;

namespace ge {
class GE_FUNC_VISIBILITY GEInit {
public:
// GE Environment Initialize, return Status: SUCCESS,FAILED
static Status Initialize(const map<string, string> &options);

static string GetPath();

// GE Environment Finalize, return Status: SUCCESS,FAILED
static Status Finalize();
};
} // namespace ge

#endif // INC_FRAMEWORK_OMG_GE_INIT_H_

+ 34
- 0
third_party/graphengine/inc/framework/omg/model_tool.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_OMG_MODEL_TOOL_H_
#define INC_FRAMEWORK_OMG_MODEL_TOOL_H_

#include <memory>
#include <string>

#include "framework/common/debug/ge_log.h"
#include "proto/ge_ir.pb.h"

namespace ge {
class GE_FUNC_VISIBILITY ModelTool {
public:
static Status GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef &model_def, uint32_t &modeldef_size);
static Status GetModelInfoFromPbtxt(const char *model_file, ge::proto::ModelDef &model_def);
};
} // namespace ge
#endif // INC_FRAMEWORK_OMG_MODEL_TOOL_H_

Loading…
Cancel
Save