From 2dcaf0d3493efc7a4fda2f3bf029ee898f40f491 Mon Sep 17 00:00:00 2001 From: wjm Date: Fri, 26 Feb 2021 17:23:19 +0800 Subject: [PATCH 1/3] move atc to parser --- CMakeLists.txt | 8 + atc/CMakeLists.txt | 169 ++++ atc/atc | 21 + atc/atc_ir_common.cc | 595 +++++++++++++ atc/atc_ir_common.h | 82 ++ atc/common/types.cc | 23 + atc/common/types.h | 32 + atc/main.cc | 1428 ++++++++++++++++++++++++++++++ atc/parse_graph.cc | 1004 +++++++++++++++++++++ atc/parse_graph.h | 108 +++ atc/single_op_parser.cc | 609 +++++++++++++ atc/single_op_parser.h | 90 ++ atc/util/gflags_util.h | 85 ++ atc/util/properties_manager.cc | 145 +++ atc/util/properties_manager.h | 82 ++ atc/util/string_util.h | 171 ++++ atc/util/tool.h | 32 + atc/util/util.cc | 283 ++++++ atc/util/util.h | 159 ++++ parser/common/convert/pb2json.cc | 10 + parser/common/convert/pb2json.h | 4 + 21 files changed, 5140 insertions(+) create mode 100644 atc/CMakeLists.txt create mode 100644 atc/atc create mode 100644 atc/atc_ir_common.cc create mode 100644 atc/atc_ir_common.h create mode 100644 atc/common/types.cc create mode 100644 atc/common/types.h create mode 100644 atc/main.cc create mode 100644 atc/parse_graph.cc create mode 100644 atc/parse_graph.h create mode 100644 atc/single_op_parser.cc create mode 100644 atc/single_op_parser.h create mode 100644 atc/util/gflags_util.h create mode 100644 atc/util/properties_manager.cc create mode 100644 atc/util/properties_manager.h create mode 100644 atc/util/string_util.h create mode 100644 atc/util/tool.h create mode 100644 atc/util/util.cc create mode 100644 atc/util/util.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 5394587..8f74bc0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,7 @@ if (ENABLE_OPEN_SRC) include(cmake/external_libs/protoc.cmake) include(cmake/external_libs/securec.cmake) include(cmake/external_libs/json.cmake) + include(cmake/external_libs/gflags.cmake) include(cmake/FindModule.cmake) include(cmake/intf_pub_linux.cmake) @@ -35,6 +36,9 @@ if (ENABLE_OPEN_SRC) set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) find_module(slog libalog.so ${GE_LIB_PATH}) find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) + find_module(ge_compiler libge_compiler.so ${GE_LIB_PATH}) + find_module(ge_runner libge_runner.so ${GE_LIB_PATH}) + find_module(ge_common libge_common.so ${GE_LIB_PATH}) elseif(ENABLE_GE_COV OR ENABLE_GE_UT) message(STATUS "Runing on llt mode, no need to depend other component") else() @@ -48,6 +52,9 @@ if (ENABLE_OPEN_SRC) find_module(slog libalog.so ${ASCEND_ATC_DIR}) find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) + find_module(ge_compiler libge_compiler.so ${ASCEND_ATC_DIR}) + #find_module(ge_runner libge_runner.so ${ASCEND_RUNTIME_DIR}) + find_module(ge_common libge_common.so ${ASCEND_ATC_DIR}) endif() if (NOT DEFINED METADEF_DIR) @@ -66,3 +73,4 @@ add_subdirectory(parser/common) add_subdirectory(parser/func_to_graph) add_subdirectory(parser/onnx) add_subdirectory(parser/proto/caffe) +#add_subdirectory(atc) diff --git a/atc/CMakeLists.txt b/atc/CMakeLists.txt new file mode 100644 index 0000000..4fa9f8b --- /dev/null +++ b/atc/CMakeLists.txt @@ -0,0 +1,169 @@ +set(PROTO_LIST + "${METADEF_DIR}/proto/om.proto" + "${METADEF_DIR}/proto/ge_ir.proto" + "${METADEF_DIR}/proto/insert_op.proto" + "${METADEF_DIR}/proto/task.proto" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +set(SRC_LIST + "main.cc" + "single_op_parser.cc" + "parse_graph.cc" + "atc_ir_common.cc" + "common/types.cc" + "util/util.cc" + "util/properties_manager.cc" +) + +############ atc_atc.bin ############ +add_executable(atc_atc.bin ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(atc_atc.bin PRIVATE + -Werror + -O2 + -Wno-deprecated-declarations + -fno-common + -fvisibility=hidden +) + +target_compile_definitions(atc_atc.bin PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + COMPILE_OMG_PACKAGE + google=ascend_private + LOG_CPP + FUNC_VISIBILITY +) + +target_include_directories(atc_atc.bin PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc/register + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/external/register + ${PARSER_DIR} + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + ${GE_CODE_DIR}/../inc/common + #### blue zone #### + ${METADEF_DIR}/third_party/graphengine/inc/ + ${METADEF_DIR}/third_party/graphengine/inc/external + ${METADEF_DIR}/third_party/graphengine/inc/framework + ${METADEF_DIR}/third_party/graphengine/ge + ${METADEF_DIR}/third_party/fwkacllib/inc + ${PARSER_DIR}/third_party/graphengine/inc +) + +target_link_options(atc_atc.bin PRIVATE + -Wl,-Bsymbolic +) + +target_link_libraries(atc_atc.bin PRIVATE + $ + ascend_protobuf + ge_common + register + c_sec + graph + error_manager + ge_compiler + parser_common + gflags + json + slog + static_mmpa + -lrt + -ldl +) + +set_target_properties(atc_atc.bin PROPERTIES + OUTPUT_NAME atc.bin + RUNTIME_OUTPUT_DIRECTORY atclib +) + +############ fwk_atc.bin ############ +add_executable(fwk_atc.bin ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(fwk_atc.bin PRIVATE + -Werror + -O2 + -Wno-deprecated-declarations + -fno-common + -fvisibility=hidden +) + +target_compile_definitions(fwk_atc.bin PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + COMPILE_OMG_PACKAGE + google=ascend_private + LOG_CPP + FUNC_VISIBILITY +) + +target_include_directories(fwk_atc.bin PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc/register + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/external/register + ${PARSER_DIR} + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + ${GE_CODE_DIR}/../inc/common + #### blue zone #### + ${METADEF_DIR}/third_party/graphengine/inc/ + ${METADEF_DIR}/third_party/graphengine/inc/external + ${METADEF_DIR}/third_party/graphengine/inc/framework + ${METADEF_DIR}/third_party/graphengine/ge + ${METADEF_DIR}/third_party/fwkacllib/inc + ${PARSER_DIR}/third_party/graphengine/inc/ + +) + +target_link_options(fwk_atc.bin PRIVATE + -Wl,-Bsymbolic +) + +target_link_libraries(fwk_atc.bin PRIVATE + $ + ascend_protobuf + ge_common + register + c_sec + graph + error_manager + ge_runner + parser_common + gflags + json + slog + static_mmpa + -lrt + -ldl +) + +set_target_properties(fwk_atc.bin PROPERTIES + OUTPUT_NAME atc.bin + RUNTIME_OUTPUT_DIRECTORY fwkacl +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS atc_atc.bin OPTIONAL + RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/atclib +) + +install(TARGETS fwk_atc.bin OPTIONAL + RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/fwkacl +) diff --git a/atc/atc b/atc/atc new file mode 100644 index 0000000..05c65c2 --- /dev/null +++ b/atc/atc @@ -0,0 +1,21 @@ +#!/bin/bash +#------------------------------------------------------------------- +# Purpose: +# Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. +#------------------------------------------------------------------- + +real_path=$(readlink "$0") +if [ $? -eq 0 ]; then + LOCAL_PATH=$(cd "$(dirname "$real_path")"; pwd) +else + LOCAL_PATH=$(cd "$(dirname "$0")"; pwd) +fi +PKG_PATH=$(cd ${LOCAL_PATH}/..; pwd) +LIB_P="/lib64" +PYTHON_P="/python/site-packages" +LIB64_PATH="${PKG_PATH}${LIB_P}" +PYTHON_PATH="${PKG_PATH}${PYTHON_P}" +export LD_LIBRARY_PATH="${LIB64_PATH}:${LD_LIBRARY_PATH}" +export PYTHONPATH="${PYTHON_PATH}:${PYTHONPATH}" + +${PKG_PATH}/bin/atc.bin "$@" diff --git a/atc/atc_ir_common.cc b/atc/atc_ir_common.cc new file mode 100644 index 0000000..0114fb7 --- /dev/null +++ b/atc/atc_ir_common.cc @@ -0,0 +1,595 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atc_ir_common.h" +#include "common/util/error_manager/error_manager.h" +#include "external/ge/ge_api_types.h" +#include "util/string_util.h" +#include "util/util.h" +#include "graph/utils/type_utils.h" + +using std::pair; +using std::string; +using std::vector; + +namespace ge { +namespace { +const int64_t kDynamicInputDim = -1; +const int64_t kDynamicImageSizeNum = 2; +const uint32_t NCHW_DIM_H = 2; +const uint32_t NCHW_DIM_W = 3; +const uint32_t NHWC_DIM_H = 1; +const uint32_t NHWC_DIM_W = 2; +const int32_t DIM_DEFAULT_SIZE = 4; +const size_t kMaxDynamicDimNum = 100; +const size_t kMaxNDDimNum = 4; +const size_t kMinNDDimNum = 1; +// datatype/formats from user to GE, Unified to util interface file later +const std::map kOutputTypeSupportDatatype = { + {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; +const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; +const std::set kBufferOptimizeSupportOption = {"l1_optimize", "l2_optimize", "off_optimize", + "l1_and_l2_optimize"}; +// The function is incomplete. Currently, only l2_optimize, off_optimize is supported. +const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize"; +const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance"; +const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision"; +const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; +const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; +const char *const kSplitError1 = "size not equal to 2 split by \":\""; +const char *const kEmptyError = "can not be empty"; +const char *const kFloatNumError = "exist float number"; +const char *const kDigitError = "is not digit"; +const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; +const char *const kSelectImplmodeError = "only support high_performance, high_precision"; +const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; +const char *const kKeepDtypeError = "file not found"; + +vector SplitInputShape(const std::string &input_shape) { + vector shape_pair_vec; + size_t pos = input_shape.rfind(":"); + if (pos != std::string::npos) { + shape_pair_vec.emplace_back(input_shape.substr(0, pos)); + shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos)); + } + return shape_pair_vec; +} +} // namespace + +Status CheckInputFormat(const string &input_format) { + if (input_format.empty()) { + return ge::SUCCESS; + } + if (!ge::TypeUtils::IsFormatValid(input_format.c_str())) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, "input format is invalid!"}); + GELOGE(ge::PARAM_INVALID, "input format [%s] is invalid!", input_format.c_str()); + return ge::PARAM_INVALID; + } + return ge::SUCCESS; +} + +bool CheckDynamicBatchShape(const vector &shape, const string &data_name) { + if (shape[0] == kDynamicInputDim) { + for (size_t i = 1; i < shape.size(); ++i) { + if (shape[i] < 1) { + ErrorManager::GetInstance().ATCReportErrMessage("E10018", {"index", "shape"}, + {std::to_string(i), std::to_string(shape[i])}); + GELOGE(ge::PARAM_INVALID, + "Only batch N can be -1 when set --dynamic_batch_size, current data: %s shape[%zu] is %ld", + data_name.c_str(), i, shape[i]); + return false; + } + } + return true; + } else { + return false; + } +} + +bool CheckDynamicBatchSizeInputShapeValid(map> shape_map, + std::string &dynamic_batch_size) { + int32_t size = 0; + for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) { + vector shape = iter->second; + if (shape.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10012"); + GELOGE(ge::PARAM_INVALID, "--input_shape's shape size can not be less than 1 when set --dynamic_batch_size."); + return false; + } + + if (std::count(shape.begin(), shape.end(), kDynamicInputDim) == 0) { + continue; + } + + bool ret = CheckDynamicBatchShape(shape, iter->first); + if (ret) { + size++; + } + } + + if (size == 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10031"); + GELOGE(ge::PARAM_INVALID, "At least one batch n must be equal to -1 when set --dynamic_batch_size."); + return false; + } + + for (char c : dynamic_batch_size) { + if (!isdigit(c) && (c != ',') && (c != ' ')) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10033", {"value", "reason"}, {dynamic_batch_size, kDynamicBatchSizeError}); + GELOGE(ge::PARAM_INVALID, "Input parameter[--dynamic_batch_size]'s value[%s] is invalid. reason: %s", + dynamic_batch_size.c_str(), kDynamicBatchSizeError); + return false; + } + } + if (dynamic_batch_size.back() == ',') { + dynamic_batch_size.erase(dynamic_batch_size.end() - 1); + } + return true; +} + +bool CheckDynamicImageSizeShape(const vector &shape, const string &data_name, + const std::string &input_format) { + int64_t height = 0; + int64_t width = 0; + if (input_format == "NCHW") { + height = shape[NCHW_DIM_H]; + width = shape[NCHW_DIM_W]; + } + + if (input_format == "NHWC") { + height = shape[NHWC_DIM_H]; + width = shape[NHWC_DIM_W]; + } + + if (height == kDynamicInputDim && width == kDynamicInputDim && + std::count(shape.begin(), shape.end(), kDynamicInputDim) == kDynamicImageSizeNum) { + return true; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E10019"); + GELOGE(ge::PARAM_INVALID, + "--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size."); + return false; + } +} +bool CheckDynamicImagesizeInputShapeValid(map> shape_map, + const std::string input_format, std::string &dynamic_image_size) { + if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(input_format.c_str())) { + GELOGE(ge::PARAM_INVALID, "user input format [%s] is not found!", input_format.c_str()); + return false; + } + int32_t size = 0; + for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) { + vector shape = iter->second; + // only support four dim + if (shape.size() != DIM_DEFAULT_SIZE) { + if (std::count(shape.begin(), shape.end(), kDynamicInputDim) > 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10019"); + GELOGE(ge::PARAM_INVALID, + "--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size."); + return false; + } + continue; + } + + if (std::count(shape.begin(), shape.end(), kDynamicInputDim) == 0) { + continue; + } + auto ret = CheckDynamicImageSizeShape(shape, iter->first, input_format); + if (ret) { + size++; + } else { + return ret; + } + } + if (size == 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10019"); + GELOGE(ge::PARAM_INVALID, + "--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size."); + return false; + } + + EraseEndSemicolon(dynamic_image_size); + // Different parameter sets are split string by ';' + std::vector split_set = StringUtils::Split(dynamic_image_size, ';'); + // Different dimensions are split by ',' + std::vector split_dim; + for (auto str : split_set) { + split_dim = StringUtils::Split(str, ','); + if (split_dim.size() != static_cast(kDynamicImageSizeNum)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10020", {"DynamicImageSizeNum"}, + {std::to_string(kDynamicImageSizeNum)}); + GELOGE(ge::PARAM_INVALID, + "--dynamic_image_size's number of dimensions of each " + "group must be %ld.", + kDynamicImageSizeNum); + return false; + } + } + + return true; +} + +bool CheckDynamicDimsInputShapeValid(const map> &shape_map, + string input_format, string &dynamic_dims) { + if (input_format != "ND") { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--input_format", input_format.c_str(), "input_format must be ND when set dynamic_dims"}); + GELOGE(ge::PARAM_INVALID, "input_format must be ND when set dynamic_dims."); + return false; + } + + int32_t dynamic_dim = 0; + for (auto &info_shapes : shape_map) { + auto &shapes = info_shapes.second; + if (shapes.size() > kMaxNDDimNum || shapes.size() < kMinNDDimNum) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--input_shape's dim", std::to_string(shapes.size()), "Dim num must within [1, 4] when set dynamic_dims"}); + GELOGE(ge::PARAM_INVALID, "Dim num must within [%zu, %zu] when set dynamic_dims.", kMinNDDimNum, kMaxNDDimNum); + return false; + } + dynamic_dim += std::count(shapes.begin(), shapes.end(), kDynamicInputDim); + } + if (dynamic_dim == 0) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--input_shape's dynamic dim num", "0", "at least one dim should be -1 when set dynamic_dims"}); + GELOGE(ge::PARAM_INVALID, "input_shape's shape is invalid, at least one dim should be -1 when set dynamic_dims."); + return false; + } + + if (!CheckAndParseDynamicDims(dynamic_dim, dynamic_dims)) { + GELOGE(ge::PARAM_INVALID, "Check and parse dynamic dims: %s failed.", dynamic_dims.c_str()); + return false; + } + + return true; +} + +bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims) { + EraseEndSemicolon(dynamic_dims); + if (dynamic_dims.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--dynamic_dims", dynamic_dims.c_str(), "dynamic_dims can not be empty"}); + GELOGE(ge::PARAM_INVALID, "dynamic_dims can not be empty."); + return false; + } + // Different parameter sets are split by ';' + vector split_set = StringUtils::Split(dynamic_dims, ';'); + if (split_set.size() > kMaxDynamicDimNum) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10042", {"parameter", "reason"}, {"dynamic_dims", "dynamic_dims's num of parameter set can not exceed 100"}); + GELOGE(ge::PARAM_INVALID, "dynamic_dims's num of parameter set can not exceed %zu.", kMaxDynamicDimNum); + return false; + } + for (auto split_dim : split_set) { + vector one_set = StringUtils::Split(split_dim, ','); + if (one_set.size() != static_cast(dynamic_dim_num)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10042", {"parameter", "reason"}, + {"dynamic_dims", "Each gear setting needs to be consistent with the number of -1 in the inputshape"}); + GELOGE(ge::PARAM_INVALID, "Input parameter --dynamic_dims parse failed, " + "reason: Each gear setting needs to be consistent with the number of -1 in the inputshape."); + return false; + } + for (auto dim : one_set) { + for (auto c : dim) { + if (!isdigit(c)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--dynamic_dims's parameter", dim.c_str(), "must be positive integer"}); + GELOGE(ge::PARAM_INVALID, "dynamic_dims's parameter must be positive integer."); + return false; + } + } + } + } + return true; +} + +Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, + const string input_shape, const string input_format, bool &is_dynamic_input) { + int32_t param_size = static_cast(!dynamic_batch_size.empty()) + + static_cast(!dynamic_image_size.empty()) + static_cast(!dynamic_dims.empty()); + if (param_size > 1) { + ErrorManager::GetInstance().ATCReportErrMessage("E10009", {"parameter0", "parameter1", "parameter2"}, + {"dynamic_batch_size", "dynamic_image_size", "dynamic_dims"}); + GELOGE(ge::PARAM_INVALID, "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one"); + return ge::PARAM_INVALID; + } + + if (param_size == 0) { + return ge::SUCCESS; + } + + map> shape_map; + vector>> user_shape_map; + is_dynamic_input = true; + if (input_shape.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"input_shape"}); + GELOGE(ge::PARAM_INVALID, "The input_shape can not be empty in dynamic input size scenario."); + return ge::PARAM_INVALID; + } + + if (!ParseInputShape(input_shape, shape_map, user_shape_map, is_dynamic_input)) { + GELOGE(ge::PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str()); + return ge::PARAM_INVALID; + } + + if (!dynamic_batch_size.empty()) { + if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) { + GELOGE(ge::PARAM_INVALID, "Check dynamic batch size input shape failed: %s", input_shape.c_str()); + return ge::PARAM_INVALID; + } + } + + if (!dynamic_image_size.empty()) { + if (!CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size)) { + GELOGE(ge::PARAM_INVALID, "Check dynamic image size input shape failed: %s", input_shape.c_str()); + return ge::PARAM_INVALID; + } + } + + if (!dynamic_dims.empty()) { + if (!CheckDynamicDimsInputShapeValid(shape_map, input_format, dynamic_dims)) { + GELOGE(ge::PARAM_INVALID, "Check dynamic dims: %s of input shape: %s failed.", dynamic_dims.c_str(), + input_shape.c_str()); + return ge::PARAM_INVALID; + } + } + return ge::SUCCESS; +} + +bool ParseInputShape(const string &input_shape, map> &shape_map, + vector>> &user_shape_map, bool is_dynamic_input) { + vector shape_vec = StringUtils::Split(input_shape, ';'); + const int DEFAULT_SHAPE_PAIR_SIZE = 2; + for (const auto &shape : shape_vec) { + vector shape_pair_vec = SplitInputShape(shape); + if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kSplitError1, kInputShapeSample1}); + GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kSplitError1, kInputShapeSample1); + return false; + } + if (shape_pair_vec[1].empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kEmptyError, kInputShapeSample1}); + GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kEmptyError, kInputShapeSample1); + return false; + } + + vector shape_value_strs = StringUtils::Split(shape_pair_vec[1], ','); + vector shape_values; + for (auto &shape_value_str : shape_value_strs) { + // stoul: The method may throw an exception: invalid_argument/out_of_range + if (std::string::npos != shape_value_str.find('.')) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kFloatNumError, kInputShapeSample2}); + GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kFloatNumError, kInputShapeSample2); + return false; + } + + long left_result = 0; + try { + left_result = stol(StringUtils::Trim(shape_value_str)); + if (!shape_value_str.empty() && (shape_value_str.front() == '-')) { + // The value maybe dynamic shape [-1], need substr it and verify isdigit. + shape_value_str = shape_value_str.substr(1); + } + for (char c : shape_value_str) { + if (!isdigit(c)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kDigitError, kInputShapeSample2}); + GELOGE(PARAM_INVALID, "--input_shape's shape value[%s] is not digit", shape_value_str.c_str()); + return false; + } + } + } catch (const std::out_of_range &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, + {"--input_shape", shape_value_str}); + GELOGW("Input parameter[--input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str()); + return false; + } catch (const std::invalid_argument &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, + {"--input_shape", shape_value_str}); + GELOGW("Input parameter[--input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str()); + return false; + } catch (...) { + ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"}, + {"--input_shape", shape_value_str}); + GELOGW("Input parameter[--input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str()); + return false; + } + int64_t result = left_result; + // - 1 is not currently supported + if (!is_dynamic_input && result <= 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape", "result"}, {shape, std::to_string(result)}); + GELOGW( + "Input parameter[--input_shape]’s shape value[%s] is invalid, " + "expect positive integer, but value is %ld.", + shape.c_str(), result); + return false; + } + shape_values.push_back(result); + } + + shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); + user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); + } + + return true; +} + +Status CheckOutputTypeParamValid(const std::string output_type) { + if ((!output_type.empty()) && (kOutputTypeSupportDatatype.find(output_type) == kOutputTypeSupportDatatype.end())) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--output_type", output_type, kOutputTypeSupport}); + GELOGE(ge::PARAM_INVALID, + "Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport); + return ge::PARAM_INVALID; + } + return ge::SUCCESS; +} + +Status CheckBufferOptimizeParamValid(const std::string buffer_optimize) { + if ((!buffer_optimize.empty()) && + (kBufferOptimizeSupportOption.find(buffer_optimize) == kBufferOptimizeSupportOption.end())) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--buffer_optimize", buffer_optimize, kBufferOptimizeSupport}); + GELOGE(ge::PARAM_INVALID, + "Invalid value for --buffer_optimize[%s], %s.", buffer_optimize.c_str(), kBufferOptimizeSupport); + return ge::PARAM_INVALID; + } + return ge::SUCCESS; +} + +Status CheckCompressWeightParamValid(const std::string enable_compress_weight, const std::string compress_weight_conf) { + if ((!compress_weight_conf.empty()) && + (!CheckInputPathValid(compress_weight_conf, "--compress_weight_conf"))) { + GELOGE(ge::PARAM_INVALID, "compress weight config file not found, file_name:%s", compress_weight_conf.c_str()); + return ge::PARAM_INVALID; + } + if ((enable_compress_weight != "") && (enable_compress_weight != "true") && (enable_compress_weight != "false")) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10005", {"parameter", "value"}, {"enable_compress_weight", enable_compress_weight}); + GELOGE(ge::PARAM_INVALID, + "Input parameter[--enable_compress_weight]'s value[%s] must be true or false.", enable_compress_weight.c_str()); + return ge::PARAM_INVALID; + } + + if ((enable_compress_weight == "true") && (!compress_weight_conf.empty())) { + ErrorManager::GetInstance().ATCReportErrMessage("E10047", {"parameter0", "parameter1"}, + {"enable_compress_weight", "compress_weight_conf"}); + GELOGE(ge::PARAM_INVALID, "enable_compress_weight and compress_weight_conf can not both exist!!"); + return ge::PARAM_INVALID; + } + return ge::SUCCESS; +} + +Status CheckKeepTypeParamValid(const std::string &keep_dtype) { + if ((!keep_dtype.empty()) && (!CheckInputPathValid(keep_dtype, "--keep_dtype"))) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--keep_dtype", keep_dtype, kKeepDtypeError}); + GELOGE(ge::PARAM_INVALID, "keep dtype config file not found, file_name:%s", keep_dtype.c_str()); + return ge::PARAM_INVALID; + } + + return ge::SUCCESS; +} + +int CheckLogParamValidAndSetLogLevel(const std::string log) { + int ret = -1; + if (log == "default") { + ret = 0; + } else if (log == "null") { + ret = dlog_setlevel(-1, DLOG_NULL, 0); + } else if (log == "debug") { + ret = dlog_setlevel(-1, DLOG_DEBUG, 1); + } else if (log == "info") { + ret = dlog_setlevel(-1, DLOG_INFO, 1); + } else if (log == "warning") { + ret = dlog_setlevel(-1, DLOG_WARN, 1); + } else if (log == "error") { + ret = dlog_setlevel(-1, DLOG_ERROR, 1); + } else { + GELOGE(ge::PARAM_INVALID, "invalid value for log:%s, only support debug, info, warning, error, null", log.c_str()); + return ret; + } + if (ret != 0) { + GELOGE(ge::PARAM_INVALID, "Log setlevel fail !"); + } + return ret; +} + +Status CheckInsertOpConfParamValid(const std::string insert_op_conf) { + if ((!insert_op_conf.empty()) && + (!CheckInputPathValid(insert_op_conf, "--insert_op_conf"))) { + GELOGE(ge::PARAM_INVALID, "insert op config file not found: %s", insert_op_conf.c_str()); + return ge::PARAM_INVALID; + } + return ge::SUCCESS; +} + +Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory) { + if ((disable_reuse_memory != "") && (disable_reuse_memory != "0") && (disable_reuse_memory != "1")) { + ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"disable_reuse_memory"}); + GELOGE(ge::PARAM_INVALID, "Input parameter[--disable_reuse_memory]'s value must be 1 or 0."); + return ge::PARAM_INVALID; + } + return ge::SUCCESS; +} + +Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream) { + if ((enable_single_stream != "") && (enable_single_stream != "true") && (enable_single_stream != "false")) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10005", {"parameter", "value"}, {"enable_single_stream", enable_single_stream}); + GELOGE(ge::PARAM_INVALID, "Input parameter[--enable_single_stream]'s value[%s] must be true or false.", + enable_single_stream.c_str()); + return ge::PARAM_INVALID; + } + return ge::SUCCESS; +} + +Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode) { + // only appointed op_select_implmode, can user appoint optypelist_for_implmode + if (optypelist_for_implmode != "" && op_select_implmode == "") { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--op_select_implmode", op_select_implmode.c_str(), kCompressWeightError}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --op_select_implmode[%s], %s.", + op_select_implmode.c_str(), kCompressWeightError); + return ge::PARAM_INVALID; + } + // op_select_implmode default value is high_performance + if (op_select_implmode == "") { + op_select_implmode = IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT; + } else { + if (op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT && + op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--op_select_implmode", op_select_implmode.c_str(), kSelectImplmodeError}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --op_select_implmode[%s], %s.", + op_select_implmode.c_str(), kSelectImplmodeError); + return ge::PARAM_INVALID; + } + } + + return ge::SUCCESS; +} + +void PrintOptionMap(std::map &options, std::string tips) { + for (auto iter = options.begin(); iter != options.end(); iter++) { + std::string key = iter->first; + std::string option_name = iter->second; + GELOGD("%s set successfully, option_key=%s, option_value=%s", tips.c_str(), key.c_str(), option_name.c_str()); + } +} + +void EraseEndSemicolon(string ¶m) { + if (param.empty()) { + return; + } + if (param.back() == ';') { + param.erase(param.end() - 1); + } +} +} // namespace ge diff --git a/atc/atc_ir_common.h b/atc/atc_ir_common.h new file mode 100644 index 0000000..c48cd21 --- /dev/null +++ b/atc/atc_ir_common.h @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATC_IR_COMMON_H_ +#define ATC_IR_COMMON_H_ + +#include +#include +#include +#include +#include +#include + +#include "framework/common/debug/ge_log.h" +#include "register/register_types.h" + +using std::string; +using std::vector; +using std::map; + +namespace ge { +static std::set caffe_support_input_format = {"NCHW", "ND"}; +static std::set tf_support_input_format = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"}; +static std::set onnx_support_input_format = {"NCHW", "ND"}; + +static std::map input_format_str_to_geformat = { + {"ND", domi::DOMI_TENSOR_ND}, + {"NCHW", domi::DOMI_TENSOR_NCHW}, + {"NHWC", domi::DOMI_TENSOR_NHWC}, + {"CHWN", domi::DOMI_TENSOR_CHWN}, + {"NC1HWC0", domi::DOMI_TENSOR_NC1HWC0}, + {"NHWC1C0", domi::DOMI_TENSOR_NHWC1C0}, + {"NCDHW", domi::DOMI_TENSOR_NCDHW}, + {"NDHWC", domi::DOMI_TENSOR_NDHWC} +}; +static const std::string kEnableCompressWeightTrue = "1"; +static const std::string kEnableCompressWeightFalse = "0"; + +bool CheckDynamicBatchSizeInputShapeValid(map> shape_map, + std::string &dynamic_batch_size); + +bool CheckDynamicImagesizeInputShapeValid(map> shape_map, + const std::string input_format, std::string &dynamic_image_size); + +bool CheckDynamicDimsInputShapeValid(const std::map> &shape_map, + std::string input_format, std::string &dynamic_dims); + +bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims); + +Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size, + std::string &dynamic_dims, const std::string input_shape, + const std::string input_format, bool &is_dynamic_input); + +bool ParseInputShape(const std::string &input_shape, std::map> &shape_map, + std::vector>> &user_shape_map, bool is_dynamic_input = false); + +Status CheckOutputTypeParamValid(const std::string output_type); +Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); +Status CheckCompressWeightParamValid(const std::string enable_compress_weight, const std::string compress_weight_conf); +int CheckLogParamValidAndSetLogLevel(const std::string log); +Status CheckInsertOpConfParamValid(const std::string insert_op_conf); +Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory); +Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream); +Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode); +Status CheckInputFormat(const string &input_format); +Status CheckKeepTypeParamValid(const std::string &keep_dtype); +void PrintOptionMap(std::map &options, std::string tips); +void EraseEndSemicolon(std::string ¶m); +} +#endif // ATC_IR_COMMON_H_ diff --git a/atc/common/types.cc b/atc/common/types.cc new file mode 100644 index 0000000..55a7f49 --- /dev/null +++ b/atc/common/types.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "types.h" + +namespace ge{ +const char *DATA = "Data"; +const char *NETOUTPUT = "NetOutput"; +const std::string OP_CONF_DELIMITER = ":"; +const std::string MODEL_ATTR_FUSION_MODEL_DEF = "fm"; +} // namespace ge diff --git a/atc/common/types.h b/atc/common/types.h new file mode 100644 index 0000000..bde7a05 --- /dev/null +++ b/atc/common/types.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_TYPES_H_ +#define COMMON_TYPES_H_ + +#include +#include + +#include "register/register_types.h" + +namespace ge { +extern const char *DATA; +extern const char *NETOUTPUT; +extern const std::string OP_CONF_DELIMITER; +extern const std::string MODEL_ATTR_FUSION_MODEL_DEF; +} // namespace ge + +#endif // COMMON_TYPES_H_ diff --git a/atc/main.cc b/atc/main.cc new file mode 100644 index 0000000..620b18f --- /dev/null +++ b/atc/main.cc @@ -0,0 +1,1428 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "util/gflags_util.h" +#include "util/string_util.h" +#include "util/util.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "framework/generator/ge_generator.h" +#include "framework/omg/ge_init.h" +#include "ge/ge_api_types.h" +#include "graph/anchor.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/graph.h" +#include "graph/op_desc.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/type_utils.h" +#include "atc_ir_common.h" + +#include "parse_graph.h" +#include "omg/parser/parser_factory.h" +#include "omg/parser/parser_inner_ctx.h" +#include "parser/common/register_tbe.h" +#include "register/op_registry.h" +#include "single_op_parser.h" +#include "external/ge/ge_ir_build.h" + +using domi::OpRegistrationData; +using domi::OpRegistry; +using domi::Status; +using domi::SUCCESS; +using ge::GEN_OM_MODEL; +using ge::GflagsUtils; +using ge::MODEL_TO_JSON; +using ge::ONLY_PRE_CHECK; +using ge::ParseInputShape; +using ge::PBTXT_TO_JSON; +using std::map; +using std::pair; +using std::shared_ptr; +using std::string; +using std::vector; + +static bool is_dynamic_input = false; + +const char *const kModeSupport = "only support 0(model to framework model), " + "1(framework model to json), 3(only pre-check), " + "5(pbtxt to json), 6(display model info)"; +const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow) 5(Onnx)"; + +static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; +static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; +static const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model"; + +// limit available mem size 2G +const long kMinAvailableMem = 2097152; // 2 * 1024 * 1024 + +DEFINE_string(model, "", "The model file."); +DEFINE_string(output, "", "The output file path&name."); +DEFINE_int32(framework, -1, "Framework type(0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx)."); +DEFINE_string(weight, "", "Optional; weight file. Required when framework is Caffe."); + +DEFINE_string(input_shape, "", + "Optional; shape of input data. Required when framework is caffe " + "or TensorFLow or MindSpore or Onnx. " + "Format: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\""); +DEFINE_bool(h, false, "show this help message"); +DEFINE_string(cal_conf, "", "Optional; the calibration config file."); + +DEFINE_string(insert_op_conf, "", "Optional; the config file to insert new op, for example AIPP op."); +DEFINE_string(op_name_map, "", "Optional; custom op name mapping file."); + +DEFINE_string(target, "", "Optional; mini."); + +DEFINE_string(om, "", "The model file to be converted to json."); +DEFINE_string(json, "", "The output json file path&name which is converted from a model."); +DEFINE_int32(mode, 0, + "Optional; run mode, 0(default): model => framework model; 1: " + "framework model => json; 3: only pre-check; 5: txt => json."); + +#if !defined(__ANDROID__) && !defined(ANDROID) +DEFINE_int32(encrypt_mode, -1, "Optional; the encrypt flag. 0: encrypt; -1(default): not encrypt"); +DEFINE_string(encrypt_key, "", "Optional; the encrypt_key file."); +DEFINE_string(certificate, "", "Optional; the certificate file."); +DEFINE_string(hardware_key, "", "Optional; the ISV key file."); +DEFINE_string(private_key, "", "Optional; the private key file."); +#endif + +DEFINE_string(out_nodes, "", + "Optional; output nodes designated by users." + "Format: \"node_name1:0;node_name1:1;node_name2:0\""); + +DEFINE_string(precision_mode, "force_fp16", + "Optional; precision mode." + "Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); + +DEFINE_string(keep_dtype, "", + "Optional; config file to specify the precision used by the operator during compilation."); + +DEFINE_string(input_format, "", + "Optional; input_format, format of input data, NCHW;NHWC." + "Format:\"NHWC\""); + +DEFINE_string(check_report, "check_result.json", "Optional; the pre-checking report file."); + +DEFINE_string(input_fp16_nodes, "", + "Optional; input node datatype is fp16 and format is NC1HWC0." + "Format:\"node_name1;node_name2\""); + +DEFINE_string(is_output_adjust_hw_layout, "", + "Optional; Net output node's datatype is fp16 and format is " + "NC1HWC0, or not." + "Format:\"false,true,false,true\""); + +DEFINE_string(is_input_adjust_hw_layout, "", + "Optional; Intput node's datatype is fp16 and format is " + "NC1HWC0, or not." + "Format:\"false,true,false,true\""); + +DEFINE_string(output_type, "", + "Optional; output type! " + "Support FP32,FP16,INT8,INT16,UINT16,UINT8,INT32,INT64,UINT32,UINT64,DOUBLE."); + +DEFINE_string(op_select_implmode, "", + "Optional; op select implmode! " + "Support high_precision, high_performance."); + +DEFINE_string(optypelist_for_implmode, "", + "Optional; Nodes need use implmode selected in op_select_implmode " + "Format:\"node_name1,node_name2\""); + +DEFINE_string(singleop, "", "Optional; If set, generate single op model with the given json file."); + +DEFINE_int32(disable_reuse_memory, 0, "Optional; If set to 1, disable reuse memory when generating if."); + +DEFINE_string(auto_tune_mode, "", "Optional; Set tune mode."); + +DEFINE_string(soc_version, "", "The soc version."); + +DEFINE_string(core_type, "AiCore", "Optional; If set to VectorCore, only use vector core."); + +DEFINE_string(aicore_num, "", "Optional; Set aicore num"); + +DEFINE_string(buffer_optimize, "l2_optimize", "Optional; buffer optimize"); + +DEFINE_string(fusion_switch_file, "", "Optional; Set fusion switch file path"); + +DEFINE_string(save_original_model, "", "Optional; enable output original offline model. false(default)"); + +DEFINE_string(dynamic_batch_size, "", + "Optional; If set, generate dynamic multi batch model. " + "Different batch sizes are split by ','." + "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); + +DEFINE_string(dynamic_image_size, "", + "Optional; If set, generate dynamic multi image size model." + "Different groups of image size are split by ';'," + "while different dimensions of each group are split by ','." + "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); + +DEFINE_string(dynamic_dims, "", + "Optional; If set, generate dynamic input size model. " + "Different groups of size are split by ';', while different dimensions of each group are split by ','." + "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); + +DEFINE_string(enable_small_channel, "0", "Optional; If set to 1, small channel is enabled."); + +DEFINE_string(enable_compress_weight, "false", + "Optional; enable compress weight. true: enable; false(default): disable"); + +DEFINE_string(compress_weight_conf, "", "Optional; the config file to compress weight"); + +DEFINE_string(enable_single_stream, "", "Optional; enable single stream. true: enable; false(default): disable"); + +DEFINE_string(log, "null", "Optional; generate atc log. Support debug, info, warning, error, null"); + +DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0."); + +DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug;" + "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler"); +DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass," + "multiple names can be set and separated by ','."); +DEFINE_string(debug_dir, "", "Optional; the path to save the intermediate files of operator compilation"); + +DEFINE_string(op_compiler_cache_dir, "", "Optional; the path to cache operator compilation files"); + +DEFINE_string(op_compiler_cache_mode, "", "Optional; choose the operator compiler cache mode"); + +DEFINE_string(mdl_bank_path, "", "Optional; model bank path"); + +DEFINE_string(op_bank_path, "", "Optional; op bank path"); + +DEFINE_string(display_model_info, "0", "Optional; display model info"); + +class GFlagUtils { + public: + /** + * @name InitGFlag + * @brief initialize gflag + * @return void + */ + static void InitGFlag(int argc, char *argv[]) { + // -help + gflags::SetUsageMessage( + "usage: ./atc \n" + "generate offline model example:\n" + "./atc --model=./alexnet.prototxt --weight=./alexnet.caffemodel \n" + "--framework=0 --output=./domi \n" + "generate offline model for single op example:\n" + "./atc --singleop=./op_list.json --output=./op_model \n" + "===== Basic Functionality =====\n" + "[General]\n" + " --h/help Show this help message\n" + " --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format; " + "3: only pre-check; 5: convert ge dump txt file to JSON format; 6: display model info\n" + "\n[Input]\n" + " --model Model file\n" + " --weight Weight file. Required when framework is Caffe\n" + " --om The model file to be converted to json\n" + " --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n" + " --input_format Format of input data. E.g.: \"NCHW\"\n" + " --input_shape Shape of input data. Separate multiple nodes with semicolons (;). " + "Use double quotation marks (\") to enclose each argument.\n" + " E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n" + " --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n" + " --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;). " + "Use double quotation marks (\") to enclose each argument.\n" + " E.g.: \"imagesize1_height,imagesize1_width;imagesize2_height,imagesize2_width\"\n" + " --dynamic_dims Set dynamic dims. Separate multiple nodes with semicolons (;). " + "Use double quotation marks (\") to enclose each argument.\n" + " E.g.: \"dims1_n1,dims1_n2;dims2_n1,dims2_n2\"\n" + " --singleop Single op definition file. atc will generate offline " + "model(s) for single op if --singleop is set.\n" + "\n[Output]\n" + " --output Output file path&name(needn't suffix, will add .om automatically). \n" + " If --singleop is set, this arg specifies the directory to " + "which the single op offline model will be generated\n" + " --output_type Set net output type. Support FP32, FP16, UINT8. " + "E.g.: FP16, indicates that all out nodes are set to FP16.\n" + " \"node1:0:FP16;node2:1:FP32\", indicates setting the datatype of multiple out nodes.\n" + " --check_report The pre-checking report file. Default value is: \"check_result.json\"\n" + " --json The output json file path&name which is converted from a model\n" + "\n[Target]\n" + " --soc_version The soc version.\n" + " --core_type Set core type AiCore or VectorCore. VectorCore: use vector core. " + "Default value is: AiCore\n" + " --aicore_num Set aicore num\n" + "===== Advanced Functionality =====\n" + "[Feature]\n" + " --out_nodes Output nodes designated by users. Separate multiple nodes with semicolons (;)." + "Use double quotation marks (\") to enclose each argument.\n" + " E.g.: \"node_name1:0;node_name1:1;node_name2:0\"\n" + " --input_fp16_nodes Input node datatype is fp16. Separate multiple nodes with semicolons (;). " + "Use double quotation marks (\") to enclose each argument. " + "E.g.: \"node_name1;node_name2\"\n" + " --insert_op_conf Config file to insert new op\n" + " --op_name_map Custom op name mapping file\n" + " Note: A semicolon(;) cannot be included in each " + "path, otherwise the resolved path will not match the expected one.\n" + " --is_input_adjust_hw_layout Intput node datatype is fp16 and format is " + "NC1HWC0, used with input_fp16_nodes. E.g.: \"true,true,false,true\"\n" + " --is_output_adjust_hw_layout Net output node datatype is fp16 and format is " + "NC1HWC0, used with out_nodes. E.g.: \"true,true,false,true\"\n" + "\n[Model Tuning]\n" + " --disable_reuse_memory The switch of reuse memory. Default value is : 0. " + "0 means reuse memory, 1 means do not reuse memory.\n" + " --fusion_switch_file Set fusion switch file path\n" + " --enable_scope_fusion_passes validate the non-general scope fusion passes, " + "multiple names can be set and separated by ','. E.g.: ScopePass1,ScopePass2,...\n" + " --enable_single_stream Enable single stream. true: enable; false(default): disable\n" + " --enable_small_channel Set enable small channel. 0(default): disable; 1: enable\n" + " --enable_compress_weight Enable compress weight. true: enable; false(default): disable\n" + " --compress_weight_conf Config file to compress weight\n" + " --buffer_optimize Set buffer optimize. Support \"l2_optimize\" (default), " + "\"l1_optimize\", \"off_optimize\"\n" + " --mdl_bank_path Set the path of the custom repository generated after model tuning.\n" + "\n[Operator Tuning]\n" + " --precision_mode precision mode, support force_fp16(default), allow_mix_precision, " + "allow_fp32_to_fp16, must_keep_origin_dtype.\n" + " --keep_dtype Retains the precision of certain operators in inference " + "scenarios by using a configuration file.\n" + " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" + " --op_bank_path Set the path of the custom repository generated after operator tuning with Auto Tune.\n" + " --op_select_implmode Set op select implmode. Support high_precision, high_performance. " + "default: high_performance\n" + " --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n" + " Separate multiple nodes with commas (,). Use double quotation marks (\") " + "to enclose each argument. E.g.: \"node_name1,node_name2\"\n" + " --op_debug_level Debug enable for TBE operator building.\n" + " 0 (default): Disable debug; 1: Enable TBE pipe_all, " + "and generate the operator CCE file and Python-CCE mapping file (.json);\n" + " 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file " + "(.json), and enable the CCE compiler -O0-g.\n" + " 3: Disable debug, and keep generating kernel file (.o and .json)\n" + "\n[Debug]\n" + " --save_original_model Control whether to output original model. E.g.: true: output original model\n" + " --log Generate log with level. Support debug, info, warning, error, null\n" + " --dump_mode The switch of dump json with shape, to be used with mode 1. " + "0(default): disable; 1: enable.\n" + " --debug_dir Set the save path of operator compilation intermediate files.\n" + "Default value: ./kernel_meta\n" + " --op_compiler_cache_dir Set the save path of operator compilation cache files.\n" + "Default value: $HOME/atc_data\n" + " --op_compiler_cache_mode Set the operator compilation cache mode." + "Options are disable(default), enable and force(force to refresh the cache)\n" + " --display_model_info enable for display model info; 0(default): close display, 1: open display"); + + gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true); + // Using gflags to analyze input parameters + GflagsUtils::ChangeHelpFlags(FLAGS_h); + gflags::HandleCommandLineHelpFlags(); + } + + static Status CheckDumpInfershapeJsonFlags() { + Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight); + GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, + "check custom aicpu run so failed!"); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), + return domi::FAILED, "Input parameter[--weight]'s value[%s] is invalid!", + FLAGS_weight.c_str()); + return domi::SUCCESS; + } + + static Status CheckFlags() { + Status ret = ge::SUCCESS; + // No model file information passed in + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_model == "", + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"model"}); + ret = ge::FAILED, "Input parameter[--model]'s value is empty!"); + + // check param disable_reuse_memory + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ge::CheckDisableReuseMemoryParamValid(to_string(FLAGS_disable_reuse_memory)) != ge::SUCCESS, + ret = ge::FAILED, "check disable_reuse_memory failed!"); + + // check optypelist_for_implmode and op_select_implmode + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, + FLAGS_op_select_implmode) != ge::SUCCESS, + ret = ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!"); + // No output file information passed in + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_mode == GEN_OM_MODEL && FLAGS_output == "", + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"output"}); + ret = ge::FAILED, "Input parameter[--output]'s value is empty!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + CheckFrameWorkValid(FLAGS_framework, FLAGS_weight) != ge::SUCCESS, + ret = ge::FAILED, + "CheckFrameWorkValid failed"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size, + FLAGS_dynamic_dims, FLAGS_input_shape, + FLAGS_input_format, is_dynamic_input) != ge::SUCCESS, + ret = ge::FAILED, "check dynamic size(batch size, image size or dims) failed!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + !FLAGS_insert_op_conf.empty() && !FLAGS_dynamic_dims.empty(), + ErrorManager::GetInstance().ATCReportErrMessage("E10001", + {"parameter", "value", "reason"}, + {"--insert_op_conf", FLAGS_insert_op_conf, + "dynamic dims function does not support aipp"}); + ret = ge::FAILED, "dynamic dims function does not support aipp"); + +#if !defined(__ANDROID__) && !defined(ANDROID) + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!CheckEncryptModeValid(FLAGS_encrypt_mode), ret = ge::FAILED, + "encrypt_mode %d not valid!!", FLAGS_encrypt_mode); + + if (FLAGS_encrypt_mode == 0) { // Encryption mode + GELOGI("ge will run with encrypt!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_encrypt_key), ret = ge::FAILED, + "encrypt_key file not found!!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_certificate), ret = ge::FAILED, + "certificate file not found!!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_hardware_key), ret = ge::FAILED, + "hardware_key file not found!!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_private_key), ret = ge::FAILED, + "private_key file not found!!"); + } else { // No encryption + GELOGI("ge will run without encrypt!"); + } +#endif + + /** + * Check the validity of the I / O file path + */ + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_model != "" && !ge::CheckInputPathValid(FLAGS_model, "--model"), ret = ge::FAILED, + "model file %s not found!!", FLAGS_model.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), + ret = ge::FAILED, "weight file %s not found!!", + FLAGS_weight.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "--cal_conf"), + ret = ge::FAILED, "calibration config file %s not found!!", + FLAGS_cal_conf.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "--op_name_map"), + ret = ge::FAILED, "op config file %s not found!!", + FLAGS_op_name_map.c_str()); + + GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(FLAGS_insert_op_conf)) == ge::SUCCESS, + ret = ge::FAILED, "check insert op conf failed!"); + + GE_CHK_BOOL_EXEC(ge::CheckCompressWeightParamValid( + FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, + ret = ge::FAILED, "check compress weight failed!"); + + GE_CHK_BOOL_EXEC(ge::CheckKeepTypeParamValid(FLAGS_keep_dtype) == ge::SUCCESS, + ret = ge::FAILED, "check keep dtype failed!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + !ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED, + "check_report file %s not found!!", FLAGS_check_report.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_mode == GEN_OM_MODEL && FLAGS_output != "" && + (!ge::CheckOutputPathValid(FLAGS_output, "--output") || !CheckPathWithName(FLAGS_output)), + ret = ge::FAILED, "output path %s is not valid!!", FLAGS_output.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_save_original_model != "" && + FLAGS_save_original_model != "true" && + FLAGS_save_original_model != "false", + ErrorManager::GetInstance().ATCReportErrMessage( + "E10005", {"parameter", "value"}, {"save_original_model", FLAGS_save_original_model}); + ret = ge::FAILED, + "Input parameter[--save_original_model]'s value[%s] must be true or false.", + FLAGS_save_original_model.c_str()); + GE_CHK_BOOL_EXEC(ge::CheckBufferOptimizeParamValid(FLAGS_buffer_optimize) == ge::SUCCESS, + ret = ge::FAILED, "check output type failed!"); + + GE_CHK_BOOL_EXEC( + ge::CheckEnableSingleStreamParamValid(std::string(FLAGS_enable_single_stream)) == ge::SUCCESS, + ret = ge::FAILED, "check enable single stream failed!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((FLAGS_display_model_info != "0") && (FLAGS_display_model_info != "1"), + ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"display_model_info"}); + ret = ge::FAILED, "Input parameter[--display_model_info]'s value must be 1 or 0."); + + return ret; + } + + /** + * Verifying the parameters of converting model to JSON + * 1. Fmk_model + * 2. out_json + **/ + static Status CheckConverJsonParamFlags() { + Status ret = ge::SUCCESS; + + // No model path passed in + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "", + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"}); + ret = ge::FAILED, + "Input parameter[--om]'s value is empty!!"); + + // JSON path not passed in + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_json == "", + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"json"}); + ret = ge::FAILED, + "Input parameter[--json]'s value is empty!!"); + + // Check if the model path is valid + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"), + ret = ge::FAILED, + "model file path is invalid: %s.", FLAGS_om.c_str()); + + // Check whether the JSON path is valid + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_json != "" && !ge::CheckOutputPathValid(FLAGS_json, "--json"), + ret = ge::FAILED, + "json file path is invalid: %s.", FLAGS_json.c_str()); + + return ret; + } + + /** + * Check command line parameters for explicit settings + * true: Explicit setup + * false: Not set up + * */ + static bool CheckFlagSet(string flag) { + gflags::CommandLineFlagInfo info; + return !(gflags::GetCommandLineFlagInfo(flag.c_str(), &info) && info.is_default); + } + + private: + static bool CheckEncryptModeValid(const int encrypt_mode) { +#if !defined(__ANDROID__) && !defined(ANDROID) + if (encrypt_mode != 0 && encrypt_mode != -1) { + GELOGE(ge::FAILED, "encrypt mode must be 0 or -1"); + return false; + } +#else + if (encrypt_mode != -1) { + GELOGE(ge::FAILED, "encrypt mode must be -1"); + return false; + } +#endif + + return true; + } + + static Status CheckFrameWorkValid(int framework, const std::string weight_file) { + if (framework != (int32_t)domi::CAFFE && framework != (int32_t)domi::TENSORFLOW && + framework != (int32_t)domi::MINDSPORE && framework != (int32_t)domi::ONNX) { + // No framework information was passed in or the entered framework is illegal + ErrorManager::GetInstance().ATCReportErrMessage( + "E10007", {"parameter", "support"}, + {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)"}); + GELOGE(ge::FAILED, "Input parameter[--framework] is mandatory and it's value must be: " + "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)."); + return domi::PARAM_INVALID; + } + + if ((framework == (int32_t)domi::CAFFE) && (weight_file == "")) { + ErrorManager::GetInstance().ATCReportErrMessage("E10008", {"parameter"}, {"weight"}); + GELOGE(ge::FAILED, "Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!"); + return domi::PARAM_INVALID; + } + + if ((framework == (int32_t)domi::TENSORFLOW) && (weight_file != "")) { + GELOGW("Parameter weight is ignored for TensorFlow."); + } + + if ((framework == (int32_t)domi::ONNX) && (weight_file != "")) { + GELOGW("Parameter weight is ignored for Onnx."); + } + return domi::SUCCESS; + } + + static bool CheckPathWithName(const std::string &fileName) { + // Determine file path length + if (fileName.size() > static_cast(PATH_MAX)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)}); + GELOGE(ge::FAILED, "Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX); + return false; + } + + // Find the last separator + int slashPosition = fileName.size() - 1; + for (; slashPosition >= 0; slashPosition--) { + if (fileName[slashPosition] == '\\' || fileName[slashPosition] == '/') { + break; + } + } + + // Failure if no filename follows the path + if (slashPosition == static_cast(fileName.size() - 1)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10022", {"parameter", "filename"}, {"output", fileName}); + GELOGE(ge::FAILED, "Input parameter[--output]'s path[%s] not include file name", fileName.c_str()); + return false; + } + + return true; + } +}; + +void SetDynamicInputSizeOptions() { + if (!FLAGS_dynamic_batch_size.empty()) { + domi::GetContext().dynamic_batch_size = FLAGS_dynamic_batch_size; + } + if (!FLAGS_dynamic_image_size.empty()) { + domi::GetContext().dynamic_image_size = FLAGS_dynamic_image_size; + } + if (!FLAGS_dynamic_dims.empty()) { + domi::GetContext().dynamic_dims = FLAGS_dynamic_dims; + } +} + +/// Validate the non-general scope fusion pass. +/// The parameter is set to the name of the fusion rule. +/// Multiple names can be set and separated by ",". +void SetEnableScopeFusionPasses(const std::string pass_names) { + ge::GetParserContext().enable_scope_fusion_passes = pass_names; +} + +static bool CheckInputFormat() { + if (FLAGS_input_format.empty()) { + // Set default format + if (FLAGS_framework == static_cast(domi::TENSORFLOW)) { + FLAGS_input_format = "NHWC"; + } else { + FLAGS_input_format = "NCHW"; + } + return true; + } else if ((FLAGS_framework == static_cast(domi::CAFFE))) { // caffe + if (ge::caffe_support_input_format.find(FLAGS_input_format) != ge::caffe_support_input_format.end()) { + return true; + } + // only support NCHW ND + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport}); + GELOGE(ge::FAILED, + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kCaffeFormatSupport); + return false; + } else if ((FLAGS_framework == static_cast(domi::TENSORFLOW))) { // tf + if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) { + return true; + } + // only support NCHW NHWC ND NCDHW NDHWC + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport}); + GELOGE(ge::FAILED, + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kTFFormatSupport); + return false; + } else if (FLAGS_framework == static_cast(domi::ONNX)) { + if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) { + return true; + } + // only support NCHW ND + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport}); + GELOGE(ge::FAILED, + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kONNXFormatSupport); + return false; + } + return true; +} + +#if !defined(__ANDROID__) && !defined(ANDROID) +static void GetCustomOpPath(std::string &customop_path) { + GELOGI("Enter get custom op path schedule"); + std::string fmk_type = ge::TypeUtils::FmkTypeToSerialString(static_cast(FLAGS_framework)); + GELOGI("Framework type is %s.", fmk_type.c_str()); + + const char *path_env = std::getenv("ASCEND_OPP_PATH"); + if (path_env != nullptr) { + std::string path = path_env; + customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type); + GELOGI("Get custom so path from env : %s", path_env); + return; + } + std::string path_base = ge::GEInit::GetPath(); + GELOGI("path_base is %s", path_base.c_str()); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); + customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type); + return; +} + +void GetPluginSoFileList(const string &path, vector &fileList, string &caffe_parser_path) { + // Support to split multiple so directories by ":" + GELOGI("path is %s", path.c_str()); + vector v_path = ge::StringUtils::Split(path, ':'); + for (size_t i = 0; i < v_path.size(); ++i) { + ge::FindParserSo(v_path[i], fileList, caffe_parser_path); + GELOGI("CustomOpLib full name = %s", v_path[i].c_str()); + } +} + +void LoadModelParserLib(std::string caffe_parser_path) { + if (FLAGS_framework == static_cast(domi::TENSORFLOW)) { + void *tf_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL); + if (tf_handle == nullptr) { + GELOGW("dlopen fmk library [libfmk_parser.so] failed."); + return; + } + GELOGI("plugin load libfmk_parser.so success."); + } else if (FLAGS_framework == static_cast(domi::CAFFE)) { + // What we are dealing with here is that the user modifies the caffe.proto scenario. + // If no lib_Caffe_Parser.so is found under the plugin path, use the default lib_Caffe_Parser.so path. + caffe_parser_path = caffe_parser_path.empty() ? "lib_caffe_parser.so" : caffe_parser_path; + + void *handle = dlopen(caffe_parser_path.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + GELOGW("dlopen failed, plugin name:%s. Message(%s).", caffe_parser_path.c_str(), dlerror()); + return; + } + GELOGI("plugin load %s success.", caffe_parser_path.c_str()); + // According to the dependency, the Caffe parsing module of the framework is loaded here( libfmk_parser.so). + // (depend on the lib_caffe_parser.so) + void *fmk_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL); + if (fmk_handle == nullptr) { + GELOGW("dlopen fmk library [libfmk_parser.so] failed."); + if (dlclose(handle) != 0) { + GELOGW("dlclose lib_caffe_parser.so failed."); + } + return; + } + GELOGI("plugin load libfmk_parser.so success."); + } else if (FLAGS_framework == static_cast(domi::ONNX)) { + void *handle = dlopen("libfmk_onnx_parser.so", RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + GELOGW("dlopen fmk library [libfmk_onnx_parser.so] failed."); + return; + } + GELOGI("plugin load libfmk_onnx_parser.so success."); + } else { + GELOGW("Framework:%s is not support.", + ge::TypeUtils::FmkTypeToSerialString(static_cast(FLAGS_framework)).c_str()); + return; + } + return; +} + +void LoadCustomOpLib(bool need_load_ops_plugin) { + std::string plugin_path; + GetCustomOpPath(plugin_path); + + vector fileList; + string caffe_parser_path = ""; + + // whether there are files in the plugin so path + GetPluginSoFileList(plugin_path, fileList, caffe_parser_path); + + // no file + if (fileList.empty() && caffe_parser_path.empty()) { + GELOGW("can not find any plugin file in plugin_path: %s", plugin_path.c_str()); + } + + LoadModelParserLib(caffe_parser_path); + if (!need_load_ops_plugin) { + GELOGI("No need to load ops plugin so."); + return; + } + OpRegistry::Instance()->registrationDatas.clear(); + // load other so files except lib_caffe_parser.so in the plugin so path + for (auto elem : fileList) { + ge::StringUtils::Trim(elem); + + void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); + } else { + GELOGI("plugin load %s success.", elem.c_str()); + } + } + + std::vector registrationDatas = OpRegistry::Instance()->registrationDatas; + for (OpRegistrationData reg_data : registrationDatas) { + if (reg_data.GetFrameworkType() == static_cast(FLAGS_framework)) { + (void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data); + (void)OpRegistry::Instance()->Register(reg_data); + } + } +} + +void SaveCustomCaffeProtoPath() { + GELOGI("Enter save custom caffe proto path."); + + std::string path_base = ge::GEInit::GetPath(); + GELOGI("path_base is %s", path_base.c_str()); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); + ge::GetParserContext().caffe_proto_path = path_base + "include/proto/"; + + string customop_path; + const char *path_env = std::getenv("ASCEND_OPP_PATH"); + if (path_env != nullptr) { + std::string path = path_env; + customop_path = path + "/framework/custom/caffe/"; + GELOGI("Get custom proto path from env : %s", path_env); + ge::GetParserContext().custom_proto_path = customop_path; + return; + } + customop_path = path_base + "ops/framework/custom/caffe/"; + ge::GetParserContext().custom_proto_path = customop_path; + return; +} + +#endif + +Status CreateInputsForInference(const ge::Graph &graph, vector &inputs) { + auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + for (ge::NodePtr &input_node : compute_graph->GetAllNodes()) { + GE_CHECK_NOTNULL(input_node); + ge::OpDescPtr op = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op); + if (op->GetType() == "Data") { + GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size()); + ge::GeTensorDesc tensor = op->GetInputDesc(0); + string data_op_name = op->GetName(); + GELOGI("Data op name is: %s", data_op_name.c_str()); + ge::GeShape data_shape; + auto iter = domi::GetContext().input_dims.find(data_op_name); + if (iter != domi::GetContext().input_dims.end()) { + data_shape = ge::GeShape(iter->second); + GELOGI("Data op get shape from Context."); + } else { + data_shape = tensor.GetShape(); + GELOGI("Data op get shape from InputDesc in geir graph."); + } + + ge::DataType data_type = tensor.GetDataType(); + string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type); + GELOGI("Data op get data type:%s from InputDesc in geir graph.", data_type_str.c_str()); + + ge::GeTensor input_tensor; + ge::GeTensorDesc desc(data_shape, ge::Format(domi::GetContext().format), data_type); + input_tensor.SetTensorDesc(desc); + inputs.push_back(input_tensor); + } + } + GELOGI("Build ME model, inputs size is: %zu", inputs.size()); + return ge::SUCCESS; +} + +domi::Status GenerateInfershapeJson() { + if (!CheckInputFormat()) { + GELOGE(ge::FAILED, "Check input_format failed"); + return domi::FAILED; + } + Status ret = GFlagUtils::CheckDumpInfershapeJsonFlags(); + GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check flags failed!"); + + ge::GeGenerator ge_generator; + std::map options; + ge::Status geRet = ge_generator.Initialize(options, domi::GetContext()); + if (geRet != ge::SUCCESS) { + GELOGE(ge::FAILED, "GeGenerator initialize failed!"); + return domi::FAILED; + } + + ge::Graph graph; + std::map atc_params; + atc_params.insert(std::pair("input_format", FLAGS_input_format)); + ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework, + "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false); + if (ret != ge::SUCCESS) { + GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED"); + (void)ge_generator.Finalize(); + return domi::FAILED; + } + + geRet = ge_generator.GenerateInfershapeGraph(graph); + if (geRet != ge::SUCCESS) { + GELOGE(ge::FAILED, "ATC GenerateInfershapeJson failed"); + (void)ge_generator.Finalize(); + return domi::FAILED; + } + if (DumpInfershapeJson(graph, FLAGS_json.c_str()) != SUCCESS) { + GELOGE(ge::FAILED, "ATC DumpInfershapeJson failed"); + (void)ge_generator.Finalize(); + return domi::FAILED; + } + (void)ge_generator.Finalize(); + return ge::SUCCESS; +} + +static Status ConvertModelToJson(int fwk_type, const string &model_file, const string &json_file) { + Status ret = ge::SUCCESS; + if (fwk_type == -1) { + ret = ge::ConvertOm(model_file.c_str(), json_file.c_str(), true); + return ret; + } + + if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE) && (fwk_type != domi::ONNX)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--framework", std::to_string(fwk_type), kModelToJsonSupport}); + GELOGE(ge::FAILED, "Invalid value for --framework[%d], %s.", fwk_type, kModelToJsonSupport); + ret = ge::FAILED; + } + + if (FLAGS_dump_mode != "0" && FLAGS_dump_mode != "1") { + ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"dump_mode"}); + GELOGE(ge::FAILED, "Input parameter[--dump_mode]'s value must be 1 or 0."); + ret = ge::FAILED; + } + + if (ret != ge::SUCCESS) return ret; + + // Need to save caffe.proto path + SaveCustomCaffeProtoPath(); + + if (FLAGS_dump_mode == "0") { + // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so. + LoadCustomOpLib(false); + ret = ge::ConvertFwkModelToJson((domi::FrameworkType)fwk_type, model_file.c_str(), json_file.c_str()); + } else if (FLAGS_dump_mode == "1") { + // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so and ops plugin so. + LoadCustomOpLib(true); + ret = GenerateInfershapeJson(); + } + + return ret; +} + +static Status SetAttrOptions(ge::Graph &graph) { + if (!FLAGS_keep_dtype.empty()) { + if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_KEEP_DTYPE, FLAGS_keep_dtype.c_str()) != ge::GRAPH_SUCCESS) { + return ge::FAILED; + } + } + if (!FLAGS_compress_weight_conf.empty()) { + if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_WEIGHT_COMPRESS, FLAGS_compress_weight_conf.c_str()) + != ge::GRAPH_SUCCESS) { + return ge::FAILED; + } + } + + return ge::SUCCESS; +} + +domi::Status GenerateModel(std::map &options, std::string output) { + ge::GeGenerator ge_generator; + ge::Status geRet = ge::SUCCESS; + geRet = ge::GEInit::Initialize(options); + if (geRet != ge::SUCCESS) { + GELOGE(ge::FAILED, "GE initialize failed!"); + return domi::FAILED; + } + geRet = ge_generator.Initialize(options, domi::GetContext()); + if (geRet != ge::SUCCESS) { + GELOGE(ge::FAILED, "GeGenerator initialize failed!"); + (void)ge::GEInit::Finalize(); + return domi::FAILED; + } + + ge::Graph graph; + std::vector inputs; + if (FLAGS_framework == domi::MINDSPORE) { + // load model from file + ge::Model load_model = ge::Model("loadmodel", "version2"); + auto ret1 = load_model.LoadFromFile(FLAGS_model); + if (ret1 != ge::GRAPH_SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model}); + GELOGE(ge::FAILED, "Load model from %s failed, please check model file or " + "input parameter[--framework] is correct", FLAGS_model.c_str()); + (void)ge_generator.Finalize(); + (void)ge::GEInit::Finalize(); + return domi::FAILED; + } + + graph = load_model.GetGraph(); + + GE_CHK_STATUS_EXEC(ge::InitDomiOmgContext(FLAGS_input_shape, FLAGS_input_format, "", is_dynamic_input), + GELOGE(ge::FAILED, "ATC Generate call InitDomiOmgContext ret fail"); + (void)ge_generator.Finalize(); (void)ge::GEInit::Finalize(); return ge::FAILED); + + Status ret = CreateInputsForInference(graph, inputs); + if (ret != ge::SUCCESS) { + GELOGE(ge::FAILED, "create inputs for inference failed."); + (void)ge_generator.Finalize(); + (void)ge::GEInit::Finalize(); + return domi::FAILED; + } + + } else { + std::map atc_params; + atc_params.insert(std::pair("input_shape", FLAGS_input_shape)); + atc_params.insert(std::pair("out_nodes", FLAGS_out_nodes)); + atc_params.insert(std::pair("input_format", FLAGS_input_format)); + atc_params.insert(std::pair("check_report", FLAGS_check_report)); + atc_params.insert(std::pair("input_fp16_nodes", FLAGS_input_fp16_nodes)); + atc_params.insert(std::pair("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout)); + atc_params.insert(std::pair("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout)); + atc_params.insert(std::pair(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); + atc_params.insert(std::pair("output", output)); + + Status ret = + ParseGraph(graph, atc_params, FLAGS_model.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType)FLAGS_framework, + FLAGS_op_name_map.c_str(), FLAGS_target.c_str(), (ge::RunMode)FLAGS_mode, is_dynamic_input); + + // in ONLY_PRE_CHECK mode, pre-checking report has already saved in ParseGraph + if (FLAGS_mode == ge::ONLY_PRE_CHECK) { + (void)ge_generator.Finalize(); + (void)ge::GEInit::Finalize(); + if (ret != ge::SUCCESS) { + GELOGE(ge::FAILED, "ATC precheck fail."); + return domi::FAILED; + } + return domi::SUCCESS; + } + + if (ret != ge::SUCCESS) { + GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED"); + GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case + (void)ge_generator.Finalize(); + (void)ge::GEInit::Finalize(); + return domi::FAILED; + } + if (ge::SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) { + GELOGE(ge::FAILED, "Set output node info fail."); + (void)ge_generator.Finalize(); + (void)ge::GEInit::Finalize(); + return domi::FAILED; + } + } + + if (SetAttrOptions(graph) != ge::SUCCESS) { + (void)ge_generator.Finalize(); + (void)ge::GEInit::Finalize(); + return domi::FAILED; + } + + geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); + if (geRet != ge::SUCCESS) { + GELOGE(ge::FAILED, "GE GenerateOfflineModel execute failed"); + GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case + // checking error log) + (void)ge_generator.Finalize(); + (void)ge::GEInit::Finalize(); + return domi::FAILED; + } + (void)ge_generator.Finalize(); + (void)ge::GEInit::Finalize(); + return ge::SUCCESS; +} + +static void SetEnvForSingleOp(std::map &options) { + string flag_on = "1"; + string flag_off = "0"; + options.emplace(ge::GE_FE_FLAG, flag_on); + options.emplace(ge::STREAM_NUM, "1"); // single op only use one stream + options.emplace(ge::RUN_FLAG, flag_off); + options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off); + options.emplace(ge::SINGLE_OP_FLAG, flag_on); + options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode); + options.emplace(ge::SOC_VERSION, FLAGS_soc_version); + options.emplace(ge::CORE_TYPE, FLAGS_core_type); + options.emplace(ge::AICORE_NUM, FLAGS_aicore_num); + options.emplace(ge::OP_SELECT_IMPL_MODE, FLAGS_op_select_implmode); + options.emplace(ge::OPTYPELIST_FOR_IMPLMODE, FLAGS_optypelist_for_implmode); + options.emplace(ge::AUTO_TUNE_MODE, FLAGS_auto_tune_mode); + options.emplace(ge::OP_DEBUG_LEVEL, to_string(FLAGS_op_debug_level)); + options.emplace(ge::DEBUG_DIR, FLAGS_debug_dir); + options.emplace(ge::OP_COMPILER_CACHE_DIR, FLAGS_op_compiler_cache_dir); + options.emplace(ge::OP_COMPILER_CACHE_MODE, FLAGS_op_compiler_cache_mode); + options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path); + options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path); +} + +domi::Status GenerateSingleOp(const std::string& json_file_path) { + if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) { + GELOGE(ge::FAILED, "output path %s is not valid!", FLAGS_output.c_str()); + return domi::FAILED; + } + // check optypelist_for_implmode and op_select_implmode + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS, + return ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!"); + + std::map options; + // need to be changed when ge.ini plan is done + SetEnvForSingleOp(options); + + auto ret = ge::GEInit::Initialize(options); + if (ret != ge::SUCCESS) { + GELOGE(ge::FAILED, "GE initialize failed!"); + return domi::FAILED; + } + + ge::GeGenerator generator; + ret = generator.Initialize(options, domi::GetContext()); + if (ret != SUCCESS) { + GELOGE(ge::FAILED, "GeGenerator initialize failed!"); + (void)ge::GEInit::Finalize(); + return domi::FAILED; + } + + vector build_params; + if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) { + GELOGE(ge::FAILED, "parse single op json file failed"); + (void)generator.Finalize(); + (void)ge::GEInit::Finalize(); + return domi::FAILED; + } + + int index = 0; + for (auto ¶m : build_params) { + string output_path; + if (!FLAGS_output.empty()) { + output_path = FLAGS_output + "/"; + } + output_path += param.file_name; + ret = generator.BuildSingleOpModel(param.op_desc, param.inputs, param.outputs, output_path); + if (ret != SUCCESS) { + GELOGE(ge::FAILED, "Compile op failed. ge ret = %u, op index = %d", ret, index); + ret = domi::FAILED; + break; + } + GELOGI("Compile op success. op index = %d, output = %s", index, output_path.c_str()); + index += 1; + } + + (void)generator.Finalize(); + (void)ge::GEInit::Finalize(); + return ret; +} + +domi::Status GenerateOmModel() { + if (!CheckInputFormat()) { + GELOGE(ge::FAILED, "Check input_format failed"); + return domi::FAILED; + } + Status ret = GFlagUtils::CheckFlags(); + GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, + "Check flags failed! Please check whether some atc params that include semicolons[;] use double " + "quotation marks (\") to enclose each argument such as out_nodes, input_shape, dynamic_image_size"); +#if !defined(__ANDROID__) && !defined(ANDROID) + // Load custom operator Library + LoadCustomOpLib(true); + + SaveCustomCaffeProtoPath(); + + GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "check custom aicpu run so failed!"); +#endif + + const int f_stream_num = 1; + std::map options; + options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(FLAGS_framework))); + options.insert(std::pair(string(ge::STREAM_NUM), to_string(f_stream_num))); + options.insert(std::pair(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf)); + options.insert(std::pair(string(ge::ENCRYPT_MODE), to_string(FLAGS_encrypt_mode))); + options.insert(std::pair(string(ge::EK_FILE), FLAGS_encrypt_key)); + options.insert(std::pair(string(ge::CERT_FILE), FLAGS_certificate)); + options.insert(std::pair(string(ge::HW_KEY_FILE), FLAGS_hardware_key)); + options.insert(std::pair(string(ge::PRIVATE_KEY_FILE), FLAGS_private_key)); + options.insert(std::pair(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); + options.insert(std::pair(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); + options.insert(std::pair(string(ge::PRECISION_MODE), FLAGS_precision_mode)); + + options.insert(std::pair(string(ge::RUN_FLAG), to_string(0))); + options.insert(std::pair(string(ge::TRAIN_FLAG), to_string(0))); + + if (!FLAGS_output_type.empty()) { + options.insert(std::pair(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); + } + + options.insert(std::pair(string(ge::OP_SELECT_IMPL_MODE), FLAGS_op_select_implmode)); + options.insert(std::pair(string(ge::OPTYPELIST_FOR_IMPLMODE), FLAGS_optypelist_for_implmode)); + + if (!FLAGS_input_fp16_nodes.empty()) { + GELOGI("FLAGS_input_fp16_nodes : %s .", FLAGS_input_fp16_nodes.c_str()); + options.insert(std::pair(ge::INPUT_FP16_NODES, FLAGS_input_fp16_nodes)); + } + + options.insert(std::pair(string(ge::AUTO_TUNE_MODE), FLAGS_auto_tune_mode)); + + options.insert( + std::pair(string(ge::OPTION_EXEC_DISABLE_REUSED_MEMORY), to_string(FLAGS_disable_reuse_memory))); + + options.insert(std::pair(string(ge::SOC_VERSION), FLAGS_soc_version)); + + options.insert(std::pair(string(ge::CORE_TYPE), FLAGS_core_type)); + + options.insert(std::pair(string(ge::AICORE_NUM), FLAGS_aicore_num)); + + options.insert(std::pair(string(ge::BUFFER_OPTIMIZE), FLAGS_buffer_optimize)); + + options.insert(std::pair(string(ge::ENABLE_SMALL_CHANNEL), FLAGS_enable_small_channel)); + + options.insert(std::pair(string(ge::FUSION_SWITCH_FILE), FLAGS_fusion_switch_file)); + + options.insert(std::pair(string(ge::ENABLE_COMPRESS_WEIGHT), + (FLAGS_enable_compress_weight == "true") ? + ge::kEnableCompressWeightTrue : ge::kEnableCompressWeightFalse)); + + options.insert(std::pair(string(ge::ENABLE_SINGLE_STREAM), FLAGS_enable_single_stream)); + + options.insert(std::pair(string(ge::DEBUG_DIR), FLAGS_debug_dir)); + + options.insert(std::pair(string(ge::OP_COMPILER_CACHE_DIR), FLAGS_op_compiler_cache_dir)); + + options.insert(std::pair(string(ge::OP_COMPILER_CACHE_MODE), FLAGS_op_compiler_cache_mode)); + + SetDynamicInputSizeOptions(); + + if (!FLAGS_save_original_model.empty()) { + options.insert(std::pair(string(ge::SAVE_ORIGINAL_MODEL), FLAGS_save_original_model)); + options.insert(std::pair(string(ge::ORIGINAL_MODEL_FILE), FLAGS_output + "_original.om")); + } + + options.insert(std::pair(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level))); + + options.insert(std::pair(string(ge::MDL_BANK_PATH_FLAG), FLAGS_mdl_bank_path)); + + options.insert(std::pair(string(ge::OP_BANK_PATH_FLAG), FLAGS_op_bank_path)); + + options.insert(std::pair(string(ge::DISPLAY_MODEL_INFO), FLAGS_display_model_info)); + // set enable scope fusion passes + SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes); + // print atc option map + ge::PrintOptionMap(options, "atc option"); + + // When the ATC module is transferred to a model, the suffix ".om" is automatically added to the model name + FLAGS_output = FLAGS_output + ".om"; + ret = GenerateModel(options, FLAGS_output); + if (ret != domi::SUCCESS) { + return domi::FAILED; + } + + if (FLAGS_display_model_info == "1") { + GELOGI("need to display model info."); + return ge::ConvertOm(FLAGS_output.c_str(), "", false); + } + + return domi::SUCCESS; +} + +domi::Status ConvertModelToJson() { + Status ret = GFlagUtils::CheckConverJsonParamFlags(); + GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check convert json params flags failed!"); + + ret = ConvertModelToJson(FLAGS_framework, FLAGS_om, FLAGS_json); + + GE_IF_BOOL_EXEC(ret != domi::SUCCESS, return domi::FAILED); + return domi::SUCCESS; +} + +domi::Status DisplayModelInfo() { + // No model path passed in + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "", + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"}); + return ge::FAILED, + "Input parameter[--om]'s value is empty!!"); + + // Check if the model path is valid + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"), + return ge::FAILED, + "model file path is invalid: %s.", FLAGS_om.c_str()); + + if (FLAGS_framework == -1) { + return ge::ConvertOm(FLAGS_om.c_str(), "", false); + } + + return ge::FAILED; +} + +bool CheckRet(domi::Status ret) { + if (ret != domi::SUCCESS) { + if (FLAGS_mode == ONLY_PRE_CHECK) { + GELOGW("ATC precheck failed."); + } else if (FLAGS_mode == GEN_OM_MODEL) { + GELOGW("ATC generate offline model failed."); + } else if (FLAGS_mode == MODEL_TO_JSON) { + GELOGW("ATC convert model to json file failed."); + } else if (FLAGS_mode == PBTXT_TO_JSON) { + GELOGW("ATC convert pbtxt to json file failed."); + } else { + return false; + } + return false; + } + + if (FLAGS_mode == ONLY_PRE_CHECK) { + GELOGI("ATC precheck success."); + } else if (FLAGS_mode == GEN_OM_MODEL) { + GELOGI("ATC generate offline model success."); + } else if (FLAGS_mode == MODEL_TO_JSON) { + GELOGI("ATC convert model to json file success."); + } else if (FLAGS_mode == PBTXT_TO_JSON) { + GELOGI("ATC convert pbtxt to json file success."); + } + return true; +} + +domi::Status ConvertPbtxtToJson() { + Status ret = GFlagUtils::CheckConverJsonParamFlags(); + if (ret != domi::SUCCESS) { + GELOGE(ge::FAILED, "Check convert json params flags failed!"); + return domi::FAILED; + } + + ret = ge::ConvertPbtxtToJson(FLAGS_om.c_str(), FLAGS_json.c_str()); + if (ret != domi::SUCCESS) { + GELOGE(ge::FAILED, "ConvertPbtxtToJson fail."); + return domi::FAILED; + } + + return domi::SUCCESS; +} + +int init(int argc, char* argv[]) { + GFlagUtils::InitGFlag(argc, argv); + // set log level + int ret = -1; + const std::set log_level = {"null", "debug", "info", "warning", "error"}; + if (log_level.count(FLAGS_log) == 0) { + std::cout << "E10010: invalid value for --log:" << FLAGS_log + <<", only support debug, info, warning, error, null"<< std::endl; + return ret; + } + + ret = ge::CheckLogParamValidAndSetLogLevel(FLAGS_log); + if (ret != 0) { + return ret; + } + + std::string path_base = ge::GEInit::GetPath(); + ret = ErrorManager::GetInstance().Init(path_base); + if (ret != 0) { + GELOGE(ge::FAILED, "ErrorManager init fail !"); + return ret; + } + + return 0; +} + +long GetMemInfo(const std::string &key) { + std::string file_path = "/proc/meminfo"; + std::ifstream fs(file_path, std::ifstream::in); + if (!fs.is_open()) { + GELOGW("Can not open %s .", file_path.c_str()); + return 0; + } + std::string line; + while (getline(fs, line)) { // line not with \n + if (line.find(key) != std::string::npos) { + GELOGI("Find mem [%s] info line [%s]", key.c_str(), line.c_str()); + fs.close(); + size_t pos = line.find(":"); + if (pos == std::string::npos) { + return 0; + } + std::string current_mem_info_str = line.substr(pos + 1); + ge::StringUtils::Trim(current_mem_info_str); + GELOGI("Find mem [%s] info [%s].", key.c_str(), current_mem_info_str.c_str()); + return stol(current_mem_info_str); + } + } + fs.close(); // close the file + return 0; +} + +bool CheckMemInfo() { + if (FLAGS_auto_tune_mode.empty()) { + return true; + } + // only check current available mem when auto_tune_mode is set. + long current_mem_available = GetMemInfo("MemAvailable"); + GELOGI("Get mem available [%lu kB].", current_mem_available); + std::cout << "Current available mem is " << current_mem_available << "kB." << std::endl; + if ((current_mem_available > 0) && (current_mem_available < kMinAvailableMem)) { + GELOGE(ge::PARAM_INVALID, "Current available mem [%lu kB] can not be smaller than [%lu kB] .", + current_mem_available, kMinAvailableMem); + ErrorManager::GetInstance().ATCReportErrMessage("E10044", {"value", "min_value"}, + {to_string(current_mem_available), to_string(kMinAvailableMem)}); + return false; + } + return true; +} + +int main(int argc, char* argv[]) { + Status ret = domi::SUCCESS; + std::cout << "ATC start working now, please wait for a moment." << std::endl; + + // Initialize + if (init(argc, argv) != 0) { + std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl; + return -1; + } + do { + if (!CheckMemInfo()) { + GELOGE(ge::PARAM_INVALID, "Current available mem is too small"); + ret = domi::FAILED; + break; + } + if (!FLAGS_singleop.empty()) { + ret = GenerateSingleOp(FLAGS_singleop); + break; + } + + // default mode(mode:0), Open source model to model + if (GEN_OM_MODEL == FLAGS_mode || ONLY_PRE_CHECK == FLAGS_mode) { + GE_IF_BOOL_EXEC(GenerateOmModel() != domi::SUCCESS, ret = domi::FAILED; break); + } else if (MODEL_TO_JSON == FLAGS_mode) { // Mode 1, transfer model to JSON + GE_CHK_BOOL_EXEC(ConvertModelToJson() == domi::SUCCESS, ret = domi::FAILED; + break, "ATC ConvertJson execute failed!!"); + } else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) { + GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; + break, "ATC convert pbtxt to json execute failed!!"); + } else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) { + GE_CHK_BOOL_EXEC(DisplayModelInfo() == domi::SUCCESS, ret = domi::FAILED; + break, "ATC DisplayModelInfo failed!!"); + } else { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --mode[%d], %s.", FLAGS_mode, kModeSupport); + ret = domi::FAILED; + break; + } + } while (0); + + if (!CheckRet(ret)) { + std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl; + int result = ErrorManager::GetInstance().OutputErrMessage(STDOUT_FILENO); + if (result != 0) { + GELOGE(ge::FAILED, "ErrorManager outputErrMessage fail !"); + } + GELOGI("Current mem available mem is [%lu kB]", GetMemInfo("MemAvailable")); + return ret; + } else { + std::cout << "ATC run success, welcome to the next use." << std::endl; + (void)ErrorManager::GetInstance().OutputMessage(STDOUT_FILENO); + return 0; + } +} /*lint +e530*/ diff --git a/atc/parse_graph.cc b/atc/parse_graph.cc new file mode 100644 index 0000000..2561923 --- /dev/null +++ b/atc/parse_graph.cc @@ -0,0 +1,1004 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parse_graph.h" +#include +#include +#include + +#include "util/tool.h" +#include "util/util.h" +#include "util/properties_manager.h" +#include "util/string_util.h" +#include "common/types.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "framework/omg/model_tool.h" +#include "graph/compute_graph.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/type_utils.h" +#include "atc_ir_common.h" +#include "omg/omg_inner_types.h" +#include "omg/parser/model_parser.h" +#include "omg/parser/parser_factory.h" +#include "omg/parser/weights_parser.h" +#include "parser/common/pre_checker.h" +#include "parser/common/convert/pb2json.h" +#include "proto/ge_ir.pb.h" +#include "register/op_registry.h" + +using domi::ModelParserFactory; +using domi::WeightsParserFactory; +using std::ostringstream; + +namespace ge { +namespace { +const std::string kGraphDefaultName = "domi_default"; +const std::string kScopeIdAttr = "fusion_scope"; +const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; +const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; +const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; +const size_t kNodeNameIndex = 0; +const size_t kIndexStrIndex = 1; +const size_t kDTValueIndex = 2; +const size_t kOmInfoSize = 4; +} // namespace + +// When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored +const std::set kOmBlackFields = {"output", "data_offset", "data", "workspace", "workspace_bytes", + "memory_size", "weight_size", "size", "bt", "quantize_factor"}; + +static std::map output_type_str_to_datatype = { + {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; + +static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_param) { + if ((s == "true") || (s == "false")) { + return true; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {atc_param, s}); + GELOGE(PARAM_INVALID, "Input parameter[--%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str()); + return false; + } +} + +static void ParseAtcParms(const std::map &atc_params, const std::string &key, + std::string ¶m) { + auto iter = atc_params.find(key); + if (iter != atc_params.end()) { + param = iter->second; + } +} + +static Status CheckInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input, RunMode run_mode) { + if (!is_dynamic_input && run_mode != MODEL_TO_JSON) { + for (auto node : graph->GetDirectNode()) { + if (node->GetType() == DATA) { + auto data_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(data_op_desc); + auto tensor_desc = data_op_desc->MutableInputDesc(0); + GE_CHECK_NOTNULL(tensor_desc); + for (auto dim : tensor_desc->GetShape().GetDims()) { + if (dim < 0) { + GELOGE(PARAM_INVALID, + "Input op [%s] shape %ld is negative, maybe you should set input_shape to specify its shape", + node->GetName().c_str(), dim); + const string reason = "maybe you should set input_shape to specify its shape"; + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {node->GetName(), to_string(dim), reason}); + return PARAM_INVALID; + } + } + } + } + } + for (auto it : domi::GetContext().user_input_dims) { + std::string node_name = it.first; + ge::NodePtr node = graph->FindNode(node_name); + if (node == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_shape", node_name}); + GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not exist in model", node_name.c_str()); + return PARAM_INVALID; + } + if (node->GetType() != DATA) { + ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_shape", node_name}); + GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not a input opname", node_name.c_str()); + return PARAM_INVALID; + } + } + return SUCCESS; +} + +void AddAttrsForInputNodes(const vector &adjust_fp16_format_vec, const string &fp16_nodes_name, uint32_t index, + OpDescPtr &op_desc) { + if (AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) { + if ((index < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[index] == "true")) { + GELOGI("This node [%s] should be set NC1HWC0", fp16_nodes_name.c_str()); + if (!AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_FORMAT, TypeUtils::FormatToSerialString(FORMAT_NC1HWC0))) { + GELOGW("This node [%s] set NC1HWC0 failed", fp16_nodes_name.c_str()); + } + } + } +} + +static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, + const string &is_input_adjust_hw_layout) { + GE_CHECK_NOTNULL(graph); + vector adjust_fp16_format_vec; + if (!is_input_adjust_hw_layout.empty()) { + adjust_fp16_format_vec = StringUtils::Split(is_input_adjust_hw_layout, ','); + for (auto &s : adjust_fp16_format_vec) { + StringUtils::Trim(s); + if (!CheckInputTrueOrFalse(s, "is_input_adjust_hw_layout")) { + GELOGE(PARAM_INVALID, "Invalid Param, is_input_adjust_hw_layout only support true/false: but is [%s]", + is_input_adjust_hw_layout.c_str()); + return PARAM_INVALID; + } + } + } + if (input_fp16_nodes.empty()) { + return SUCCESS; + } + GELOGI("The input_fp16_nodes is set %s", input_fp16_nodes.c_str()); + vector input_fp16_nodes_vec = StringUtils::Split(input_fp16_nodes, ';'); + for (uint32_t i = 0; i < input_fp16_nodes_vec.size(); ++i) { + ge::NodePtr node = graph->FindNode(input_fp16_nodes_vec[i]); + if (node == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"input_fp16_nodes", input_fp16_nodes_vec[i]}); + GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not exist in model", + input_fp16_nodes_vec[i].c_str()); + return PARAM_INVALID; + } + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (op_desc->GetType() != DATA) { + ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, + {"input_fp16_nodes", input_fp16_nodes_vec[i]}); + GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not a input opname", + input_fp16_nodes_vec[i].c_str()); + return PARAM_INVALID; + } + AddAttrsForInputNodes(adjust_fp16_format_vec, input_fp16_nodes_vec[i], i, op_desc); + } + return SUCCESS; +} + +static Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { + if (is_output_fp16.empty()) { + return SUCCESS; + } + + vector &output_formats = domi::GetContext().output_formats; + output_formats.clear(); + vector node_format_vec = StringUtils::Split(is_output_fp16, ','); + for (auto &is_fp16 : node_format_vec) { + StringUtils::Trim(is_fp16); + if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) { + GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]", + is_output_fp16.c_str()); + return PARAM_INVALID; + } + if (is_fp16 == "false") { + output_formats.push_back(DOMI_TENSOR_ND); + } else if (is_fp16 == "true") { + output_formats.push_back(domi::DOMI_TENSOR_NC1HWC0); + } + } + return SUCCESS; +} + +void FindParserSo(const string &path, vector &file_list, string &caffe_parser_path) { + // path, Change to absolute path + string real_path = RealPath(path.c_str()); + if (real_path.empty()) { // plugin path does not exist + return; + } + struct stat stat_buf; + if ((stat(real_path.c_str(), &stat_buf) != 0) || (!S_ISDIR(stat_buf.st_mode))) { + GELOGI("The path %s is not a directory.", real_path.c_str()); + return; + } + + struct dirent *dent(nullptr); + DIR *dir = opendir(real_path.c_str()); + + if (nullptr == dir) { // plugin path does not exist + GELOGW("Open directory %s failed.", path.c_str()); + return; + } + + while ((dent = readdir(dir)) != nullptr) { + if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) continue; + string name = dent->d_name; + string full_name = real_path + "/" + name; + const string so_suff = ".so"; + const string caffe_parser_so_suff = "lib_caffe_parser.so"; + if (name.size() >= so_suff.size() && name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { + if (full_name.size() >= caffe_parser_so_suff.size() && + full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(), + caffe_parser_so_suff) == 0) { + caffe_parser_path = full_name; + } else { // save parser so path into file_list vector + file_list.push_back(full_name); + } + continue; + } + + FindParserSo(full_name, file_list, caffe_parser_path); + } + closedir(dir); + return; +} + +Status SetOutFormatAndDataTypeAttr(ge::OpDescPtr op_desc, const ge::Format format, const ge::DataType data_type) { + if (op_desc == nullptr) { + GELOGE(domi::FAILED, "Input op desc invalid."); + return domi::FAILED; + } + (void)ge::AttrUtils::SetInt(op_desc, ATTR_NAME_NET_OUTPUT_FORMAT, format); + (void)ge::AttrUtils::SetInt(op_desc, ATTR_NAME_NET_OUTPUT_DATATYPE, data_type); + return domi::SUCCESS; +} + +bool CheckDigitStr(std::string &str) { + for (char c : str) { + if (!isdigit(c)) { + GELOGE(domi::FAILED, "value[%s] is not positive integer", str.c_str()); + return false; + } + } + return true; +} + +Status StringToInt(std::string &str, int32_t &value) { + try { + if (!CheckDigitStr(str)) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", str, "is not positive integer"}); + return PARAM_INVALID; + } + value = stoi(str); + } catch (std::invalid_argument &) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"--output_type", str}); + return PARAM_INVALID; + } catch (std::out_of_range &) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"--output_type", str}); + return PARAM_INVALID; + } + return SUCCESS; +} + +Status VerifyOutputTypeAndOutNodes(std::vector &out_type_vec) { + std::vector> user_out_nodes = domi::GetContext().user_out_nodes; + std::set out_nodes_info; + for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { + // out_nodes set should include output_type and output_format + std::string tmp = user_out_nodes[i].first + ":" + to_string(user_out_nodes[i].second); + out_nodes_info.emplace(tmp); + } + for (uint32_t i = 0; i < out_type_vec.size(); ++i) { + if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", out_type_vec[i], kOutputTypeError}); + GELOGE(domi::FAILED, "Invalid value for --output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError); + return domi::FAILED; + } + } + return domi::SUCCESS; +} + +Status CheckOutPutDataTypeSupport(const std::string &output_type) { + auto it = output_type_str_to_datatype.find(output_type); + if (it == output_type_str_to_datatype.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", output_type, kOutputTypeSupport}); + GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport); + return domi::FAILED; + } + return domi::SUCCESS; +} + +Status ParseOutputType(const std::string &output_type, std::map> &output_node_dt_map) { + if (output_type.find(':') == std::string::npos) { + GELOGI("output_type is not multiple nodes, means all out nodes"); + return CheckOutPutDataTypeSupport(output_type); + } + std::vector out_type_vec; + vector nodes_v = StringUtils::Split(output_type, ';'); + for (const string &node : nodes_v) { + vector node_index_type_v = StringUtils::Split(node, ':'); + if (node_index_type_v.size() != 3) { // The size must be 3. + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", node, kOutputTypeSample}); + GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample); + return domi::FAILED; + } + ge::DataType tmp_dt; + std::string node_name = StringUtils::Trim(node_index_type_v[kNodeNameIndex]); + std::string index_str = StringUtils::Trim(node_index_type_v[kIndexStrIndex]); + int32_t index; + if (StringToInt(index_str, index) != SUCCESS) { + GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str()); + return domi::FAILED; + } + std::string dt_value = StringUtils::Trim(node_index_type_v[kDTValueIndex]); + auto it = output_type_str_to_datatype.find(dt_value); + if (it == output_type_str_to_datatype.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", dt_value, kOutputTypeSupport}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport); + return domi::FAILED; + } else { + tmp_dt = it->second; + } + out_type_vec.push_back(node_name + ":" + index_str); + std::string index_dt_str = index_str + ":" + TypeUtils::DataTypeToSerialString(tmp_dt); + auto it1 = output_node_dt_map.find(node_name); + if (it1 == output_node_dt_map.end()) { + vector tmp_vec; + tmp_vec.push_back(index_dt_str); + output_node_dt_map.emplace(node_name, tmp_vec); + } else { + it1->second.push_back(index_dt_str); + } + } + return VerifyOutputTypeAndOutNodes(out_type_vec); +} + +Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { + int32_t out_size = op_desc->GetOutputsSize(); + if (index < 0 || index >= out_size) { + GELOGE(domi::FAILED, + "out_node [%s] output index:%d must be smaller " + "than node output size:%d and can not be negative!", + op_desc->GetName().c_str(), index, out_size); + std::string fail_reason = "output index:" + to_string(index) + " must be smaller than output size:" + + to_string(out_size) + " and can not be negative!"; + ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, + {"out_nodes", op_desc->GetName(), fail_reason}); + return domi::FAILED; + } + return domi::SUCCESS; +} +Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, + std::vector> &output_nodes_info) { + std::vector> default_out_nodes = domi::GetContext().default_out_nodes; + if (domi::GetContext().type == domi::CAFFE && !default_out_nodes.empty()) { + for (uint32_t i = 0; i < default_out_nodes.size(); ++i) { + ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first); + if (out_node == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"out_nodes", default_out_nodes[i].first}); + GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", default_out_nodes[i].first.c_str()); + return domi::FAILED; + } + output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second)); + GELOGD("Get default output node:%s.", out_node->GetName().c_str()); + } + return domi::SUCCESS; + } + + for (ge::NodePtr node : compute_graph->GetDirectNode()) { + if (!node->GetInAllNodes().empty() && node->GetOutAllNodes().empty()) { + Status ret = GetOutputLeaf(node, output_nodes_info); + GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "find leaf fail."); + } + } + return domi::SUCCESS; +} + +Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output) { + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + std::vector> user_out_nodes = domi::GetContext().user_out_nodes; + std::vector output_formats = domi::GetContext().output_formats; + std::vector> output_nodes_info; + std::vector output_nodes_name; + std::map> output_node_dt_map; + if (!output_type.empty()) { + if (ParseOutputType(output_type, output_node_dt_map) != SUCCESS) { + GELOGE(domi::FAILED, "Parse output_type failed."); + return domi::FAILED; + } + } + + // User declared outputs + for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { + ge::NodePtr out_node = compute_graph->FindNode(user_out_nodes[i].first); + if (out_node == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"out_nodes", user_out_nodes[i].first}); + GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str()); + return domi::FAILED; + } + auto op_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) { + GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str()); + return domi::FAILED; + } + + // add user_define_output_nodes attr. + (void)ge::AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_OUTPUT_NODES, "true"); + + if (i < output_formats.size()) { + if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) { + GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str()); + vector output_fp16_5hd_vec; + (void)ge::AttrUtils::GetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec); + output_fp16_5hd_vec.push_back(std::to_string(user_out_nodes[i].second) + ":" + "NC1HWC0"); + (void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec); + } + } + auto it = output_node_dt_map.find(user_out_nodes[i].first); + if (it != output_node_dt_map.end()) { + GELOGI("The output node [%s] need to be set output_type", user_out_nodes[i].first.c_str()); + (void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_data_type", it->second); + } + output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second)); + } + // default output node (leaf) + if (user_out_nodes.empty()) { + if (GetDefaultOutInfo(compute_graph, output_nodes_info) != SUCCESS) { + GELOGE(domi::FAILED, "Get default output info failed."); + return domi::FAILED; + } + } + GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); + compute_graph->SetGraphOutNodesInfo(output_nodes_info); + domi::GetContext().net_out_nodes = output_nodes_name; + return domi::SUCCESS; +} + +void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, + std::vector &output_nodes_name) { + output_nodes_name.clear(); + if (domi::GetContext().out_top_names.empty()) { + // tf process, no top name. + for (const auto output_node_info : output_nodes_info) { + std::string node_name = output_node_info.first->GetName(); + int32_t index = output_node_info.second; + output_nodes_name.push_back(node_name + ":" + std::to_string(index)); + } + return; + } + // caffe process, need add top name after node_name:index + for (size_t i = 0; i < output_nodes_info.size(); ++i) { + std::string node_name = output_nodes_info[i].first->GetName(); + int32_t index = output_nodes_info[i].second; + if (i < domi::GetContext().out_top_names.size()) { + output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + domi::GetContext().out_top_names[i]); + } else { + GELOGW("Get top name of node [%s] fail.", node_name.c_str()); + output_nodes_name.push_back(node_name + ":" + std::to_string(index)); + } + } +} + +Status GetOutputLeaf(NodePtr node, std::vector> &output_nodes_info) { + ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); + if (tmpDescPtr == nullptr) { + GELOGE(domi::FAILED, "Get outnode op desc fail."); + return domi::FAILED; + } + size_t size = tmpDescPtr->GetOutputsSize(); + if (node->GetType() != NETOUTPUT) { + for (size_t index = 0; index < size; ++index) { + output_nodes_info.push_back(std::make_pair(node, index)); + GELOGD("Get output leaf node:%s.", node->GetName().c_str()); + } + } else { + const auto in_anchors = node->GetAllInDataAnchors(); + for (auto in_anchor : in_anchors) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + GELOGE(domi::FAILED, "Get leaf node op desc fail."); + return domi::FAILED; + } + auto out_node = out_anchor->GetOwnerNode(); + output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx())); + } + } + return SUCCESS; +} + +/// +/// @ingroup domi_common +/// @brief Initialize omgcontext based on command line input +/// @param [in] input_shape Input shape string to be parsed +/// @return SUCCESS: parse successfully; PARAM_INVALID:parse failed +/// +Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, + bool is_dynamic_input) { + // Clear omgcontext data first + domi::GetContext().input_dims.clear(); + domi::GetContext().user_input_dims.clear(); + domi::GetContext().is_dynamic_input = is_dynamic_input; + + // the default value is ND + domi::GetContext().format = DOMI_TENSOR_ND; + if (!input_format.empty()) { + auto iter = ge::input_format_str_to_geformat.find(input_format); + if (iter != ge::input_format_str_to_geformat.end()) { + domi::GetContext().format = iter->second; + } else { + GELOGE(PARAM_INVALID, "Input format %s not support , expect ND/NCHW/NHWC/CHWN/NC1HWC0/NHWC1C0.", + input_format.c_str()); + return PARAM_INVALID; + } + } + + // Input is empty, do not process + if (input_shape.empty()) { + return SUCCESS; + } + + // Analyze the input shape paramete + map> &shape_map = domi::GetContext().input_dims; + + if (!ge::ParseInputShape(input_shape, domi::GetContext().input_dims, domi::GetContext().user_input_dims, + is_dynamic_input) || + shape_map.empty()) { + GELOGE(PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + +Status ParseOutNodes(const string &out_nodes) { + try { + // parse output node + if (!out_nodes.empty()) { + domi::GetContext().out_nodes_map.clear(); + domi::GetContext().user_out_nodes.clear(); + domi::GetContext().user_out_nodes_top_vec.clear(); + + vector nodes_v = StringUtils::Split(out_nodes, ';'); + for (const string &node : nodes_v) { + vector key_value_v = StringUtils::Split(node, ':'); + if (key_value_v.size() != 2) { // The size must be 2. + if (key_value_v.size() == 1 && domi::GetContext().type == domi::CAFFE) { + domi::GetContext().user_out_nodes_top_vec.push_back(node); + continue; + } + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""}); + GELOGE(PARAM_INVALID, + "The input format of --out_nodes is invalid, the correct format is " + "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.", + node.c_str()); + return PARAM_INVALID; + } + if (!domi::GetContext().user_out_nodes_top_vec.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--out_nodes", out_nodes, "is not all index or top_name"}); + GELOGE(PARAM_INVALID, + "This out_nodes str must be all index or top_name, while the actual input is %s", out_nodes.c_str()); + return PARAM_INVALID; + } + // stoi: The method may throw an exception: invalid_argument/out_of_range + if (!CheckDigitStr(key_value_v[1])) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--out_nodes", out_nodes, "is not positive integer"}); + GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str()); + return PARAM_INVALID; + } + + auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]); + int32_t index = stoi(StringUtils::Trim(key_value_v[1])); + GELOGD("Get output info: node[%s] and index[%d]", key_value_v[0].c_str(), index); + if (iter != domi::GetContext().out_nodes_map.end()) { + iter->second.emplace_back(index); + } else { + std::vector index_v; + index_v.emplace_back(index); + domi::GetContext().out_nodes_map.emplace(key_value_v[0], index_v); + } + domi::GetContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index)); + } + } + } catch (std::invalid_argument &) { + GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"--out_nodes", out_nodes}); + return PARAM_INVALID; + } catch (std::out_of_range &) { + GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"--out_nodes", out_nodes}); + return PARAM_INVALID; + } + return SUCCESS; +} + +/// @ingroup domi_common +/// @brief Judge whether the op_Name_Map parameter matches the network +/// @param [in] graph Input network graph +/// @return SUCCESS: Input parameters are correct; PARAM_INVALID: Input parameters are incorrect +/// +static Status CheckOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf) { + GE_CHECK_NOTNULL(graph); + map graphNodeTypes; + for (const NodePtr &node : graph->GetAllNodes()) { + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(PARAM_INVALID, "Invalid parameter for opDesc."); + return PARAM_INVALID; + } + graphNodeTypes[op_desc->GetType()] = ""; + } + std::map &propertiesMap = domi::GetContext().op_conf_map; + if (propertiesMap.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10003", {"parameter", "value", "reason"}, {"op_name_map", op_conf, "the file content is empty"}); + GELOGE(PARAM_INVALID, "op_name_map file content is empty, please check file!"); + return PARAM_INVALID; + } + for (auto iter = propertiesMap.begin(); iter != propertiesMap.end(); iter++) { + GE_IF_BOOL_EXEC(graphNodeTypes.find(iter->second) == graphNodeTypes.end(), + ErrorManager::GetInstance().ATCReportErrMessage( + "E10003", {"parameter", "value", "reason"}, + {"op_name_map", op_conf, "type[" + iter->second + "] is not found in model"}); + GELOGE(PARAM_INVALID, "Invalid parameter for op_name_map."); return PARAM_INVALID;); + } + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map &atc_params, + const char *model_file, const char *weights_file, domi::FrameworkType type, + const char *op_conf, const char *target, RunMode run_mode, + bool is_dynamic_input) { + GE_CHECK_NOTNULL(model_file); + GE_CHECK_NOTNULL(weights_file); + domi::GetContext().type = type; + domi::GetContext().run_mode = run_mode; + // Prevent data residue in multiple calls + PreChecker::Instance().Clear(); + + // Create an empty computegraph + std::string om_name; + ParseAtcParms(atc_params, "output", om_name); + + string graph_name = ""; + bool name_ret = GetNameFromFileName(om_name, graph_name); + if (!name_ret) { + graph_name = kGraphDefaultName + "_" + CurrentTimeInStr(); + } + ComputeGraphPtr compute_graph = MakeShared(graph_name); + GE_CHECK_NOTNULL(compute_graph); + graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); + + // initialize omgContext + std::string input_shape; + ParseAtcParms(atc_params, "input_shape", input_shape); + std::string input_format; + ParseAtcParms(atc_params, "input_format", input_format); + GE_RETURN_WITH_LOG_IF_ERROR(InitDomiOmgContext(input_shape, input_format, "", is_dynamic_input), + "ATC Generate call InitDomiOmgContext ret fail"); + + std::string is_output_adjust_hw_layout; + ParseAtcParms(atc_params, "is_output_adjust_hw_layout", is_output_adjust_hw_layout); + GE_RETURN_WITH_LOG_IF_ERROR(ParseOutputFp16NodesFormat(is_output_adjust_hw_layout), "Parse is_output_fp16 failed"); + + std::string out_nodes; + ParseAtcParms(atc_params, "out_nodes", out_nodes); + GE_RETURN_WITH_LOG_IF_ERROR(ParseOutNodes(out_nodes), "ATC Generate parse out nodes fail"); + + std::string output_type; + ParseAtcParms(atc_params, "output_type", output_type); + + // parse configuration item + if (op_conf != nullptr && *op_conf != '\0') { + // divided by ":" + atc::PropertiesManager::Instance().SetPropertyDelimiter(OP_CONF_DELIMITER); + // Parsing the op_conf configuration item file + GE_IF_BOOL_EXEC(!atc::PropertiesManager::Instance().Init(op_conf), + ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, + {"op_name_map", op_conf, "file content error"}); + GELOGE(FAILED, "op_name_map init failed!"); return FAILED); + // Return map and put it into ATC global variable + domi::GetContext().op_conf_map = atc::PropertiesManager::Instance().GetPropertyMap(); + } + + // parse network model + auto model_parser = ModelParserFactory::Instance()->CreateModelParser(type); + GE_CHK_BOOL_RET_STATUS(model_parser != nullptr, FAILED, "ATC create model parser ret fail, type:%d.", type); + + UpdateParserCtxWithOmgCtx(); + Status ret = model_parser->Parse(model_file, graph); + UpdateOmgCtxWithParserCtx(); + + // Generate the report in case of pre inspection failure or only pre inspection mode + if (PreChecker::Instance().HasError() || run_mode == ONLY_PRE_CHECK) { + std::string check_report; + ParseAtcParms(atc_params, "check_report", check_report); + GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().Save(check_report), "Generate pre-checking report failed."); + GEEVENT("The pre-checking report has been saved to %s.", check_report.c_str()); + } + + GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "ATC model parse ret fail."); + + std::string input_fp16_nodes; + ParseAtcParms(atc_params, "input_fp16_nodes", input_fp16_nodes); + std::string is_input_adjust_hw_layout; + ParseAtcParms(atc_params, "is_input_adjust_hw_layout", is_input_adjust_hw_layout); + compute_graph = GraphUtils::GetComputeGraph(graph); + GE_RETURN_IF_ERROR(CheckInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout)); + + GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, run_mode)); + + // Verify the contents of the op_name_map + if (op_conf != nullptr && *op_conf != '\0') { + GE_RETURN_WITH_LOG_IF_ERROR(CheckOpNameMap(compute_graph, op_conf), + "op_name_map parameter is not fit with input net!"); + } + + // Print parse network structure + compute_graph->Dump(); + + // parse weight + graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); + auto weights_parser = WeightsParserFactory::Instance()->CreateWeightsParser(type); + ret = weights_parser->Parse(weights_file, graph); + + // IN ONLY_PRE_CHECK mode, generate pre inspection report only. + if (PreChecker::Instance().HasError() || run_mode == ONLY_PRE_CHECK) { + std::string check_report; + ParseAtcParms(atc_params, "check_report", check_report); + GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().Save(check_report), "Generate pre-checking report failed."); + GEEVENT("The pre-checking report has been saved to %s.", check_report.c_str()); + } + // Prevent data residue in multiple calls + PreChecker::Instance().Clear(); + + GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "ATC weights parse ret fail."); + + GELOGI("ATC parser success."); + + return SUCCESS; +} + +void GetGroupName(ge::proto::ModelDef &model_def) { + auto modelAttrMap = model_def.mutable_attr(); + auto fusionModelOpListIter = modelAttrMap->find(MODEL_ATTR_FUSION_MODEL_DEF); + GE_IF_BOOL_EXEC( + fusionModelOpListIter != modelAttrMap->end(), int fusionOpIndex = 0; + for (int i = 0; i < model_def.graph_size(); i++) { + auto graph = model_def.mutable_graph(i); + for (int j = 0; j < graph->op_size(); j++) { + int64_t scope_id = 0; + auto bt = fusionModelOpListIter->second.list().bt(fusionOpIndex++); + ge::proto::OpDef fusion_op_def; + GE_CHK_BOOL_EXEC(bt.size() != 0, GELOGW("Invalid bt size"); return;); + + (void)(fusion_op_def.ParseFromArray(bt.data(), bt.size())); + auto fusion_attr_map = fusion_op_def.mutable_attr(); + auto fusion_iter = fusion_attr_map->find(kScopeIdAttr); + GE_IF_BOOL_EXEC(fusion_iter == fusion_attr_map->end(), continue;); + + scope_id = fusion_iter->second.i(); + ge::proto::OpDef *opdef = graph->mutable_op(j); + auto attr_map = opdef->mutable_attr(); + + int64_t stream_id = opdef->stream_id(); + + uint16_t l1_id = (((uint64_t)scope_id & 0xFFFF0000)) >> 16; + GE_IF_BOOL_EXEC(l1_id != 0, ostringstream groupName; groupName << "group_op_l1_" << l1_id << "_" << stream_id; + (*attr_map)["group_op_name"].set_s(groupName.str()); continue;); + + uint16_t ub_id = ((uint64_t)scope_id & 0xFFFF); + GE_IF_BOOL_EXEC(ub_id != 0, ostringstream groupName; groupName << "group_op_ub_" << ub_id << "_" << stream_id; + (*attr_map)["group_op_name"].set_s(groupName.str());); + } + }); +} + +FMK_FUNC_HOST_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def, uint32_t modeldef_size) { + std::cout << "============ Display Model Info start ============" << std::endl; + + auto model_attr_map = model_def->mutable_attr(); + // system info + auto iter = model_attr_map->find(ATTR_MODEL_ATC_VERSION); + auto atc_version = (iter != model_attr_map->end()) ? iter->second.s() : ""; + iter = model_attr_map->find("soc_version"); + auto soc_version = (iter != model_attr_map->end()) ? iter->second.s() : ""; + iter = model_attr_map->find("framework_type"); + auto framework_type = (iter != model_attr_map->end()) ? iter->second.s() : ""; + std::cout << "system info: " + << ATTR_MODEL_ATC_VERSION + << "[" << atc_version << "], " + << "soc_version" + << "[" << soc_version << "], " + << "framework_type" + << "[" << framework_type << "]." << std::endl; + + // resource info + iter = model_attr_map->find(ATTR_MODEL_MEMORY_SIZE); + auto memory_size = (iter != model_attr_map->end()) ? iter->second.i() : -1; + iter = model_attr_map->find(ATTR_MODEL_WEIGHT_SIZE); + auto weight_size = (iter != model_attr_map->end()) ? iter->second.i() : -1; + iter = model_attr_map->find(ATTR_MODEL_STREAM_NUM); + auto stream_num = (iter != model_attr_map->end()) ? iter->second.i() : -1; + iter = model_attr_map->find(ATTR_MODEL_EVENT_NUM); + auto event_num = (iter != model_attr_map->end()) ? iter->second.i() : -1; + std::cout << "resource info: " + << ATTR_MODEL_MEMORY_SIZE + << "[" << memory_size << " B], " + << ATTR_MODEL_WEIGHT_SIZE + << "[" << weight_size << " B], " + << ATTR_MODEL_STREAM_NUM + << "[" << stream_num << "], " + << ATTR_MODEL_EVENT_NUM + << "[" << event_num << "]." + << std::endl; + + // om info + iter = model_attr_map->find("om_info_list"); + if (iter == model_attr_map->end()) { + std::cout << "Display Model Info failed, attr \"om_info_list\" is not found in om, check the version is matched." + << std::endl; + std::cout << "============ Display Model Info end ============" << std::endl; + return; + } + auto list_size = iter->second.list().i_size(); + if (list_size == kOmInfoSize) { + std::cout << "om info: " + << "modeldef_size" + << "[" << modeldef_size << " B], " + << "weight_data_size" + << "[" << iter->second.list().i(0) << " B], " + << "tbe_kernels_size" + << "[" << iter->second.list().i(1) << " B], " + << "cust_aicpu_kernel_store_size" + << "[" << iter->second.list().i(2) << " B], " + << "task_info_size" + << "[" << iter->second.list().i(3) << " B]." << std::endl; + } else { + std::cout << "Display Model Info error, please check!" << std::endl; + }; + + std::cout << "============ Display Model Info end ============" << std::endl; +} + +FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *json_file, bool is_covert_to_json) { + GE_CHECK_NOTNULL(model_file); + if (is_covert_to_json) { + GE_CHECK_NOTNULL(json_file); + } + try { + ge::proto::ModelDef model_def; + uint32_t modeldef_size = 0; + Status ret = ModelTool::GetModelInfoFromOm(model_file, model_def, modeldef_size); + if (ret != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E10003", + {"parameter", "value", "reason"}, {"om", model_file, "invalid om file"}); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "ParseModelContent failed because of invalid om file. Please check --om param."); + return ret; + } + + if (is_covert_to_json) { + GetGroupName(model_def); + ret = Pb2Json::ToJson(model_def, kOmBlackFields, json_file, true); + } else { + PrintModelInfo(&model_def, modeldef_size); + } + return ret; + } catch (const std::exception &e) { + ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, + {"Convert om model to json failed, exception message[" + std::string(e.what()) + "]"}); + GELOGE(FAILED, "Convert om model to json failed, exception message : %s.", e.what()); + return FAILED; + } +} + +FMK_FUNC_HOST_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const char *json_file) { + + try { + ge::proto::ModelDef model_def; + Status ret = ModelTool::GetModelInfoFromPbtxt(model_file, model_def); + if (ret != SUCCESS) { + GELOGE(ret, "LoadFromFile failed."); + return ret; + } + + GetGroupName(model_def); + ret = Pb2Json::ToJson(model_def, kOmBlackFields, json_file, true); + return ret; + } catch (const std::exception &e) { + ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, + {"Convert pbtxt to json failed, exception message[" + std::string(e.what()) + "]"}); + GELOGE(FAILED, "Convert pbtxt to json failed, exception message : %s.", e.what()); + return FAILED; + } +} + +FMK_FUNC_HOST_VISIBILITY Status ConvertFwkModelToJson(const domi::FrameworkType framework, const char *model_file, + const char *json_file) { + if (framework == domi::CAFFE || framework == domi::TENSORFLOW || framework == domi::ONNX) { + auto model_parser = ModelParserFactory::Instance()->CreateModelParser(framework); + GE_CHK_BOOL_RET_STATUS(model_parser != nullptr, FAILED, "ATC create model parser ret fail, framework:%d.", + framework); + return model_parser->ToJson(model_file, json_file); + } + + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--framework", std::to_string(framework), "only support 0(Caffe) 3(TensorFlow) 5(Onnx)"}); + GELOGE(PARAM_INVALID, "Input parameter[--framework] is mandatory and it's value must be: 0(Caffe) 3(TensorFlow) " + "or 5(Onnx)."); + return PARAM_INVALID; +} + +FMK_FUNC_HOST_VISIBILITY Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file) { + // Create buffer + GELOGI("Enter to dump infershape json schedule."); + ge::Model model("", ""); + model.SetGraph(graph); + Buffer buffer; + model.Save(buffer, true); + + ge::proto::ModelDef ge_proto; + if (buffer.GetData() != nullptr) { + std::string str(reinterpret_cast(buffer.GetData()), buffer.GetSize()); + if (!ge_proto.ParseFromString(str)) { + GELOGE(GRAPH_FAILED, "parse from string failed."); + return FAILED; + } + + Pb2Json::ToJson(ge_proto, std::set(), json_file); + } + return SUCCESS; +} + +void UpdateOmgCtxWithParserCtx() { + domi::GetContext().format = GetParserContext().format; + domi::GetContext().input_dims = GetParserContext().input_dims; + domi::GetContext().user_input_dims = GetParserContext().user_input_dims; + domi::GetContext().is_dynamic_input = GetParserContext().is_dynamic_input; + domi::GetContext().type = GetParserContext().type; + domi::GetContext().user_out_nodes = GetParserContext().user_out_nodes; + domi::GetContext().train_flag = GetParserContext().train_flag; + domi::GetContext().run_mode = GetParserContext().run_mode; + domi::GetContext().op_conf_map = GetParserContext().op_conf_map; + domi::GetContext().out_nodes_map = GetParserContext().out_nodes_map; + domi::GetContext().input_nodes_format_map = GetParserContext().input_nodes_format_map; + domi::GetContext().out_top_names = GetParserContext().out_top_names; + domi::GetContext().user_out_nodes_top_vec = GetParserContext().user_out_nodes_top_vec; + domi::GetContext().default_out_nodes = GetParserContext().default_out_nodes; + domi::GetContext().data_top_names = GetParserContext().data_top_names; +} + +void UpdateParserCtxWithOmgCtx() { + GetParserContext().format = domi::GetContext().format; + GetParserContext().input_dims = domi::GetContext().input_dims; + GetParserContext().user_input_dims = domi::GetContext().user_input_dims; + GetParserContext().is_dynamic_input = domi::GetContext().is_dynamic_input; + GetParserContext().type = domi::GetContext().type; + GetParserContext().user_out_nodes = domi::GetContext().user_out_nodes; + GetParserContext().train_flag = domi::GetContext().train_flag; + GetParserContext().run_mode = domi::GetContext().run_mode; + GetParserContext().op_conf_map = domi::GetContext().op_conf_map; + GetParserContext().out_nodes_map = domi::GetContext().out_nodes_map; + GetParserContext().input_nodes_format_map = domi::GetContext().input_nodes_format_map; + GetParserContext().out_top_names = domi::GetContext().out_top_names; + GetParserContext().user_out_nodes_top_vec = domi::GetContext().user_out_nodes_top_vec; + GetParserContext().data_top_names = domi::GetContext().data_top_names; +} +} // namespace ge diff --git a/atc/parse_graph.h b/atc/parse_graph.h new file mode 100644 index 0000000..5e04c73 --- /dev/null +++ b/atc/parse_graph.h @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSE_GRAPH_H_ +#define PARSE_GRAPH_H_ + +#include +#include +#include +#include "framework/omg/omg_inner_types.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "proto/ge_ir.pb.h" +#include "proto/om.pb.h" + +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/model.h" +//#include "runtime/kernel.h" + +using domi::Status; +using std::pair; +using std::string; +using std::unordered_map; +using std::vector; + +namespace ge { +/** + * @ingroup domi_omg + * @brief init omg context + * @return void + */ +GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, + bool is_dynamic_input); + +/** + * @ingroup domi_omg + * @brief generate graph based on the input model file and weight file + * @param [out] graph graph + * @param [in] model_file path of model file + * @param [in] weights_file path of weight file + * @param [in] type type of the input model + * @param [in] op_conf op mapping configuration + * @param [in] target type of platform. If a tiny model is generated, set target to tiny + * @param [in] run_mode run model + * @param [in] enable_l2dynamic enable l2dynamic + * @param [in] is_dynamic_input dynamic input, true of false + * @param [in] atc_params multiply atc params + * @return Status result code + */ +GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map &atc_params, const char *model_file, + const char *weights_file, domi::FrameworkType type, const char *op_conf = nullptr, + const char *target = nullptr, RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false); + +/** + * @ingroup domi_omg + * @brief generates a simplified JSON file based on the key value of the offline model file in protobuf format + * @param [in] model_file path of offline model file + * @param [out] json_file path of json file + * @param [key] encrypted key + * @return Status result code + */ +GE_FUNC_VISIBILITY Status ConvertOm(const char *model_file, const char *json_file, bool is_covert_to_json); + +GE_FUNC_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const char *json_file); +/** + * @ingroup domi_omg + * @brief convert the model file in protobuf format into a JSON file. + * @param [in] framework type of model + * @param [in] om model_file path of offline model file + * @param [out] json_file path of json file + * @param [key] encrypted key + * @return Status result code + */ +GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, const char *json_file); + +GE_FUNC_VISIBILITY void GetGroupName(ge::proto::ModelDef &model); + +GE_FUNC_VISIBILITY void FindParserSo(const string &path, vector &fileList, string &caffe_parser_path); + +GE_FUNC_VISIBILITY Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); + +GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); + +GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node, std::vector> &output_nodes_info); + +GE_FUNC_VISIBILITY void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, + std::vector &output_nodes_name); + +GE_FUNC_VISIBILITY void UpdateOmgCtxWithParserCtx(); + +GE_FUNC_VISIBILITY void UpdateParserCtxWithOmgCtx(); + +GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def); +} // namespace ge +#endif // PARSE_GRAPH_H_ diff --git a/atc/single_op_parser.cc b/atc/single_op_parser.cc new file mode 100644 index 0000000..d05053e --- /dev/null +++ b/atc/single_op_parser.cc @@ -0,0 +1,609 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "single_op_parser.h" + +#include +#include +#include +#include + +#include + +#include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" +#include "util/util.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/operator_factory_impl.h" + +using Json = nlohmann::json; +using std::string; +using std::vector; +using std::map; + +namespace ge { +namespace { +constexpr char const *kKeyOp = "op"; +constexpr char const *kKeyInputDesc = "input_desc"; +constexpr char const *kKeyOutputDesc = "output_desc"; +constexpr char const *kKeyAttr = "attr"; +constexpr char const *kKeyName = "name"; +constexpr char const *kKeyType = "type"; +constexpr char const *kKeyShape = "shape"; +constexpr char const *kKeyOriginShape = "origin_shape"; +constexpr char const *kKeyShapeRange = "shape_range"; +constexpr char const *kKeyValue = "value"; +constexpr char const *kKeyFormat = "format"; +constexpr char const *kKeyOriginFormat = "origin_format"; +constexpr char const *kFileSuffix = ".om"; +constexpr char const *kKeyDynamicInput = "dynamic_input"; +constexpr char const *kKeyDynamicOutput = "dynamic_output"; +constexpr int kDumpJsonIndent = 2; +constexpr int kShapeRangePairSize = 2; +constexpr int kShapeRangeLow = 0; +constexpr int kShapeRangeHigh = 1; +constexpr int kMaxFileNameLen = 128; + +map kAttrTypeDict = { + {"bool", GeAttrValue::VT_BOOL}, + {"int", GeAttrValue::VT_INT}, + {"float", GeAttrValue::VT_FLOAT}, + {"string", GeAttrValue::VT_STRING}, + {"list_bool", GeAttrValue::VT_LIST_BOOL}, + {"list_int", GeAttrValue::VT_LIST_INT}, + {"list_float", GeAttrValue::VT_LIST_FLOAT}, + {"list_string", GeAttrValue::VT_LIST_STRING}, + {"list_list_int", GeAttrValue::VT_LIST_LIST_INT}, + {"data_type", GeAttrValue::VT_DATA_TYPE}, +}; + +map kDataTypeDict = { + {"bool", DT_BOOL}, + {"int8", DT_INT8}, + {"uint8", DT_UINT8}, + {"int16", DT_INT16}, + {"uint16", DT_UINT16}, + {"int32", DT_INT32}, + {"uint32", DT_UINT32}, + {"int64", DT_INT64}, + {"uint64", DT_UINT64}, + {"float16", DT_FLOAT16}, + {"half", DT_FLOAT16}, + {"fp16", DT_FLOAT16}, + {"float", DT_FLOAT}, + {"float32", DT_FLOAT}, + {"double", DT_DOUBLE}, +}; + +map kFormatDict = { + {"nchw", FORMAT_NCHW}, + {"nhwc", FORMAT_NHWC}, + {"nd", FORMAT_ND}, + {"nc1hwc0", FORMAT_NC1HWC0}, + {"fractal_z", FORMAT_FRACTAL_Z}, + {"nc1c0hwpad", FORMAT_NC1C0HWPAD}, + {"nhwc1c0", FORMAT_NHWC1C0}, + {"fsr_nchw", FORMAT_FSR_NCHW}, + {"fractal_deconv", FORMAT_FRACTAL_DECONV}, + {"c1hwnc0", FORMAT_C1HWNC0}, + {"fractal_deconv_transpose", FORMAT_FRACTAL_DECONV_TRANSPOSE}, + {"fractal_deconv_sp_stride_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, + {"nc1hwc0_c04", FORMAT_NC1HWC0_C04}, + {"fractal_z_c04", FORMAT_FRACTAL_Z_C04}, + {"chwn", FORMAT_CHWN}, + {"deconv_sp_stride8_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, + {"nc1khkwhwc0", FORMAT_NC1KHKWHWC0}, + {"bn_weight", FORMAT_BN_WEIGHT}, + {"filter_hwck", FORMAT_FILTER_HWCK}, + {"hwcn", FORMAT_HWCN}, + {"lookup_lookups", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, + {"lookup_keys", FORMAT_HASHTABLE_LOOKUP_KEYS}, + {"lookup_value", FORMAT_HASHTABLE_LOOKUP_VALUE}, + {"lookup_output", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, + {"lookup_hits", FORMAT_HASHTABLE_LOOKUP_HITS}, + {"md", FORMAT_MD}, + {"c1hwncoc0", FORMAT_C1HWNCoC0}, + {"fractal_nz", FORMAT_FRACTAL_NZ}, + {"ndhwc", FORMAT_NDHWC}, + {"ncdhw", FORMAT_NCDHW}, + {"dhwcn", FORMAT_DHWCN}, + {"dhwnc", FORMAT_DHWNC}, + {"ndc1hwc0", FORMAT_NDC1HWC0}, + {"fractal_z_3d", FORMAT_FRACTAL_Z_3D}, + {"fractal_z_3d_transpose", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, + {"cn", FORMAT_CN}, + {"nc", FORMAT_NC}, + {"fractal_zn_lstm", FORMAT_FRACTAL_ZN_LSTM}, + {"fractal_z_g", FORMAT_FRACTAL_Z_G} +}; + +std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) { + std::stringstream file_name_ss; + file_name_ss << index; + file_name_ss << "_" << single_op_desc.op; + for (auto &desc : single_op_desc.input_desc) { + file_name_ss << "_" << desc.type << "_" << desc.format; + for (auto dim : desc.dims) { + file_name_ss << "_" << dim; + } + } + + for (auto &desc : single_op_desc.output_desc) { + file_name_ss << "_" << desc.type << "_" << desc.format; + for (auto dim : desc.dims) { + file_name_ss << "_" << dim; + } + } + + std::string file_name = file_name_ss.str(); + if (file_name.length() > kMaxFileNameLen) { + GELOGI("Trim file name for it is too long, origin file name = %s", file_name.c_str()); + file_name = file_name.substr(0, kMaxFileNameLen); + } + file_name += kFileSuffix; + return file_name; +} +} // namespace + +template +void SetAttrValue(const Json &j, SingleOpAttr &attr) { + attr.value.SetValue(j.at(kKeyValue).get()); +} + +template +T GetValue(const map &dict, string &key, T default_val) { + transform(key.begin(), key.end(), key.begin(), ::tolower); + auto it = dict.find(key); + if (it == dict.end()) { + return default_val; + } + + return it->second; +} + +void from_json(const Json &j, SingleOpTensorDesc &desc) { + bool is_tensor_valid = true; + desc.dims = j.at(kKeyShape).get>(); + auto it = j.find(kKeyShapeRange); + if (it != j.end()) { + desc.dim_ranges = j.at(kKeyShapeRange).get>>(); + } + it = j.find(kKeyOriginShape); + if (it != j.end()) { + desc.ori_dims = j.at(kKeyOriginShape).get>(); + } + string format_str = j.at(kKeyFormat).get(); + string type_str = j.at(kKeyType).get(); + desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); + desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); + is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(format_str); + is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsDataTypeValid(type_str); + it = j.find(kKeyOriginFormat); + if (it != j.end()) { + string origin_format_str = j.at(kKeyOriginFormat).get(); + is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(origin_format_str); + desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED); + } + auto tensor_name = j.find(kKeyName); + if (tensor_name != j.end()) { + desc.name = tensor_name->get(); + } + auto dynamic_input_name = j.find(kKeyDynamicInput); + if (dynamic_input_name != j.end()) { + desc.dynamic_input_name = dynamic_input_name->get(); + } + if (!is_tensor_valid) { + desc.SetValidFlag(is_tensor_valid); + } +} + +void from_json(const Json &j, SingleOpAttr &attr) { + attr.name = j.at(kKeyName).get(); + attr.type = j.at(kKeyType).get(); + auto it = kAttrTypeDict.find(attr.type); + if (it == kAttrTypeDict.end()) { + GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str()); + return; + } + + switch (it->second) { + case GeAttrValue::VT_BOOL: + SetAttrValue(j, attr); + break; + case GeAttrValue::VT_INT: + SetAttrValue(j, attr); + break; + case GeAttrValue::VT_FLOAT: + SetAttrValue(j, attr); + break; + case GeAttrValue::VT_STRING: + SetAttrValue(j, attr); + break; + case GeAttrValue::VT_LIST_BOOL: + SetAttrValue>(j, attr); + break; + case GeAttrValue::VT_LIST_INT: + SetAttrValue>(j, attr); + break; + case GeAttrValue::VT_LIST_FLOAT: + SetAttrValue>(j, attr); + break; + case GeAttrValue::VT_LIST_STRING: + SetAttrValue>(j, attr); + break; + case GeAttrValue::VT_LIST_LIST_INT: + SetAttrValue>>(j, attr); + break; + case GeAttrValue::VT_DATA_TYPE: + SetAttrValue(j, attr); + break; + default: + GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str()); + break; + } +} + +void from_json(const Json &j, SingleOpDesc &desc) { + desc.op = j.at(kKeyOp).get(); + + auto input_desc = j.find(kKeyInputDesc); + if (input_desc != j.end()) { + desc.input_desc = input_desc->get>(); + } + + auto output_desc = j.find(kKeyOutputDesc); + if (output_desc != j.end()) { + desc.output_desc = output_desc->get>(); + } + + auto attr_field = j.find(kKeyAttr); + if (attr_field != j.end()) { + desc.attrs = attr_field->get>(); + } +} + +Status SingleOpParser::ReadJsonFile(const std::string &file, Json &json_obj) { + std::string real_path = RealPath(file.c_str()); + if (real_path.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10023", {"value"}, {file}); + GELOGE(FAILED, "Input parameter[--singleop]'s value[%s] is not a valid path.", file.c_str()); + return INTERNAL_ERROR; + } + + std::ifstream ifs(real_path); + if (!ifs.is_open()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10024", {"value"}, {file}); + GELOGE(FAILED, "Open file[%s] provided in input parameter[--singleop] failed.", file.c_str()); + return FAILED; + } + try { + ifs >> json_obj; + } catch (const std::exception &e) { + ErrorManager::GetInstance().ATCReportErrMessage("E10025", {"realpath", "errmsg"}, {real_path, e.what()}); + GELOGE(PARAM_INVALID, "Parse file[%s] provided in input parameter[--singleop] failed, exception = %s.", + real_path.c_str(), e.what()); + return PARAM_INVALID; + } + + ifs.close(); + return SUCCESS; +} + +bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { + if (op_desc.op.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10026"); + GELOGE(PARAM_INVALID, "Op name is empty"); + return false; + } + + int index = 0; + for (auto &tensor_desc : op_desc.input_desc) { + if (!tensor_desc.GetValidFlag()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, + {"intput", "datatype or format", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index); + return false; + } + if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) || + (tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){ + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, + {"intput", "datatype or format", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index); + return false; + } + ++index; + } + + index = 0; + for (auto &tensor_desc : op_desc.output_desc) { + if (!tensor_desc.GetValidFlag()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, + {"output", "datatype", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index); + return false; + } + if (tensor_desc.type == DT_UNDEFINED) { + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, + {"output", "datatype", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index); + return false; + } + + if (tensor_desc.format == FORMAT_RESERVED) { + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, + {"output", "format", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Output's format is invalid when the index is %d", index); + return false; + } + ++index; + } + + for (auto &attr : op_desc.attrs) { + if (attr.name.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10029"); + GELOGE(PARAM_INVALID, "attr name is empty"); + return false; + } + + if (attr.value.IsEmpty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10030", {"attrname"}, {attr.name}); + GELOGE(PARAM_INVALID, "Parse attr \"%s\" failed. ", attr.name.c_str()); + return false; + } + } + + return true; +} + +std::unique_ptr SingleOpParser::CreateOpDesc(const string &op_type) { + return std::unique_ptr(new(std::nothrow) OpDesc(op_type, op_type)); +} + +Status SingleOpParser::UpdateDynamicTensorName(std::vector &desc) { + std::map dynamic_name_map; + for (auto &tensor : desc) { + if (tensor.dynamic_input_name.empty()) { + continue; + } + if (dynamic_name_map.find(tensor.dynamic_input_name) == dynamic_name_map.end()) { + dynamic_name_map[tensor.dynamic_input_name] = 0; + } else { + dynamic_name_map[tensor.dynamic_input_name]++; + } + tensor.name = tensor.dynamic_input_name + std::to_string(dynamic_name_map[tensor.dynamic_input_name]); + } + GELOGD("Update dynamic tensor name success!"); + return SUCCESS; +} + +Status SingleOpParser::ConvertToBuildParam(int index, + const SingleOpDesc &single_op_desc, + SingleOpBuildParam &build_param) { + auto op_desc = CreateOpDesc(single_op_desc.op); + GE_CHECK_NOTNULL(op_desc); + + for (auto &desc : single_op_desc.input_desc) { + GeTensorDesc ge_tensor_desc(GeShape(desc.dims), + desc.format, + desc.type); + auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format; + auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims; + ge_tensor_desc.SetOriginFormat(ori_format_to_set); + ge_tensor_desc.SetOriginShape(GeShape(ori_dims)); + GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); + TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size()); + TensorUtils::SetInputTensor(ge_tensor_desc, true); + TensorUtils::SetOutputTensor(ge_tensor_desc, false); + if (desc.name.empty()) { + op_desc->AddInputDesc(ge_tensor_desc); + } else { + op_desc->AddInputDesc(desc.name, ge_tensor_desc); + } + build_param.inputs.emplace_back(ge_tensor_desc); + } + + for (auto &desc : single_op_desc.output_desc) { + GeTensorDesc ge_tensor_desc(GeShape(desc.dims), + desc.format, + desc.type); + auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format; + auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims; + ge_tensor_desc.SetOriginFormat(ori_format_to_set); + ge_tensor_desc.SetOriginShape(GeShape(ori_dims)); + GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); + TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size()); + TensorUtils::SetInputTensor(ge_tensor_desc, false); + TensorUtils::SetOutputTensor(ge_tensor_desc, true); + if (desc.name.empty()) { + op_desc->AddOutputDesc(ge_tensor_desc); + } else { + op_desc->AddOutputDesc(desc.name, ge_tensor_desc); + } + build_param.outputs.emplace_back(ge_tensor_desc); + } + + for (const auto &attr : single_op_desc.attrs) { + op_desc->SetAttr(attr.name, attr.value); + } + + if (VerifyOpInputOutputSizeByIr(*op_desc) != SUCCESS) { + GELOGE(PARAM_INVALID, "Verify op [%s] input or output size failed.", op_desc->GetType().c_str()); + return PARAM_INVALID; + } + + build_param.file_name = GenerateFileName(single_op_desc, index); + build_param.op_desc.reset(op_desc.release()); + return SUCCESS; +} + +Status SingleOpParser::VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc) { + ge::Operator operator_ir = ge::OperatorFactory::CreateOperator("tmp_operator", current_op_desc.GetType()); + if (!operator_ir.IsEmpty()) { + auto opdesc_ir = ge::OpDescUtils::GetOpDescFromOperator(operator_ir); + GE_CHECK_NOTNULL(opdesc_ir); + size_t current_opdesc_inputs_num = current_op_desc.GetInputsSize(); + size_t ir_opdesc_inputs_num = opdesc_ir->GetInputsSize(); + if (current_opdesc_inputs_num < ir_opdesc_inputs_num) { + string reason = "is smaller than the ir needed input size " + std::to_string(ir_opdesc_inputs_num); + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {current_op_desc.GetName(), "input size " + std::to_string(current_opdesc_inputs_num), reason}); + GELOGE(PARAM_INVALID, "This op [%s] input size %zu is smaller than the ir needed input size %zu", + current_op_desc.GetName().c_str(), current_opdesc_inputs_num, ir_opdesc_inputs_num); + return PARAM_INVALID; + } + size_t current_opdesc_outputs_num = current_op_desc.GetOutputsSize(); + size_t ir_opdesc_outputs_num = opdesc_ir->GetOutputsSize(); + if (current_opdesc_outputs_num < ir_opdesc_outputs_num) { + string reason = "is smaller than the ir needed output size " + std::to_string(ir_opdesc_outputs_num); + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {current_op_desc.GetName(), "output size " + std::to_string(current_opdesc_outputs_num), reason}); + GELOGE(PARAM_INVALID, "This op [%s] output size %zu is smaller than the ir needed output size %zu", + current_op_desc.GetName().c_str(), current_opdesc_outputs_num, ir_opdesc_outputs_num); + return PARAM_INVALID; + } + } + return SUCCESS; +} + +Status SingleOpParser::SetShapeRange(const std::string &op_name, + const SingleOpTensorDesc &tensor_desc, + GeTensorDesc &ge_tensor_desc) { + auto num_shape_ranges = tensor_desc.dim_ranges.size(); + GELOGD("Number of shape ranges = %zu", num_shape_ranges); + auto it = std::find(tensor_desc.dims.begin(), tensor_desc.dims.end(), ge::UNKNOWN_DIM_NUM); + if (it != tensor_desc.dims.end()) { + if (tensor_desc.dims != ge::UNKNOWN_RANK) { + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {op_name, + "shape", + "has unknown rank but dim size is not one"}); + GELOGE(PARAM_INVALID, "Invalid tensor shape: [%s]", ge_tensor_desc.MutableShape().ToString().c_str()); + return PARAM_INVALID; + } + if (!tensor_desc.dim_ranges.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {op_name, + "shape range", + "is not needed while the rank the shape is unknown"}); + GELOGE(PARAM_INVALID, "Shape range is not needed while the rank the shape is unknown"); + return PARAM_INVALID; + } + + GELOGD("Shape is unknown rank, do not set shape range"); + return SUCCESS; + } + + std::vector> shape_range; + size_t range_index = 0; + for (auto dim : tensor_desc.dims) { + if (dim >= 0) { + shape_range.emplace_back(dim, dim); + GELOGD("Adding shape range: [%ld, %ld]", dim, dim); + } else { + GELOGD("To get shape range by index = %zu", range_index); + if (range_index >= num_shape_ranges) { + string reason = "is smaller than the unknown dim size " + std::to_string(++range_index); + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {op_name, + "shape range size " + std::to_string(num_shape_ranges), + reason}); + GELOGE(PARAM_INVALID, "The number of shape_range mismatches that of unknown dims."); + return PARAM_INVALID; + } + + auto &range = tensor_desc.dim_ranges[range_index]; + if (range.size() != kShapeRangePairSize) { + string reason = "has " + std::to_string(range.size()) + " item(s)"; + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {op_name, + "shape range " + std::to_string(range_index), + reason}); + GELOGE(PARAM_INVALID, "Invalid shape range entry. index = %zu, size = %zu", range_index, range.size()); + return PARAM_INVALID; + } + + shape_range.emplace_back(range[kShapeRangeLow], range[kShapeRangeHigh]); + GELOGD("Adding shape range: [%ld, %ld]", range[kShapeRangeLow], range[kShapeRangeHigh]); + ++range_index; + } + } + + if (num_shape_ranges != range_index) { + string reason = "is greater than the unknown dim size " + std::to_string(range_index); + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {op_name, + "shape range size " + std::to_string(num_shape_ranges), + reason}); + GELOGE(PARAM_INVALID, + "The number of shape_range(%zu) mismatches that of unknown dims(%zu).", + num_shape_ranges, + range_index); + return PARAM_INVALID; + } + + if (range_index > 0) { + ge_tensor_desc.SetShapeRange(shape_range); + } + + return SUCCESS; +} + +Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector &op_list) { + int index = 0; + try { + Json single_op_list_json; + auto ret = ReadJsonFile(file, single_op_list_json); + if (ret != SUCCESS) { + return ret; + } + + for (const Json &single_op_json : single_op_list_json) { + SingleOpDesc single_op_desc; + GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str()); + single_op_desc = single_op_json; + if (UpdateDynamicTensorName(single_op_desc.input_desc) != SUCCESS) { + GELOGE(FAILED, "Update dynamic tensor name failed!"); + return FAILED; + } + + if (!Validate(single_op_desc)) { + GELOGE(PARAM_INVALID, "Validate the index[%d] of op failed when read json file[%s].", index, file.c_str()); + return PARAM_INVALID; + } + + SingleOpBuildParam param; + ret = ConvertToBuildParam(index, single_op_desc, param); + if (ret != SUCCESS) { + return ret; + } + + op_list.emplace_back(param); + GELOGI("Parse the index[%d] of op success", index); + index += 1; + } + } catch (const nlohmann::json::exception &e) { + ErrorManager::GetInstance().ATCReportErrMessage("E10032", {"index", "jsonfile", "exception"}, + {std::to_string(index), file, e.what()}); + GELOGE(PARAM_INVALID, "Parse the index[%d] of op failed when read json file[%s], exception %s", + index, file.c_str(), e.what()); + return PARAM_INVALID; + } + + return SUCCESS; +} +} // namespace ge + diff --git a/atc/single_op_parser.h b/atc/single_op_parser.h new file mode 100644 index 0000000..71aa58b --- /dev/null +++ b/atc/single_op_parser.h @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ACL_TOOLS_COMPILE_PARSER_H +#define ACL_TOOLS_COMPILE_PARSER_H + +#include +#include + +#include + +#include "ge/ge_api_error_codes.h" +#include "graph/types.h" +#include "graph/ge_attr_value.h" +#include "graph/op_desc.h" + +namespace ge { +struct SingleOpTensorDesc { +public: + bool GetValidFlag() const { return is_valid_; } + void SetValidFlag(bool is_valid) { is_valid_ = is_valid; } +public: + std::string name; + std::vector dims; + std::vector ori_dims; + std::vector> dim_ranges; + ge::Format format = ge::FORMAT_RESERVED; + ge::Format ori_format = ge::FORMAT_RESERVED; + ge::DataType type = ge::DT_UNDEFINED; + std::string dynamic_input_name; +private: + bool is_valid_ = true; +}; + +struct SingleOpAttr { + std::string name; + std::string type; + ge::GeAttrValue value; +}; + +struct SingleOpDesc { + std::string op; + std::vector input_desc; + std::vector output_desc; + std::vector attrs; +}; + +struct SingleOpBuildParam { + ge::OpDescPtr op_desc; + std::vector inputs; + std::vector outputs; + std::string file_name; +}; + +void from_json(const nlohmann::json &json, SingleOpTensorDesc &desc); + +void from_json(const nlohmann::json &json, SingleOpAttr &desc); + +void from_json(const nlohmann::json &json, SingleOpDesc &desc); + +class SingleOpParser { + public: + static Status ParseSingleOpList(const std::string &file, std::vector &op_list); + + private: + static Status ReadJsonFile(const std::string &file, nlohmann::json &json_obj); + static bool Validate(const SingleOpDesc &op_desc); + static std::unique_ptr CreateOpDesc(const std::string &op_type); + static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param); + static Status UpdateDynamicTensorName(std::vector &desc); + static Status VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc); + static Status SetShapeRange(const std::string &op_name, + const SingleOpTensorDesc &tensor_desc, + GeTensorDesc &ge_tensor_desc); +}; +} // namespace ge + +#endif // ACL_TOOLS_COMPILE_PARSER_H diff --git a/atc/util/gflags_util.h b/atc/util/gflags_util.h new file mode 100644 index 0000000..c3d490a --- /dev/null +++ b/atc/util/gflags_util.h @@ -0,0 +1,85 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef UTIL_GFLAGS_UTIL_H_ +#define UTIL_GFLAGS_UTIL_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY _declspec(dllexport) +#else +#define GE_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_VISIBILITY +#endif +#endif + +#include +#include + +namespace ge { +class GE_FUNC_VISIBILITY GflagsUtils { + public: + static bool IsSetCommandTrue(const char *name) { + std::string out; + return gflags::GetCommandLineOption(name, &out) && out == "true"; + } + + /// + /// @brief Determines whether the parameter is empty + /// @param name name parameter name + /// @return true if empty otherwise false + /// + static bool IsSetCommandNotEmpty(const char *name) { + std::string out; + return gflags::GetCommandLineOption(name, &out) && !out.empty(); + } + + /// + /// @brief Determines whether the parameter is not default + /// @param flag_name name parameter name + /// @return true if not default otherwise false + /// + static bool IsCommandLineNotDefault(const char *flag_name) { + google::CommandLineFlagInfo info; + return GetCommandLineFlagInfo(flag_name, &info) && !info.is_default; + } + + /// + /// @brief Modify gflags to print help information + /// @param flags_h Pass in the self-defined help parameter, it is recommended to be FLAGS_h + /// @return void + /// + static void ChangeHelpFlags(bool flags_h) { + if (flags_h || IsSetCommandTrue("help") || IsSetCommandTrue("helpfull") || IsSetCommandNotEmpty("helpon") || + IsSetCommandNotEmpty("helpmatch") || IsSetCommandTrue("helppackage") || IsSetCommandTrue("helpxml")) { + gflags::SetCommandLineOption("help", "false"); + gflags::SetCommandLineOption("helpfull", "false"); + gflags::SetCommandLineOption("helpon", ""); + gflags::SetCommandLineOption("helpmatch", ""); + gflags::SetCommandLineOption("helppackage", "false"); + gflags::SetCommandLineOption("helpxml", "false"); + gflags::SetCommandLineOption("helpshort", "true"); + } + } +}; +} // namespace ge + +#endif // UTIL_GFLAGS_UTIL_H_ diff --git a/atc/util/properties_manager.cc b/atc/util/properties_manager.cc new file mode 100644 index 0000000..2c29765 --- /dev/null +++ b/atc/util/properties_manager.cc @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "properties_manager.h" + +#include +#include +#include + +#include "framework/common/debug/ge_log.h" +#include "util/util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/ge_context.h" +#include "graph/utils/attr_utils.h" + +namespace ge { +namespace atc { +PropertiesManager::PropertiesManager() : is_inited_(false), delimiter("=") {} +PropertiesManager::~PropertiesManager() {} + +// singleton +PropertiesManager &PropertiesManager::Instance() { + static PropertiesManager instance; + return instance; +} + +// Initialize property configuration +bool PropertiesManager::Init(const std::string &file_path) { + std::lock_guard lock(mutex_); + if (is_inited_) { + GELOGW("Already inited, will be initialized again"); + properties_map_.clear(); + is_inited_ = false; + return is_inited_; + } + + if (!LoadFileContent(file_path)) { + return false; + } + + is_inited_ = true; + return is_inited_; +} + +// Load file contents +bool PropertiesManager::LoadFileContent(const std::string &file_path) { + // Normalize the path + string resolved_file_path = RealPath(file_path.c_str()); + if (resolved_file_path.empty()) { + GELOGE(FAILED, "Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str()); + return false; + } + std::ifstream fs(resolved_file_path, std::ifstream::in); + + if (!fs.is_open()) { + GELOGE(PARAM_INVALID, "Open %s failed.", file_path.c_str()); + return false; + } + + std::string line; + + while (getline(fs, line)) { // line not with \n + if (!ParseLine(line)) { + GELOGE(PARAM_INVALID, "Parse line failed. content is [%s].", line.c_str()); + fs.close(); + return false; + } + } + + fs.close(); // close the file + + GELOGI("LoadFileContent success."); + return true; +} + +// Parsing the command line +bool PropertiesManager::ParseLine(const std::string &line) { + std::string temp = Trim(line); + // Comment or newline returns true directly + if (temp.find_first_of('#') == 0 || *(temp.c_str()) == '\n') { + return true; + } + + if (!temp.empty()) { + std::string::size_type pos = temp.find_first_of(delimiter); + if (pos == std::string::npos) { + GELOGE(PARAM_INVALID, "Incorrect line [%s], it must include [%s].Perhaps you use illegal chinese symbol", + line.c_str(), delimiter.c_str()); + return false; + } + + std::string map_key = Trim(temp.substr(0, pos)); + std::string value = Trim(temp.substr(pos + 1)); + if (map_key.empty() || value.empty()) { + GELOGE(PARAM_INVALID, "Map_key or value empty. %s", line.c_str()); + return false; + } + + properties_map_[map_key] = value; + } + + return true; +} + +// Remove the space and tab before and after the string +std::string PropertiesManager::Trim(const std::string &str) { + if (str.empty()) { + return str; + } + + std::string::size_type start = str.find_first_not_of(" \t\r\n"); + if (start == std::string::npos) { + return str; + } + + std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1; + return str.substr(start, end); +} + +// return properties_map_ +std::map PropertiesManager::GetPropertyMap() { + std::lock_guard lock(mutex_); + return properties_map_; +} + +// Set separator +void PropertiesManager::SetPropertyDelimiter(const std::string &de) { + std::lock_guard lock(mutex_); + delimiter = de; +} +} +} // namespace ge diff --git a/atc/util/properties_manager.h b/atc/util/properties_manager.h new file mode 100644 index 0000000..1f65245 --- /dev/null +++ b/atc/util/properties_manager.h @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef UTIL_PROPERTIES_MANAGER_H_ +#define UTIL_PROPERTIES_MANAGER_H_ + +#include +#include +#include + +namespace ge { +namespace atc { +class PropertiesManager { + public: + // Singleton + static PropertiesManager &Instance(); + + /** + * @ingroup domi_ome + * @brief Initialize configuration parameters, which must be invoked in main. + * @param [in] file_path Property profile path + * @return true success + * @return false fail + * @author + */ + bool Init(const std::string &file_path); + + /** + * @ingroup domi_ome + * @brief Return configuration parameters + * @return properties_map_ + * @author + */ + std::map GetPropertyMap(); + + /** + * @ingroup domi_ome + * @brief Adapt key value pair form, set different separators + * @param [in] delimiter + * @author + */ + void SetPropertyDelimiter(const std::string &de); + + private: + // Private construct, destructor + PropertiesManager(); + ~PropertiesManager(); + + // Get file content + bool LoadFileContent(const std::string &file_path); + + // Parsing a single line file + bool ParseLine(const std::string &line); + + // Remove space before and after string + std::string Trim(const std::string &str); + + bool is_inited_; + + // Configuration item separator, default is "=" + std::string delimiter; + + std::map properties_map_; + std::mutex mutex_; +}; +} +} // namespace ge + +#endif // UTIL_PROPERTIES_MANAGER_H_ diff --git a/atc/util/string_util.h b/atc/util/string_util.h new file mode 100644 index 0000000..f036836 --- /dev/null +++ b/atc/util/string_util.h @@ -0,0 +1,171 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_STRING_UTIL_H_ +#define INC_FRAMEWORK_COMMON_STRING_UTIL_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY _declspec(dllexport) +#else +#define GE_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_VISIBILITY +#endif +#endif + +#include +#include + +#include +#include +#include +#include +#include + +namespace ge { +class GE_FUNC_VISIBILITY StringUtils { + public: + static std::string &Ltrim(std::string &s) { +#if __cplusplus >= 201103L + (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); +#else + (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), std::not1(std::ptr_fun(std::isspace)))); +#endif + return s; + } + // lint -esym(551,*) + static std::string &Rtrim(std::string &s) { /*lint !e618*/ +#if __cplusplus >= 201103L + (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); +#else + (void)s.erase(std::find_if(s.rbegin(), s.rend(), std::not1(std::ptr_fun(std::isspace))).base(), s.end()); +#endif + return s; + } + // lint -esym(551,*) + /// + /// @ingroup domi_common + /// @brief delete spaces at the beginning and end of a string + /// @param [in] string to be trimmed + /// @return string after trim + /// + static std::string &Trim(std::string &s) { return Ltrim(Rtrim(s)); } + + /// + /// @ingroup domi_common + /// @brief string splitting + /// @param [in] str string to be trimmed + /// @param [in] delim separator + /// @return string array after segmentation + /// + static std::vector Split(const std::string &str, char delim) { + std::vector elems; + + if (str.empty()) { + elems.emplace_back(""); + return elems; + } + + std::stringstream ss(str); + std::string item; + + while (getline(ss, item, delim)) { + elems.push_back(item); + } + + auto str_size = str.size(); + if (str_size > 0 && str[str_size - 1] == delim) { + elems.emplace_back(""); + } + + return elems; + } + /// + /// @ingroup domi_common + /// @brief obtain the file name + /// @param [in] s path name + /// @return file name + /// + static std::string GetFileName(std::string &s) { + if (s.empty()) { + return ""; + } + std::vector files = StringUtils::Split(s, '/'); + + return files.empty() ? "" : files[files.size() - 1]; + } + /// + /// @ingroup domi_common + /// @brief full replacement + /// @link + /// @param [in] str str string to be replaced + /// @param [in] old_value old Characters Before Replacement + /// @param [in] new_value new Characters Before Replacement + /// @return string after replacement + /// + static std::string ReplaceAll(std::string str, const std::string &old_value, const std::string &new_value) { + std::string::size_type cur_pos = 0; + std::string::size_type old_length = old_value.length(); + std::string::size_type new_length = new_value.length(); + // cycle replace + for (; cur_pos != std::string::npos; cur_pos += new_length) { + if ((cur_pos = str.find(old_value, cur_pos)) != std::string::npos) { + (void)str.replace(cur_pos, old_length, new_value); + } else { + break; + } + } + return str; + } + + /// + /// @ingroup domi_common + /// @brief checks whether a character string starts with a character string (prefix) + /// @link + /// @param [in] str string to be compared + /// @param [in] str_x prefix + /// @return if the value is a prefix, true is returned. Otherwise, false is returned + /// + static bool StartWith(const std::string &str, const std::string str_x) { + return ((str.size() >= str_x.size()) && (str.compare(0, str_x.size(), str_x) == 0)); + } + + /// + /// @ingroup domi_common + /// @brief format string + /// @link + /// @param [in] format specifies the character string format + /// @param [in] ... format Filling Content + /// @return formatted string + /// + static std::string FormatString(const char *format, ...) { + const uint32_t MAX_BUFFER_LEN = 1024; // the stack memory plint check result must be less than 1024 + va_list args; + va_start(args, format); + char buffer[MAX_BUFFER_LEN] = {0}; + int32_t ret = vsnprintf_s(buffer, MAX_BUFFER_LEN, MAX_BUFFER_LEN - 1, format, args); + va_end(args); + return ret > 0 ? buffer : ""; + } +}; +} // namespace ge + +#endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_ diff --git a/atc/util/tool.h b/atc/util/tool.h new file mode 100644 index 0000000..9f95a46 --- /dev/null +++ b/atc/util/tool.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef UTIL_TOOL_H_ +#define UTIL_TOOL_H_ + +#include +#include +#include + +namespace ge { +template +static inline std::shared_ptr MakeShared(Args &&... args) { + typedef typename std::remove_const::type T_nc; + std::shared_ptr ret(new (std::nothrow) T_nc(std::forward(args)...)); + return ret; +} +} // namespace ge +#endif // UTIL_TOOL_H_ diff --git a/atc/util/util.cc b/atc/util/util.cc new file mode 100644 index 0000000..c8321a1 --- /dev/null +++ b/atc/util/util.cc @@ -0,0 +1,283 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util.h" + +#include +#ifdef __GNUC__ +#include +#else +#include +#endif +#include +#include +#include +#include +#include + +#include "common/util/error_manager/error_manager.h" +#include "external/ge/ge_api_error_codes.h" +#include "framework/common/debug/ge_log.h" +#include "mmpa/mmpa_api.h" + +namespace { +const int kMaxBuffSize = 256; +const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; +} // namespace + +namespace ge { +/** + * @ingroup domi_common + * @brief Create directory, support to create multi-level directory + * @param [in] directory_path Path, can be multi-level directory + * @return -1 fail + * @return 0 success + */ +int CreateDirectory(const std::string &directory_path) { + GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); + auto dir_path_len = directory_path.length(); + if (dir_path_len >= MMPA_MAX_PATH) { + ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, + {directory_path, std::to_string(MMPA_MAX_PATH)}); + GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), MMPA_MAX_PATH); + return -1; + } + char tmp_dir_path[MMPA_MAX_PATH] = {0}; + for (size_t i = 0; i < dir_path_len; i++) { + tmp_dir_path[i] = directory_path[i]; + if ((tmp_dir_path[i] == '\\') || (tmp_dir_path[i] == '/')) { + if (mmAccess2(tmp_dir_path, M_F_OK) != EN_OK) { + int32_t ret = mmMkdir(tmp_dir_path, M_IRUSR | M_IWUSR | M_IXUSR); // 700 + if (ret != 0) { + if (errno != EEXIST) { + ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); + GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); + return ret; + } + } + } + } + } + int32_t ret = mmMkdir(const_cast(directory_path.c_str()), M_IRUSR | M_IWUSR | M_IXUSR); // 700 + if (ret != 0) { + if (errno != EEXIST) { + ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); + GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); + return ret; + } + } + return 0; +} + +std::string CurrentTimeInStr() { + std::time_t now = std::time(nullptr); + std::tm *ptm = std::localtime(&now); + if (ptm == nullptr) { + GELOGE(ge::FAILED, "Localtime failed."); + return ""; + } + + const int kTimeBufferLen = 32; + char buffer[kTimeBufferLen + 1] = {0}; + // format: 20171122042550 + std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm); + return std::string(buffer); +} + +std::string RealPath(const char *path) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL."); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH, + ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, + {path, std::to_string(MMPA_MAX_PATH)}); + return "", "Path[%s] len is too long, it must be less than %d", path, MMPA_MAX_PATH); + + // Nullptr is returned when the path does not exist or there is no permission + // Return absolute path when path is accessible + std::string res; + char resolved_path[MMPA_MAX_PATH] = {0}; + if (mmRealPath(path, resolved_path, MMPA_MAX_PATH) == EN_OK) { + res = resolved_path; + } + + return res; +} + +bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param) { + // The specified path is empty + std::map args_map; + if (file_path.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); + GELOGW("Input parameter %s is empty.", file_path.c_str()); + return false; + } + std::string real_path = RealPath(file_path.c_str()); + // Unable to get absolute path (does not exist or does not have permission to access) + if (real_path.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {file_path, strerror(errno)}); + GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); + return false; + } + + // A regular matching expression to verify the validity of the input file path + // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores + // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) +#ifdef __GNUC__ + std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; +#else + std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$"; +#endif + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + !ValidateStr(real_path, mode), + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {atc_param, real_path, kPathValidReason}); + return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); + + // The absolute path points to a file that is not readable + if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) { + ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"file", "errmsg"}, {file_path.c_str(), strerror(errno)}); + GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); + return false; + } + + return true; +} + +bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param) { + // The specified path is empty + if (file_path.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); + GELOGW("Input parameter's value is empty."); + return false; + } + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH, + ErrorManager::GetInstance().ATCReportErrMessage( + "E19002", {"filepath", "size"}, {file_path, std::to_string(MMPA_MAX_PATH)}); + return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(), + MMPA_MAX_PATH); + + // A regular matching expression to verify the validity of the input file path + // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores + // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) +#ifdef __GNUC__ + std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; +#else + std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$"; +#endif + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + !ValidateStr(file_path, mode), + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {atc_param, file_path, kPathValidReason}); + return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason); + + std::string real_path = RealPath(file_path.c_str()); + // Can get absolute path (file exists) + if (!real_path.empty()) { + // File is not readable or writable + if (mmAccess2(real_path.c_str(), M_W_OK | M_F_OK) != EN_OK) { + ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"file", "errmsg"}, {real_path, strerror(errno)}); + GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno)); + return false; + } + } else { + // Find the last separator + int path_split_pos = static_cast(file_path.size() - 1); + for (; path_split_pos >= 0; path_split_pos--) { + if (file_path[path_split_pos] == '\\' || file_path[path_split_pos] == '/') { + break; + } + } + if (path_split_pos == 0) { + return true; + } + if (path_split_pos != -1) { + std::string prefix_path = std::string(file_path).substr(0, static_cast(path_split_pos)); + // Determine whether the specified path is valid by creating the path + if (CreateDirectory(prefix_path) != 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {file_path}); + GELOGW("Can not create directory[%s].", file_path.c_str()); + return false; + } + } + } + + return true; +} + +bool ValidateStr(const std::string &str, const std::string &mode) { +#ifdef __GNUC__ + char ebuff[kMaxBuffSize]; + regex_t reg; + int cflags = REG_EXTENDED | REG_NOSUB; + int ret = regcomp(®, mode.c_str(), cflags); + if (ret) { + regerror(ret, ®, ebuff, kMaxBuffSize); + GELOGW("regcomp failed, reason: %s", ebuff); + regfree(®); + return true; + } + + ret = regexec(®, str.c_str(), 0, NULL, 0); + if (ret) { + regerror(ret, ®, ebuff, kMaxBuffSize); + GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff); + regfree(®); + return false; + } + + regfree(®); + return true; +#else + std::wstring wstr(str.begin(), str.end()); + std::wstring wmode(mode.begin(), mode.end()); + std::wsmatch match; + bool res = false; + + try { + std::wregex reg(wmode, std::regex::icase); + // Matching string part + res = regex_match(wstr, match, reg); + res = regex_search(str, std::regex("[`!@#$%^&*()|{}';',<>?]")); + } catch (std::exception &ex) { + GELOGW("The directory %s is invalid, error: %s.", str.c_str(), ex.what()); + return false; + } + return !(res) && (str.size() == match.str().size()); +#endif +} + +bool GetNameFromFileName(const std::string &file_name, std::string &base_name) { + if (file_name.empty()) { + GELOGW("File path may not valid, check params --output"); + return false; + } + size_t start_position = 0; + // using output as base_name (ignore ".om") + size_t filename_suffixes = 3; + if (file_name.find_last_of('/') != std::string::npos) { + start_position = file_name.find_last_of('/') + 1; + } + size_t end_position = file_name.length() - filename_suffixes; + base_name = file_name.substr(start_position, end_position - start_position); + if (base_name.empty()) { + GELOGW("File path may not valid, check params --output"); + return false; + } + return true; +} +} // namespace ge diff --git a/atc/util/util.h b/atc/util/util.h new file mode 100644 index 0000000..441f703 --- /dev/null +++ b/atc/util/util.h @@ -0,0 +1,159 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef UTIL_UTIL_H_ +#define UTIL_UTIL_H_ + +#include +#include + +#include "framework/common/debug/ge_log.h" +#include "mmpa/mmpa_api.h" + +// For propagating errors when calling a function. +#define GE_RETURN_IF_ERROR(expr) \ + do { \ + const ::ge::Status _status = (expr); \ + if (_status) return _status; \ + } while (0) + +// If expr is not true, print the log and return the specified status +#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return _status; \ + } \ + } while (0); + +// If expr is not true, print the log and execute a custom statement +#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (!b) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + exec_expr; \ + } \ + } + +// If expr is true, print logs and execute custom statements +#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (b) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + exec_expr; \ + } \ + } + +// If expr is not SUCCESS, return the same value +#define GE_CHK_STATUS_RET_NOLOG(expr) \ + do { \ + const ge::Status _status = (expr); \ + if (_status != ge::SUCCESS) { \ + return _status; \ + } \ + } while (0); + +#define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ + do { \ + const ::ge::Status _status = (expr); \ + if (_status) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +// If expr is not SUCCESS, print the log and execute a custom statement +#define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \ + do { \ + const ge::Status _status = (expr); \ + GE_CHK_BOOL_EXEC(_status == SUCCESS, exec_expr, __VA_ARGS__); \ + } while (0); + +// Check if the parameter is null. If yes, return PARAM_INVALID and record the error +#define GE_CHECK_NOTNULL(val) \ + do { \ + if (val == nullptr) { \ + GELOGE(ge::FAILED, "param[%s] must not be null.", #val); \ + return ge::PARAM_INVALID; \ + } \ + } while (0) + +// If expr is true, execute exec_expr without printing logs +#define GE_IF_BOOL_EXEC(expr, exec_expr) \ + { \ + if (expr) { \ + exec_expr; \ + } \ + } + +namespace ge { +/// +/// @ingroup domi_common +/// @brief Recursively Creating a Directory +/// @param [in] directory_path Path, which can be a multi-level directory. +/// @return 0 success +/// @return -1 fail +/// +extern int CreateDirectory(const std::string &directory_path); + +/// +/// @ingroup domi_common +/// @brief Obtains the current time string. +/// @return Time character string in the format : %Y%m%d%H%M%S, eg: 20171011083555 +/// +std::string CurrentTimeInStr(); + +/// +/// @ingroup domi_common +/// @brief Absolute path for obtaining files. +/// @param [in] path of input file +/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned +/// +std::string RealPath(const char *path); + +/// +/// @ingroup domi_common +/// @brief Check whether the specified input file path is valid. +/// 1. The specified path cannot be empty. +/// 2. The path can be converted to an absolute path. +/// 3. The file path exists and is readable. +/// @param [in] file_path path of input file +/// @param [out] result +/// +bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param = ""); + +/// +/// @ingroup domi_common +/// @brief Checks whether the specified output file path is valid. +/// @param [in] file_path path of output file +/// @param [out] result +/// +bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param = ""); + +/// +/// @ingroup domi_common +/// @brief Check whether the file path meets the whitelist verification requirements. +/// @param [in] filePath file path +/// @param [out] result +/// +bool ValidateStr(const std::string &filePath, const std::string &mode); + +bool GetNameFromFileName(const std::string &file_name, std::string &base_name); +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_UTIL_H_ diff --git a/parser/common/convert/pb2json.cc b/parser/common/convert/pb2json.cc index af13ed2..791b8b2 100644 --- a/parser/common/convert/pb2json.cc +++ b/parser/common/convert/pb2json.cc @@ -23,6 +23,7 @@ #include "securec.h" #include "framework/common/fmk_types.h" #include "framework/common/debug/ge_log.h" +#include "parser/common/model_saver.h" using std::set; using std::string; @@ -31,6 +32,15 @@ namespace ge { namespace { const int kSignificantDigits = 10; } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status Pb2Json::ToJson(const ProtobufMsg &message, + const set &black_fields, + const char *json_file, bool enum2str) { + Json j; + Message2Json(message, black_fields, j, enum2str); + return ge::parser::ModelSaver::SaveJsonToFile(json_file, j); +} + // JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message, const set &black_fields, Json &json, diff --git a/parser/common/convert/pb2json.h b/parser/common/convert/pb2json.h index 4f8e406..4a13fe8 100644 --- a/parser/common/convert/pb2json.h +++ b/parser/common/convert/pb2json.h @@ -23,6 +23,7 @@ #include #include #include +#include "ge/ge_api_error_codes.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "nlohmann/json.hpp" @@ -51,6 +52,9 @@ class Pb2Json { const ProtobufReflection *reflection, const std::set &black_fields, Json &json, bool enum2str); + static Status ToJson(const ProtobufMsg &message, const std::set &black_fields, const char *json_file, + bool enum2str = false); + protected: static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, bool enum2str, Json &json); From 6277ec6c1097be3aced149ca334addc471c0ec5c Mon Sep 17 00:00:00 2001 From: wjm Date: Fri, 26 Feb 2021 20:13:43 +0800 Subject: [PATCH 2/3] fix --- atc/CMakeLists.txt | 2 - atc/atc_ir_common.cc | 1 + atc/main.cc | 72 ++++---- atc/parse_graph.h | 25 +-- .../graphengine/inc/external/ge/ge_ir_build.h | 159 ++++++++++++++++++ .../inc/framework/generator/ge_generator.h | 60 +++++++ .../graphengine/inc/framework/omg/ge_init.h | 39 +++++ .../inc/framework/omg/model_tool.h | 34 ++++ 8 files changed, 343 insertions(+), 49 deletions(-) create mode 100644 third_party/graphengine/inc/external/ge/ge_ir_build.h create mode 100644 third_party/graphengine/inc/framework/generator/ge_generator.h create mode 100644 third_party/graphengine/inc/framework/omg/ge_init.h create mode 100644 third_party/graphengine/inc/framework/omg/model_tool.h diff --git a/atc/CMakeLists.txt b/atc/CMakeLists.txt index 4fa9f8b..acbea71 100644 --- a/atc/CMakeLists.txt +++ b/atc/CMakeLists.txt @@ -54,7 +54,6 @@ target_include_directories(atc_atc.bin PRIVATE ${METADEF_DIR}/third_party/graphengine/inc/ ${METADEF_DIR}/third_party/graphengine/inc/external ${METADEF_DIR}/third_party/graphengine/inc/framework - ${METADEF_DIR}/third_party/graphengine/ge ${METADEF_DIR}/third_party/fwkacllib/inc ${PARSER_DIR}/third_party/graphengine/inc ) @@ -123,7 +122,6 @@ target_include_directories(fwk_atc.bin PRIVATE ${METADEF_DIR}/third_party/graphengine/inc/ ${METADEF_DIR}/third_party/graphengine/inc/external ${METADEF_DIR}/third_party/graphengine/inc/framework - ${METADEF_DIR}/third_party/graphengine/ge ${METADEF_DIR}/third_party/fwkacllib/inc ${PARSER_DIR}/third_party/graphengine/inc/ diff --git a/atc/atc_ir_common.cc b/atc/atc_ir_common.cc index 0114fb7..4e6bcfa 100644 --- a/atc/atc_ir_common.cc +++ b/atc/atc_ir_common.cc @@ -165,6 +165,7 @@ bool CheckDynamicImageSizeShape(const vector &shape, const string &data return false; } } + bool CheckDynamicImagesizeInputShapeValid(map> shape_map, const std::string input_format, std::string &dynamic_image_size) { if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(input_format.c_str())) { diff --git a/atc/main.cc b/atc/main.cc index 620b18f..4847337 100644 --- a/atc/main.cc +++ b/atc/main.cc @@ -63,19 +63,18 @@ using std::shared_ptr; using std::string; using std::vector; +namespace { static bool is_dynamic_input = false; - const char *const kModeSupport = "only support 0(model to framework model), " "1(framework model to json), 3(only pre-check), " "5(pbtxt to json), 6(display model info)"; const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow) 5(Onnx)"; - -static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; -static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; -static const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model"; - +const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; +const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; +const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model"; // limit available mem size 2G const long kMinAvailableMem = 2097152; // 2 * 1024 * 1024 +} // namespace DEFINE_string(model, "", "The model file."); DEFINE_string(output, "", "The output file path&name."); @@ -523,12 +522,12 @@ class GFlagUtils { static bool CheckEncryptModeValid(const int encrypt_mode) { #if !defined(__ANDROID__) && !defined(ANDROID) if (encrypt_mode != 0 && encrypt_mode != -1) { - GELOGE(ge::FAILED, "encrypt mode must be 0 or -1"); + GELOGE(ge::FAILED, "encrypt mode must be 0 or -1"); return false; } #else if (encrypt_mode != -1) { - GELOGE(ge::FAILED, "encrypt mode must be -1"); + GELOGE(ge::FAILED, "encrypt mode must be -1"); return false; } #endif @@ -543,14 +542,14 @@ class GFlagUtils { ErrorManager::GetInstance().ATCReportErrMessage( "E10007", {"parameter", "support"}, {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)"}); - GELOGE(ge::FAILED, "Input parameter[--framework] is mandatory and it's value must be: " - "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)."); + GELOGE(ge::FAILED, "Input parameter[--framework] is mandatory and it's value must be: " + "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)."); return domi::PARAM_INVALID; } if ((framework == (int32_t)domi::CAFFE) && (weight_file == "")) { ErrorManager::GetInstance().ATCReportErrMessage("E10008", {"parameter"}, {"weight"}); - GELOGE(ge::FAILED, "Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!"); + GELOGE(ge::FAILED, "Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!"); return domi::PARAM_INVALID; } @@ -584,7 +583,7 @@ class GFlagUtils { // Failure if no filename follows the path if (slashPosition == static_cast(fileName.size() - 1)) { ErrorManager::GetInstance().ATCReportErrMessage("E10022", {"parameter", "filename"}, {"output", fileName}); - GELOGE(ge::FAILED, "Input parameter[--output]'s path[%s] not include file name", fileName.c_str()); + GELOGE(ge::FAILED, "Input parameter[--output]'s path[%s] not include file name", fileName.c_str()); return false; } @@ -628,7 +627,7 @@ static bool CheckInputFormat() { ErrorManager::GetInstance().ATCReportErrMessage( "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport}); GELOGE(ge::FAILED, - "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kCaffeFormatSupport); + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kCaffeFormatSupport); return false; } else if ((FLAGS_framework == static_cast(domi::TENSORFLOW))) { // tf if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) { @@ -638,7 +637,7 @@ static bool CheckInputFormat() { ErrorManager::GetInstance().ATCReportErrMessage( "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport}); GELOGE(ge::FAILED, - "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kTFFormatSupport); + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kTFFormatSupport); return false; } else if (FLAGS_framework == static_cast(domi::ONNX)) { if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) { @@ -648,7 +647,7 @@ static bool CheckInputFormat() { ErrorManager::GetInstance().ATCReportErrMessage( "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport}); GELOGE(ge::FAILED, - "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kONNXFormatSupport); + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kONNXFormatSupport); return false; } return true; @@ -804,7 +803,7 @@ Status CreateInputsForInference(const ge::Graph &graph, vector &in GE_CHECK_NOTNULL(input_node); ge::OpDescPtr op = input_node->GetOpDesc(); GE_CHECK_NOTNULL(op); - if (op->GetType() == "Data") { + if (op->GetType() == ge::DATA) { GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size()); ge::GeTensorDesc tensor = op->GetInputDesc(0); string data_op_name = op->GetName(); @@ -845,7 +844,7 @@ domi::Status GenerateInfershapeJson() { std::map options; ge::Status geRet = ge_generator.Initialize(options, domi::GetContext()); if (geRet != ge::SUCCESS) { - GELOGE(ge::FAILED, "GeGenerator initialize failed!"); + GELOGE(ge::FAILED, "GeGenerator initialize failed!"); return domi::FAILED; } @@ -855,19 +854,19 @@ domi::Status GenerateInfershapeJson() { ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework, "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false); if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED"); + GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED"); (void)ge_generator.Finalize(); return domi::FAILED; } geRet = ge_generator.GenerateInfershapeGraph(graph); if (geRet != ge::SUCCESS) { - GELOGE(ge::FAILED, "ATC GenerateInfershapeJson failed"); + GELOGE(ge::FAILED, "ATC GenerateInfershapeJson failed"); (void)ge_generator.Finalize(); return domi::FAILED; } if (DumpInfershapeJson(graph, FLAGS_json.c_str()) != SUCCESS) { - GELOGE(ge::FAILED, "ATC DumpInfershapeJson failed"); + GELOGE(ge::FAILED, "ATC DumpInfershapeJson failed"); (void)ge_generator.Finalize(); return domi::FAILED; } @@ -935,12 +934,12 @@ domi::Status GenerateModel(std::map &options, std::string output ge::Status geRet = ge::SUCCESS; geRet = ge::GEInit::Initialize(options); if (geRet != ge::SUCCESS) { - GELOGE(ge::FAILED, "GE initialize failed!"); + GELOGE(ge::FAILED, "GE initialize failed!"); return domi::FAILED; } geRet = ge_generator.Initialize(options, domi::GetContext()); if (geRet != ge::SUCCESS) { - GELOGE(ge::FAILED, "GeGenerator initialize failed!"); + GELOGE(ge::FAILED, "GeGenerator initialize failed!"); (void)ge::GEInit::Finalize(); return domi::FAILED; } @@ -953,8 +952,8 @@ domi::Status GenerateModel(std::map &options, std::string output auto ret1 = load_model.LoadFromFile(FLAGS_model); if (ret1 != ge::GRAPH_SUCCESS) { ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model}); - GELOGE(ge::FAILED, "Load model from %s failed, please check model file or " - "input parameter[--framework] is correct", FLAGS_model.c_str()); + GELOGE(ge::FAILED, "Load model from %s failed, please check model file or " + "input parameter[--framework] is correct", FLAGS_model.c_str()); (void)ge_generator.Finalize(); (void)ge::GEInit::Finalize(); return domi::FAILED; @@ -995,21 +994,21 @@ domi::Status GenerateModel(std::map &options, std::string output (void)ge_generator.Finalize(); (void)ge::GEInit::Finalize(); if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "ATC precheck fail."); + GELOGE(ge::FAILED, "ATC precheck fail."); return domi::FAILED; } return domi::SUCCESS; } if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED"); - GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case + GELOGE(ge::FAILED, "ATC Parse graph domi::FAILED"); + GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case (void)ge_generator.Finalize(); (void)ge::GEInit::Finalize(); return domi::FAILED; } if (ge::SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) { - GELOGE(ge::FAILED, "Set output node info fail."); + GELOGE(ge::FAILED, "Set output node info fail."); (void)ge_generator.Finalize(); (void)ge::GEInit::Finalize(); return domi::FAILED; @@ -1024,8 +1023,8 @@ domi::Status GenerateModel(std::map &options, std::string output geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); if (geRet != ge::SUCCESS) { - GELOGE(ge::FAILED, "GE GenerateOfflineModel execute failed"); - GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case + GELOGE(ge::FAILED, "GE GenerateOfflineModel execute failed"); + GELOGE(ge::FAILED, "ATC Generate execute failed"); // Duplicate log. (for test case // checking error log) (void)ge_generator.Finalize(); (void)ge::GEInit::Finalize(); @@ -1061,7 +1060,7 @@ static void SetEnvForSingleOp(std::map &options) { domi::Status GenerateSingleOp(const std::string& json_file_path) { if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) { - GELOGE(ge::FAILED, "output path %s is not valid!", FLAGS_output.c_str()); + GELOGE(ge::FAILED, "output path %s is not valid!", FLAGS_output.c_str()); return domi::FAILED; } // check optypelist_for_implmode and op_select_implmode @@ -1075,21 +1074,21 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) { auto ret = ge::GEInit::Initialize(options); if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "GE initialize failed!"); + GELOGE(ge::FAILED, "GE initialize failed!"); return domi::FAILED; } ge::GeGenerator generator; ret = generator.Initialize(options, domi::GetContext()); if (ret != SUCCESS) { - GELOGE(ge::FAILED, "GeGenerator initialize failed!"); + GELOGE(ge::FAILED, "GeGenerator initialize failed!"); (void)ge::GEInit::Finalize(); return domi::FAILED; } vector build_params; if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) { - GELOGE(ge::FAILED, "parse single op json file failed"); + GELOGE(ge::FAILED, "parse single op json file failed"); (void)generator.Finalize(); (void)ge::GEInit::Finalize(); return domi::FAILED; @@ -1104,7 +1103,7 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) { output_path += param.file_name; ret = generator.BuildSingleOpModel(param.op_desc, param.inputs, param.outputs, output_path); if (ret != SUCCESS) { - GELOGE(ge::FAILED, "Compile op failed. ge ret = %u, op index = %d", ret, index); + GELOGE(ge::FAILED, "Compile op failed. ge ret = %u, op index = %d", ret, index); ret = domi::FAILED; break; } @@ -1320,10 +1319,11 @@ int init(int argc, char* argv[]) { std::string path_base = ge::GEInit::GetPath(); ret = ErrorManager::GetInstance().Init(path_base); if (ret != 0) { - GELOGE(ge::FAILED, "ErrorManager init fail !"); + GELOGE(ge::FAILED, "ErrorManager init fail !"); return ret; } + ErrorManager::GetInstance().GenWorkStreamIdDefault(); return 0; } diff --git a/atc/parse_graph.h b/atc/parse_graph.h index 5e04c73..ad59d34 100644 --- a/atc/parse_graph.h +++ b/atc/parse_graph.h @@ -28,7 +28,6 @@ #include "graph/compute_graph.h" #include "graph/graph.h" #include "graph/model.h" -//#include "runtime/kernel.h" using domi::Status; using std::pair; @@ -42,8 +41,8 @@ namespace ge { * @brief init omg context * @return void */ -GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, - bool is_dynamic_input); +GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format, + const string &net_format, bool is_dynamic_input); /** * @ingroup domi_omg @@ -60,9 +59,10 @@ GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const st * @param [in] atc_params multiply atc params * @return Status result code */ -GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map &atc_params, const char *model_file, - const char *weights_file, domi::FrameworkType type, const char *op_conf = nullptr, - const char *target = nullptr, RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false); +GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map &atc_params, + const char *model_file, const char *weights_file, domi::FrameworkType type, + const char *op_conf = nullptr, const char *target = nullptr, + RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false); /** * @ingroup domi_omg @@ -84,7 +84,8 @@ GE_FUNC_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const char * @param [key] encrypted key * @return Status result code */ -GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, const char *json_file); +GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, + const char *json_file); GE_FUNC_VISIBILITY void GetGroupName(ge::proto::ModelDef &model); @@ -92,17 +93,19 @@ GE_FUNC_VISIBILITY void FindParserSo(const string &path, vector &fileLis GE_FUNC_VISIBILITY Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); -GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); +GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, + const std::string &output_format); -GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node, std::vector> &output_nodes_info); +GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node, + std::vector> &output_nodes_info); GE_FUNC_VISIBILITY void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, - std::vector &output_nodes_name); + std::vector &output_nodes_name); GE_FUNC_VISIBILITY void UpdateOmgCtxWithParserCtx(); GE_FUNC_VISIBILITY void UpdateParserCtxWithOmgCtx(); -GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def); +GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def, uint32_t modeldef_size); } // namespace ge #endif // PARSE_GRAPH_H_ diff --git a/third_party/graphengine/inc/external/ge/ge_ir_build.h b/third_party/graphengine/inc/external/ge/ge_ir_build.h new file mode 100644 index 0000000..04e059a --- /dev/null +++ b/third_party/graphengine/inc/external/ge/ge_ir_build.h @@ -0,0 +1,159 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd + +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef INC_EXTERNAL_GE_IR_BUILD_H_ +#define INC_EXTERNAL_GE_IR_BUILD_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY _declspec(dllexport) +#else +#define GE_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_VISIBILITY +#endif +#endif + +#include +#include +#include +#include "graph/graph.h" +#include "graph/ge_error_codes.h" + +namespace { +const int IR_MAJOR_VERSION = 1; +const int IR_MINOR_VERSION = 0; +const int IR_PATCH_VERSION = 0; +} // namespace + +namespace ge { + +struct ModelBufferData { + std::shared_ptr data = nullptr; + uint64_t length; +}; + +enum aclgrphAttrType { ATTR_TYPE_KEEP_DTYPE = 0, ATTR_TYPE_WEIGHT_COMPRESS }; + +/** + * @ingroup AscendCL + * @brief build model.Notice the model is stored in buffer + * + * @param global_options[IN] global init params for build + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map &)) +GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map global_options); + +GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map &global_options); + +/** + * @ingroup AscendCL + * @brief build model.Notice the model is stored in buffer + * + */ +GE_FUNC_VISIBILITY void aclgrphBuildFinalize(); + +/** + * @ingroup AscendCL + * @brief build model.Notice the model is stored in buffer + * + * @param graph[IN] the graph ready to build + * @param options[IN] options used for build + * @param model[OUT] builded model + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &, + const std::map &, + ModelBufferData &)) +GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, + const std::map &build_options, + ModelBufferData &model); + +GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, + const std::map &build_options, + ModelBufferData &model); + +/** + * @ingroup AscendCL + * @brief save model buffer to file + * + * @param output_file[IN] the file path to be saved + * @param model[IN] model buffer data + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char *, const ModelBufferData &)) +GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model); + +GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model); + +/** + * @ingroup AscendCL + * @brief query IR interface version + * + * @param major_version[OUT] IR interface major version + * @param minor_version[OUT] IR interface minor version + * @param patch_version[OUT] IR interface patch version + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version); + +/** + * @ingroup AscendCL + * @brief dump graph + * + * @param graph[IN] the graph ready to build + * @param file[IN] file path + * @param file[IN] file path string len + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len); + +/** + * @ingroup AscendCL + * @brief create single op graph + * + * @param op_type[IN] the op_type + * @param inputs[IN] the inputdesc + * @param outputs[IN] the outputdesc + * @param graph[OUT] the graph + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector &inputs, + const std::vector &outputs, Graph &graph); + +/** + * @name aclgrphSetOpAttr + * @brief set attribute for operators in the configuration file + * @param graph [IN/OUT] compute graph + * @param attr_type [In] attribute type + * @param cfg_path [IN] the config file path + * @return graphStatus + */ +GE_FUNC_VISIBILITY graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path); + +}; // namespace ge +#endif // INC_EXTERNAL_GE_IR_BUILD_H_ diff --git a/third_party/graphengine/inc/framework/generator/ge_generator.h b/third_party/graphengine/inc/framework/generator/ge_generator.h new file mode 100644 index 0000000..1b3cade --- /dev/null +++ b/third_party/graphengine/inc/framework/generator/ge_generator.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_ +#define INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_ + +#include +#include +#include +#include +#include "common/ge_inner_error_codes.h" +#include "graph/ge_tensor.h" +#include "graph/graph.h" +#include "graph/op_desc.h" +#include "graph/detail/attributes_holder.h" +#include "omg/omg_inner_types.h" + +namespace ge { +class GE_FUNC_VISIBILITY GeGenerator { + public: + GeGenerator() = default; + + ~GeGenerator() { (void)Finalize(); } + + GeGenerator(const GeGenerator &) = delete; + + GeGenerator &operator=(const GeGenerator &) = delete; + + Status Initialize(const std::map &options, OmgContext &context); + + Status Finalize(); + + Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix, + const std::vector &inputs = std::vector()); + + Status GenerateInfershapeGraph(const Graph &graph); + + Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector &inputs, + const std::vector &outputs, const std::string &model_file_name); + private: + class Impl; + + std::shared_ptr impl_; +}; +} // namespace ge + +#endif // INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_ diff --git a/third_party/graphengine/inc/framework/omg/ge_init.h b/third_party/graphengine/inc/framework/omg/ge_init.h new file mode 100644 index 0000000..b0cd0d6 --- /dev/null +++ b/third_party/graphengine/inc/framework/omg/ge_init.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_GE_INIT_H_ +#define INC_FRAMEWORK_OMG_GE_INIT_H_ +#include +#include +#include "common/ge_inner_error_codes.h" + +using std::string; +using std::map; + +namespace ge { +class GE_FUNC_VISIBILITY GEInit { + public: + // GE Environment Initialize, return Status: SUCCESS,FAILED + static Status Initialize(const map &options); + + static string GetPath(); + + // GE Environment Finalize, return Status: SUCCESS,FAILED + static Status Finalize(); +}; +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_GE_INIT_H_ diff --git a/third_party/graphengine/inc/framework/omg/model_tool.h b/third_party/graphengine/inc/framework/omg/model_tool.h new file mode 100644 index 0000000..93c4e68 --- /dev/null +++ b/third_party/graphengine/inc/framework/omg/model_tool.h @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_MODEL_TOOL_H_ +#define INC_FRAMEWORK_OMG_MODEL_TOOL_H_ + +#include +#include + +#include "framework/common/debug/ge_log.h" +#include "proto/ge_ir.pb.h" + +namespace ge { +class GE_FUNC_VISIBILITY ModelTool { + public: + static Status GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef &model_def, uint32_t &modeldef_size); + + static Status GetModelInfoFromPbtxt(const char *model_file, ge::proto::ModelDef &model_def); +}; +} // namespace ge +#endif // INC_FRAMEWORK_OMG_MODEL_TOOL_H_ From 8f50c0175d25ccd6c2ef09696c29699932c26107 Mon Sep 17 00:00:00 2001 From: wjm Date: Fri, 26 Feb 2021 20:16:07 +0800 Subject: [PATCH 3/3] fix --- third_party/graphengine/inc/framework/omg/ge_init.h | 7 ++----- third_party/graphengine/inc/framework/omg/model_tool.h | 3 ++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/third_party/graphengine/inc/framework/omg/ge_init.h b/third_party/graphengine/inc/framework/omg/ge_init.h index b0cd0d6..42fd897 100644 --- a/third_party/graphengine/inc/framework/omg/ge_init.h +++ b/third_party/graphengine/inc/framework/omg/ge_init.h @@ -20,16 +20,13 @@ #include #include "common/ge_inner_error_codes.h" -using std::string; -using std::map; - namespace ge { class GE_FUNC_VISIBILITY GEInit { public: // GE Environment Initialize, return Status: SUCCESS,FAILED - static Status Initialize(const map &options); + static Status Initialize(const std::map &options); - static string GetPath(); + static std::string GetPath(); // GE Environment Finalize, return Status: SUCCESS,FAILED static Status Finalize(); diff --git a/third_party/graphengine/inc/framework/omg/model_tool.h b/third_party/graphengine/inc/framework/omg/model_tool.h index 93c4e68..8c42582 100644 --- a/third_party/graphengine/inc/framework/omg/model_tool.h +++ b/third_party/graphengine/inc/framework/omg/model_tool.h @@ -27,8 +27,9 @@ namespace ge { class GE_FUNC_VISIBILITY ModelTool { public: static Status GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef &model_def, uint32_t &modeldef_size); - + static Status GetModelInfoFromPbtxt(const char *model_file, ge::proto::ModelDef &model_def); }; } // namespace ge + #endif // INC_FRAMEWORK_OMG_MODEL_TOOL_H_