Browse Source

Pre Merge pull request !250 from wangjiming/master

pull/250/MERGE
wangjiming Gitee 4 years ago
parent
commit
b39ee37e2b
25 changed files with 5432 additions and 0 deletions
  1. +8
    -0
      CMakeLists.txt
  2. +167
    -0
      atc/CMakeLists.txt
  3. +21
    -0
      atc/atc
  4. +596
    -0
      atc/atc_ir_common.cc
  5. +82
    -0
      atc/atc_ir_common.h
  6. +23
    -0
      atc/common/types.cc
  7. +32
    -0
      atc/common/types.h
  8. +1428
    -0
      atc/main.cc
  9. +1004
    -0
      atc/parse_graph.cc
  10. +111
    -0
      atc/parse_graph.h
  11. +609
    -0
      atc/single_op_parser.cc
  12. +90
    -0
      atc/single_op_parser.h
  13. +85
    -0
      atc/util/gflags_util.h
  14. +145
    -0
      atc/util/properties_manager.cc
  15. +82
    -0
      atc/util/properties_manager.h
  16. +171
    -0
      atc/util/string_util.h
  17. +32
    -0
      atc/util/tool.h
  18. +283
    -0
      atc/util/util.cc
  19. +159
    -0
      atc/util/util.h
  20. +10
    -0
      parser/common/convert/pb2json.cc
  21. +4
    -0
      parser/common/convert/pb2json.h
  22. +159
    -0
      third_party/graphengine/inc/external/ge/ge_ir_build.h
  23. +60
    -0
      third_party/graphengine/inc/framework/generator/ge_generator.h
  24. +36
    -0
      third_party/graphengine/inc/framework/omg/ge_init.h
  25. +35
    -0
      third_party/graphengine/inc/framework/omg/model_tool.h

+ 8
- 0
CMakeLists.txt View File

@@ -19,6 +19,7 @@ if (ENABLE_OPEN_SRC)
include(cmake/external_libs/protoc.cmake)
include(cmake/external_libs/securec.cmake)
include(cmake/external_libs/json.cmake)
include(cmake/external_libs/gflags.cmake)
include(cmake/FindModule.cmake)
include(cmake/intf_pub_linux.cmake)

@@ -38,6 +39,9 @@ if (ENABLE_OPEN_SRC)
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH})
find_module(slog libalog.so ${GE_LIB_PATH})
find_module(static_mmpa libmmpa.a ${GE_LIB_PATH})
find_module(ge_compiler libge_compiler.so ${GE_LIB_PATH})
find_module(ge_runner libge_runner.so ${GE_LIB_PATH})
find_module(ge_common libge_common.so ${GE_LIB_PATH})
elseif(ENABLE_GE_COV OR ENABLE_GE_UT)
message(STATUS "Runing on llt mode, no need to depend other component")
else()
@@ -51,6 +55,9 @@ if (ENABLE_OPEN_SRC)

find_module(slog libalog.so ${ASCEND_ATC_DIR})
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR})
find_module(ge_compiler libge_compiler.so ${ASCEND_ATC_DIR})
#find_module(ge_runner libge_runner.so ${ASCEND_RUNTIME_DIR})
find_module(ge_common libge_common.so ${ASCEND_ATC_DIR})
endif()

if (NOT DEFINED METADEF_DIR)
@@ -69,3 +76,4 @@ add_subdirectory(parser/common)
add_subdirectory(parser/func_to_graph)
add_subdirectory(parser/onnx)
add_subdirectory(parser/proto/caffe)
#add_subdirectory(atc)

+ 167
- 0
atc/CMakeLists.txt View File

@@ -0,0 +1,167 @@
set(PROTO_LIST
"${METADEF_DIR}/proto/om.proto"
"${METADEF_DIR}/proto/ge_ir.proto"
"${METADEF_DIR}/proto/insert_op.proto"
"${METADEF_DIR}/proto/task.proto"
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})

set(SRC_LIST
"main.cc"
"single_op_parser.cc"
"parse_graph.cc"
"atc_ir_common.cc"
"common/types.cc"
"util/util.cc"
"util/properties_manager.cc"
)

############ atc_atc.bin ############
add_executable(atc_atc.bin ${SRC_LIST} ${PROTO_HDRS})

target_compile_options(atc_atc.bin PRIVATE
-Werror
-O2
-Wno-deprecated-declarations
-fno-common
-fvisibility=hidden
)

target_compile_definitions(atc_atc.bin PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
COMPILE_OMG_PACKAGE
google=ascend_private
LOG_CPP
FUNC_VISIBILITY
)

target_include_directories(atc_atc.bin PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/register
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/external/register
${PARSER_DIR}
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
#### yellow zone ####
${GE_CODE_DIR}/../inc
${GE_CODE_DIR}/../inc/common
#### blue zone ####
${METADEF_DIR}/third_party/graphengine/inc/
${METADEF_DIR}/third_party/graphengine/inc/external
${METADEF_DIR}/third_party/graphengine/inc/framework
${METADEF_DIR}/third_party/fwkacllib/inc
${PARSER_DIR}/third_party/graphengine/inc
)

target_link_options(atc_atc.bin PRIVATE
-Wl,-Bsymbolic
)

target_link_libraries(atc_atc.bin PRIVATE
$<BUILD_INTERFACE:intf_pub>
ascend_protobuf
ge_common
register
c_sec
graph
error_manager
ge_compiler
parser_common
gflags
json
slog
static_mmpa
-lrt
-ldl
)

set_target_properties(atc_atc.bin PROPERTIES
OUTPUT_NAME atc.bin
RUNTIME_OUTPUT_DIRECTORY atclib
)

############ fwk_atc.bin ############
add_executable(fwk_atc.bin ${SRC_LIST} ${PROTO_HDRS})

target_compile_options(fwk_atc.bin PRIVATE
-Werror
-O2
-Wno-deprecated-declarations
-fno-common
-fvisibility=hidden
)

target_compile_definitions(fwk_atc.bin PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
COMPILE_OMG_PACKAGE
google=ascend_private
LOG_CPP
FUNC_VISIBILITY
)

target_include_directories(fwk_atc.bin PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/register
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/external/register
${PARSER_DIR}
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
#### yellow zone ####
${GE_CODE_DIR}/../inc
${GE_CODE_DIR}/../inc/common
#### blue zone ####
${METADEF_DIR}/third_party/graphengine/inc/
${METADEF_DIR}/third_party/graphengine/inc/external
${METADEF_DIR}/third_party/graphengine/inc/framework
${METADEF_DIR}/third_party/fwkacllib/inc
${PARSER_DIR}/third_party/graphengine/inc/

)

target_link_options(fwk_atc.bin PRIVATE
-Wl,-Bsymbolic
)

target_link_libraries(fwk_atc.bin PRIVATE
$<BUILD_INTERFACE:intf_pub>
ascend_protobuf
ge_common
register
c_sec
graph
error_manager
ge_runner
parser_common
gflags
json
slog
static_mmpa
-lrt
-ldl
)

set_target_properties(fwk_atc.bin PROPERTIES
OUTPUT_NAME atc.bin
RUNTIME_OUTPUT_DIRECTORY fwkacl
)

############ install ############
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(TARGETS atc_atc.bin OPTIONAL
RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/atclib
)

install(TARGETS fwk_atc.bin OPTIONAL
RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/fwkacl
)

+ 21
- 0
atc/atc View File

@@ -0,0 +1,21 @@
#!/bin/bash
#-------------------------------------------------------------------
# Purpose:
# Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved.
#-------------------------------------------------------------------

real_path=$(readlink "$0")
if [ $? -eq 0 ]; then
LOCAL_PATH=$(cd "$(dirname "$real_path")"; pwd)
else
LOCAL_PATH=$(cd "$(dirname "$0")"; pwd)
fi
PKG_PATH=$(cd ${LOCAL_PATH}/..; pwd)
LIB_P="/lib64"
PYTHON_P="/python/site-packages"
LIB64_PATH="${PKG_PATH}${LIB_P}"
PYTHON_PATH="${PKG_PATH}${PYTHON_P}"
export LD_LIBRARY_PATH="${LIB64_PATH}:${LD_LIBRARY_PATH}"
export PYTHONPATH="${PYTHON_PATH}:${PYTHONPATH}"

${PKG_PATH}/bin/atc.bin "$@"

+ 596
- 0
atc/atc_ir_common.cc View File

@@ -0,0 +1,596 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "atc_ir_common.h"
#include "common/util/error_manager/error_manager.h"
#include "external/ge/ge_api_types.h"
#include "util/string_util.h"
#include "util/util.h"
#include "graph/utils/type_utils.h"

using std::pair;
using std::string;
using std::vector;

namespace ge {
namespace {
const int64_t kDynamicInputDim = -1;
const int64_t kDynamicImageSizeNum = 2;
const uint32_t NCHW_DIM_H = 2;
const uint32_t NCHW_DIM_W = 3;
const uint32_t NHWC_DIM_H = 1;
const uint32_t NHWC_DIM_W = 2;
const int32_t DIM_DEFAULT_SIZE = 4;
const size_t kMaxDynamicDimNum = 100;
const size_t kMaxNDDimNum = 4;
const size_t kMinNDDimNum = 1;
// datatype/formats from user to GE, Unified to util interface file later
const std::map<std::string, ge::DataType> kOutputTypeSupportDatatype = {
{"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}};
const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8";
const std::set<std::string> kBufferOptimizeSupportOption = {"l1_optimize", "l2_optimize", "off_optimize",
"l1_and_l2_optimize"};
// The function is incomplete. Currently, only l2_optimize, off_optimize is supported.
const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision";
const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\"";
const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\"";
const char *const kSplitError1 = "size not equal to 2 split by \":\"";
const char *const kEmptyError = "can not be empty";
const char *const kFloatNumError = "exist float number";
const char *const kDigitError = "is not digit";
const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]";
const char *const kSelectImplmodeError = "only support high_performance, high_precision";
const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \"";
const char *const kKeepDtypeError = "file not found";

vector<string> SplitInputShape(const std::string &input_shape) {
vector<string> shape_pair_vec;
size_t pos = input_shape.rfind(":");
if (pos != std::string::npos) {
shape_pair_vec.emplace_back(input_shape.substr(0, pos));
shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos));
}
return shape_pair_vec;
}
} // namespace

Status CheckInputFormat(const string &input_format) {
if (input_format.empty()) {
return ge::SUCCESS;
}
if (!ge::TypeUtils::IsFormatValid(input_format.c_str())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, "input format is invalid!"});
GELOGE(ge::PARAM_INVALID, "input format [%s] is invalid!", input_format.c_str());
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

bool CheckDynamicBatchShape(const vector<int64_t> &shape, const string &data_name) {
if (shape[0] == kDynamicInputDim) {
for (size_t i = 1; i < shape.size(); ++i) {
if (shape[i] < 1) {
ErrorManager::GetInstance().ATCReportErrMessage("E10018", {"index", "shape"},
{std::to_string(i), std::to_string(shape[i])});
GELOGE(ge::PARAM_INVALID,
"Only batch N can be -1 when set --dynamic_batch_size, current data: %s shape[%zu] is %ld",
data_name.c_str(), i, shape[i]);
return false;
}
}
return true;
} else {
return false;
}
}

bool CheckDynamicBatchSizeInputShapeValid(map<string, vector<int64_t>> shape_map,
std::string &dynamic_batch_size) {
int32_t size = 0;
for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) {
vector<int64_t> shape = iter->second;
if (shape.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10012");
GELOGE(ge::PARAM_INVALID, "--input_shape's shape size can not be less than 1 when set --dynamic_batch_size.");
return false;
}

if (std::count(shape.begin(), shape.end(), kDynamicInputDim) == 0) {
continue;
}

bool ret = CheckDynamicBatchShape(shape, iter->first);
if (ret) {
size++;
}
}

if (size == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10031");
GELOGE(ge::PARAM_INVALID, "At least one batch n must be equal to -1 when set --dynamic_batch_size.");
return false;
}

for (char c : dynamic_batch_size) {
if (!isdigit(c) && (c != ',') && (c != ' ')) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10033", {"value", "reason"}, {dynamic_batch_size, kDynamicBatchSizeError});
GELOGE(ge::PARAM_INVALID, "Input parameter[--dynamic_batch_size]'s value[%s] is invalid. reason: %s",
dynamic_batch_size.c_str(), kDynamicBatchSizeError);
return false;
}
}
if (dynamic_batch_size.back() == ',') {
dynamic_batch_size.erase(dynamic_batch_size.end() - 1);
}
return true;
}

bool CheckDynamicImageSizeShape(const vector<int64_t> &shape, const string &data_name,
const std::string &input_format) {
int64_t height = 0;
int64_t width = 0;
if (input_format == "NCHW") {
height = shape[NCHW_DIM_H];
width = shape[NCHW_DIM_W];
}

if (input_format == "NHWC") {
height = shape[NHWC_DIM_H];
width = shape[NHWC_DIM_W];
}

if (height == kDynamicInputDim && width == kDynamicInputDim &&
std::count(shape.begin(), shape.end(), kDynamicInputDim) == kDynamicImageSizeNum) {
return true;
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E10019");
GELOGE(ge::PARAM_INVALID,
"--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size.");
return false;
}
}

bool CheckDynamicImagesizeInputShapeValid(map<string, vector<int64_t>> shape_map,
const std::string input_format, std::string &dynamic_image_size) {
if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(input_format.c_str())) {
GELOGE(ge::PARAM_INVALID, "user input format [%s] is not found!", input_format.c_str());
return false;
}
int32_t size = 0;
for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) {
vector<int64_t> shape = iter->second;
// only support four dim
if (shape.size() != DIM_DEFAULT_SIZE) {
if (std::count(shape.begin(), shape.end(), kDynamicInputDim) > 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10019");
GELOGE(ge::PARAM_INVALID,
"--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size.");
return false;
}
continue;
}

if (std::count(shape.begin(), shape.end(), kDynamicInputDim) == 0) {
continue;
}
auto ret = CheckDynamicImageSizeShape(shape, iter->first, input_format);
if (ret) {
size++;
} else {
return ret;
}
}
if (size == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10019");
GELOGE(ge::PARAM_INVALID,
"--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size.");
return false;
}

EraseEndSemicolon(dynamic_image_size);
// Different parameter sets are split string by ';'
std::vector<std::string> split_set = StringUtils::Split(dynamic_image_size, ';');
// Different dimensions are split by ','
std::vector<std::string> split_dim;
for (auto str : split_set) {
split_dim = StringUtils::Split(str, ',');
if (split_dim.size() != static_cast<size_t>(kDynamicImageSizeNum)) {
ErrorManager::GetInstance().ATCReportErrMessage("E10020", {"DynamicImageSizeNum"},
{std::to_string(kDynamicImageSizeNum)});
GELOGE(ge::PARAM_INVALID,
"--dynamic_image_size's number of dimensions of each "
"group must be %ld.",
kDynamicImageSizeNum);
return false;
}
}

return true;
}

bool CheckDynamicDimsInputShapeValid(const map<string, vector<int64_t>> &shape_map,
string input_format, string &dynamic_dims) {
if (input_format != "ND") {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--input_format", input_format.c_str(), "input_format must be ND when set dynamic_dims"});
GELOGE(ge::PARAM_INVALID, "input_format must be ND when set dynamic_dims.");
return false;
}

int32_t dynamic_dim = 0;
for (auto &info_shapes : shape_map) {
auto &shapes = info_shapes.second;
if (shapes.size() > kMaxNDDimNum || shapes.size() < kMinNDDimNum) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--input_shape's dim", std::to_string(shapes.size()), "Dim num must within [1, 4] when set dynamic_dims"});
GELOGE(ge::PARAM_INVALID, "Dim num must within [%zu, %zu] when set dynamic_dims.", kMinNDDimNum, kMaxNDDimNum);
return false;
}
dynamic_dim += std::count(shapes.begin(), shapes.end(), kDynamicInputDim);
}
if (dynamic_dim == 0) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--input_shape's dynamic dim num", "0", "at least one dim should be -1 when set dynamic_dims"});
GELOGE(ge::PARAM_INVALID, "input_shape's shape is invalid, at least one dim should be -1 when set dynamic_dims.");
return false;
}

if (!CheckAndParseDynamicDims(dynamic_dim, dynamic_dims)) {
GELOGE(ge::PARAM_INVALID, "Check and parse dynamic dims: %s failed.", dynamic_dims.c_str());
return false;
}

return true;
}

bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims) {
EraseEndSemicolon(dynamic_dims);
if (dynamic_dims.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--dynamic_dims", dynamic_dims.c_str(), "dynamic_dims can not be empty"});
GELOGE(ge::PARAM_INVALID, "dynamic_dims can not be empty.");
return false;
}
// Different parameter sets are split by ';'
vector<string> split_set = StringUtils::Split(dynamic_dims, ';');
if (split_set.size() > kMaxDynamicDimNum) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10042", {"parameter", "reason"}, {"dynamic_dims", "dynamic_dims's num of parameter set can not exceed 100"});
GELOGE(ge::PARAM_INVALID, "dynamic_dims's num of parameter set can not exceed %zu.", kMaxDynamicDimNum);
return false;
}
for (auto split_dim : split_set) {
vector<string> one_set = StringUtils::Split(split_dim, ',');
if (one_set.size() != static_cast<size_t>(dynamic_dim_num)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10042", {"parameter", "reason"},
{"dynamic_dims", "Each gear setting needs to be consistent with the number of -1 in the inputshape"});
GELOGE(ge::PARAM_INVALID, "Input parameter --dynamic_dims parse failed, "
"reason: Each gear setting needs to be consistent with the number of -1 in the inputshape.");
return false;
}
for (auto dim : one_set) {
for (auto c : dim) {
if (!isdigit(c)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--dynamic_dims's parameter", dim.c_str(), "must be positive integer"});
GELOGE(ge::PARAM_INVALID, "dynamic_dims's parameter must be positive integer.");
return false;
}
}
}
}
return true;
}

Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims,
const string input_shape, const string input_format, bool &is_dynamic_input) {
int32_t param_size = static_cast<int32_t>(!dynamic_batch_size.empty()) +
static_cast<int32_t>(!dynamic_image_size.empty()) + static_cast<int32_t>(!dynamic_dims.empty());
if (param_size > 1) {
ErrorManager::GetInstance().ATCReportErrMessage("E10009", {"parameter0", "parameter1", "parameter2"},
{"dynamic_batch_size", "dynamic_image_size", "dynamic_dims"});
GELOGE(ge::PARAM_INVALID, "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one");
return ge::PARAM_INVALID;
}

if (param_size == 0) {
return ge::SUCCESS;
}

map<string, vector<int64_t>> shape_map;
vector<pair<string, vector<int64_t>>> user_shape_map;
is_dynamic_input = true;
if (input_shape.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"input_shape"});
GELOGE(ge::PARAM_INVALID, "The input_shape can not be empty in dynamic input size scenario.");
return ge::PARAM_INVALID;
}

if (!ParseInputShape(input_shape, shape_map, user_shape_map, is_dynamic_input)) {
GELOGE(ge::PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str());
return ge::PARAM_INVALID;
}

if (!dynamic_batch_size.empty()) {
if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) {
GELOGE(ge::PARAM_INVALID, "Check dynamic batch size input shape failed: %s", input_shape.c_str());
return ge::PARAM_INVALID;
}
}

if (!dynamic_image_size.empty()) {
if (!CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size)) {
GELOGE(ge::PARAM_INVALID, "Check dynamic image size input shape failed: %s", input_shape.c_str());
return ge::PARAM_INVALID;
}
}

if (!dynamic_dims.empty()) {
if (!CheckDynamicDimsInputShapeValid(shape_map, input_format, dynamic_dims)) {
GELOGE(ge::PARAM_INVALID, "Check dynamic dims: %s of input shape: %s failed.", dynamic_dims.c_str(),
input_shape.c_str());
return ge::PARAM_INVALID;
}
}
return ge::SUCCESS;
}

bool ParseInputShape(const string &input_shape, map<string, vector<int64_t>> &shape_map,
vector<pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input) {
vector<string> shape_vec = StringUtils::Split(input_shape, ';');
const int DEFAULT_SHAPE_PAIR_SIZE = 2;
for (const auto &shape : shape_vec) {
vector<string> shape_pair_vec = SplitInputShape(shape);
if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kSplitError1, kInputShapeSample1});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kSplitError1, kInputShapeSample1);
return false;
}
if (shape_pair_vec[1].empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kEmptyError, kInputShapeSample1});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kEmptyError, kInputShapeSample1);
return false;
}

vector<string> shape_value_strs = StringUtils::Split(shape_pair_vec[1], ',');
vector<int64_t> shape_values;
for (auto &shape_value_str : shape_value_strs) {
// stoul: The method may throw an exception: invalid_argument/out_of_range
if (std::string::npos != shape_value_str.find('.')) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kFloatNumError, kInputShapeSample2});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kFloatNumError, kInputShapeSample2);
return false;
}

long left_result = 0;
try {
left_result = stol(StringUtils::Trim(shape_value_str));
if (!shape_value_str.empty() && (shape_value_str.front() == '-')) {
// The value maybe dynamic shape [-1], need substr it and verify isdigit.
shape_value_str = shape_value_str.substr(1);
}
for (char c : shape_value_str) {
if (!isdigit(c)) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kDigitError, kInputShapeSample2});
GELOGE(PARAM_INVALID, "--input_shape's shape value[%s] is not digit", shape_value_str.c_str());
return false;
}
}
} catch (const std::out_of_range &) {
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"},
{"--input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str());
return false;
} catch (const std::invalid_argument &) {
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"},
{"--input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str());
return false;
} catch (...) {
ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"},
{"--input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str());
return false;
}
int64_t result = left_result;
// - 1 is not currently supported
if (!is_dynamic_input && result <= 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape", "result"}, {shape, std::to_string(result)});
GELOGW(
"Input parameter[--input_shape]’s shape value[%s] is invalid, "
"expect positive integer, but value is %ld.",
shape.c_str(), result);
return false;
}
shape_values.push_back(result);
}

shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values));
user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values));
}

return true;
}

Status CheckOutputTypeParamValid(const std::string output_type) {
if ((!output_type.empty()) && (kOutputTypeSupportDatatype.find(output_type) == kOutputTypeSupportDatatype.end())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--output_type", output_type, kOutputTypeSupport});
GELOGE(ge::PARAM_INVALID,
"Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport);
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckBufferOptimizeParamValid(const std::string buffer_optimize) {
if ((!buffer_optimize.empty()) &&
(kBufferOptimizeSupportOption.find(buffer_optimize) == kBufferOptimizeSupportOption.end())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--buffer_optimize", buffer_optimize, kBufferOptimizeSupport});
GELOGE(ge::PARAM_INVALID,
"Invalid value for --buffer_optimize[%s], %s.", buffer_optimize.c_str(), kBufferOptimizeSupport);
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckCompressWeightParamValid(const std::string enable_compress_weight, const std::string compress_weight_conf) {
if ((!compress_weight_conf.empty()) &&
(!CheckInputPathValid(compress_weight_conf, "--compress_weight_conf"))) {
GELOGE(ge::PARAM_INVALID, "compress weight config file not found, file_name:%s", compress_weight_conf.c_str());
return ge::PARAM_INVALID;
}
if ((enable_compress_weight != "") && (enable_compress_weight != "true") && (enable_compress_weight != "false")) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10005", {"parameter", "value"}, {"enable_compress_weight", enable_compress_weight});
GELOGE(ge::PARAM_INVALID,
"Input parameter[--enable_compress_weight]'s value[%s] must be true or false.", enable_compress_weight.c_str());
return ge::PARAM_INVALID;
}

if ((enable_compress_weight == "true") && (!compress_weight_conf.empty())) {
ErrorManager::GetInstance().ATCReportErrMessage("E10047", {"parameter0", "parameter1"},
{"enable_compress_weight", "compress_weight_conf"});
GELOGE(ge::PARAM_INVALID, "enable_compress_weight and compress_weight_conf can not both exist!!");
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckKeepTypeParamValid(const std::string &keep_dtype) {
if ((!keep_dtype.empty()) && (!CheckInputPathValid(keep_dtype, "--keep_dtype"))) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--keep_dtype", keep_dtype, kKeepDtypeError});
GELOGE(ge::PARAM_INVALID, "keep dtype config file not found, file_name:%s", keep_dtype.c_str());
return ge::PARAM_INVALID;
}

return ge::SUCCESS;
}

int CheckLogParamValidAndSetLogLevel(const std::string log) {
int ret = -1;
if (log == "default") {
ret = 0;
} else if (log == "null") {
ret = dlog_setlevel(-1, DLOG_NULL, 0);
} else if (log == "debug") {
ret = dlog_setlevel(-1, DLOG_DEBUG, 1);
} else if (log == "info") {
ret = dlog_setlevel(-1, DLOG_INFO, 1);
} else if (log == "warning") {
ret = dlog_setlevel(-1, DLOG_WARN, 1);
} else if (log == "error") {
ret = dlog_setlevel(-1, DLOG_ERROR, 1);
} else {
GELOGE(ge::PARAM_INVALID, "invalid value for log:%s, only support debug, info, warning, error, null", log.c_str());
return ret;
}
if (ret != 0) {
GELOGE(ge::PARAM_INVALID, "Log setlevel fail !");
}
return ret;
}

Status CheckInsertOpConfParamValid(const std::string insert_op_conf) {
if ((!insert_op_conf.empty()) &&
(!CheckInputPathValid(insert_op_conf, "--insert_op_conf"))) {
GELOGE(ge::PARAM_INVALID, "insert op config file not found: %s", insert_op_conf.c_str());
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory) {
if ((disable_reuse_memory != "") && (disable_reuse_memory != "0") && (disable_reuse_memory != "1")) {
ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"disable_reuse_memory"});
GELOGE(ge::PARAM_INVALID, "Input parameter[--disable_reuse_memory]'s value must be 1 or 0.");
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream) {
if ((enable_single_stream != "") && (enable_single_stream != "true") && (enable_single_stream != "false")) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10005", {"parameter", "value"}, {"enable_single_stream", enable_single_stream});
GELOGE(ge::PARAM_INVALID, "Input parameter[--enable_single_stream]'s value[%s] must be true or false.",
enable_single_stream.c_str());
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode) {
// only appointed op_select_implmode, can user appoint optypelist_for_implmode
if (optypelist_for_implmode != "" && op_select_implmode == "") {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--op_select_implmode", op_select_implmode.c_str(), kCompressWeightError});
GELOGE(ge::PARAM_INVALID, "Invalid value for --op_select_implmode[%s], %s.",
op_select_implmode.c_str(), kCompressWeightError);
return ge::PARAM_INVALID;
}
// op_select_implmode default value is high_performance
if (op_select_implmode == "") {
op_select_implmode = IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT;
} else {
if (op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT &&
op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--op_select_implmode", op_select_implmode.c_str(), kSelectImplmodeError});
GELOGE(ge::PARAM_INVALID, "Invalid value for --op_select_implmode[%s], %s.",
op_select_implmode.c_str(), kSelectImplmodeError);
return ge::PARAM_INVALID;
}
}

return ge::SUCCESS;
}

void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips) {
for (auto iter = options.begin(); iter != options.end(); iter++) {
std::string key = iter->first;
std::string option_name = iter->second;
GELOGD("%s set successfully, option_key=%s, option_value=%s", tips.c_str(), key.c_str(), option_name.c_str());
}
}

void EraseEndSemicolon(string &param) {
if (param.empty()) {
return;
}
if (param.back() == ';') {
param.erase(param.end() - 1);
}
}
} // namespace ge

+ 82
- 0
atc/atc_ir_common.h View File

@@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ATC_IR_COMMON_H_
#define ATC_IR_COMMON_H_

#include <unistd.h>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <set>

#include "framework/common/debug/ge_log.h"
#include "register/register_types.h"

using std::string;
using std::vector;
using std::map;

namespace ge {
static std::set<std::string> caffe_support_input_format = {"NCHW", "ND"};
static std::set<std::string> tf_support_input_format = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"};
static std::set<std::string> onnx_support_input_format = {"NCHW", "ND"};

static std::map<std::string, domi::domiTensorFormat_t> input_format_str_to_geformat = {
{"ND", domi::DOMI_TENSOR_ND},
{"NCHW", domi::DOMI_TENSOR_NCHW},
{"NHWC", domi::DOMI_TENSOR_NHWC},
{"CHWN", domi::DOMI_TENSOR_CHWN},
{"NC1HWC0", domi::DOMI_TENSOR_NC1HWC0},
{"NHWC1C0", domi::DOMI_TENSOR_NHWC1C0},
{"NCDHW", domi::DOMI_TENSOR_NCDHW},
{"NDHWC", domi::DOMI_TENSOR_NDHWC}
};
static const std::string kEnableCompressWeightTrue = "1";
static const std::string kEnableCompressWeightFalse = "0";

bool CheckDynamicBatchSizeInputShapeValid(map<string, vector<int64_t>> shape_map,
std::string &dynamic_batch_size);

bool CheckDynamicImagesizeInputShapeValid(map<string, vector<int64_t>> shape_map,
const std::string input_format, std::string &dynamic_image_size);

bool CheckDynamicDimsInputShapeValid(const std::map<std::string, std::vector<int64_t>> &shape_map,
std::string input_format, std::string &dynamic_dims);

bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims);

Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size,
std::string &dynamic_dims, const std::string input_shape,
const std::string input_format, bool &is_dynamic_input);

bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map,
std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false);

Status CheckOutputTypeParamValid(const std::string output_type);
Status CheckBufferOptimizeParamValid(const std::string buffer_optimize);
Status CheckCompressWeightParamValid(const std::string enable_compress_weight, const std::string compress_weight_conf);
int CheckLogParamValidAndSetLogLevel(const std::string log);
Status CheckInsertOpConfParamValid(const std::string insert_op_conf);
Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory);
Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream);
Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode);
Status CheckInputFormat(const string &input_format);
Status CheckKeepTypeParamValid(const std::string &keep_dtype);
void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips);
void EraseEndSemicolon(std::string &param);
}
#endif // ATC_IR_COMMON_H_

+ 23
- 0
atc/common/types.cc View File

@@ -0,0 +1,23 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "types.h"

namespace ge{
const char *DATA = "Data";
const char *NETOUTPUT = "NetOutput";
const std::string OP_CONF_DELIMITER = ":";
const std::string MODEL_ATTR_FUSION_MODEL_DEF = "fm";
} // namespace ge

+ 32
- 0
atc/common/types.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMMON_TYPES_H_
#define COMMON_TYPES_H_

#include <stdint.h>
#include <string>

#include "register/register_types.h"

namespace ge {
extern const char *DATA;
extern const char *NETOUTPUT;
extern const std::string OP_CONF_DELIMITER;
extern const std::string MODEL_ATTR_FUSION_MODEL_DEF;
} // namespace ge

#endif // COMMON_TYPES_H_

+ 1428
- 0
atc/main.cc
File diff suppressed because it is too large
View File


+ 1004
- 0
atc/parse_graph.cc
File diff suppressed because it is too large
View File


+ 111
- 0
atc/parse_graph.h View File

@@ -0,0 +1,111 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSE_GRAPH_H_
#define PARSE_GRAPH_H_

#include <string>
#include <unordered_map>
#include <vector>
#include "framework/omg/omg_inner_types.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "proto/ge_ir.pb.h"
#include "proto/om.pb.h"

#include "graph/compute_graph.h"
#include "graph/graph.h"
#include "graph/model.h"

using domi::Status;
using std::pair;
using std::string;
using std::unordered_map;
using std::vector;

namespace ge {
/**
* @ingroup domi_omg
* @brief init omg context
* @return void
*/
GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format,
const string &net_format, bool is_dynamic_input);

/**
* @ingroup domi_omg
* @brief generate graph based on the input model file and weight file
* @param [out] graph graph
* @param [in] model_file path of model file
* @param [in] weights_file path of weight file
* @param [in] type type of the input model
* @param [in] op_conf op mapping configuration
* @param [in] target type of platform. If a tiny model is generated, set target to tiny
* @param [in] run_mode run model
* @param [in] enable_l2dynamic enable l2dynamic
* @param [in] is_dynamic_input dynamic input, true of false
* @param [in] atc_params multiply atc params
* @return Status result code
*/
GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params,
const char *model_file, const char *weights_file, domi::FrameworkType type,
const char *op_conf = nullptr, const char *target = nullptr,
RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false);

/**
* @ingroup domi_omg
* @brief generates a simplified JSON file based on the key value of the offline model file in protobuf format
* @param [in] model_file path of offline model file
* @param [out] json_file path of json file
* @param [key] encrypted key
* @return Status result code
*/
GE_FUNC_VISIBILITY Status ConvertOm(const char *model_file, const char *json_file, bool is_covert_to_json);

GE_FUNC_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const char *json_file);
/**
* @ingroup domi_omg
* @brief convert the model file in protobuf format into a JSON file.
* @param [in] framework type of model
* @param [in] om model_file path of offline model file
* @param [out] json_file path of json file
* @param [key] encrypted key
* @return Status result code
*/
GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file,
const char *json_file);

GE_FUNC_VISIBILITY void GetGroupName(ge::proto::ModelDef &model);

GE_FUNC_VISIBILITY void FindParserSo(const string &path, vector<string> &fileList, string &caffe_parser_path);

GE_FUNC_VISIBILITY Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file);

GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type,
const std::string &output_format);

GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);

GE_FUNC_VISIBILITY void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);

GE_FUNC_VISIBILITY void UpdateOmgCtxWithParserCtx();

GE_FUNC_VISIBILITY void UpdateParserCtxWithOmgCtx();

GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def, uint32_t modeldef_size);
} // namespace ge
#endif // PARSE_GRAPH_H_

+ 609
- 0
atc/single_op_parser.cc View File

@@ -0,0 +1,609 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "single_op_parser.h"

#include <vector>
#include <algorithm>
#include <fstream>
#include <sstream>

#include <nlohmann/json.hpp>

#include "framework/common/debug/ge_log.h"
#include "common/util/error_manager/error_manager.h"
#include "util/util.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/operator_factory_impl.h"

using Json = nlohmann::json;
using std::string;
using std::vector;
using std::map;

namespace ge {
namespace {
constexpr char const *kKeyOp = "op";
constexpr char const *kKeyInputDesc = "input_desc";
constexpr char const *kKeyOutputDesc = "output_desc";
constexpr char const *kKeyAttr = "attr";
constexpr char const *kKeyName = "name";
constexpr char const *kKeyType = "type";
constexpr char const *kKeyShape = "shape";
constexpr char const *kKeyOriginShape = "origin_shape";
constexpr char const *kKeyShapeRange = "shape_range";
constexpr char const *kKeyValue = "value";
constexpr char const *kKeyFormat = "format";
constexpr char const *kKeyOriginFormat = "origin_format";
constexpr char const *kFileSuffix = ".om";
constexpr char const *kKeyDynamicInput = "dynamic_input";
constexpr char const *kKeyDynamicOutput = "dynamic_output";
constexpr int kDumpJsonIndent = 2;
constexpr int kShapeRangePairSize = 2;
constexpr int kShapeRangeLow = 0;
constexpr int kShapeRangeHigh = 1;
constexpr int kMaxFileNameLen = 128;

map<string, GeAttrValue::ValueType> kAttrTypeDict = {
{"bool", GeAttrValue::VT_BOOL},
{"int", GeAttrValue::VT_INT},
{"float", GeAttrValue::VT_FLOAT},
{"string", GeAttrValue::VT_STRING},
{"list_bool", GeAttrValue::VT_LIST_BOOL},
{"list_int", GeAttrValue::VT_LIST_INT},
{"list_float", GeAttrValue::VT_LIST_FLOAT},
{"list_string", GeAttrValue::VT_LIST_STRING},
{"list_list_int", GeAttrValue::VT_LIST_LIST_INT},
{"data_type", GeAttrValue::VT_DATA_TYPE},
};

map<string, DataType> kDataTypeDict = {
{"bool", DT_BOOL},
{"int8", DT_INT8},
{"uint8", DT_UINT8},
{"int16", DT_INT16},
{"uint16", DT_UINT16},
{"int32", DT_INT32},
{"uint32", DT_UINT32},
{"int64", DT_INT64},
{"uint64", DT_UINT64},
{"float16", DT_FLOAT16},
{"half", DT_FLOAT16},
{"fp16", DT_FLOAT16},
{"float", DT_FLOAT},
{"float32", DT_FLOAT},
{"double", DT_DOUBLE},
};

map<string, Format> kFormatDict = {
{"nchw", FORMAT_NCHW},
{"nhwc", FORMAT_NHWC},
{"nd", FORMAT_ND},
{"nc1hwc0", FORMAT_NC1HWC0},
{"fractal_z", FORMAT_FRACTAL_Z},
{"nc1c0hwpad", FORMAT_NC1C0HWPAD},
{"nhwc1c0", FORMAT_NHWC1C0},
{"fsr_nchw", FORMAT_FSR_NCHW},
{"fractal_deconv", FORMAT_FRACTAL_DECONV},
{"c1hwnc0", FORMAT_C1HWNC0},
{"fractal_deconv_transpose", FORMAT_FRACTAL_DECONV_TRANSPOSE},
{"fractal_deconv_sp_stride_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS},
{"nc1hwc0_c04", FORMAT_NC1HWC0_C04},
{"fractal_z_c04", FORMAT_FRACTAL_Z_C04},
{"chwn", FORMAT_CHWN},
{"deconv_sp_stride8_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS},
{"nc1khkwhwc0", FORMAT_NC1KHKWHWC0},
{"bn_weight", FORMAT_BN_WEIGHT},
{"filter_hwck", FORMAT_FILTER_HWCK},
{"hwcn", FORMAT_HWCN},
{"lookup_lookups", FORMAT_HASHTABLE_LOOKUP_LOOKUPS},
{"lookup_keys", FORMAT_HASHTABLE_LOOKUP_KEYS},
{"lookup_value", FORMAT_HASHTABLE_LOOKUP_VALUE},
{"lookup_output", FORMAT_HASHTABLE_LOOKUP_OUTPUT},
{"lookup_hits", FORMAT_HASHTABLE_LOOKUP_HITS},
{"md", FORMAT_MD},
{"c1hwncoc0", FORMAT_C1HWNCoC0},
{"fractal_nz", FORMAT_FRACTAL_NZ},
{"ndhwc", FORMAT_NDHWC},
{"ncdhw", FORMAT_NCDHW},
{"dhwcn", FORMAT_DHWCN},
{"dhwnc", FORMAT_DHWNC},
{"ndc1hwc0", FORMAT_NDC1HWC0},
{"fractal_z_3d", FORMAT_FRACTAL_Z_3D},
{"fractal_z_3d_transpose", FORMAT_FRACTAL_Z_3D_TRANSPOSE},
{"cn", FORMAT_CN},
{"nc", FORMAT_NC},
{"fractal_zn_lstm", FORMAT_FRACTAL_ZN_LSTM},
{"fractal_z_g", FORMAT_FRACTAL_Z_G}
};

std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) {
std::stringstream file_name_ss;
file_name_ss << index;
file_name_ss << "_" << single_op_desc.op;
for (auto &desc : single_op_desc.input_desc) {
file_name_ss << "_" << desc.type << "_" << desc.format;
for (auto dim : desc.dims) {
file_name_ss << "_" << dim;
}
}

for (auto &desc : single_op_desc.output_desc) {
file_name_ss << "_" << desc.type << "_" << desc.format;
for (auto dim : desc.dims) {
file_name_ss << "_" << dim;
}
}

std::string file_name = file_name_ss.str();
if (file_name.length() > kMaxFileNameLen) {
GELOGI("Trim file name for it is too long, origin file name = %s", file_name.c_str());
file_name = file_name.substr(0, kMaxFileNameLen);
}
file_name += kFileSuffix;
return file_name;
}
} // namespace

template<typename T>
void SetAttrValue(const Json &j, SingleOpAttr &attr) {
attr.value.SetValue<T>(j.at(kKeyValue).get<T>());
}

template<typename T>
T GetValue(const map<string, T> &dict, string &key, T default_val) {
transform(key.begin(), key.end(), key.begin(), ::tolower);
auto it = dict.find(key);
if (it == dict.end()) {
return default_val;
}

return it->second;
}

void from_json(const Json &j, SingleOpTensorDesc &desc) {
bool is_tensor_valid = true;
desc.dims = j.at(kKeyShape).get<vector<int64_t>>();
auto it = j.find(kKeyShapeRange);
if (it != j.end()) {
desc.dim_ranges = j.at(kKeyShapeRange).get<vector<std::vector<int64_t>>>();
}
it = j.find(kKeyOriginShape);
if (it != j.end()) {
desc.ori_dims = j.at(kKeyOriginShape).get<vector<int64_t>>();
}
string format_str = j.at(kKeyFormat).get<string>();
string type_str = j.at(kKeyType).get<string>();
desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED);
desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED);
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(format_str);
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsDataTypeValid(type_str);
it = j.find(kKeyOriginFormat);
if (it != j.end()) {
string origin_format_str = j.at(kKeyOriginFormat).get<string>();
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(origin_format_str);
desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED);
}
auto tensor_name = j.find(kKeyName);
if (tensor_name != j.end()) {
desc.name = tensor_name->get<string>();
}
auto dynamic_input_name = j.find(kKeyDynamicInput);
if (dynamic_input_name != j.end()) {
desc.dynamic_input_name = dynamic_input_name->get<string>();
}
if (!is_tensor_valid) {
desc.SetValidFlag(is_tensor_valid);
}
}

void from_json(const Json &j, SingleOpAttr &attr) {
attr.name = j.at(kKeyName).get<string>();
attr.type = j.at(kKeyType).get<string>();
auto it = kAttrTypeDict.find(attr.type);
if (it == kAttrTypeDict.end()) {
GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str());
return;
}

switch (it->second) {
case GeAttrValue::VT_BOOL:
SetAttrValue<bool>(j, attr);
break;
case GeAttrValue::VT_INT:
SetAttrValue<int64_t>(j, attr);
break;
case GeAttrValue::VT_FLOAT:
SetAttrValue<float>(j, attr);
break;
case GeAttrValue::VT_STRING:
SetAttrValue<string>(j, attr);
break;
case GeAttrValue::VT_LIST_BOOL:
SetAttrValue<vector<bool>>(j, attr);
break;
case GeAttrValue::VT_LIST_INT:
SetAttrValue<vector<int64_t>>(j, attr);
break;
case GeAttrValue::VT_LIST_FLOAT:
SetAttrValue<vector<float>>(j, attr);
break;
case GeAttrValue::VT_LIST_STRING:
SetAttrValue<vector<string>>(j, attr);
break;
case GeAttrValue::VT_LIST_LIST_INT:
SetAttrValue<vector<vector<int64_t>>>(j, attr);
break;
case GeAttrValue::VT_DATA_TYPE:
SetAttrValue<DataType>(j, attr);
break;
default:
GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str());
break;
}
}

void from_json(const Json &j, SingleOpDesc &desc) {
desc.op = j.at(kKeyOp).get<string>();

auto input_desc = j.find(kKeyInputDesc);
if (input_desc != j.end()) {
desc.input_desc = input_desc->get<vector<SingleOpTensorDesc>>();
}

auto output_desc = j.find(kKeyOutputDesc);
if (output_desc != j.end()) {
desc.output_desc = output_desc->get<vector<SingleOpTensorDesc>>();
}

auto attr_field = j.find(kKeyAttr);
if (attr_field != j.end()) {
desc.attrs = attr_field->get<vector<SingleOpAttr>>();
}
}

Status SingleOpParser::ReadJsonFile(const std::string &file, Json &json_obj) {
std::string real_path = RealPath(file.c_str());
if (real_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10023", {"value"}, {file});
GELOGE(FAILED, "Input parameter[--singleop]'s value[%s] is not a valid path.", file.c_str());
return INTERNAL_ERROR;
}

std::ifstream ifs(real_path);
if (!ifs.is_open()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10024", {"value"}, {file});
GELOGE(FAILED, "Open file[%s] provided in input parameter[--singleop] failed.", file.c_str());
return FAILED;
}
try {
ifs >> json_obj;
} catch (const std::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E10025", {"realpath", "errmsg"}, {real_path, e.what()});
GELOGE(PARAM_INVALID, "Parse file[%s] provided in input parameter[--singleop] failed, exception = %s.",
real_path.c_str(), e.what());
return PARAM_INVALID;
}

ifs.close();
return SUCCESS;
}

bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {
if (op_desc.op.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10026");
GELOGE(PARAM_INVALID, "Op name is empty");
return false;
}

int index = 0;
for (auto &tensor_desc : op_desc.input_desc) {
if (!tensor_desc.GetValidFlag()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"intput", "datatype or format", std::to_string(index)});
GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index);
return false;
}
if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) ||
(tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"intput", "datatype or format", std::to_string(index)});
GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index);
return false;
}
++index;
}

index = 0;
for (auto &tensor_desc : op_desc.output_desc) {
if (!tensor_desc.GetValidFlag()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"output", "datatype", std::to_string(index)});
GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index);
return false;
}
if (tensor_desc.type == DT_UNDEFINED) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"output", "datatype", std::to_string(index)});
GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index);
return false;
}

if (tensor_desc.format == FORMAT_RESERVED) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"output", "format", std::to_string(index)});
GELOGE(PARAM_INVALID, "Output's format is invalid when the index is %d", index);
return false;
}
++index;
}

for (auto &attr : op_desc.attrs) {
if (attr.name.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10029");
GELOGE(PARAM_INVALID, "attr name is empty");
return false;
}

if (attr.value.IsEmpty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10030", {"attrname"}, {attr.name});
GELOGE(PARAM_INVALID, "Parse attr \"%s\" failed. ", attr.name.c_str());
return false;
}
}

return true;
}

std::unique_ptr<OpDesc> SingleOpParser::CreateOpDesc(const string &op_type) {
return std::unique_ptr<OpDesc>(new(std::nothrow) OpDesc(op_type, op_type));
}

Status SingleOpParser::UpdateDynamicTensorName(std::vector<SingleOpTensorDesc> &desc) {
std::map<std::string, int> dynamic_name_map;
for (auto &tensor : desc) {
if (tensor.dynamic_input_name.empty()) {
continue;
}
if (dynamic_name_map.find(tensor.dynamic_input_name) == dynamic_name_map.end()) {
dynamic_name_map[tensor.dynamic_input_name] = 0;
} else {
dynamic_name_map[tensor.dynamic_input_name]++;
}
tensor.name = tensor.dynamic_input_name + std::to_string(dynamic_name_map[tensor.dynamic_input_name]);
}
GELOGD("Update dynamic tensor name success!");
return SUCCESS;
}

Status SingleOpParser::ConvertToBuildParam(int index,
const SingleOpDesc &single_op_desc,
SingleOpBuildParam &build_param) {
auto op_desc = CreateOpDesc(single_op_desc.op);
GE_CHECK_NOTNULL(op_desc);

for (auto &desc : single_op_desc.input_desc) {
GeTensorDesc ge_tensor_desc(GeShape(desc.dims),
desc.format,
desc.type);
auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format;
auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims;
ge_tensor_desc.SetOriginFormat(ori_format_to_set);
ge_tensor_desc.SetOriginShape(GeShape(ori_dims));
GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc));
TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size());
TensorUtils::SetInputTensor(ge_tensor_desc, true);
TensorUtils::SetOutputTensor(ge_tensor_desc, false);
if (desc.name.empty()) {
op_desc->AddInputDesc(ge_tensor_desc);
} else {
op_desc->AddInputDesc(desc.name, ge_tensor_desc);
}
build_param.inputs.emplace_back(ge_tensor_desc);
}

for (auto &desc : single_op_desc.output_desc) {
GeTensorDesc ge_tensor_desc(GeShape(desc.dims),
desc.format,
desc.type);
auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format;
auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims;
ge_tensor_desc.SetOriginFormat(ori_format_to_set);
ge_tensor_desc.SetOriginShape(GeShape(ori_dims));
GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc));
TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size());
TensorUtils::SetInputTensor(ge_tensor_desc, false);
TensorUtils::SetOutputTensor(ge_tensor_desc, true);
if (desc.name.empty()) {
op_desc->AddOutputDesc(ge_tensor_desc);
} else {
op_desc->AddOutputDesc(desc.name, ge_tensor_desc);
}
build_param.outputs.emplace_back(ge_tensor_desc);
}

for (const auto &attr : single_op_desc.attrs) {
op_desc->SetAttr(attr.name, attr.value);
}

if (VerifyOpInputOutputSizeByIr(*op_desc) != SUCCESS) {
GELOGE(PARAM_INVALID, "Verify op [%s] input or output size failed.", op_desc->GetType().c_str());
return PARAM_INVALID;
}

build_param.file_name = GenerateFileName(single_op_desc, index);
build_param.op_desc.reset(op_desc.release());
return SUCCESS;
}

Status SingleOpParser::VerifyOpInputOutputSizeByIr(const OpDesc &current_op_desc) {
ge::Operator operator_ir = ge::OperatorFactory::CreateOperator("tmp_operator", current_op_desc.GetType());
if (!operator_ir.IsEmpty()) {
auto opdesc_ir = ge::OpDescUtils::GetOpDescFromOperator(operator_ir);
GE_CHECK_NOTNULL(opdesc_ir);
size_t current_opdesc_inputs_num = current_op_desc.GetInputsSize();
size_t ir_opdesc_inputs_num = opdesc_ir->GetInputsSize();
if (current_opdesc_inputs_num < ir_opdesc_inputs_num) {
string reason = "is smaller than the ir needed input size " + std::to_string(ir_opdesc_inputs_num);
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{current_op_desc.GetName(), "input size " + std::to_string(current_opdesc_inputs_num), reason});
GELOGE(PARAM_INVALID, "This op [%s] input size %zu is smaller than the ir needed input size %zu",
current_op_desc.GetName().c_str(), current_opdesc_inputs_num, ir_opdesc_inputs_num);
return PARAM_INVALID;
}
size_t current_opdesc_outputs_num = current_op_desc.GetOutputsSize();
size_t ir_opdesc_outputs_num = opdesc_ir->GetOutputsSize();
if (current_opdesc_outputs_num < ir_opdesc_outputs_num) {
string reason = "is smaller than the ir needed output size " + std::to_string(ir_opdesc_outputs_num);
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{current_op_desc.GetName(), "output size " + std::to_string(current_opdesc_outputs_num), reason});
GELOGE(PARAM_INVALID, "This op [%s] output size %zu is smaller than the ir needed output size %zu",
current_op_desc.GetName().c_str(), current_opdesc_outputs_num, ir_opdesc_outputs_num);
return PARAM_INVALID;
}
}
return SUCCESS;
}

Status SingleOpParser::SetShapeRange(const std::string &op_name,
const SingleOpTensorDesc &tensor_desc,
GeTensorDesc &ge_tensor_desc) {
auto num_shape_ranges = tensor_desc.dim_ranges.size();
GELOGD("Number of shape ranges = %zu", num_shape_ranges);
auto it = std::find(tensor_desc.dims.begin(), tensor_desc.dims.end(), ge::UNKNOWN_DIM_NUM);
if (it != tensor_desc.dims.end()) {
if (tensor_desc.dims != ge::UNKNOWN_RANK) {
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{op_name,
"shape",
"has unknown rank but dim size is not one"});
GELOGE(PARAM_INVALID, "Invalid tensor shape: [%s]", ge_tensor_desc.MutableShape().ToString().c_str());
return PARAM_INVALID;
}
if (!tensor_desc.dim_ranges.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{op_name,
"shape range",
"is not needed while the rank the shape is unknown"});
GELOGE(PARAM_INVALID, "Shape range is not needed while the rank the shape is unknown");
return PARAM_INVALID;
}

GELOGD("Shape is unknown rank, do not set shape range");
return SUCCESS;
}

std::vector<std::pair<int64_t, int64_t>> shape_range;
size_t range_index = 0;
for (auto dim : tensor_desc.dims) {
if (dim >= 0) {
shape_range.emplace_back(dim, dim);
GELOGD("Adding shape range: [%ld, %ld]", dim, dim);
} else {
GELOGD("To get shape range by index = %zu", range_index);
if (range_index >= num_shape_ranges) {
string reason = "is smaller than the unknown dim size " + std::to_string(++range_index);
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{op_name,
"shape range size " + std::to_string(num_shape_ranges),
reason});
GELOGE(PARAM_INVALID, "The number of shape_range mismatches that of unknown dims.");
return PARAM_INVALID;
}

auto &range = tensor_desc.dim_ranges[range_index];
if (range.size() != kShapeRangePairSize) {
string reason = "has " + std::to_string(range.size()) + " item(s)";
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{op_name,
"shape range " + std::to_string(range_index),
reason});
GELOGE(PARAM_INVALID, "Invalid shape range entry. index = %zu, size = %zu", range_index, range.size());
return PARAM_INVALID;
}

shape_range.emplace_back(range[kShapeRangeLow], range[kShapeRangeHigh]);
GELOGD("Adding shape range: [%ld, %ld]", range[kShapeRangeLow], range[kShapeRangeHigh]);
++range_index;
}
}

if (num_shape_ranges != range_index) {
string reason = "is greater than the unknown dim size " + std::to_string(range_index);
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{op_name,
"shape range size " + std::to_string(num_shape_ranges),
reason});
GELOGE(PARAM_INVALID,
"The number of shape_range(%zu) mismatches that of unknown dims(%zu).",
num_shape_ranges,
range_index);
return PARAM_INVALID;
}

if (range_index > 0) {
ge_tensor_desc.SetShapeRange(shape_range);
}

return SUCCESS;
}

Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector<SingleOpBuildParam> &op_list) {
int index = 0;
try {
Json single_op_list_json;
auto ret = ReadJsonFile(file, single_op_list_json);
if (ret != SUCCESS) {
return ret;
}

for (const Json &single_op_json : single_op_list_json) {
SingleOpDesc single_op_desc;
GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str());
single_op_desc = single_op_json;
if (UpdateDynamicTensorName(single_op_desc.input_desc) != SUCCESS) {
GELOGE(FAILED, "Update dynamic tensor name failed!");
return FAILED;
}

if (!Validate(single_op_desc)) {
GELOGE(PARAM_INVALID, "Validate the index[%d] of op failed when read json file[%s].", index, file.c_str());
return PARAM_INVALID;
}

SingleOpBuildParam param;
ret = ConvertToBuildParam(index, single_op_desc, param);
if (ret != SUCCESS) {
return ret;
}

op_list.emplace_back(param);
GELOGI("Parse the index[%d] of op success", index);
index += 1;
}
} catch (const nlohmann::json::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E10032", {"index", "jsonfile", "exception"},
{std::to_string(index), file, e.what()});
GELOGE(PARAM_INVALID, "Parse the index[%d] of op failed when read json file[%s], exception %s",
index, file.c_str(), e.what());
return PARAM_INVALID;
}

return SUCCESS;
}
} // namespace ge


+ 90
- 0
atc/single_op_parser.h View File

@@ -0,0 +1,90 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_TOOLS_COMPILE_PARSER_H
#define ACL_TOOLS_COMPILE_PARSER_H

#include <vector>
#include <string>

#include <nlohmann/json.hpp>

#include "ge/ge_api_error_codes.h"
#include "graph/types.h"
#include "graph/ge_attr_value.h"
#include "graph/op_desc.h"

namespace ge {
struct SingleOpTensorDesc {
public:
bool GetValidFlag() const { return is_valid_; }
void SetValidFlag(bool is_valid) { is_valid_ = is_valid; }
public:
std::string name;
std::vector<int64_t> dims;
std::vector<int64_t> ori_dims;
std::vector<std::vector<int64_t>> dim_ranges;
ge::Format format = ge::FORMAT_RESERVED;
ge::Format ori_format = ge::FORMAT_RESERVED;
ge::DataType type = ge::DT_UNDEFINED;
std::string dynamic_input_name;
private:
bool is_valid_ = true;
};

struct SingleOpAttr {
std::string name;
std::string type;
ge::GeAttrValue value;
};

struct SingleOpDesc {
std::string op;
std::vector<SingleOpTensorDesc> input_desc;
std::vector<SingleOpTensorDesc> output_desc;
std::vector<SingleOpAttr> attrs;
};

struct SingleOpBuildParam {
ge::OpDescPtr op_desc;
std::vector<ge::GeTensor> inputs;
std::vector<ge::GeTensor> outputs;
std::string file_name;
};

void from_json(const nlohmann::json &json, SingleOpTensorDesc &desc);

void from_json(const nlohmann::json &json, SingleOpAttr &desc);

void from_json(const nlohmann::json &json, SingleOpDesc &desc);

class SingleOpParser {
public:
static Status ParseSingleOpList(const std::string &file, std::vector<SingleOpBuildParam> &op_list);

private:
static Status ReadJsonFile(const std::string &file, nlohmann::json &json_obj);
static bool Validate(const SingleOpDesc &op_desc);
static std::unique_ptr<OpDesc> CreateOpDesc(const std::string &op_type);
static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param);
static Status UpdateDynamicTensorName(std::vector<SingleOpTensorDesc> &desc);
static Status VerifyOpInputOutputSizeByIr(const OpDesc &current_op_desc);
static Status SetShapeRange(const std::string &op_name,
const SingleOpTensorDesc &tensor_desc,
GeTensorDesc &ge_tensor_desc);
};
} // namespace ge

#endif // ACL_TOOLS_COMPILE_PARSER_H

+ 85
- 0
atc/util/gflags_util.h View File

@@ -0,0 +1,85 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef UTIL_GFLAGS_UTIL_H_
#define UTIL_GFLAGS_UTIL_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY _declspec(dllexport)
#else
#define GE_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_VISIBILITY
#endif
#endif

#include <gflags/gflags.h>
#include <string>

namespace ge {
class GE_FUNC_VISIBILITY GflagsUtils {
public:
static bool IsSetCommandTrue(const char *name) {
std::string out;
return gflags::GetCommandLineOption(name, &out) && out == "true";
}

///
/// @brief Determines whether the parameter is empty
/// @param name name parameter name
/// @return true if empty otherwise false
///
static bool IsSetCommandNotEmpty(const char *name) {
std::string out;
return gflags::GetCommandLineOption(name, &out) && !out.empty();
}

///
/// @brief Determines whether the parameter is not default
/// @param flag_name name parameter name
/// @return true if not default otherwise false
///
static bool IsCommandLineNotDefault(const char *flag_name) {
google::CommandLineFlagInfo info;
return GetCommandLineFlagInfo(flag_name, &info) && !info.is_default;
}

///
/// @brief Modify gflags to print help information
/// @param flags_h Pass in the self-defined help parameter, it is recommended to be FLAGS_h
/// @return void
///
static void ChangeHelpFlags(bool flags_h) {
if (flags_h || IsSetCommandTrue("help") || IsSetCommandTrue("helpfull") || IsSetCommandNotEmpty("helpon") ||
IsSetCommandNotEmpty("helpmatch") || IsSetCommandTrue("helppackage") || IsSetCommandTrue("helpxml")) {
gflags::SetCommandLineOption("help", "false");
gflags::SetCommandLineOption("helpfull", "false");
gflags::SetCommandLineOption("helpon", "");
gflags::SetCommandLineOption("helpmatch", "");
gflags::SetCommandLineOption("helppackage", "false");
gflags::SetCommandLineOption("helpxml", "false");
gflags::SetCommandLineOption("helpshort", "true");
}
}
};
} // namespace ge

#endif // UTIL_GFLAGS_UTIL_H_

+ 145
- 0
atc/util/properties_manager.cc View File

@@ -0,0 +1,145 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "properties_manager.h"

#include <climits>
#include <cstdio>
#include <fstream>

#include "framework/common/debug/ge_log.h"
#include "util/util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_context.h"
#include "graph/utils/attr_utils.h"

namespace ge {
namespace atc {
PropertiesManager::PropertiesManager() : is_inited_(false), delimiter("=") {}
PropertiesManager::~PropertiesManager() {}

// singleton
PropertiesManager &PropertiesManager::Instance() {
static PropertiesManager instance;
return instance;
}

// Initialize property configuration
bool PropertiesManager::Init(const std::string &file_path) {
std::lock_guard<std::mutex> lock(mutex_);
if (is_inited_) {
GELOGW("Already inited, will be initialized again");
properties_map_.clear();
is_inited_ = false;
return is_inited_;
}

if (!LoadFileContent(file_path)) {
return false;
}

is_inited_ = true;
return is_inited_;
}

// Load file contents
bool PropertiesManager::LoadFileContent(const std::string &file_path) {
// Normalize the path
string resolved_file_path = RealPath(file_path.c_str());
if (resolved_file_path.empty()) {
GELOGE(FAILED, "Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str());
return false;
}
std::ifstream fs(resolved_file_path, std::ifstream::in);

if (!fs.is_open()) {
GELOGE(PARAM_INVALID, "Open %s failed.", file_path.c_str());
return false;
}

std::string line;

while (getline(fs, line)) { // line not with \n
if (!ParseLine(line)) {
GELOGE(PARAM_INVALID, "Parse line failed. content is [%s].", line.c_str());
fs.close();
return false;
}
}

fs.close(); // close the file

GELOGI("LoadFileContent success.");
return true;
}

// Parsing the command line
bool PropertiesManager::ParseLine(const std::string &line) {
std::string temp = Trim(line);
// Comment or newline returns true directly
if (temp.find_first_of('#') == 0 || *(temp.c_str()) == '\n') {
return true;
}

if (!temp.empty()) {
std::string::size_type pos = temp.find_first_of(delimiter);
if (pos == std::string::npos) {
GELOGE(PARAM_INVALID, "Incorrect line [%s], it must include [%s].Perhaps you use illegal chinese symbol",
line.c_str(), delimiter.c_str());
return false;
}

std::string map_key = Trim(temp.substr(0, pos));
std::string value = Trim(temp.substr(pos + 1));
if (map_key.empty() || value.empty()) {
GELOGE(PARAM_INVALID, "Map_key or value empty. %s", line.c_str());
return false;
}

properties_map_[map_key] = value;
}

return true;
}

// Remove the space and tab before and after the string
std::string PropertiesManager::Trim(const std::string &str) {
if (str.empty()) {
return str;
}

std::string::size_type start = str.find_first_not_of(" \t\r\n");
if (start == std::string::npos) {
return str;
}

std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1;
return str.substr(start, end);
}

// return properties_map_
std::map<std::string, std::string> PropertiesManager::GetPropertyMap() {
std::lock_guard<std::mutex> lock(mutex_);
return properties_map_;
}

// Set separator
void PropertiesManager::SetPropertyDelimiter(const std::string &de) {
std::lock_guard<std::mutex> lock(mutex_);
delimiter = de;
}
}
} // namespace ge

+ 82
- 0
atc/util/properties_manager.h View File

@@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef UTIL_PROPERTIES_MANAGER_H_
#define UTIL_PROPERTIES_MANAGER_H_

#include <map>
#include <mutex>
#include <string>

namespace ge {
namespace atc {
class PropertiesManager {
public:
// Singleton
static PropertiesManager &Instance();

/**
* @ingroup domi_ome
* @brief Initialize configuration parameters, which must be invoked in main.
* @param [in] file_path Property profile path
* @return true success
* @return false fail
* @author
*/
bool Init(const std::string &file_path);

/**
* @ingroup domi_ome
* @brief Return configuration parameters
* @return properties_map_
* @author
*/
std::map<std::string, std::string> GetPropertyMap();

/**
* @ingroup domi_ome
* @brief Adapt key value pair form, set different separators
* @param [in] delimiter
* @author
*/
void SetPropertyDelimiter(const std::string &de);

private:
// Private construct, destructor
PropertiesManager();
~PropertiesManager();

// Get file content
bool LoadFileContent(const std::string &file_path);

// Parsing a single line file
bool ParseLine(const std::string &line);

// Remove space before and after string
std::string Trim(const std::string &str);

bool is_inited_;

// Configuration item separator, default is "="
std::string delimiter;

std::map<std::string, std::string> properties_map_;
std::mutex mutex_;
};
}
} // namespace ge

#endif // UTIL_PROPERTIES_MANAGER_H_

+ 171
- 0
atc/util/string_util.h View File

@@ -0,0 +1,171 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_COMMON_STRING_UTIL_H_
#define INC_FRAMEWORK_COMMON_STRING_UTIL_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY _declspec(dllexport)
#else
#define GE_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_VISIBILITY
#endif
#endif

#include <cctype>
#include <securec.h>

#include <algorithm>
#include <functional>
#include <sstream>
#include <string>
#include <vector>

namespace ge {
class GE_FUNC_VISIBILITY StringUtils {
public:
static std::string &Ltrim(std::string &s) {
#if __cplusplus >= 201103L
(void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); }));
#else
(void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), std::not1(std::ptr_fun<int, int>(std::isspace))));
#endif
return s;
}
// lint -esym(551,*)
static std::string &Rtrim(std::string &s) { /*lint !e618*/
#if __cplusplus >= 201103L
(void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); }));
#else
(void)s.erase(std::find_if(s.rbegin(), s.rend(), std::not1(std::ptr_fun<int, int>(std::isspace))).base(), s.end());
#endif
return s;
}
// lint -esym(551,*)
///
/// @ingroup domi_common
/// @brief delete spaces at the beginning and end of a string
/// @param [in] string to be trimmed
/// @return string after trim
///
static std::string &Trim(std::string &s) { return Ltrim(Rtrim(s)); }

///
/// @ingroup domi_common
/// @brief string splitting
/// @param [in] str string to be trimmed
/// @param [in] delim separator
/// @return string array after segmentation
///
static std::vector<std::string> Split(const std::string &str, char delim) {
std::vector<std::string> elems;

if (str.empty()) {
elems.emplace_back("");
return elems;
}

std::stringstream ss(str);
std::string item;

while (getline(ss, item, delim)) {
elems.push_back(item);
}

auto str_size = str.size();
if (str_size > 0 && str[str_size - 1] == delim) {
elems.emplace_back("");
}

return elems;
}
///
/// @ingroup domi_common
/// @brief obtain the file name
/// @param [in] s path name
/// @return file name
///
static std::string GetFileName(std::string &s) {
if (s.empty()) {
return "";
}
std::vector<std::string> files = StringUtils::Split(s, '/');

return files.empty() ? "" : files[files.size() - 1];
}
///
/// @ingroup domi_common
/// @brief full replacement
/// @link
/// @param [in] str str string to be replaced
/// @param [in] old_value old Characters Before Replacement
/// @param [in] new_value new Characters Before Replacement
/// @return string after replacement
///
static std::string ReplaceAll(std::string str, const std::string &old_value, const std::string &new_value) {
std::string::size_type cur_pos = 0;
std::string::size_type old_length = old_value.length();
std::string::size_type new_length = new_value.length();
// cycle replace
for (; cur_pos != std::string::npos; cur_pos += new_length) {
if ((cur_pos = str.find(old_value, cur_pos)) != std::string::npos) {
(void)str.replace(cur_pos, old_length, new_value);
} else {
break;
}
}
return str;
}

///
/// @ingroup domi_common
/// @brief checks whether a character string starts with a character string (prefix)
/// @link
/// @param [in] str string to be compared
/// @param [in] str_x prefix
/// @return if the value is a prefix, true is returned. Otherwise, false is returned
///
static bool StartWith(const std::string &str, const std::string str_x) {
return ((str.size() >= str_x.size()) && (str.compare(0, str_x.size(), str_x) == 0));
}

///
/// @ingroup domi_common
/// @brief format string
/// @link
/// @param [in] format specifies the character string format
/// @param [in] ... format Filling Content
/// @return formatted string
///
static std::string FormatString(const char *format, ...) {
const uint32_t MAX_BUFFER_LEN = 1024; // the stack memory plint check result must be less than 1024
va_list args;
va_start(args, format);
char buffer[MAX_BUFFER_LEN] = {0};
int32_t ret = vsnprintf_s(buffer, MAX_BUFFER_LEN, MAX_BUFFER_LEN - 1, format, args);
va_end(args);
return ret > 0 ? buffer : "";
}
};
} // namespace ge

#endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_

+ 32
- 0
atc/util/tool.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef UTIL_TOOL_H_
#define UTIL_TOOL_H_

#include <iostream>
#include <memory>
#include <utility>

namespace ge {
template <typename T, typename... Args>
static inline std::shared_ptr<T> MakeShared(Args &&... args) {
typedef typename std::remove_const<T>::type T_nc;
std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...));
return ret;
}
} // namespace ge
#endif // UTIL_TOOL_H_

+ 283
- 0
atc/util/util.cc View File

@@ -0,0 +1,283 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "util.h"

#include <sys/stat.h>
#ifdef __GNUC__
#include <regex.h>
#else
#include <regex>
#endif
#include <algorithm>
#include <climits>
#include <cstdlib>
#include <ctime>
#include <fstream>

#include "common/util/error_manager/error_manager.h"
#include "external/ge/ge_api_error_codes.h"
#include "framework/common/debug/ge_log.h"
#include "mmpa/mmpa_api.h"

namespace {
const int kMaxBuffSize = 256;
const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character";
} // namespace

namespace ge {
/**
* @ingroup domi_common
* @brief Create directory, support to create multi-level directory
* @param [in] directory_path Path, can be multi-level directory
* @return -1 fail
* @return 0 success
*/
int CreateDirectory(const std::string &directory_path) {
GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty.");
auto dir_path_len = directory_path.length();
if (dir_path_len >= MMPA_MAX_PATH) {
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"},
{directory_path, std::to_string(MMPA_MAX_PATH)});
GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), MMPA_MAX_PATH);
return -1;
}
char tmp_dir_path[MMPA_MAX_PATH] = {0};
for (size_t i = 0; i < dir_path_len; i++) {
tmp_dir_path[i] = directory_path[i];
if ((tmp_dir_path[i] == '\\') || (tmp_dir_path[i] == '/')) {
if (mmAccess2(tmp_dir_path, M_F_OK) != EN_OK) {
int32_t ret = mmMkdir(tmp_dir_path, M_IRUSR | M_IWUSR | M_IXUSR); // 700
if (ret != 0) {
if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
return ret;
}
}
}
}
}
int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), M_IRUSR | M_IWUSR | M_IXUSR); // 700
if (ret != 0) {
if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
return ret;
}
}
return 0;
}

std::string CurrentTimeInStr() {
std::time_t now = std::time(nullptr);
std::tm *ptm = std::localtime(&now);
if (ptm == nullptr) {
GELOGE(ge::FAILED, "Localtime failed.");
return "";
}

const int kTimeBufferLen = 32;
char buffer[kTimeBufferLen + 1] = {0};
// format: 20171122042550
std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm);
return std::string(buffer);
}

std::string RealPath(const char *path) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH,
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"},
{path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", path, MMPA_MAX_PATH);

// Nullptr is returned when the path does not exist or there is no permission
// Return absolute path when path is accessible
std::string res;
char resolved_path[MMPA_MAX_PATH] = {0};
if (mmRealPath(path, resolved_path, MMPA_MAX_PATH) == EN_OK) {
res = resolved_path;
}

return res;
}

bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param) {
// The specified path is empty
std::map<std::string, std::string> args_map;
if (file_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param});
GELOGW("Input parameter %s is empty.", file_path.c_str());
return false;
}
std::string real_path = RealPath(file_path.c_str());
// Unable to get absolute path (does not exist or does not have permission to access)
if (real_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {file_path, strerror(errno)});
GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno));
return false;
}

// A regular matching expression to verify the validity of the input file path
// Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores
// File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.)
#ifdef __GNUC__
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
#else
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
#endif

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(real_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, real_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason);

// The absolute path points to a file that is not readable
if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) {
ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"file", "errmsg"}, {file_path.c_str(), strerror(errno)});
GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno));
return false;
}

return true;
}

bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param) {
// The specified path is empty
if (file_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param});
GELOGW("Input parameter's value is empty.");
return false;
}

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH,
ErrorManager::GetInstance().ATCReportErrMessage(
"E19002", {"filepath", "size"}, {file_path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(),
MMPA_MAX_PATH);

// A regular matching expression to verify the validity of the input file path
// Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores
// File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.)
#ifdef __GNUC__
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
#else
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
#endif

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(file_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, file_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason);

std::string real_path = RealPath(file_path.c_str());
// Can get absolute path (file exists)
if (!real_path.empty()) {
// File is not readable or writable
if (mmAccess2(real_path.c_str(), M_W_OK | M_F_OK) != EN_OK) {
ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"file", "errmsg"}, {real_path, strerror(errno)});
GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno));
return false;
}
} else {
// Find the last separator
int path_split_pos = static_cast<int>(file_path.size() - 1);
for (; path_split_pos >= 0; path_split_pos--) {
if (file_path[path_split_pos] == '\\' || file_path[path_split_pos] == '/') {
break;
}
}
if (path_split_pos == 0) {
return true;
}
if (path_split_pos != -1) {
std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos));
// Determine whether the specified path is valid by creating the path
if (CreateDirectory(prefix_path) != 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {file_path});
GELOGW("Can not create directory[%s].", file_path.c_str());
return false;
}
}
}

return true;
}

bool ValidateStr(const std::string &str, const std::string &mode) {
#ifdef __GNUC__
char ebuff[kMaxBuffSize];
regex_t reg;
int cflags = REG_EXTENDED | REG_NOSUB;
int ret = regcomp(&reg, mode.c_str(), cflags);
if (ret) {
regerror(ret, &reg, ebuff, kMaxBuffSize);
GELOGW("regcomp failed, reason: %s", ebuff);
regfree(&reg);
return true;
}

ret = regexec(&reg, str.c_str(), 0, NULL, 0);
if (ret) {
regerror(ret, &reg, ebuff, kMaxBuffSize);
GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff);
regfree(&reg);
return false;
}

regfree(&reg);
return true;
#else
std::wstring wstr(str.begin(), str.end());
std::wstring wmode(mode.begin(), mode.end());
std::wsmatch match;
bool res = false;

try {
std::wregex reg(wmode, std::regex::icase);
// Matching string part
res = regex_match(wstr, match, reg);
res = regex_search(str, std::regex("[`!@#$%^&*()|{}';',<>?]"));
} catch (std::exception &ex) {
GELOGW("The directory %s is invalid, error: %s.", str.c_str(), ex.what());
return false;
}
return !(res) && (str.size() == match.str().size());
#endif
}

bool GetNameFromFileName(const std::string &file_name, std::string &base_name) {
if (file_name.empty()) {
GELOGW("File path may not valid, check params --output");
return false;
}
size_t start_position = 0;
// using output as base_name (ignore ".om")
size_t filename_suffixes = 3;
if (file_name.find_last_of('/') != std::string::npos) {
start_position = file_name.find_last_of('/') + 1;
}
size_t end_position = file_name.length() - filename_suffixes;
base_name = file_name.substr(start_position, end_position - start_position);
if (base_name.empty()) {
GELOGW("File path may not valid, check params --output");
return false;
}
return true;
}
} // namespace ge

+ 159
- 0
atc/util/util.h View File

@@ -0,0 +1,159 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef UTIL_UTIL_H_
#define UTIL_UTIL_H_

#include <limits.h>
#include <string>

#include "framework/common/debug/ge_log.h"
#include "mmpa/mmpa_api.h"

// For propagating errors when calling a function.
#define GE_RETURN_IF_ERROR(expr) \
do { \
const ::ge::Status _status = (expr); \
if (_status) return _status; \
} while (0)

// If expr is not true, print the log and return the specified status
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \
do { \
bool b = (expr); \
if (!b) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
return _status; \
} \
} while (0);

// If expr is not true, print the log and execute a custom statement
#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (!b) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
exec_expr; \
} \
}

// If expr is true, print logs and execute custom statements
#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (b) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
exec_expr; \
} \
}

// If expr is not SUCCESS, return the same value
#define GE_CHK_STATUS_RET_NOLOG(expr) \
do { \
const ge::Status _status = (expr); \
if (_status != ge::SUCCESS) { \
return _status; \
} \
} while (0);

#define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \
do { \
const ::ge::Status _status = (expr); \
if (_status) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
return _status; \
} \
} while (0)

// If expr is not SUCCESS, print the log and execute a custom statement
#define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \
do { \
const ge::Status _status = (expr); \
GE_CHK_BOOL_EXEC(_status == SUCCESS, exec_expr, __VA_ARGS__); \
} while (0);

// Check if the parameter is null. If yes, return PARAM_INVALID and record the error
#define GE_CHECK_NOTNULL(val) \
do { \
if (val == nullptr) { \
GELOGE(ge::FAILED, "param[%s] must not be null.", #val); \
return ge::PARAM_INVALID; \
} \
} while (0)

// If expr is true, execute exec_expr without printing logs
#define GE_IF_BOOL_EXEC(expr, exec_expr) \
{ \
if (expr) { \
exec_expr; \
} \
}

namespace ge {
///
/// @ingroup domi_common
/// @brief Recursively Creating a Directory
/// @param [in] directory_path Path, which can be a multi-level directory.
/// @return 0 success
/// @return -1 fail
///
extern int CreateDirectory(const std::string &directory_path);

///
/// @ingroup domi_common
/// @brief Obtains the current time string.
/// @return Time character string in the format : %Y%m%d%H%M%S, eg: 20171011083555
///
std::string CurrentTimeInStr();

///
/// @ingroup domi_common
/// @brief Absolute path for obtaining files.
/// @param [in] path of input file
/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned
///
std::string RealPath(const char *path);

///
/// @ingroup domi_common
/// @brief Check whether the specified input file path is valid.
/// 1. The specified path cannot be empty.
/// 2. The path can be converted to an absolute path.
/// 3. The file path exists and is readable.
/// @param [in] file_path path of input file
/// @param [out] result
///
bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param = "");

///
/// @ingroup domi_common
/// @brief Checks whether the specified output file path is valid.
/// @param [in] file_path path of output file
/// @param [out] result
///
bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param = "");

///
/// @ingroup domi_common
/// @brief Check whether the file path meets the whitelist verification requirements.
/// @param [in] filePath file path
/// @param [out] result
///
bool ValidateStr(const std::string &filePath, const std::string &mode);

bool GetNameFromFileName(const std::string &file_name, std::string &base_name);
} // namespace ge
#endif // INC_FRAMEWORK_COMMON_UTIL_H_

+ 10
- 0
parser/common/convert/pb2json.cc View File

@@ -23,6 +23,7 @@
#include "securec.h"
#include "framework/common/fmk_types.h"
#include "framework/common/debug/ge_log.h"
#include "parser/common/model_saver.h"

using std::set;
using std::string;
@@ -31,6 +32,15 @@ namespace ge {
namespace {
const int kSignificantDigits = 10;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status Pb2Json::ToJson(const ProtobufMsg &message,
const set<string> &black_fields,
const char *json_file, bool enum2str) {
Json j;
Message2Json(message, black_fields, j, enum2str);
return ge::parser::ModelSaver::SaveJsonToFile(json_file, j);
}

// JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message,
const set<string> &black_fields, Json &json,


+ 4
- 0
parser/common/convert/pb2json.h View File

@@ -23,6 +23,7 @@
#include <memory>
#include <set>
#include <string>
#include "ge/ge_api_error_codes.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
#include "nlohmann/json.hpp"
@@ -51,6 +52,9 @@ class Pb2Json {
const ProtobufReflection *reflection, const std::set<std::string> &black_fields,
Json &json, bool enum2str);

static Status ToJson(const ProtobufMsg &message, const std::set<std::string> &black_fields, const char *json_file,
bool enum2str = false);

protected:
static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field,
bool enum2str, Json &json);


+ 159
- 0
third_party/graphengine/inc/external/ge/ge_ir_build.h View File

@@ -0,0 +1,159 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GE_IR_BUILD_H_
#define INC_EXTERNAL_GE_IR_BUILD_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY _declspec(dllexport)
#else
#define GE_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_VISIBILITY
#endif
#endif

#include <string>
#include <map>
#include <memory>
#include "graph/graph.h"
#include "graph/ge_error_codes.h"

namespace {
const int IR_MAJOR_VERSION = 1;
const int IR_MINOR_VERSION = 0;
const int IR_PATCH_VERSION = 0;
} // namespace

namespace ge {

struct ModelBufferData {
std::shared_ptr<uint8_t> data = nullptr;
uint64_t length;
};

enum aclgrphAttrType { ATTR_TYPE_KEEP_DTYPE = 0, ATTR_TYPE_WEIGHT_COMPRESS };

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
* @param global_options[IN] global init params for build
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &))
GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options);

GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &global_options);

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
*/
GE_FUNC_VISIBILITY void aclgrphBuildFinalize();

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
* @param graph[IN] the graph ready to build
* @param options[IN] options used for build
* @param model[OUT] builded model
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &,
const std::map<AscendString, AscendString> &,
ModelBufferData &))
GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph,
const std::map<std::string, std::string> &build_options,
ModelBufferData &model);

GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph,
const std::map<AscendString, AscendString> &build_options,
ModelBufferData &model);

/**
* @ingroup AscendCL
* @brief save model buffer to file
*
* @param output_file[IN] the file path to be saved
* @param model[IN] model buffer data
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char *, const ModelBufferData &))
GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model);

GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model);

/**
* @ingroup AscendCL
* @brief query IR interface version
*
* @param major_version[OUT] IR interface major version
* @param minor_version[OUT] IR interface minor version
* @param patch_version[OUT] IR interface patch version
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
GE_FUNC_VISIBILITY graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version);

/**
* @ingroup AscendCL
* @brief dump graph
*
* @param graph[IN] the graph ready to build
* @param file[IN] file path
* @param file[IN] file path string len
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
GE_FUNC_VISIBILITY graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len);

/**
* @ingroup AscendCL
* @brief create single op graph
*
* @param op_type[IN] the op_type
* @param inputs[IN] the inputdesc
* @param outputs[IN] the outputdesc
* @param graph[OUT] the graph
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector<TensorDesc> &inputs,
const std::vector<TensorDesc> &outputs, Graph &graph);

/**
* @name aclgrphSetOpAttr
* @brief set attribute for operators in the configuration file
* @param graph [IN/OUT] compute graph
* @param attr_type [In] attribute type
* @param cfg_path [IN] the config file path
* @return graphStatus
*/
GE_FUNC_VISIBILITY graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path);

}; // namespace ge
#endif // INC_EXTERNAL_GE_IR_BUILD_H_

+ 60
- 0
third_party/graphengine/inc/framework/generator/ge_generator.h View File

@@ -0,0 +1,60 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_
#define INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_

#include <map>
#include <memory>
#include <string>
#include <vector>
#include "common/ge_inner_error_codes.h"
#include "graph/ge_tensor.h"
#include "graph/graph.h"
#include "graph/op_desc.h"
#include "graph/detail/attributes_holder.h"
#include "omg/omg_inner_types.h"

namespace ge {
class GE_FUNC_VISIBILITY GeGenerator {
public:
GeGenerator() = default;

~GeGenerator() { (void)Finalize(); }

GeGenerator(const GeGenerator &) = delete;

GeGenerator &operator=(const GeGenerator &) = delete;

Status Initialize(const std::map<std::string, std::string> &options, OmgContext &context);

Status Finalize();

Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix,
const std::vector<GeTensor> &inputs = std::vector<GeTensor>());

Status GenerateInfershapeGraph(const Graph &graph);

Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector<GeTensor> &inputs,
const std::vector<GeTensor> &outputs, const std::string &model_file_name);
private:
class Impl;

std::shared_ptr<Impl> impl_;
};
} // namespace ge

#endif // INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_

+ 36
- 0
third_party/graphengine/inc/framework/omg/ge_init.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_OMG_GE_INIT_H_
#define INC_FRAMEWORK_OMG_GE_INIT_H_
#include <map>
#include <string>
#include "common/ge_inner_error_codes.h"

namespace ge {
class GE_FUNC_VISIBILITY GEInit {
public:
// GE Environment Initialize, return Status: SUCCESS,FAILED
static Status Initialize(const std::map<std::string, std::string> &options);

static std::string GetPath();

// GE Environment Finalize, return Status: SUCCESS,FAILED
static Status Finalize();
};
} // namespace ge

#endif // INC_FRAMEWORK_OMG_GE_INIT_H_

+ 35
- 0
third_party/graphengine/inc/framework/omg/model_tool.h View File

@@ -0,0 +1,35 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_OMG_MODEL_TOOL_H_
#define INC_FRAMEWORK_OMG_MODEL_TOOL_H_

#include <memory>
#include <string>

#include "framework/common/debug/ge_log.h"
#include "proto/ge_ir.pb.h"

namespace ge {
class GE_FUNC_VISIBILITY ModelTool {
public:
static Status GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef &model_def, uint32_t &modeldef_size);

static Status GetModelInfoFromPbtxt(const char *model_file, ge::proto::ModelDef &model_def);
};
} // namespace ge

#endif // INC_FRAMEWORK_OMG_MODEL_TOOL_H_

Loading…
Cancel
Save