Browse Source

atc

pull/242/head
wjm 5 years ago
parent
commit
be3f670f82
27 changed files with 6433 additions and 0 deletions
  1. +168
    -0
      atc/CMakeLists.txt
  2. +21
    -0
      atc/atc
  3. +595
    -0
      atc/atc_ir_common.cc
  4. +82
    -0
      atc/atc_ir_common.h
  5. +1428
    -0
      atc/main.cc
  6. +164
    -0
      atc/module.mk
  7. +1002
    -0
      atc/parse_graph.cc
  8. +108
    -0
      atc/parse_graph.h
  9. +191
    -0
      atc/proto/ge_ir.proto
  10. +139
    -0
      atc/proto/insert_op.proto
  11. +396
    -0
      atc/proto/om.proto
  12. +179
    -0
      atc/proto/task.proto
  13. +609
    -0
      atc/single_op_parser.cc
  14. +90
    -0
      atc/single_op_parser.h
  15. +85
    -0
      atc/util/gflags_util.h
  16. +145
    -0
      atc/util/properties_manager.cc
  17. +82
    -0
      atc/util/properties_manager.h
  18. +171
    -0
      atc/util/string_util.h
  19. +32
    -0
      atc/util/tool.h
  20. +283
    -0
      atc/util/util.cc
  21. +159
    -0
      atc/util/util.h
  22. +9
    -0
      parser/common/convert/pb2json.cc
  23. +3
    -0
      parser/common/convert/pb2json.h
  24. +159
    -0
      third_party/graphengine/inc/external/ge/ge_ir_build.h
  25. +60
    -0
      third_party/graphengine/inc/framework/generator/ge_generator.h
  26. +39
    -0
      third_party/graphengine/inc/framework/omg/ge_init.h
  27. +34
    -0
      third_party/graphengine/inc/framework/omg/model_tool.h

+ 168
- 0
atc/CMakeLists.txt View File

@@ -0,0 +1,168 @@
set(PROTO_LIST
"${METADEF_DIR}/proto/om.proto"
"${METADEF_DIR}/proto/ge_ir.proto"
"${METADEF_DIR}/proto/insert_op.proto"
"${METADEF_DIR}/proto/task.proto"
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})

set(SRC_LIST
"main.cc"
"single_op_parser.cc"
"parse_graph.cc"
"atc_ir_common.cc"
"util/util.cc"
"util/properties_manager.cc"
)

############ atc_atc.bin ############
add_executable(atc_atc.bin ${SRC_LIST} ${PROTO_HDRS})

target_compile_options(atc_atc.bin PRIVATE
-Werror
-O2
-Wno-deprecated-declarations
-fno-common
-fvisibility=hidden
)

target_compile_definitions(atc_atc.bin PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
COMPILE_OMG_PACKAGE
google=ascend_private
LOG_CPP
FUNC_VISIBILITY
)

target_include_directories(atc_atc.bin PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/register
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/external/register
${PARSER_DIR}
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
#### yellow zone ####
${GE_CODE_DIR}/../inc
${GE_CODE_DIR}/../inc/common
#### blue zone ####
${METADEF_DIR}/third_party/graphengine/inc/
${METADEF_DIR}/third_party/graphengine/inc/external
${METADEF_DIR}/third_party/graphengine/inc/framework
${METADEF_DIR}/third_party/graphengine/ge
${METADEF_DIR}/third_party/fwkacllib/inc
${PARSER_DIR}/third_party/graphengine/inc
)

target_link_options(atc_atc.bin PRIVATE
-Wl,-Bsymbolic
)

target_link_libraries(atc_atc.bin PRIVATE
$<BUILD_INTERFACE:intf_pub>
ascend_protobuf
ge_common
register
c_sec
graph
error_manager
ge_compiler
parser_common
gflags
json
slog
static_mmpa
-lrt
-ldl
)

set_target_properties(atc_atc.bin PROPERTIES
OUTPUT_NAME atc.bin
RUNTIME_OUTPUT_DIRECTORY atclib
)

############ fwk_atc.bin ############
add_executable(fwk_atc.bin ${SRC_LIST} ${PROTO_HDRS})

target_compile_options(fwk_atc.bin PRIVATE
-Werror
-O2
-Wno-deprecated-declarations
-fno-common
-fvisibility=hidden
)

target_compile_definitions(fwk_atc.bin PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
COMPILE_OMG_PACKAGE
google=ascend_private
LOG_CPP
FUNC_VISIBILITY
)

target_include_directories(fwk_atc.bin PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/register
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/external/register
${PARSER_DIR}
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
#### yellow zone ####
${GE_CODE_DIR}/../inc
${GE_CODE_DIR}/../inc/common
#### blue zone ####
${METADEF_DIR}/third_party/graphengine/inc/
${METADEF_DIR}/third_party/graphengine/inc/external
${METADEF_DIR}/third_party/graphengine/inc/framework
${METADEF_DIR}/third_party/graphengine/ge
${METADEF_DIR}/third_party/fwkacllib/inc
${PARSER_DIR}/third_party/graphengine/inc/

)

target_link_options(fwk_atc.bin PRIVATE
-Wl,-Bsymbolic
)

target_link_libraries(fwk_atc.bin PRIVATE
$<BUILD_INTERFACE:intf_pub>
ascend_protobuf
ge_common
register
c_sec
graph
error_manager
ge_runner
parser_common
gflags
json
slog
static_mmpa
-lrt
-ldl
)

set_target_properties(fwk_atc.bin PROPERTIES
OUTPUT_NAME atc.bin
RUNTIME_OUTPUT_DIRECTORY fwkacl
)

############ install ############
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(TARGETS atc_atc.bin OPTIONAL
RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/atclib
)

install(TARGETS fwk_atc.bin OPTIONAL
RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/fwkacl
)

+ 21
- 0
atc/atc View File

@@ -0,0 +1,21 @@
#!/bin/bash
#-------------------------------------------------------------------
# Purpose:
# Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved.
#-------------------------------------------------------------------

real_path=$(readlink "$0")
if [ $? -eq 0 ]; then
LOCAL_PATH=$(cd "$(dirname "$real_path")"; pwd)
else
LOCAL_PATH=$(cd "$(dirname "$0")"; pwd)
fi
PKG_PATH=$(cd ${LOCAL_PATH}/..; pwd)
LIB_P="/lib64"
PYTHON_P="/python/site-packages"
LIB64_PATH="${PKG_PATH}${LIB_P}"
PYTHON_PATH="${PKG_PATH}${PYTHON_P}"
export LD_LIBRARY_PATH="${LIB64_PATH}:${LD_LIBRARY_PATH}"
export PYTHONPATH="${PYTHON_PATH}:${PYTHONPATH}"

${PKG_PATH}/bin/atc.bin "$@"

+ 595
- 0
atc/atc_ir_common.cc View File

@@ -0,0 +1,595 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "atc_ir_common.h"
#include "common/util/error_manager/error_manager.h"
#include "external/ge/ge_api_types.h"
#include "util/string_util.h"
#include "util/util.h"
#include "graph/utils/type_utils.h"

using std::pair;
using std::string;
using std::vector;

namespace ge {
namespace {
const int64_t kDynamicInputDim = -1;
const int64_t kDynamicImageSizeNum = 2;
const uint32_t NCHW_DIM_H = 2;
const uint32_t NCHW_DIM_W = 3;
const uint32_t NHWC_DIM_H = 1;
const uint32_t NHWC_DIM_W = 2;
const int32_t DIM_DEFAULT_SIZE = 4;
const size_t kMaxDynamicDimNum = 100;
const size_t kMaxNDDimNum = 4;
const size_t kMinNDDimNum = 1;
// datatype/formats from user to GE, Unified to util interface file later
const std::map<std::string, ge::DataType> kOutputTypeSupportDatatype = {
{"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}};
const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8";
const std::set<std::string> kBufferOptimizeSupportOption = {"l1_optimize", "l2_optimize", "off_optimize",
"l1_and_l2_optimize"};
// The function is incomplete. Currently, only l2_optimize, off_optimize is supported.
const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision";
const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\"";
const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\"";
const char *const kSplitError1 = "size not equal to 2 split by \":\"";
const char *const kEmptyError = "can not be empty";
const char *const kFloatNumError = "exist float number";
const char *const kDigitError = "is not digit";
const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]";
const char *const kSelectImplmodeError = "only support high_performance, high_precision";
const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \"";
const char *const kKeepDtypeError = "file not found";

vector<string> SplitInputShape(const std::string &input_shape) {
vector<string> shape_pair_vec;
size_t pos = input_shape.rfind(":");
if (pos != std::string::npos) {
shape_pair_vec.emplace_back(input_shape.substr(0, pos));
shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos));
}
return shape_pair_vec;
}
} // namespace

Status CheckInputFormat(const string &input_format) {
if (input_format.empty()) {
return ge::SUCCESS;
}
if (!ge::TypeUtils::IsFormatValid(input_format.c_str())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, "input format is invalid!"});
GELOGE(ge::PARAM_INVALID, "input format [%s] is invalid!", input_format.c_str());
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

bool CheckDynamicBatchShape(const vector<int64_t> &shape, const string &data_name) {
if (shape[0] == kDynamicInputDim) {
for (size_t i = 1; i < shape.size(); ++i) {
if (shape[i] < 1) {
ErrorManager::GetInstance().ATCReportErrMessage("E10018", {"index", "shape"},
{std::to_string(i), std::to_string(shape[i])});
GELOGE(ge::PARAM_INVALID,
"Only batch N can be -1 when set --dynamic_batch_size, current data: %s shape[%zu] is %ld",
data_name.c_str(), i, shape[i]);
return false;
}
}
return true;
} else {
return false;
}
}

bool CheckDynamicBatchSizeInputShapeValid(map<string, vector<int64_t>> shape_map,
std::string &dynamic_batch_size) {
int32_t size = 0;
for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) {
vector<int64_t> shape = iter->second;
if (shape.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10012");
GELOGE(ge::PARAM_INVALID, "--input_shape's shape size can not be less than 1 when set --dynamic_batch_size.");
return false;
}

if (std::count(shape.begin(), shape.end(), kDynamicInputDim) == 0) {
continue;
}

bool ret = CheckDynamicBatchShape(shape, iter->first);
if (ret) {
size++;
}
}

if (size == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10031");
GELOGE(ge::PARAM_INVALID, "At least one batch n must be equal to -1 when set --dynamic_batch_size.");
return false;
}

for (char c : dynamic_batch_size) {
if (!isdigit(c) && (c != ',') && (c != ' ')) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10033", {"value", "reason"}, {dynamic_batch_size, kDynamicBatchSizeError});
GELOGE(ge::PARAM_INVALID, "Input parameter[--dynamic_batch_size]'s value[%s] is invalid. reason: %s",
dynamic_batch_size.c_str(), kDynamicBatchSizeError);
return false;
}
}
if (dynamic_batch_size.back() == ',') {
dynamic_batch_size.erase(dynamic_batch_size.end() - 1);
}
return true;
}

bool CheckDynamicImageSizeShape(const vector<int64_t> &shape, const string &data_name,
const std::string &input_format) {
int64_t height = 0;
int64_t width = 0;
if (input_format == "NCHW") {
height = shape[NCHW_DIM_H];
width = shape[NCHW_DIM_W];
}

if (input_format == "NHWC") {
height = shape[NHWC_DIM_H];
width = shape[NHWC_DIM_W];
}

if (height == kDynamicInputDim && width == kDynamicInputDim &&
std::count(shape.begin(), shape.end(), kDynamicInputDim) == kDynamicImageSizeNum) {
return true;
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E10019");
GELOGE(ge::PARAM_INVALID,
"--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size.");
return false;
}
}
bool CheckDynamicImagesizeInputShapeValid(map<string, vector<int64_t>> shape_map,
const std::string input_format, std::string &dynamic_image_size) {
if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(input_format.c_str())) {
GELOGE(ge::PARAM_INVALID, "user input format [%s] is not found!", input_format.c_str());
return false;
}
int32_t size = 0;
for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) {
vector<int64_t> shape = iter->second;
// only support four dim
if (shape.size() != DIM_DEFAULT_SIZE) {
if (std::count(shape.begin(), shape.end(), kDynamicInputDim) > 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10019");
GELOGE(ge::PARAM_INVALID,
"--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size.");
return false;
}
continue;
}

if (std::count(shape.begin(), shape.end(), kDynamicInputDim) == 0) {
continue;
}
auto ret = CheckDynamicImageSizeShape(shape, iter->first, input_format);
if (ret) {
size++;
} else {
return ret;
}
}
if (size == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10019");
GELOGE(ge::PARAM_INVALID,
"--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size.");
return false;
}

EraseEndSemicolon(dynamic_image_size);
// Different parameter sets are split string by ';'
std::vector<std::string> split_set = StringUtils::Split(dynamic_image_size, ';');
// Different dimensions are split by ','
std::vector<std::string> split_dim;
for (auto str : split_set) {
split_dim = StringUtils::Split(str, ',');
if (split_dim.size() != static_cast<size_t>(kDynamicImageSizeNum)) {
ErrorManager::GetInstance().ATCReportErrMessage("E10020", {"DynamicImageSizeNum"},
{std::to_string(kDynamicImageSizeNum)});
GELOGE(ge::PARAM_INVALID,
"--dynamic_image_size's number of dimensions of each "
"group must be %ld.",
kDynamicImageSizeNum);
return false;
}
}

return true;
}

bool CheckDynamicDimsInputShapeValid(const map<string, vector<int64_t>> &shape_map,
string input_format, string &dynamic_dims) {
if (input_format != "ND") {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--input_format", input_format.c_str(), "input_format must be ND when set dynamic_dims"});
GELOGE(ge::PARAM_INVALID, "input_format must be ND when set dynamic_dims.");
return false;
}

int32_t dynamic_dim = 0;
for (auto &info_shapes : shape_map) {
auto &shapes = info_shapes.second;
if (shapes.size() > kMaxNDDimNum || shapes.size() < kMinNDDimNum) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--input_shape's dim", std::to_string(shapes.size()), "Dim num must within [1, 4] when set dynamic_dims"});
GELOGE(ge::PARAM_INVALID, "Dim num must within [%zu, %zu] when set dynamic_dims.", kMinNDDimNum, kMaxNDDimNum);
return false;
}
dynamic_dim += std::count(shapes.begin(), shapes.end(), kDynamicInputDim);
}
if (dynamic_dim == 0) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--input_shape's dynamic dim num", "0", "at least one dim should be -1 when set dynamic_dims"});
GELOGE(ge::PARAM_INVALID, "input_shape's shape is invalid, at least one dim should be -1 when set dynamic_dims.");
return false;
}

if (!CheckAndParseDynamicDims(dynamic_dim, dynamic_dims)) {
GELOGE(ge::PARAM_INVALID, "Check and parse dynamic dims: %s failed.", dynamic_dims.c_str());
return false;
}

return true;
}

bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims) {
EraseEndSemicolon(dynamic_dims);
if (dynamic_dims.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--dynamic_dims", dynamic_dims.c_str(), "dynamic_dims can not be empty"});
GELOGE(ge::PARAM_INVALID, "dynamic_dims can not be empty.");
return false;
}
// Different parameter sets are split by ';'
vector<string> split_set = StringUtils::Split(dynamic_dims, ';');
if (split_set.size() > kMaxDynamicDimNum) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10042", {"parameter", "reason"}, {"dynamic_dims", "dynamic_dims's num of parameter set can not exceed 100"});
GELOGE(ge::PARAM_INVALID, "dynamic_dims's num of parameter set can not exceed %zu.", kMaxDynamicDimNum);
return false;
}
for (auto split_dim : split_set) {
vector<string> one_set = StringUtils::Split(split_dim, ',');
if (one_set.size() != static_cast<size_t>(dynamic_dim_num)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10042", {"parameter", "reason"},
{"dynamic_dims", "Each gear setting needs to be consistent with the number of -1 in the inputshape"});
GELOGE(ge::PARAM_INVALID, "Input parameter --dynamic_dims parse failed, "
"reason: Each gear setting needs to be consistent with the number of -1 in the inputshape.");
return false;
}
for (auto dim : one_set) {
for (auto c : dim) {
if (!isdigit(c)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--dynamic_dims's parameter", dim.c_str(), "must be positive integer"});
GELOGE(ge::PARAM_INVALID, "dynamic_dims's parameter must be positive integer.");
return false;
}
}
}
}
return true;
}

Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims,
const string input_shape, const string input_format, bool &is_dynamic_input) {
int32_t param_size = static_cast<int32_t>(!dynamic_batch_size.empty()) +
static_cast<int32_t>(!dynamic_image_size.empty()) + static_cast<int32_t>(!dynamic_dims.empty());
if (param_size > 1) {
ErrorManager::GetInstance().ATCReportErrMessage("E10009", {"parameter0", "parameter1", "parameter2"},
{"dynamic_batch_size", "dynamic_image_size", "dynamic_dims"});
GELOGE(ge::PARAM_INVALID, "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one");
return ge::PARAM_INVALID;
}

if (param_size == 0) {
return ge::SUCCESS;
}

map<string, vector<int64_t>> shape_map;
vector<pair<string, vector<int64_t>>> user_shape_map;
is_dynamic_input = true;
if (input_shape.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"input_shape"});
GELOGE(ge::PARAM_INVALID, "The input_shape can not be empty in dynamic input size scenario.");
return ge::PARAM_INVALID;
}

if (!ParseInputShape(input_shape, shape_map, user_shape_map, is_dynamic_input)) {
GELOGE(ge::PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str());
return ge::PARAM_INVALID;
}

if (!dynamic_batch_size.empty()) {
if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) {
GELOGE(ge::PARAM_INVALID, "Check dynamic batch size input shape failed: %s", input_shape.c_str());
return ge::PARAM_INVALID;
}
}

if (!dynamic_image_size.empty()) {
if (!CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size)) {
GELOGE(ge::PARAM_INVALID, "Check dynamic image size input shape failed: %s", input_shape.c_str());
return ge::PARAM_INVALID;
}
}

if (!dynamic_dims.empty()) {
if (!CheckDynamicDimsInputShapeValid(shape_map, input_format, dynamic_dims)) {
GELOGE(ge::PARAM_INVALID, "Check dynamic dims: %s of input shape: %s failed.", dynamic_dims.c_str(),
input_shape.c_str());
return ge::PARAM_INVALID;
}
}
return ge::SUCCESS;
}

bool ParseInputShape(const string &input_shape, map<string, vector<int64_t>> &shape_map,
vector<pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input) {
vector<string> shape_vec = StringUtils::Split(input_shape, ';');
const int DEFAULT_SHAPE_PAIR_SIZE = 2;
for (const auto &shape : shape_vec) {
vector<string> shape_pair_vec = SplitInputShape(shape);
if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kSplitError1, kInputShapeSample1});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kSplitError1, kInputShapeSample1);
return false;
}
if (shape_pair_vec[1].empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kEmptyError, kInputShapeSample1});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kEmptyError, kInputShapeSample1);
return false;
}

vector<string> shape_value_strs = StringUtils::Split(shape_pair_vec[1], ',');
vector<int64_t> shape_values;
for (auto &shape_value_str : shape_value_strs) {
// stoul: The method may throw an exception: invalid_argument/out_of_range
if (std::string::npos != shape_value_str.find('.')) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kFloatNumError, kInputShapeSample2});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kFloatNumError, kInputShapeSample2);
return false;
}

long left_result = 0;
try {
left_result = stol(StringUtils::Trim(shape_value_str));
if (!shape_value_str.empty() && (shape_value_str.front() == '-')) {
// The value maybe dynamic shape [-1], need substr it and verify isdigit.
shape_value_str = shape_value_str.substr(1);
}
for (char c : shape_value_str) {
if (!isdigit(c)) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kDigitError, kInputShapeSample2});
GELOGE(PARAM_INVALID, "--input_shape's shape value[%s] is not digit", shape_value_str.c_str());
return false;
}
}
} catch (const std::out_of_range &) {
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"},
{"--input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str());
return false;
} catch (const std::invalid_argument &) {
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"},
{"--input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str());
return false;
} catch (...) {
ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"},
{"--input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str());
return false;
}
int64_t result = left_result;
// - 1 is not currently supported
if (!is_dynamic_input && result <= 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape", "result"}, {shape, std::to_string(result)});
GELOGW(
"Input parameter[--input_shape]’s shape value[%s] is invalid, "
"expect positive integer, but value is %ld.",
shape.c_str(), result);
return false;
}
shape_values.push_back(result);
}

shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values));
user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values));
}

return true;
}

Status CheckOutputTypeParamValid(const std::string output_type) {
if ((!output_type.empty()) && (kOutputTypeSupportDatatype.find(output_type) == kOutputTypeSupportDatatype.end())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--output_type", output_type, kOutputTypeSupport});
GELOGE(ge::PARAM_INVALID,
"Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport);
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckBufferOptimizeParamValid(const std::string buffer_optimize) {
if ((!buffer_optimize.empty()) &&
(kBufferOptimizeSupportOption.find(buffer_optimize) == kBufferOptimizeSupportOption.end())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--buffer_optimize", buffer_optimize, kBufferOptimizeSupport});
GELOGE(ge::PARAM_INVALID,
"Invalid value for --buffer_optimize[%s], %s.", buffer_optimize.c_str(), kBufferOptimizeSupport);
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckCompressWeightParamValid(const std::string enable_compress_weight, const std::string compress_weight_conf) {
if ((!compress_weight_conf.empty()) &&
(!CheckInputPathValid(compress_weight_conf, "--compress_weight_conf"))) {
GELOGE(ge::PARAM_INVALID, "compress weight config file not found, file_name:%s", compress_weight_conf.c_str());
return ge::PARAM_INVALID;
}
if ((enable_compress_weight != "") && (enable_compress_weight != "true") && (enable_compress_weight != "false")) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10005", {"parameter", "value"}, {"enable_compress_weight", enable_compress_weight});
GELOGE(ge::PARAM_INVALID,
"Input parameter[--enable_compress_weight]'s value[%s] must be true or false.", enable_compress_weight.c_str());
return ge::PARAM_INVALID;
}

if ((enable_compress_weight == "true") && (!compress_weight_conf.empty())) {
ErrorManager::GetInstance().ATCReportErrMessage("E10047", {"parameter0", "parameter1"},
{"enable_compress_weight", "compress_weight_conf"});
GELOGE(ge::PARAM_INVALID, "enable_compress_weight and compress_weight_conf can not both exist!!");
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckKeepTypeParamValid(const std::string &keep_dtype) {
if ((!keep_dtype.empty()) && (!CheckInputPathValid(keep_dtype, "--keep_dtype"))) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--keep_dtype", keep_dtype, kKeepDtypeError});
GELOGE(ge::PARAM_INVALID, "keep dtype config file not found, file_name:%s", keep_dtype.c_str());
return ge::PARAM_INVALID;
}

return ge::SUCCESS;
}

int CheckLogParamValidAndSetLogLevel(const std::string log) {
int ret = -1;
if (log == "default") {
ret = 0;
} else if (log == "null") {
ret = dlog_setlevel(-1, DLOG_NULL, 0);
} else if (log == "debug") {
ret = dlog_setlevel(-1, DLOG_DEBUG, 1);
} else if (log == "info") {
ret = dlog_setlevel(-1, DLOG_INFO, 1);
} else if (log == "warning") {
ret = dlog_setlevel(-1, DLOG_WARN, 1);
} else if (log == "error") {
ret = dlog_setlevel(-1, DLOG_ERROR, 1);
} else {
GELOGE(ge::PARAM_INVALID, "invalid value for log:%s, only support debug, info, warning, error, null", log.c_str());
return ret;
}
if (ret != 0) {
GELOGE(ge::PARAM_INVALID, "Log setlevel fail !");
}
return ret;
}

Status CheckInsertOpConfParamValid(const std::string insert_op_conf) {
if ((!insert_op_conf.empty()) &&
(!CheckInputPathValid(insert_op_conf, "--insert_op_conf"))) {
GELOGE(ge::PARAM_INVALID, "insert op config file not found: %s", insert_op_conf.c_str());
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory) {
if ((disable_reuse_memory != "") && (disable_reuse_memory != "0") && (disable_reuse_memory != "1")) {
ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"disable_reuse_memory"});
GELOGE(ge::PARAM_INVALID, "Input parameter[--disable_reuse_memory]'s value must be 1 or 0.");
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream) {
if ((enable_single_stream != "") && (enable_single_stream != "true") && (enable_single_stream != "false")) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10005", {"parameter", "value"}, {"enable_single_stream", enable_single_stream});
GELOGE(ge::PARAM_INVALID, "Input parameter[--enable_single_stream]'s value[%s] must be true or false.",
enable_single_stream.c_str());
return ge::PARAM_INVALID;
}
return ge::SUCCESS;
}

Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode) {
// only appointed op_select_implmode, can user appoint optypelist_for_implmode
if (optypelist_for_implmode != "" && op_select_implmode == "") {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--op_select_implmode", op_select_implmode.c_str(), kCompressWeightError});
GELOGE(ge::PARAM_INVALID, "Invalid value for --op_select_implmode[%s], %s.",
op_select_implmode.c_str(), kCompressWeightError);
return ge::PARAM_INVALID;
}
// op_select_implmode default value is high_performance
if (op_select_implmode == "") {
op_select_implmode = IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT;
} else {
if (op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT &&
op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--op_select_implmode", op_select_implmode.c_str(), kSelectImplmodeError});
GELOGE(ge::PARAM_INVALID, "Invalid value for --op_select_implmode[%s], %s.",
op_select_implmode.c_str(), kSelectImplmodeError);
return ge::PARAM_INVALID;
}
}

return ge::SUCCESS;
}

void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips) {
for (auto iter = options.begin(); iter != options.end(); iter++) {
std::string key = iter->first;
std::string option_name = iter->second;
GELOGD("%s set successfully, option_key=%s, option_value=%s", tips.c_str(), key.c_str(), option_name.c_str());
}
}

void EraseEndSemicolon(string &param) {
if (param.empty()) {
return;
}
if (param.back() == ';') {
param.erase(param.end() - 1);
}
}
} // namespace ge

+ 82
- 0
atc/atc_ir_common.h View File

@@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ATC_IR_COMMON_H_
#define ATC_IR_COMMON_H_

#include <unistd.h>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <set>

#include "framework/common/debug/ge_log.h"
#include "register/register_types.h"

using std::string;
using std::vector;
using std::map;

namespace ge {
static std::set<std::string> caffe_support_input_format = {"NCHW", "ND"};
static std::set<std::string> tf_support_input_format = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"};
static std::set<std::string> onnx_support_input_format = {"NCHW", "ND"};

static std::map<std::string, domi::domiTensorFormat_t> input_format_str_to_geformat = {
{"ND", domi::DOMI_TENSOR_ND},
{"NCHW", domi::DOMI_TENSOR_NCHW},
{"NHWC", domi::DOMI_TENSOR_NHWC},
{"CHWN", domi::DOMI_TENSOR_CHWN},
{"NC1HWC0", domi::DOMI_TENSOR_NC1HWC0},
{"NHWC1C0", domi::DOMI_TENSOR_NHWC1C0},
{"NCDHW", domi::DOMI_TENSOR_NCDHW},
{"NDHWC", domi::DOMI_TENSOR_NDHWC}
};
static const std::string kEnableCompressWeightTrue = "1";
static const std::string kEnableCompressWeightFalse = "0";

bool CheckDynamicBatchSizeInputShapeValid(map<string, vector<int64_t>> shape_map,
std::string &dynamic_batch_size);

bool CheckDynamicImagesizeInputShapeValid(map<string, vector<int64_t>> shape_map,
const std::string input_format, std::string &dynamic_image_size);

bool CheckDynamicDimsInputShapeValid(const std::map<std::string, std::vector<int64_t>> &shape_map,
std::string input_format, std::string &dynamic_dims);

bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims);

Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size,
std::string &dynamic_dims, const std::string input_shape,
const std::string input_format, bool &is_dynamic_input);

bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map,
std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false);

Status CheckOutputTypeParamValid(const std::string output_type);
Status CheckBufferOptimizeParamValid(const std::string buffer_optimize);
Status CheckCompressWeightParamValid(const std::string enable_compress_weight, const std::string compress_weight_conf);
int CheckLogParamValidAndSetLogLevel(const std::string log);
Status CheckInsertOpConfParamValid(const std::string insert_op_conf);
Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory);
Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream);
Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode);
Status CheckInputFormat(const string &input_format);
Status CheckKeepTypeParamValid(const std::string &keep_dtype);
void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips);
void EraseEndSemicolon(std::string &param);
}
#endif // ATC_IR_COMMON_H_

+ 1428
- 0
atc/main.cc
File diff suppressed because it is too large
View File


+ 164
- 0
atc/module.mk View File

@@ -0,0 +1,164 @@

LOCAL_PATH := $(call my-dir)

include $(CLEAR_VARS)

LOCAL_MODULE := atc

LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations
LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private

LOCAL_SRC_FILES := \
main.cc \
keep_dtype_option.cc \
single_op_parser.cc \
../session/omg.cc \
../ir_build/atc_ir_common.cc \

LOCAL_C_INCLUDES := \
$(LOCAL_PATH)/../ ./ \
$(TOPDIR)inc \
$(TOPDIR)metadef/inc \
$(TOPDIR)graphengine/inc \
$(TOPDIR)inc/external \
$(TOPDIR)metadef/inc/external \
$(TOPDIR)graphengine/inc/external \
$(TOPDIR)metadef/inc/external/graph \
$(TOPDIR)graphengine/inc/framework \
$(TOPDIR)libc_sec/include \
$(TOPDIR)metadef/inc/common/util \
$(TOPDIR)parser \
third_party/json/include \
third_party/gflags/include \
third_party/protobuf/include \
proto/om.proto \
proto/ge_ir.proto \
proto/task.proto \
proto/insert_op.proto \

LOCAL_SHARED_LIBRARIES := \
libc_sec \
libge_common \
libascend_protobuf \
libslog \
libgraph \
libregister \
liberror_manager \
libge_compiler \
libruntime_compile \
libparser_common \
liberror_manager \

LOCAL_STATIC_LIBRARIES := libgflags

LOCAL_LDFLAGS := -lrt -ldl

include $(BUILD_HOST_EXECUTABLE)

include $(CLEAR_VARS)

LOCAL_MODULE := atclib/atc.bin

LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations
LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private

LOCAL_SRC_FILES := \
main.cc \
keep_dtype_option.cc \
single_op_parser.cc \
../session/omg.cc \
../ir_build/atc_ir_common.cc \

LOCAL_C_INCLUDES := \
$(LOCAL_PATH)/../ ./ \
$(TOPDIR)inc \
$(TOPDIR)metadef/inc \
$(TOPDIR)graphengine/inc \
$(TOPDIR)inc/external \
$(TOPDIR)metadef/inc/external \
$(TOPDIR)graphengine/inc/external \
$(TOPDIR)metadef/inc/external/graph \
$(TOPDIR)graphengine/inc/framework \
$(TOPDIR)libc_sec/include \
$(TOPDIR)metadef/inc/common/util \
$(TOPDIR)parser \
third_party/json/include \
third_party/gflags/include \
third_party/protobuf/include \
proto/om.proto \
proto/ge_ir.proto \
proto/task.proto \
proto/insert_op.proto \

LOCAL_SHARED_LIBRARIES := \
libc_sec \
libge_common \
libascend_protobuf \
libslog \
libgraph \
libregister \
liberror_manager \
libge_compiler \
libruntime_compile \
libparser_common \
liberror_manager \

LOCAL_STATIC_LIBRARIES := libgflags

LOCAL_LDFLAGS := -lrt -ldl

include $(BUILD_HOST_EXECUTABLE)

include $(CLEAR_VARS)

LOCAL_MODULE := fwkacl/atc.bin

LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations
LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private

LOCAL_SRC_FILES := \
main.cc \
keep_dtype_option.cc \
single_op_parser.cc \
../session/omg.cc \
../ir_build/atc_ir_common.cc \

LOCAL_C_INCLUDES := \
$(LOCAL_PATH)/../ ./ \
$(TOPDIR)inc \
$(TOPDIR)metadef/inc \
$(TOPDIR)graphengine/inc \
$(TOPDIR)inc/external \
$(TOPDIR)metadef/inc/external \
$(TOPDIR)graphengine/inc/external \
$(TOPDIR)metadef/inc/external/graph \
$(TOPDIR)graphengine/inc/framework \
$(TOPDIR)libc_sec/include \
$(TOPDIR)metadef/inc/common/util \
$(TOPDIR)parser \
third_party/json/include \
third_party/gflags/include \
third_party/protobuf/include \
proto/om.proto \
proto/ge_ir.proto \
proto/task.proto \
proto/insert_op.proto \

LOCAL_SHARED_LIBRARIES := \
libc_sec \
libge_common \
libascend_protobuf \
libslog \
libgraph \
libregister \
liberror_manager \
libge_runner \
libruntime \
libparser_common \
liberror_manager \

LOCAL_STATIC_LIBRARIES := libgflags

LOCAL_LDFLAGS := -lrt -ldl

include $(BUILD_HOST_EXECUTABLE)

+ 1002
- 0
atc/parse_graph.cc
File diff suppressed because it is too large
View File


+ 108
- 0
atc/parse_graph.h View File

@@ -0,0 +1,108 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSE_GRAPH_H_
#define PARSE_GRAPH_H_

#include <string>
#include <unordered_map>
#include <vector>
#include "framework/omg/omg_inner_types.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "proto/ge_ir.pb.h"
#include "proto/om.pb.h"

#include "graph/compute_graph.h"
#include "graph/graph.h"
#include "graph/model.h"
//#include "runtime/kernel.h"

using domi::Status;
using std::pair;
using std::string;
using std::unordered_map;
using std::vector;

namespace ge {
/**
* @ingroup domi_omg
* @brief init omg context
* @return void
*/
GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format,
bool is_dynamic_input);

/**
* @ingroup domi_omg
* @brief generate graph based on the input model file and weight file
* @param [out] graph graph
* @param [in] model_file path of model file
* @param [in] weights_file path of weight file
* @param [in] type type of the input model
* @param [in] op_conf op mapping configuration
* @param [in] target type of platform. If a tiny model is generated, set target to tiny
* @param [in] run_mode run model
* @param [in] enable_l2dynamic enable l2dynamic
* @param [in] is_dynamic_input dynamic input, true of false
* @param [in] atc_params multiply atc params
* @return Status result code
*/
GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params, const char *model_file,
const char *weights_file, domi::FrameworkType type, const char *op_conf = nullptr,
const char *target = nullptr, RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false);

/**
* @ingroup domi_omg
* @brief generates a simplified JSON file based on the key value of the offline model file in protobuf format
* @param [in] model_file path of offline model file
* @param [out] json_file path of json file
* @param [key] encrypted key
* @return Status result code
*/
GE_FUNC_VISIBILITY Status ConvertOm(const char *model_file, const char *json_file, bool is_covert_to_json);

GE_FUNC_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const char *json_file);
/**
* @ingroup domi_omg
* @brief convert the model file in protobuf format into a JSON file.
* @param [in] framework type of model
* @param [in] om model_file path of offline model file
* @param [out] json_file path of json file
* @param [key] encrypted key
* @return Status result code
*/
GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, const char *json_file);

GE_FUNC_VISIBILITY void GetGroupName(ge::proto::ModelDef &model);

GE_FUNC_VISIBILITY void FindParserSo(const string &path, vector<string> &fileList, string &caffe_parser_path);

GE_FUNC_VISIBILITY Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file);

GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format);

GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);

GE_FUNC_VISIBILITY void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);

GE_FUNC_VISIBILITY void UpdateOmgCtxWithParserCtx();

GE_FUNC_VISIBILITY void UpdateParserCtxWithOmgCtx();

GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def);
} // namespace ge
#endif // PARSE_GRAPH_H_

+ 191
- 0
atc/proto/ge_ir.proto View File

@@ -0,0 +1,191 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
DT_VARIANT = 26; // variant type
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 139
- 0
atc/proto/insert_op.proto View File

@@ -0,0 +1,139 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 396
- 0
atc/proto/om.proto View File

@@ -0,0 +1,396 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 179
- 0
atc/proto/task.proto View File

@@ -0,0 +1,179 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 609
- 0
atc/single_op_parser.cc View File

@@ -0,0 +1,609 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "single_op_parser.h"

#include <vector>
#include <algorithm>
#include <fstream>
#include <sstream>

#include <nlohmann/json.hpp>

#include "framework/common/debug/ge_log.h"
#include "common/util/error_manager/error_manager.h"
#include "util/util.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/operator_factory_impl.h"

using Json = nlohmann::json;
using std::string;
using std::vector;
using std::map;

namespace ge {
namespace {
constexpr char const *kKeyOp = "op";
constexpr char const *kKeyInputDesc = "input_desc";
constexpr char const *kKeyOutputDesc = "output_desc";
constexpr char const *kKeyAttr = "attr";
constexpr char const *kKeyName = "name";
constexpr char const *kKeyType = "type";
constexpr char const *kKeyShape = "shape";
constexpr char const *kKeyOriginShape = "origin_shape";
constexpr char const *kKeyShapeRange = "shape_range";
constexpr char const *kKeyValue = "value";
constexpr char const *kKeyFormat = "format";
constexpr char const *kKeyOriginFormat = "origin_format";
constexpr char const *kFileSuffix = ".om";
constexpr char const *kKeyDynamicInput = "dynamic_input";
constexpr char const *kKeyDynamicOutput = "dynamic_output";
constexpr int kDumpJsonIndent = 2;
constexpr int kShapeRangePairSize = 2;
constexpr int kShapeRangeLow = 0;
constexpr int kShapeRangeHigh = 1;
constexpr int kMaxFileNameLen = 128;

map<string, GeAttrValue::ValueType> kAttrTypeDict = {
{"bool", GeAttrValue::VT_BOOL},
{"int", GeAttrValue::VT_INT},
{"float", GeAttrValue::VT_FLOAT},
{"string", GeAttrValue::VT_STRING},
{"list_bool", GeAttrValue::VT_LIST_BOOL},
{"list_int", GeAttrValue::VT_LIST_INT},
{"list_float", GeAttrValue::VT_LIST_FLOAT},
{"list_string", GeAttrValue::VT_LIST_STRING},
{"list_list_int", GeAttrValue::VT_LIST_LIST_INT},
{"data_type", GeAttrValue::VT_DATA_TYPE},
};

map<string, DataType> kDataTypeDict = {
{"bool", DT_BOOL},
{"int8", DT_INT8},
{"uint8", DT_UINT8},
{"int16", DT_INT16},
{"uint16", DT_UINT16},
{"int32", DT_INT32},
{"uint32", DT_UINT32},
{"int64", DT_INT64},
{"uint64", DT_UINT64},
{"float16", DT_FLOAT16},
{"half", DT_FLOAT16},
{"fp16", DT_FLOAT16},
{"float", DT_FLOAT},
{"float32", DT_FLOAT},
{"double", DT_DOUBLE},
};

map<string, Format> kFormatDict = {
{"nchw", FORMAT_NCHW},
{"nhwc", FORMAT_NHWC},
{"nd", FORMAT_ND},
{"nc1hwc0", FORMAT_NC1HWC0},
{"fractal_z", FORMAT_FRACTAL_Z},
{"nc1c0hwpad", FORMAT_NC1C0HWPAD},
{"nhwc1c0", FORMAT_NHWC1C0},
{"fsr_nchw", FORMAT_FSR_NCHW},
{"fractal_deconv", FORMAT_FRACTAL_DECONV},
{"c1hwnc0", FORMAT_C1HWNC0},
{"fractal_deconv_transpose", FORMAT_FRACTAL_DECONV_TRANSPOSE},
{"fractal_deconv_sp_stride_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS},
{"nc1hwc0_c04", FORMAT_NC1HWC0_C04},
{"fractal_z_c04", FORMAT_FRACTAL_Z_C04},
{"chwn", FORMAT_CHWN},
{"deconv_sp_stride8_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS},
{"nc1khkwhwc0", FORMAT_NC1KHKWHWC0},
{"bn_weight", FORMAT_BN_WEIGHT},
{"filter_hwck", FORMAT_FILTER_HWCK},
{"hwcn", FORMAT_HWCN},
{"lookup_lookups", FORMAT_HASHTABLE_LOOKUP_LOOKUPS},
{"lookup_keys", FORMAT_HASHTABLE_LOOKUP_KEYS},
{"lookup_value", FORMAT_HASHTABLE_LOOKUP_VALUE},
{"lookup_output", FORMAT_HASHTABLE_LOOKUP_OUTPUT},
{"lookup_hits", FORMAT_HASHTABLE_LOOKUP_HITS},
{"md", FORMAT_MD},
{"c1hwncoc0", FORMAT_C1HWNCoC0},
{"fractal_nz", FORMAT_FRACTAL_NZ},
{"ndhwc", FORMAT_NDHWC},
{"ncdhw", FORMAT_NCDHW},
{"dhwcn", FORMAT_DHWCN},
{"dhwnc", FORMAT_DHWNC},
{"ndc1hwc0", FORMAT_NDC1HWC0},
{"fractal_z_3d", FORMAT_FRACTAL_Z_3D},
{"fractal_z_3d_transpose", FORMAT_FRACTAL_Z_3D_TRANSPOSE},
{"cn", FORMAT_CN},
{"nc", FORMAT_NC},
{"fractal_zn_lstm", FORMAT_FRACTAL_ZN_LSTM},
{"fractal_z_g", FORMAT_FRACTAL_Z_G}
};

std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) {
std::stringstream file_name_ss;
file_name_ss << index;
file_name_ss << "_" << single_op_desc.op;
for (auto &desc : single_op_desc.input_desc) {
file_name_ss << "_" << desc.type << "_" << desc.format;
for (auto dim : desc.dims) {
file_name_ss << "_" << dim;
}
}

for (auto &desc : single_op_desc.output_desc) {
file_name_ss << "_" << desc.type << "_" << desc.format;
for (auto dim : desc.dims) {
file_name_ss << "_" << dim;
}
}

std::string file_name = file_name_ss.str();
if (file_name.length() > kMaxFileNameLen) {
GELOGI("Trim file name for it is too long, origin file name = %s", file_name.c_str());
file_name = file_name.substr(0, kMaxFileNameLen);
}
file_name += kFileSuffix;
return file_name;
}
} // namespace

template<typename T>
void SetAttrValue(const Json &j, SingleOpAttr &attr) {
attr.value.SetValue<T>(j.at(kKeyValue).get<T>());
}

template<typename T>
T GetValue(const map<string, T> &dict, string &key, T default_val) {
transform(key.begin(), key.end(), key.begin(), ::tolower);
auto it = dict.find(key);
if (it == dict.end()) {
return default_val;
}

return it->second;
}

void from_json(const Json &j, SingleOpTensorDesc &desc) {
bool is_tensor_valid = true;
desc.dims = j.at(kKeyShape).get<vector<int64_t>>();
auto it = j.find(kKeyShapeRange);
if (it != j.end()) {
desc.dim_ranges = j.at(kKeyShapeRange).get<vector<std::vector<int64_t>>>();
}
it = j.find(kKeyOriginShape);
if (it != j.end()) {
desc.ori_dims = j.at(kKeyOriginShape).get<vector<int64_t>>();
}
string format_str = j.at(kKeyFormat).get<string>();
string type_str = j.at(kKeyType).get<string>();
desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED);
desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED);
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(format_str);
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsDataTypeValid(type_str);
it = j.find(kKeyOriginFormat);
if (it != j.end()) {
string origin_format_str = j.at(kKeyOriginFormat).get<string>();
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(origin_format_str);
desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED);
}
auto tensor_name = j.find(kKeyName);
if (tensor_name != j.end()) {
desc.name = tensor_name->get<string>();
}
auto dynamic_input_name = j.find(kKeyDynamicInput);
if (dynamic_input_name != j.end()) {
desc.dynamic_input_name = dynamic_input_name->get<string>();
}
if (!is_tensor_valid) {
desc.SetValidFlag(is_tensor_valid);
}
}

void from_json(const Json &j, SingleOpAttr &attr) {
attr.name = j.at(kKeyName).get<string>();
attr.type = j.at(kKeyType).get<string>();
auto it = kAttrTypeDict.find(attr.type);
if (it == kAttrTypeDict.end()) {
GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str());
return;
}

switch (it->second) {
case GeAttrValue::VT_BOOL:
SetAttrValue<bool>(j, attr);
break;
case GeAttrValue::VT_INT:
SetAttrValue<int64_t>(j, attr);
break;
case GeAttrValue::VT_FLOAT:
SetAttrValue<float>(j, attr);
break;
case GeAttrValue::VT_STRING:
SetAttrValue<string>(j, attr);
break;
case GeAttrValue::VT_LIST_BOOL:
SetAttrValue<vector<bool>>(j, attr);
break;
case GeAttrValue::VT_LIST_INT:
SetAttrValue<vector<int64_t>>(j, attr);
break;
case GeAttrValue::VT_LIST_FLOAT:
SetAttrValue<vector<float>>(j, attr);
break;
case GeAttrValue::VT_LIST_STRING:
SetAttrValue<vector<string>>(j, attr);
break;
case GeAttrValue::VT_LIST_LIST_INT:
SetAttrValue<vector<vector<int64_t>>>(j, attr);
break;
case GeAttrValue::VT_DATA_TYPE:
SetAttrValue<DataType>(j, attr);
break;
default:
GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str());
break;
}
}

void from_json(const Json &j, SingleOpDesc &desc) {
desc.op = j.at(kKeyOp).get<string>();

auto input_desc = j.find(kKeyInputDesc);
if (input_desc != j.end()) {
desc.input_desc = input_desc->get<vector<SingleOpTensorDesc>>();
}

auto output_desc = j.find(kKeyOutputDesc);
if (output_desc != j.end()) {
desc.output_desc = output_desc->get<vector<SingleOpTensorDesc>>();
}

auto attr_field = j.find(kKeyAttr);
if (attr_field != j.end()) {
desc.attrs = attr_field->get<vector<SingleOpAttr>>();
}
}

Status SingleOpParser::ReadJsonFile(const std::string &file, Json &json_obj) {
std::string real_path = RealPath(file.c_str());
if (real_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10023", {"value"}, {file});
GELOGE(FAILED, "Input parameter[--singleop]'s value[%s] is not a valid path.", file.c_str());
return INTERNAL_ERROR;
}

std::ifstream ifs(real_path);
if (!ifs.is_open()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10024", {"value"}, {file});
GELOGE(FAILED, "Open file[%s] provided in input parameter[--singleop] failed.", file.c_str());
return FAILED;
}
try {
ifs >> json_obj;
} catch (const std::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E10025", {"realpath", "errmsg"}, {real_path, e.what()});
GELOGE(PARAM_INVALID, "Parse file[%s] provided in input parameter[--singleop] failed, exception = %s.",
real_path.c_str(), e.what());
return PARAM_INVALID;
}

ifs.close();
return SUCCESS;
}

bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {
if (op_desc.op.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10026");
GELOGE(PARAM_INVALID, "Op name is empty");
return false;
}

int index = 0;
for (auto &tensor_desc : op_desc.input_desc) {
if (!tensor_desc.GetValidFlag()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"intput", "datatype or format", std::to_string(index)});
GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index);
return false;
}
if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) ||
(tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"intput", "datatype or format", std::to_string(index)});
GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index);
return false;
}
++index;
}

index = 0;
for (auto &tensor_desc : op_desc.output_desc) {
if (!tensor_desc.GetValidFlag()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"output", "datatype", std::to_string(index)});
GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index);
return false;
}
if (tensor_desc.type == DT_UNDEFINED) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"output", "datatype", std::to_string(index)});
GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index);
return false;
}

if (tensor_desc.format == FORMAT_RESERVED) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"output", "format", std::to_string(index)});
GELOGE(PARAM_INVALID, "Output's format is invalid when the index is %d", index);
return false;
}
++index;
}

for (auto &attr : op_desc.attrs) {
if (attr.name.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10029");
GELOGE(PARAM_INVALID, "attr name is empty");
return false;
}

if (attr.value.IsEmpty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10030", {"attrname"}, {attr.name});
GELOGE(PARAM_INVALID, "Parse attr \"%s\" failed. ", attr.name.c_str());
return false;
}
}

return true;
}

std::unique_ptr<OpDesc> SingleOpParser::CreateOpDesc(const string &op_type) {
return std::unique_ptr<OpDesc>(new(std::nothrow) OpDesc(op_type, op_type));
}

Status SingleOpParser::UpdateDynamicTensorName(std::vector<SingleOpTensorDesc> &desc) {
std::map<std::string, int> dynamic_name_map;
for (auto &tensor : desc) {
if (tensor.dynamic_input_name.empty()) {
continue;
}
if (dynamic_name_map.find(tensor.dynamic_input_name) == dynamic_name_map.end()) {
dynamic_name_map[tensor.dynamic_input_name] = 0;
} else {
dynamic_name_map[tensor.dynamic_input_name]++;
}
tensor.name = tensor.dynamic_input_name + std::to_string(dynamic_name_map[tensor.dynamic_input_name]);
}
GELOGD("Update dynamic tensor name success!");
return SUCCESS;
}

Status SingleOpParser::ConvertToBuildParam(int index,
const SingleOpDesc &single_op_desc,
SingleOpBuildParam &build_param) {
auto op_desc = CreateOpDesc(single_op_desc.op);
GE_CHECK_NOTNULL(op_desc);

for (auto &desc : single_op_desc.input_desc) {
GeTensorDesc ge_tensor_desc(GeShape(desc.dims),
desc.format,
desc.type);
auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format;
auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims;
ge_tensor_desc.SetOriginFormat(ori_format_to_set);
ge_tensor_desc.SetOriginShape(GeShape(ori_dims));
GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc));
TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size());
TensorUtils::SetInputTensor(ge_tensor_desc, true);
TensorUtils::SetOutputTensor(ge_tensor_desc, false);
if (desc.name.empty()) {
op_desc->AddInputDesc(ge_tensor_desc);
} else {
op_desc->AddInputDesc(desc.name, ge_tensor_desc);
}
build_param.inputs.emplace_back(ge_tensor_desc);
}

for (auto &desc : single_op_desc.output_desc) {
GeTensorDesc ge_tensor_desc(GeShape(desc.dims),
desc.format,
desc.type);
auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format;
auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims;
ge_tensor_desc.SetOriginFormat(ori_format_to_set);
ge_tensor_desc.SetOriginShape(GeShape(ori_dims));
GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc));
TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size());
TensorUtils::SetInputTensor(ge_tensor_desc, false);
TensorUtils::SetOutputTensor(ge_tensor_desc, true);
if (desc.name.empty()) {
op_desc->AddOutputDesc(ge_tensor_desc);
} else {
op_desc->AddOutputDesc(desc.name, ge_tensor_desc);
}
build_param.outputs.emplace_back(ge_tensor_desc);
}

for (const auto &attr : single_op_desc.attrs) {
op_desc->SetAttr(attr.name, attr.value);
}

if (VerifyOpInputOutputSizeByIr(*op_desc) != SUCCESS) {
GELOGE(PARAM_INVALID, "Verify op [%s] input or output size failed.", op_desc->GetType().c_str());
return PARAM_INVALID;
}

build_param.file_name = GenerateFileName(single_op_desc, index);
build_param.op_desc.reset(op_desc.release());
return SUCCESS;
}

Status SingleOpParser::VerifyOpInputOutputSizeByIr(const OpDesc &current_op_desc) {
ge::Operator operator_ir = ge::OperatorFactory::CreateOperator("tmp_operator", current_op_desc.GetType());
if (!operator_ir.IsEmpty()) {
auto opdesc_ir = ge::OpDescUtils::GetOpDescFromOperator(operator_ir);
GE_CHECK_NOTNULL(opdesc_ir);
size_t current_opdesc_inputs_num = current_op_desc.GetInputsSize();
size_t ir_opdesc_inputs_num = opdesc_ir->GetInputsSize();
if (current_opdesc_inputs_num < ir_opdesc_inputs_num) {
string reason = "is smaller than the ir needed input size " + std::to_string(ir_opdesc_inputs_num);
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{current_op_desc.GetName(), "input size " + std::to_string(current_opdesc_inputs_num), reason});
GELOGE(PARAM_INVALID, "This op [%s] input size %zu is smaller than the ir needed input size %zu",
current_op_desc.GetName().c_str(), current_opdesc_inputs_num, ir_opdesc_inputs_num);
return PARAM_INVALID;
}
size_t current_opdesc_outputs_num = current_op_desc.GetOutputsSize();
size_t ir_opdesc_outputs_num = opdesc_ir->GetOutputsSize();
if (current_opdesc_outputs_num < ir_opdesc_outputs_num) {
string reason = "is smaller than the ir needed output size " + std::to_string(ir_opdesc_outputs_num);
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{current_op_desc.GetName(), "output size " + std::to_string(current_opdesc_outputs_num), reason});
GELOGE(PARAM_INVALID, "This op [%s] output size %zu is smaller than the ir needed output size %zu",
current_op_desc.GetName().c_str(), current_opdesc_outputs_num, ir_opdesc_outputs_num);
return PARAM_INVALID;
}
}
return SUCCESS;
}

Status SingleOpParser::SetShapeRange(const std::string &op_name,
const SingleOpTensorDesc &tensor_desc,
GeTensorDesc &ge_tensor_desc) {
auto num_shape_ranges = tensor_desc.dim_ranges.size();
GELOGD("Number of shape ranges = %zu", num_shape_ranges);
auto it = std::find(tensor_desc.dims.begin(), tensor_desc.dims.end(), ge::UNKNOWN_DIM_NUM);
if (it != tensor_desc.dims.end()) {
if (tensor_desc.dims != ge::UNKNOWN_RANK) {
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{op_name,
"shape",
"has unknown rank but dim size is not one"});
GELOGE(PARAM_INVALID, "Invalid tensor shape: [%s]", ge_tensor_desc.MutableShape().ToString().c_str());
return PARAM_INVALID;
}
if (!tensor_desc.dim_ranges.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{op_name,
"shape range",
"is not needed while the rank the shape is unknown"});
GELOGE(PARAM_INVALID, "Shape range is not needed while the rank the shape is unknown");
return PARAM_INVALID;
}

GELOGD("Shape is unknown rank, do not set shape range");
return SUCCESS;
}

std::vector<std::pair<int64_t, int64_t>> shape_range;
size_t range_index = 0;
for (auto dim : tensor_desc.dims) {
if (dim >= 0) {
shape_range.emplace_back(dim, dim);
GELOGD("Adding shape range: [%ld, %ld]", dim, dim);
} else {
GELOGD("To get shape range by index = %zu", range_index);
if (range_index >= num_shape_ranges) {
string reason = "is smaller than the unknown dim size " + std::to_string(++range_index);
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{op_name,
"shape range size " + std::to_string(num_shape_ranges),
reason});
GELOGE(PARAM_INVALID, "The number of shape_range mismatches that of unknown dims.");
return PARAM_INVALID;
}

auto &range = tensor_desc.dim_ranges[range_index];
if (range.size() != kShapeRangePairSize) {
string reason = "has " + std::to_string(range.size()) + " item(s)";
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{op_name,
"shape range " + std::to_string(range_index),
reason});
GELOGE(PARAM_INVALID, "Invalid shape range entry. index = %zu, size = %zu", range_index, range.size());
return PARAM_INVALID;
}

shape_range.emplace_back(range[kShapeRangeLow], range[kShapeRangeHigh]);
GELOGD("Adding shape range: [%ld, %ld]", range[kShapeRangeLow], range[kShapeRangeHigh]);
++range_index;
}
}

if (num_shape_ranges != range_index) {
string reason = "is greater than the unknown dim size " + std::to_string(range_index);
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{op_name,
"shape range size " + std::to_string(num_shape_ranges),
reason});
GELOGE(PARAM_INVALID,
"The number of shape_range(%zu) mismatches that of unknown dims(%zu).",
num_shape_ranges,
range_index);
return PARAM_INVALID;
}

if (range_index > 0) {
ge_tensor_desc.SetShapeRange(shape_range);
}

return SUCCESS;
}

Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector<SingleOpBuildParam> &op_list) {
int index = 0;
try {
Json single_op_list_json;
auto ret = ReadJsonFile(file, single_op_list_json);
if (ret != SUCCESS) {
return ret;
}

for (const Json &single_op_json : single_op_list_json) {
SingleOpDesc single_op_desc;
GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str());
single_op_desc = single_op_json;
if (UpdateDynamicTensorName(single_op_desc.input_desc) != SUCCESS) {
GELOGE(FAILED, "Update dynamic tensor name failed!");
return FAILED;
}

if (!Validate(single_op_desc)) {
GELOGE(PARAM_INVALID, "Validate the index[%d] of op failed when read json file[%s].", index, file.c_str());
return PARAM_INVALID;
}

SingleOpBuildParam param;
ret = ConvertToBuildParam(index, single_op_desc, param);
if (ret != SUCCESS) {
return ret;
}

op_list.emplace_back(param);
GELOGI("Parse the index[%d] of op success", index);
index += 1;
}
} catch (const nlohmann::json::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E10032", {"index", "jsonfile", "exception"},
{std::to_string(index), file, e.what()});
GELOGE(PARAM_INVALID, "Parse the index[%d] of op failed when read json file[%s], exception %s",
index, file.c_str(), e.what());
return PARAM_INVALID;
}

return SUCCESS;
}
} // namespace ge


+ 90
- 0
atc/single_op_parser.h View File

@@ -0,0 +1,90 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_TOOLS_COMPILE_PARSER_H
#define ACL_TOOLS_COMPILE_PARSER_H

#include <vector>
#include <string>

#include <nlohmann/json.hpp>

#include "ge/ge_api_error_codes.h"
#include "graph/types.h"
#include "graph/ge_attr_value.h"
#include "graph/op_desc.h"

namespace ge {
struct SingleOpTensorDesc {
public:
bool GetValidFlag() const { return is_valid_; }
void SetValidFlag(bool is_valid) { is_valid_ = is_valid; }
public:
std::string name;
std::vector<int64_t> dims;
std::vector<int64_t> ori_dims;
std::vector<std::vector<int64_t>> dim_ranges;
ge::Format format = ge::FORMAT_RESERVED;
ge::Format ori_format = ge::FORMAT_RESERVED;
ge::DataType type = ge::DT_UNDEFINED;
std::string dynamic_input_name;
private:
bool is_valid_ = true;
};

struct SingleOpAttr {
std::string name;
std::string type;
ge::GeAttrValue value;
};

struct SingleOpDesc {
std::string op;
std::vector<SingleOpTensorDesc> input_desc;
std::vector<SingleOpTensorDesc> output_desc;
std::vector<SingleOpAttr> attrs;
};

struct SingleOpBuildParam {
ge::OpDescPtr op_desc;
std::vector<ge::GeTensor> inputs;
std::vector<ge::GeTensor> outputs;
std::string file_name;
};

void from_json(const nlohmann::json &json, SingleOpTensorDesc &desc);

void from_json(const nlohmann::json &json, SingleOpAttr &desc);

void from_json(const nlohmann::json &json, SingleOpDesc &desc);

class SingleOpParser {
public:
static Status ParseSingleOpList(const std::string &file, std::vector<SingleOpBuildParam> &op_list);

private:
static Status ReadJsonFile(const std::string &file, nlohmann::json &json_obj);
static bool Validate(const SingleOpDesc &op_desc);
static std::unique_ptr<OpDesc> CreateOpDesc(const std::string &op_type);
static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param);
static Status UpdateDynamicTensorName(std::vector<SingleOpTensorDesc> &desc);
static Status VerifyOpInputOutputSizeByIr(const OpDesc &current_op_desc);
static Status SetShapeRange(const std::string &op_name,
const SingleOpTensorDesc &tensor_desc,
GeTensorDesc &ge_tensor_desc);
};
} // namespace ge

#endif // ACL_TOOLS_COMPILE_PARSER_H

+ 85
- 0
atc/util/gflags_util.h View File

@@ -0,0 +1,85 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef UTIL_GFLAGS_UTIL_H_
#define UTIL_GFLAGS_UTIL_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY _declspec(dllexport)
#else
#define GE_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_VISIBILITY
#endif
#endif

#include <gflags/gflags.h>
#include <string>

namespace ge {
class GE_FUNC_VISIBILITY GflagsUtils {
public:
static bool IsSetCommandTrue(const char *name) {
std::string out;
return gflags::GetCommandLineOption(name, &out) && out == "true";
}

///
/// @brief Determines whether the parameter is empty
/// @param name name parameter name
/// @return true if empty otherwise false
///
static bool IsSetCommandNotEmpty(const char *name) {
std::string out;
return gflags::GetCommandLineOption(name, &out) && !out.empty();
}

///
/// @brief Determines whether the parameter is not default
/// @param flag_name name parameter name
/// @return true if not default otherwise false
///
static bool IsCommandLineNotDefault(const char *flag_name) {
google::CommandLineFlagInfo info;
return GetCommandLineFlagInfo(flag_name, &info) && !info.is_default;
}

///
/// @brief Modify gflags to print help information
/// @param flags_h Pass in the self-defined help parameter, it is recommended to be FLAGS_h
/// @return void
///
static void ChangeHelpFlags(bool flags_h) {
if (flags_h || IsSetCommandTrue("help") || IsSetCommandTrue("helpfull") || IsSetCommandNotEmpty("helpon") ||
IsSetCommandNotEmpty("helpmatch") || IsSetCommandTrue("helppackage") || IsSetCommandTrue("helpxml")) {
gflags::SetCommandLineOption("help", "false");
gflags::SetCommandLineOption("helpfull", "false");
gflags::SetCommandLineOption("helpon", "");
gflags::SetCommandLineOption("helpmatch", "");
gflags::SetCommandLineOption("helppackage", "false");
gflags::SetCommandLineOption("helpxml", "false");
gflags::SetCommandLineOption("helpshort", "true");
}
}
};
} // namespace ge

#endif // UTIL_GFLAGS_UTIL_H_

+ 145
- 0
atc/util/properties_manager.cc View File

@@ -0,0 +1,145 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "properties_manager.h"

#include <climits>
#include <cstdio>
#include <fstream>

#include "framework/common/debug/ge_log.h"
#include "util/util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_context.h"
#include "graph/utils/attr_utils.h"

namespace ge {
namespace atc {
PropertiesManager::PropertiesManager() : is_inited_(false), delimiter("=") {}
PropertiesManager::~PropertiesManager() {}

// singleton
PropertiesManager &PropertiesManager::Instance() {
static PropertiesManager instance;
return instance;
}

// Initialize property configuration
bool PropertiesManager::Init(const std::string &file_path) {
std::lock_guard<std::mutex> lock(mutex_);
if (is_inited_) {
GELOGW("Already inited, will be initialized again");
properties_map_.clear();
is_inited_ = false;
return is_inited_;
}

if (!LoadFileContent(file_path)) {
return false;
}

is_inited_ = true;
return is_inited_;
}

// Load file contents
bool PropertiesManager::LoadFileContent(const std::string &file_path) {
// Normalize the path
string resolved_file_path = RealPath(file_path.c_str());
if (resolved_file_path.empty()) {
GELOGE(FAILED, "Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str());
return false;
}
std::ifstream fs(resolved_file_path, std::ifstream::in);

if (!fs.is_open()) {
GELOGE(PARAM_INVALID, "Open %s failed.", file_path.c_str());
return false;
}

std::string line;

while (getline(fs, line)) { // line not with \n
if (!ParseLine(line)) {
GELOGE(PARAM_INVALID, "Parse line failed. content is [%s].", line.c_str());
fs.close();
return false;
}
}

fs.close(); // close the file

GELOGI("LoadFileContent success.");
return true;
}

// Parsing the command line
bool PropertiesManager::ParseLine(const std::string &line) {
std::string temp = Trim(line);
// Comment or newline returns true directly
if (temp.find_first_of('#') == 0 || *(temp.c_str()) == '\n') {
return true;
}

if (!temp.empty()) {
std::string::size_type pos = temp.find_first_of(delimiter);
if (pos == std::string::npos) {
GELOGE(PARAM_INVALID, "Incorrect line [%s], it must include [%s].Perhaps you use illegal chinese symbol",
line.c_str(), delimiter.c_str());
return false;
}

std::string map_key = Trim(temp.substr(0, pos));
std::string value = Trim(temp.substr(pos + 1));
if (map_key.empty() || value.empty()) {
GELOGE(PARAM_INVALID, "Map_key or value empty. %s", line.c_str());
return false;
}

properties_map_[map_key] = value;
}

return true;
}

// Remove the space and tab before and after the string
std::string PropertiesManager::Trim(const std::string &str) {
if (str.empty()) {
return str;
}

std::string::size_type start = str.find_first_not_of(" \t\r\n");
if (start == std::string::npos) {
return str;
}

std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1;
return str.substr(start, end);
}

// return properties_map_
std::map<std::string, std::string> PropertiesManager::GetPropertyMap() {
std::lock_guard<std::mutex> lock(mutex_);
return properties_map_;
}

// Set separator
void PropertiesManager::SetPropertyDelimiter(const std::string &de) {
std::lock_guard<std::mutex> lock(mutex_);
delimiter = de;
}
}
} // namespace ge

+ 82
- 0
atc/util/properties_manager.h View File

@@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef UTIL_PROPERTIES_MANAGER_H_
#define UTIL_PROPERTIES_MANAGER_H_

#include <map>
#include <mutex>
#include <string>

namespace ge {
namespace atc {
class PropertiesManager {
public:
// Singleton
static PropertiesManager &Instance();

/**
* @ingroup domi_ome
* @brief Initialize configuration parameters, which must be invoked in main.
* @param [in] file_path Property profile path
* @return true success
* @return false fail
* @author
*/
bool Init(const std::string &file_path);

/**
* @ingroup domi_ome
* @brief Return configuration parameters
* @return properties_map_
* @author
*/
std::map<std::string, std::string> GetPropertyMap();

/**
* @ingroup domi_ome
* @brief Adapt key value pair form, set different separators
* @param [in] delimiter
* @author
*/
void SetPropertyDelimiter(const std::string &de);

private:
// Private construct, destructor
PropertiesManager();
~PropertiesManager();

// Get file content
bool LoadFileContent(const std::string &file_path);

// Parsing a single line file
bool ParseLine(const std::string &line);

// Remove space before and after string
std::string Trim(const std::string &str);

bool is_inited_;

// Configuration item separator, default is "="
std::string delimiter;

std::map<std::string, std::string> properties_map_;
std::mutex mutex_;
};
}
} // namespace ge

#endif // UTIL_PROPERTIES_MANAGER_H_

+ 171
- 0
atc/util/string_util.h View File

@@ -0,0 +1,171 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_COMMON_STRING_UTIL_H_
#define INC_FRAMEWORK_COMMON_STRING_UTIL_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY _declspec(dllexport)
#else
#define GE_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_VISIBILITY
#endif
#endif

#include <cctype>
#include <securec.h>

#include <algorithm>
#include <functional>
#include <sstream>
#include <string>
#include <vector>

namespace ge {
class GE_FUNC_VISIBILITY StringUtils {
public:
static std::string &Ltrim(std::string &s) {
#if __cplusplus >= 201103L
(void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); }));
#else
(void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), std::not1(std::ptr_fun<int, int>(std::isspace))));
#endif
return s;
}
// lint -esym(551,*)
static std::string &Rtrim(std::string &s) { /*lint !e618*/
#if __cplusplus >= 201103L
(void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); }));
#else
(void)s.erase(std::find_if(s.rbegin(), s.rend(), std::not1(std::ptr_fun<int, int>(std::isspace))).base(), s.end());
#endif
return s;
}
// lint -esym(551,*)
///
/// @ingroup domi_common
/// @brief delete spaces at the beginning and end of a string
/// @param [in] string to be trimmed
/// @return string after trim
///
static std::string &Trim(std::string &s) { return Ltrim(Rtrim(s)); }

///
/// @ingroup domi_common
/// @brief string splitting
/// @param [in] str string to be trimmed
/// @param [in] delim separator
/// @return string array after segmentation
///
static std::vector<std::string> Split(const std::string &str, char delim) {
std::vector<std::string> elems;

if (str.empty()) {
elems.emplace_back("");
return elems;
}

std::stringstream ss(str);
std::string item;

while (getline(ss, item, delim)) {
elems.push_back(item);
}

auto str_size = str.size();
if (str_size > 0 && str[str_size - 1] == delim) {
elems.emplace_back("");
}

return elems;
}
///
/// @ingroup domi_common
/// @brief obtain the file name
/// @param [in] s path name
/// @return file name
///
static std::string GetFileName(std::string &s) {
if (s.empty()) {
return "";
}
std::vector<std::string> files = StringUtils::Split(s, '/');

return files.empty() ? "" : files[files.size() - 1];
}
///
/// @ingroup domi_common
/// @brief full replacement
/// @link
/// @param [in] str str string to be replaced
/// @param [in] old_value old Characters Before Replacement
/// @param [in] new_value new Characters Before Replacement
/// @return string after replacement
///
static std::string ReplaceAll(std::string str, const std::string &old_value, const std::string &new_value) {
std::string::size_type cur_pos = 0;
std::string::size_type old_length = old_value.length();
std::string::size_type new_length = new_value.length();
// cycle replace
for (; cur_pos != std::string::npos; cur_pos += new_length) {
if ((cur_pos = str.find(old_value, cur_pos)) != std::string::npos) {
(void)str.replace(cur_pos, old_length, new_value);
} else {
break;
}
}
return str;
}

///
/// @ingroup domi_common
/// @brief checks whether a character string starts with a character string (prefix)
/// @link
/// @param [in] str string to be compared
/// @param [in] str_x prefix
/// @return if the value is a prefix, true is returned. Otherwise, false is returned
///
static bool StartWith(const std::string &str, const std::string str_x) {
return ((str.size() >= str_x.size()) && (str.compare(0, str_x.size(), str_x) == 0));
}

///
/// @ingroup domi_common
/// @brief format string
/// @link
/// @param [in] format specifies the character string format
/// @param [in] ... format Filling Content
/// @return formatted string
///
static std::string FormatString(const char *format, ...) {
const uint32_t MAX_BUFFER_LEN = 1024; // the stack memory plint check result must be less than 1024
va_list args;
va_start(args, format);
char buffer[MAX_BUFFER_LEN] = {0};
int32_t ret = vsnprintf_s(buffer, MAX_BUFFER_LEN, MAX_BUFFER_LEN - 1, format, args);
va_end(args);
return ret > 0 ? buffer : "";
}
};
} // namespace ge

#endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_

+ 32
- 0
atc/util/tool.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef UTIL_TOOL_H_
#define UTIL_TOOL_H_

#include <iostream>
#include <memory>
#include <utility>

namespace ge {
template <typename T, typename... Args>
static inline std::shared_ptr<T> MakeShared(Args &&... args) {
typedef typename std::remove_const<T>::type T_nc;
std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...));
return ret;
}
} // namespace ge
#endif // UTIL_TOOL_H_

+ 283
- 0
atc/util/util.cc View File

@@ -0,0 +1,283 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "util.h"

#include <sys/stat.h>
#ifdef __GNUC__
#include <regex.h>
#else
#include <regex>
#endif
#include <algorithm>
#include <climits>
#include <cstdlib>
#include <ctime>
#include <fstream>

#include "common/util/error_manager/error_manager.h"
#include "external/ge/ge_api_error_codes.h"
#include "framework/common/debug/ge_log.h"
#include "mmpa/mmpa_api.h"

namespace {
const int kMaxBuffSize = 256;
const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character";
} // namespace

namespace ge {
/**
* @ingroup domi_common
* @brief Create directory, support to create multi-level directory
* @param [in] directory_path Path, can be multi-level directory
* @return -1 fail
* @return 0 success
*/
int CreateDirectory(const std::string &directory_path) {
GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty.");
auto dir_path_len = directory_path.length();
if (dir_path_len >= MMPA_MAX_PATH) {
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"},
{directory_path, std::to_string(MMPA_MAX_PATH)});
GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), MMPA_MAX_PATH);
return -1;
}
char tmp_dir_path[MMPA_MAX_PATH] = {0};
for (size_t i = 0; i < dir_path_len; i++) {
tmp_dir_path[i] = directory_path[i];
if ((tmp_dir_path[i] == '\\') || (tmp_dir_path[i] == '/')) {
if (mmAccess2(tmp_dir_path, M_F_OK) != EN_OK) {
int32_t ret = mmMkdir(tmp_dir_path, M_IRUSR | M_IWUSR | M_IXUSR); // 700
if (ret != 0) {
if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
return ret;
}
}
}
}
}
int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), M_IRUSR | M_IWUSR | M_IXUSR); // 700
if (ret != 0) {
if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
return ret;
}
}
return 0;
}

std::string CurrentTimeInStr() {
std::time_t now = std::time(nullptr);
std::tm *ptm = std::localtime(&now);
if (ptm == nullptr) {
GELOGE(ge::FAILED, "Localtime failed.");
return "";
}

const int kTimeBufferLen = 32;
char buffer[kTimeBufferLen + 1] = {0};
// format: 20171122042550
std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm);
return std::string(buffer);
}

std::string RealPath(const char *path) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH,
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"},
{path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", path, MMPA_MAX_PATH);

// Nullptr is returned when the path does not exist or there is no permission
// Return absolute path when path is accessible
std::string res;
char resolved_path[MMPA_MAX_PATH] = {0};
if (mmRealPath(path, resolved_path, MMPA_MAX_PATH) == EN_OK) {
res = resolved_path;
}

return res;
}

bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param) {
// The specified path is empty
std::map<std::string, std::string> args_map;
if (file_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param});
GELOGW("Input parameter %s is empty.", file_path.c_str());
return false;
}
std::string real_path = RealPath(file_path.c_str());
// Unable to get absolute path (does not exist or does not have permission to access)
if (real_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {file_path, strerror(errno)});
GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno));
return false;
}

// A regular matching expression to verify the validity of the input file path
// Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores
// File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.)
#ifdef __GNUC__
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
#else
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
#endif

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(real_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, real_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason);

// The absolute path points to a file that is not readable
if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) {
ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"file", "errmsg"}, {file_path.c_str(), strerror(errno)});
GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno));
return false;
}

return true;
}

bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param) {
// The specified path is empty
if (file_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param});
GELOGW("Input parameter's value is empty.");
return false;
}

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH,
ErrorManager::GetInstance().ATCReportErrMessage(
"E19002", {"filepath", "size"}, {file_path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(),
MMPA_MAX_PATH);

// A regular matching expression to verify the validity of the input file path
// Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores
// File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.)
#ifdef __GNUC__
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
#else
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
#endif

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(file_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, file_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason);

std::string real_path = RealPath(file_path.c_str());
// Can get absolute path (file exists)
if (!real_path.empty()) {
// File is not readable or writable
if (mmAccess2(real_path.c_str(), M_W_OK | M_F_OK) != EN_OK) {
ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"file", "errmsg"}, {real_path, strerror(errno)});
GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno));
return false;
}
} else {
// Find the last separator
int path_split_pos = static_cast<int>(file_path.size() - 1);
for (; path_split_pos >= 0; path_split_pos--) {
if (file_path[path_split_pos] == '\\' || file_path[path_split_pos] == '/') {
break;
}
}
if (path_split_pos == 0) {
return true;
}
if (path_split_pos != -1) {
std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos));
// Determine whether the specified path is valid by creating the path
if (CreateDirectory(prefix_path) != 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {file_path});
GELOGW("Can not create directory[%s].", file_path.c_str());
return false;
}
}
}

return true;
}

bool ValidateStr(const std::string &str, const std::string &mode) {
#ifdef __GNUC__
char ebuff[kMaxBuffSize];
regex_t reg;
int cflags = REG_EXTENDED | REG_NOSUB;
int ret = regcomp(&reg, mode.c_str(), cflags);
if (ret) {
regerror(ret, &reg, ebuff, kMaxBuffSize);
GELOGW("regcomp failed, reason: %s", ebuff);
regfree(&reg);
return true;
}

ret = regexec(&reg, str.c_str(), 0, NULL, 0);
if (ret) {
regerror(ret, &reg, ebuff, kMaxBuffSize);
GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff);
regfree(&reg);
return false;
}

regfree(&reg);
return true;
#else
std::wstring wstr(str.begin(), str.end());
std::wstring wmode(mode.begin(), mode.end());
std::wsmatch match;
bool res = false;

try {
std::wregex reg(wmode, std::regex::icase);
// Matching string part
res = regex_match(wstr, match, reg);
res = regex_search(str, std::regex("[`!@#$%^&*()|{}';',<>?]"));
} catch (std::exception &ex) {
GELOGW("The directory %s is invalid, error: %s.", str.c_str(), ex.what());
return false;
}
return !(res) && (str.size() == match.str().size());
#endif
}

bool GetNameFromFileName(const std::string &file_name, std::string &base_name) {
if (file_name.empty()) {
GELOGW("File path may not valid, check params --output");
return false;
}
size_t start_position = 0;
// using output as base_name (ignore ".om")
size_t filename_suffixes = 3;
if (file_name.find_last_of('/') != std::string::npos) {
start_position = file_name.find_last_of('/') + 1;
}
size_t end_position = file_name.length() - filename_suffixes;
base_name = file_name.substr(start_position, end_position - start_position);
if (base_name.empty()) {
GELOGW("File path may not valid, check params --output");
return false;
}
return true;
}
} // namespace ge

+ 159
- 0
atc/util/util.h View File

@@ -0,0 +1,159 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef UTIL_UTIL_H_
#define UTIL_UTIL_H_

#include <limits.h>
#include <string>

#include "framework/common/debug/ge_log.h"
#include "mmpa/mmpa_api.h"

// For propagating errors when calling a function.
#define GE_RETURN_IF_ERROR(expr) \
do { \
const ::ge::Status _status = (expr); \
if (_status) return _status; \
} while (0)

// If expr is not true, print the log and return the specified status
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \
do { \
bool b = (expr); \
if (!b) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
return _status; \
} \
} while (0);

// If expr is not true, print the log and execute a custom statement
#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (!b) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
exec_expr; \
} \
}

// If expr is true, print logs and execute custom statements
#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (b) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
exec_expr; \
} \
}

// If expr is not SUCCESS, return the same value
#define GE_CHK_STATUS_RET_NOLOG(expr) \
do { \
const ge::Status _status = (expr); \
if (_status != ge::SUCCESS) { \
return _status; \
} \
} while (0);

#define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \
do { \
const ::ge::Status _status = (expr); \
if (_status) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
return _status; \
} \
} while (0)

// If expr is not SUCCESS, print the log and execute a custom statement
#define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \
do { \
const ge::Status _status = (expr); \
GE_CHK_BOOL_EXEC(_status == SUCCESS, exec_expr, __VA_ARGS__); \
} while (0);

// Check if the parameter is null. If yes, return PARAM_INVALID and record the error
#define GE_CHECK_NOTNULL(val) \
do { \
if (val == nullptr) { \
GELOGE(ge::FAILED, "param[%s] must not be null.", #val); \
return ge::PARAM_INVALID; \
} \
} while (0)

// If expr is true, execute exec_expr without printing logs
#define GE_IF_BOOL_EXEC(expr, exec_expr) \
{ \
if (expr) { \
exec_expr; \
} \
}

namespace ge {
///
/// @ingroup domi_common
/// @brief Recursively Creating a Directory
/// @param [in] directory_path Path, which can be a multi-level directory.
/// @return 0 success
/// @return -1 fail
///
extern int CreateDirectory(const std::string &directory_path);

///
/// @ingroup domi_common
/// @brief Obtains the current time string.
/// @return Time character string in the format : %Y%m%d%H%M%S, eg: 20171011083555
///
std::string CurrentTimeInStr();

///
/// @ingroup domi_common
/// @brief Absolute path for obtaining files.
/// @param [in] path of input file
/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned
///
std::string RealPath(const char *path);

///
/// @ingroup domi_common
/// @brief Check whether the specified input file path is valid.
/// 1. The specified path cannot be empty.
/// 2. The path can be converted to an absolute path.
/// 3. The file path exists and is readable.
/// @param [in] file_path path of input file
/// @param [out] result
///
bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param = "");

///
/// @ingroup domi_common
/// @brief Checks whether the specified output file path is valid.
/// @param [in] file_path path of output file
/// @param [out] result
///
bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param = "");

///
/// @ingroup domi_common
/// @brief Check whether the file path meets the whitelist verification requirements.
/// @param [in] filePath file path
/// @param [out] result
///
bool ValidateStr(const std::string &filePath, const std::string &mode);

bool GetNameFromFileName(const std::string &file_name, std::string &base_name);
} // namespace ge
#endif // INC_FRAMEWORK_COMMON_UTIL_H_

+ 9
- 0
parser/common/convert/pb2json.cc View File

@@ -23,6 +23,7 @@
#include "securec.h"
#include "framework/common/fmk_types.h"
#include "framework/common/debug/ge_log.h"
#include "parser/common/model_saver.h"

using std::set;
using std::string;
@@ -31,6 +32,14 @@ namespace ge {
namespace {
const int kSignificantDigits = 10;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::ToJson(const ProtobufMsg &message,
const set<string> &black_fields,
const char *json_file, bool enum2str) {
Json j;
Message2Json(message, black_fields, j, enum2str);
ge::parser::ModelSaver::SaveJsonToFile(json_file, j);
}
// JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message,
const set<string> &black_fields, Json &json,


+ 3
- 0
parser/common/convert/pb2json.h View File

@@ -51,6 +51,9 @@ class Pb2Json {
const ProtobufReflection *reflection, const std::set<std::string> &black_fields,
Json &json, bool enum2str);

static void ToJson(const ProtobufMsg &message, const std::set<std::string> &black_fields, const char *json_file,
bool enum2str = false);

protected:
static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field,
bool enum2str, Json &json);


+ 159
- 0
third_party/graphengine/inc/external/ge/ge_ir_build.h View File

@@ -0,0 +1,159 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GE_IR_BUILD_H_
#define INC_EXTERNAL_GE_IR_BUILD_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY _declspec(dllexport)
#else
#define GE_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_VISIBILITY
#endif
#endif

#include <string>
#include <map>
#include <memory>
#include "graph/graph.h"
#include "graph/ge_error_codes.h"

namespace {
const int IR_MAJOR_VERSION = 1;
const int IR_MINOR_VERSION = 0;
const int IR_PATCH_VERSION = 0;
} // namespace

namespace ge {

struct ModelBufferData {
std::shared_ptr<uint8_t> data = nullptr;
uint64_t length;
};

enum aclgrphAttrType { ATTR_TYPE_KEEP_DTYPE = 0, ATTR_TYPE_WEIGHT_COMPRESS };

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
* @param global_options[IN] global init params for build
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &))
GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options);

GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &global_options);

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
*/
GE_FUNC_VISIBILITY void aclgrphBuildFinalize();

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
* @param graph[IN] the graph ready to build
* @param options[IN] options used for build
* @param model[OUT] builded model
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &,
const std::map<AscendString, AscendString> &,
ModelBufferData &))
GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph,
const std::map<std::string, std::string> &build_options,
ModelBufferData &model);

GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph,
const std::map<AscendString, AscendString> &build_options,
ModelBufferData &model);

/**
* @ingroup AscendCL
* @brief save model buffer to file
*
* @param output_file[IN] the file path to be saved
* @param model[IN] model buffer data
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char *, const ModelBufferData &))
GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model);

GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model);

/**
* @ingroup AscendCL
* @brief query IR interface version
*
* @param major_version[OUT] IR interface major version
* @param minor_version[OUT] IR interface minor version
* @param patch_version[OUT] IR interface patch version
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
GE_FUNC_VISIBILITY graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version);

/**
* @ingroup AscendCL
* @brief dump graph
*
* @param graph[IN] the graph ready to build
* @param file[IN] file path
* @param file[IN] file path string len
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
GE_FUNC_VISIBILITY graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len);

/**
* @ingroup AscendCL
* @brief create single op graph
*
* @param op_type[IN] the op_type
* @param inputs[IN] the inputdesc
* @param outputs[IN] the outputdesc
* @param graph[OUT] the graph
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector<TensorDesc> &inputs,
const std::vector<TensorDesc> &outputs, Graph &graph);

/**
* @name aclgrphSetOpAttr
* @brief set attribute for operators in the configuration file
* @param graph [IN/OUT] compute graph
* @param attr_type [In] attribute type
* @param cfg_path [IN] the config file path
* @return graphStatus
*/
GE_FUNC_VISIBILITY graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path);

}; // namespace ge
#endif // INC_EXTERNAL_GE_IR_BUILD_H_

+ 60
- 0
third_party/graphengine/inc/framework/generator/ge_generator.h View File

@@ -0,0 +1,60 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_
#define INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_

#include <map>
#include <memory>
#include <string>
#include <vector>
#include "common/ge_inner_error_codes.h"
#include "graph/ge_tensor.h"
#include "graph/graph.h"
#include "graph/op_desc.h"
#include "graph/detail/attributes_holder.h"
#include "omg/omg_inner_types.h"

namespace ge {
class GE_FUNC_VISIBILITY GeGenerator {
public:
GeGenerator() = default;

~GeGenerator() { (void)Finalize(); }

GeGenerator(const GeGenerator &) = delete;

GeGenerator &operator=(const GeGenerator &) = delete;

Status Initialize(const std::map<std::string, std::string> &options, OmgContext &context);

Status Finalize();

Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix,
const std::vector<GeTensor> &inputs = std::vector<GeTensor>());

Status GenerateInfershapeGraph(const Graph &graph);

Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector<GeTensor> &inputs,
const std::vector<GeTensor> &outputs, const std::string &model_file_name);
private:
class Impl;

std::shared_ptr<Impl> impl_;
};
} // namespace ge

#endif // INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_

+ 39
- 0
third_party/graphengine/inc/framework/omg/ge_init.h View File

@@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_OMG_GE_INIT_H_
#define INC_FRAMEWORK_OMG_GE_INIT_H_
#include <map>
#include <string>
#include "common/ge_inner_error_codes.h"

using std::string;
using std::map;

namespace ge {
class GE_FUNC_VISIBILITY GEInit {
public:
// GE Environment Initialize, return Status: SUCCESS,FAILED
static Status Initialize(const map<string, string> &options);

static string GetPath();

// GE Environment Finalize, return Status: SUCCESS,FAILED
static Status Finalize();
};
} // namespace ge

#endif // INC_FRAMEWORK_OMG_GE_INIT_H_

+ 34
- 0
third_party/graphengine/inc/framework/omg/model_tool.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_OMG_MODEL_TOOL_H_
#define INC_FRAMEWORK_OMG_MODEL_TOOL_H_

#include <memory>
#include <string>

#include "framework/common/debug/ge_log.h"
#include "proto/ge_ir.pb.h"

namespace ge {
class GE_FUNC_VISIBILITY ModelTool {
public:
static Status GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef &model_def);
static Status GetModelInfoFromPbtxt(const char *model_file, ge::proto::ModelDef &model_def);
};
} // namespace ge
#endif // INC_FRAMEWORK_OMG_MODEL_TOOL_H_

Loading…
Cancel
Save