Browse Source

!3 update the newest code from yellow zone 20201013

Merge pull request !3 from taoxiangdong/master
pull/3/MERGE
王涛 Gitee 5 years ago
parent
commit
d758757c03
29 changed files with 1851 additions and 1377 deletions
  1. +2
    -2
      CMakeLists.txt
  2. +1
    -1
      cmake/external_libs/gflags.cmake
  3. +1
    -1
      cmake/external_libs/json.cmake
  4. +1
    -1
      cmake/external_libs/onnx.cmake
  5. +1
    -1
      cmake/external_libs/protobuf_shared.cmake
  6. +1
    -1
      cmake/external_libs/protobuf_static.cmake
  7. +1
    -1
      cmake/external_libs/protoc.cmake
  8. +1
    -1
      cmake/external_libs/securec.cmake
  9. +3
    -0
      parser/CMakeLists.txt
  10. +0
    -27
      parser/caffe/CMakeLists.txt
  11. +1
    -1
      parser/caffe/caffe_custom_parser_adapter.cc
  12. +38
    -126
      parser/caffe/caffe_parser.cc
  13. +4
    -6
      parser/caffe/caffe_parser.h
  14. +190
    -190
      parser/caffe/proto/ge_ir.proto
  15. +3
    -0
      parser/common/CMakeLists.txt
  16. +22
    -21
      parser/common/acl_graph_parser_util.h
  17. +190
    -190
      parser/common/proto/ge_ir.proto
  18. +136
    -136
      parser/common/proto/insert_op.proto
  19. +16
    -16
      parser/func_to_graph/proto_python_rule.mk
  20. +3
    -0
      parser/onnx/CMakeLists.txt
  21. +0
    -1
      parser/onnx/module.mk
  22. +1
    -1
      parser/proto/caffe/CMakeLists.txt
  23. +190
    -190
      parser/proto/ge_ir.proto
  24. +136
    -136
      parser/proto/insert_op.proto
  25. +0
    -1
      parser/tensorflow/graph_optimizer.cc
  26. +190
    -190
      parser/tensorflow/proto/ge_ir.proto
  27. +136
    -136
      parser/tensorflow/proto/insert_op.proto
  28. +6
    -0
      stub/Makefile
  29. +577
    -0
      stub/gen_stubapi.py

+ 2
- 2
CMakeLists.txt View File

@@ -29,9 +29,9 @@ if (ENABLE_OPEN_SRC)
find_module(ge_common libge_common.so ${ASCEND_RUNTIME_DIR})
find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR})

set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)
#set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)

add_subdirectory(metadef)
#add_subdirectory(metadef)
#add_subdirectory(metadef/graph)
#add_subdirectory(metadef/register)



+ 1
- 1
cmake/external_libs/gflags.cmake View File

@@ -12,7 +12,7 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
endif()

ExternalProject_Add(gflags_build
#URL http://tfk.inhuawei.com/api/containers/container1/download/protobuf-3.8.0.tar.gz
URL https://github.com/gflags/gflags/archive/v2.2.2.tar.gz
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
SOURCE_DIR ${PARSER_DIR}/../third_party/gflags/src/gflags-2.2.2
CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR>


+ 1
- 1
cmake/external_libs/json.cmake View File

@@ -6,7 +6,7 @@ include(ExternalProject)

set(JSON_SRC_DIR ${PARSER_DIR}/../third_party/json/include)
ExternalProject_Add(json_build
#URL https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip
URL https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip
#URL /home/txd/workspace/cloud_code/pkg/include.zip
SOURCE_DIR ${JSON_SRC_DIR}
CONFIGURE_COMMAND ""


+ 1
- 1
cmake/external_libs/onnx.cmake View File

@@ -7,7 +7,7 @@ set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto)
file(MAKE_DIRECTORY ${ONNX_PROTO_DIR})

ExternalProject_Add(onnx
#URL https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz
URL https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz
URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz
#URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345
#SOURCE_DIR ${ONNX_SRC_DIR}


+ 1
- 1
cmake/external_libs/protobuf_shared.cmake View File

@@ -14,7 +14,7 @@ endif()
set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 $<$<STREQUAL:${PRODUCT_SIDE},host>:-D_GLIBCXX_USE_CXX11_ABI=0> -O2")
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protobuf_build
#URL http://tfk.inhuawei.com/api/containers/container1/download/protobuf-3.8.0.tar.gz
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${PARSER_DIR}/third_party/protobuf/src/protobuf-3.8.0
DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E copy_directory ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0 <SOURCE_DIR>


+ 1
- 1
cmake/external_libs/protobuf_static.cmake View File

@@ -12,7 +12,7 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static)
ExternalProject_Add(protobuf_static_build
#URL http://tfk.inhuawei.com/api/containers/container1/download/protobuf-3.8.0.tar.gz
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0
CONFIGURE_COMMAND ${CMAKE_COMMAND}


+ 1
- 1
cmake/external_libs/protoc.cmake View File

@@ -15,7 +15,7 @@ endif()
set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protoc_build
#URL http://tfk.inhuawei.com/api/containers/container1/download/protobuf-3.8.0.tar.gz
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
SOURCE_DIR ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0
CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake


+ 1
- 1
cmake/external_libs/securec.cmake View File

@@ -11,7 +11,7 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
endif()

ExternalProject_Add(c_sec_build
#URL http://tfk.inhuawei.com/api/containers/container1/download/protobuf-3.8.0.tar.gz
URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
SOURCE_DIR ${PARSER_DIR}/../libc_sec
CONFIGURE_COMMAND ${CMAKE_COMMAND}


+ 3
- 0
parser/CMakeLists.txt View File

@@ -59,8 +59,11 @@ target_include_directories(fmk_parser PRIVATE
${PARSER_DIR}
${PARSER_DIR}/inc
${PARSER_DIR}/parser
${PARSER_DIR}/../ge
${PARSER_DIR}/../inc
${PARSER_DIR}/../inc/framework
${PARSER_DIR}/../inc/common/util
${PARSER_DIR}/../inc/external
${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/register


+ 0
- 27
parser/caffe/CMakeLists.txt View File

@@ -1,27 +0,0 @@
set(PROTO_LIST
"${METADEF_DIR}/proto/caffe/caffe.proto"
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})

############ lib_caffe_parser.so ############
add_library(_caffe_parser SHARED ${PROTO_SRCS})

target_include_directories(_caffe_parser PRIVATE
${CMAKE_CURRENT_LIST_DIR}
)

target_link_libraries(_caffe_parser PRIVATE
$<BUILD_INTERFACE:intf_pub>
-Wl,--no-as-needed
protobuf
-Wl,--as-needed
)

############ install ############
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(TARGETS _caffe_parser OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}
)

+ 1
- 1
parser/caffe/caffe_custom_parser_adapter.cc View File

@@ -83,8 +83,8 @@ Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr
}

bool bias_en = false;
int start_pos = layer->bottom_size();
bool update_in_turn = (static_cast<int64_t>(op->GetAllInputsSize()) == (layer->bottom_size() + layer->blobs_size()));
int start_pos = layer->bottom_size();
for (int i = 0; i < layer->blobs_size(); ++i) {
ge::GeTensorPtr weight = ge::parser::MakeShared<ge::GeTensor>();
GE_CHECK_NOTNULL(weight);


+ 38
- 126
parser/caffe/caffe_parser.cc View File

@@ -107,6 +107,10 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
return ret;
}
GELOGI("Weights parse success. graph: %s", graph.GetName().c_str());
if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) {
GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
return ge::SUCCESS;
}
} // namespace ge
@@ -803,10 +807,6 @@ Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter

Status CaffeModelParser::AddBlobsToMap(const domi::caffe::LayerParameter &layer,
std::map<std::string, std::string> &inplace_blob_name_remapping) {
if (layer.type() == ge::parser::NETOUTPUT) {
return SUCCESS;
}

if (layer.top_size() <= 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E19011", {"opname"}, {layer.name()});
GELOGE(FAILED, "The output size of layer %s needs to be greater than zero.", layer.name().c_str());
@@ -1085,34 +1085,6 @@ Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const dom
"while it's original input num is: %d",
layer.bottom_size());
}

// Netoutput node processing
if (op_desc->GetType() == ge::parser::NETOUTPUT) {
size_t input_output_tensor_num = 0;
if (!ge::GetParserContext().user_out_nodes.empty()) {
// User specified output
input_output_tensor_num = ge::GetParserContext().user_out_nodes.size();
} else {
for (auto t_iter = top_blobs_map_.begin(); t_iter != top_blobs_map_.end(); t_iter++) {
auto b_iter = bottom_blobs_map_.find(t_iter->first);
// Find the output node of the network
if (b_iter == bottom_blobs_map_.end()) {
input_output_tensor_num += top_blobs_map_[t_iter->first].size();
}
}
}
// add tensordesc
GELOGD(
"Current op type is NETOUTPUT, add additional input&output num: %zu."
"while it's original input num is: %d, output num is: %d",
input_output_tensor_num, layer.bottom_size(), output_tensor_num);
for (size_t i = 0; i < input_output_tensor_num; i++) {
ge::GeTensorDesc input_tensor;
GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor));
ge::GeTensorDesc output_tensor;
GE_RETURN_IF_ERROR(op_desc->AddOutputDesc(output_tensor));
}
}
return SUCCESS;
}

@@ -1140,10 +1112,12 @@ Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const
GE_CHECK_NOTNULL(op_desc);
auto valid_input_size = layer.bottom_size();
auto blob_size = layer.blobs_size();
GELOGI("After GetOpDescFromOperator op[%s] type[%s] have all input size: %zu, caffe_input_size:%d output size: %zu",
GELOGI("After GetOpDescFromOperator op[%s] type[%s] have all input size: %zu, "
"caffe_input_size:%d blob_size %d output size: %zu",
op_desc->GetName().c_str(), op_desc->GetType().c_str(),
op_desc->GetAllInputsSize(), valid_input_size, op_desc->GetOutputsSize());
bool update_in_turn = (static_cast<int64_t>(op_desc->GetAllInputsSize()) == (valid_input_size + blob_size);
op_desc->GetAllInputsSize(), valid_input_size,
blob_size, op_desc->GetOutputsSize());
bool update_in_turn = (static_cast<int64_t >(op_desc->GetAllInputsSize()) == (valid_input_size + blob_size));
for (int i = 0; i < valid_input_size; i++) {
ge::GeTensorDesc input_tensor;
std::string input_name;
@@ -1151,8 +1125,8 @@ Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const
// Below cases are supported fow now when there are optional inputs
// x means optional, o means requierd input
// a. ooxxx, number of o and x>=layer.bottom_size+layer.blobs_size>=number of o
// b. oxoxoxox, layer.bottom_size+layer.blobs_size>=number of o
// c. oxoxoxox, layer.bottom_size+layer.blobs_size>=number of o and x
// b. oxoxoxox, layer.bottom_size+layer.blobs_size=number of o
// c. oxoxoxox, layer.bottom_size+layer.blobs_size=number of o and x
if (update_in_turn) {
ret = op_desc->UpdateInputDesc(op_desc->GetInputNameByIndex(static_cast<uint32_t>(i)), input_tensor);
} else {
@@ -1164,14 +1138,16 @@ Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const
}
}
GELOGI("op [%s], type[%s], update input(%d) with name %s %s", op_desc->GetName().c_str(),
op_desc->GetType().c_str(), i, input_name.c_str(), ret == ge::GRAPH_SUCCESS ? "success" : "failed");
op_desc->GetType().c_str(), i, input_name.c_str(), ret == ge::GRAPH_SUCCESS ? "success" : "failed");
}

for (int i = 0; i < layer.top_size(); i++) {
ge::GeTensorDesc output_tensor;
auto ret = op_desc->UpdateOutputDesc(op_desc->GetOutputNameByIndex(static_cast<uint32_t>(i)), output_tensor);
GELOGI("op [%s], type[%s], update output(%d) with name %s %s", op_desc->GetName().c_str(),
op_desc->GetType().c_str(), i, op_desc->GetOutputNameByIndex(i).c_str(), ret == ge::GRAPH_SUCCESS ? "success" : "failed");
GELOGI("op [%s], type[%s], update output(%d) with name %s %s",
op_desc->GetName().c_str(), op_desc->GetType().c_str(),
i, op_desc->GetOutputNameByIndex(i).c_str(),
ret == ge::GRAPH_SUCCESS ? "success" : "failed");
}
}
return SUCCESS;
@@ -1266,46 +1242,33 @@ bool CaffeModelParser::IsOutputTop(const string &op_name, const int32_t index) {
return ret;
}

Status CaffeModelParser::AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(graph);
ge::NodePtr net_output_node = graph->FindFirstNodeMatchType(ge::parser::NETOUTPUT);
if (net_output_node == nullptr) {
GELOGE(INTERNAL_ERROR, "Can not find netoutput node.");
return INTERNAL_ERROR;
}
uint32_t net_output_num = net_output_node->GetAllInDataAnchorsSize();
Status CaffeModelParser::AddUserOutNodesTop() {
int32_t index = 0;
const std::vector<std::pair<std::string, int32_t>> &user_out_nodes = ge::GetParserContext().user_out_nodes;
int net_output_num = user_out_nodes.size();
for (const auto &out_pair : user_out_nodes) {
auto node_iter = node_map.find(out_pair.first);
auto layer_iter = layer_tops_map_.find(out_pair.first);
GELOGI("Add to output, node name: %s", out_pair.first.c_str());
if (node_iter != node_map.end()) {
if ((static_cast<uint32_t>(out_pair.second) >= node_iter->second->GetAllOutDataAnchorsSize()) ||
(static_cast<uint32_t>(index) >= net_output_num)) {
if (layer_iter != layer_tops_map_.end()) {
if (static_cast<uint32_t>(out_pair.second) >= (layer_iter->second).size()) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11016", {"opname", "outputindex", "totlaloutputindex", "inputindex", "totlalinputindex"},
{out_pair.first.c_str(), std::to_string(out_pair.second),
std::to_string(node_iter->second->GetAllOutDataAnchorsSize()), std::to_string(index),
std::to_string((layer_iter->second).size()), std::to_string(index),
std::to_string(net_output_num)});
GELOGE(INTERNAL_ERROR,
"Add op %s to NetOutput faild, current node output index:%d should < %u. NetOutput"
"input_index:%d should < %u.",
out_pair.first.c_str(), out_pair.second, node_iter->second->GetAllOutDataAnchorsSize(), index,
out_pair.first.c_str(), out_pair.second, (layer_iter->second).size(), index,
net_output_num);
return INTERNAL_ERROR;
}
GELOGD("Start add edge for user out node: From %s:%d To %s:%d.", node_iter->second->GetName().c_str(),
out_pair.second, net_output_node->GetName().c_str(), index);
ge::OutDataAnchorPtr out_archor_ptr = node_iter->second->GetOutDataAnchor(out_pair.second);
GE_CHECK_NOTNULL(out_archor_ptr);
ge::InDataAnchorPtr in_archor_ptr = net_output_node->GetInDataAnchor(index);
GE_CHECK_NOTNULL(in_archor_ptr);
if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E11013", {"opname1", "opname2"},
{node_iter->second->GetName(), net_output_node->GetName()});
GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", node_iter->second->GetName().c_str(),
net_output_node->GetName().c_str());
return INTERNAL_ERROR;

string top_name = layer_iter->second[out_pair.second];
auto top_node_iter = node_map.find(out_pair.first);
if (top_node_iter != node_map.end()) {
ge::GetParserContext().out_top_names.push_back(top_name);
GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str());
}
++index;
} else {
@@ -1317,13 +1280,7 @@ Status CaffeModelParser::AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph) {
return SUCCESS;
}

Status CaffeModelParser::AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(graph);
ge::NodePtr node = graph->FindFirstNodeMatchType(ge::parser::NETOUTPUT);

GE_RETURN_WITH_LOG_IF_FALSE(node != nullptr, "Net without output, some phase failed in front.");

int32_t index = 0;
Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_message) {
for (int32_t i = 0; i < proto_message.layer_size(); i++) {
const domi::caffe::LayerParameter &layer = proto_message.layer(i);

@@ -1333,6 +1290,7 @@ Status CaffeModelParser::AddEdge4Output(const domi::caffe::NetParameter &proto_m

for (int i = 0; i < layer.top_size(); i++) {
string top = layer.top(i);
string top_origin = top;
// Handling 'inplace' scenarios
if (IsInplaceTopBlob(layer, top)) {
top = RemapTopNameByLayer(layer, top, i);
@@ -1354,21 +1312,9 @@ Status CaffeModelParser::AddEdge4Output(const domi::caffe::NetParameter &proto_m
auto top_node_iter = node_map.find(layer.name());
GELOGI("output in top_blob: %s", layer.name().c_str());
if (top_node_iter != node_map.end()) {
// add edge
// Output node, output index, input node, input index
GELOGD("Start add edge for out node: From %s:%d To %s:%d.", top_node_iter->second->GetName().c_str(), i,
node->GetName().c_str(), index);
ge::OutDataAnchorPtr out_archor_ptr = top_node_iter->second->GetOutDataAnchor(i);
GE_CHECK_NOTNULL(out_archor_ptr);
ge::InDataAnchorPtr in_archor_ptr = node->GetInDataAnchor(index);
GE_CHECK_NOTNULL(in_archor_ptr);
GE_IF_BOOL_EXEC(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS,
ErrorManager::GetInstance().ATCReportErrMessage(
"E11013", {"opname1", "opname2"}, {top_node_iter->second->GetName(), node->GetName()});
GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to to op[%s].",
top_node_iter->second->GetName().c_str(), node->GetName().c_str());
return INTERNAL_ERROR;);
index++;
ge::GetParserContext().out_top_names.push_back(top_origin);
ge::GetParserContext().default_out_nodes.push_back(std::make_pair(layer.name(), (int32_t)i));
GELOGI("The top of out node [%s] is [%s]", layer.name().c_str(), top_origin.c_str());
}
}
}
@@ -1482,12 +1428,6 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co
CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true;
GELOGE(FAILED, "ParseInput ret fail."));

// build output layer
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(graph->GetName() + "_" + ge::parser::NODE_NAME_NET_OUTPUT);
layer->set_type(ge::parser::NETOUTPUT);

int32_t layer_count = proto_message.layer_size();
std::map<std::string, std::string> inplace_blob_name_remapping;
// Map of operator name and occurrence times
@@ -1553,9 +1493,9 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co
GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail.");

if (!(ge::GetParserContext().user_out_nodes.empty())) {
GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed.");
GE_RETURN_WITH_LOG_IF_ERROR(AddUserOutNodesTop(), "Caffe parser add top_name for user out nodes failed.");
} else {
GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail.");
GE_RETURN_WITH_LOG_IF_ERROR(AddOutputTop(proto_message), "Caffe parser add top_name for output fail.");
}
GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail.");

@@ -1657,12 +1597,6 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap
CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true;
GELOGE(FAILED, "ParseInput ret fail."));

// build output layer
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(graph->GetName() + "_" + ge::parser::NODE_NAME_NET_OUTPUT);
layer->set_type(ge::parser::NETOUTPUT);

int32_t layer_count = proto_message.layer_size();

if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) {
@@ -1735,12 +1669,11 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap
GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail.");

if (!(ge::GetParserContext().user_out_nodes.empty())) {
GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed.");
GE_RETURN_WITH_LOG_IF_ERROR(AddUserOutNodesTop(), "Caffe parser add top_name for user out nodes failed.");
} else {
GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail.");
GE_RETURN_WITH_LOG_IF_ERROR(AddOutputTop(proto_message), "Caffe parser add top_name for output fail.");
}
GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail.");
GE_RETURN_WITH_LOG_IF_ERROR(GetLeafNodeTops(graph), "Caffe parser get out nodes top names failed.");

auto nodes = graph->GetDirectNode();
GELOGI("graph node size = %zu.", nodes.size());
@@ -2451,27 +2384,6 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter &param, ge::Co
return SUCCESS;
}

Status CaffeModelParser::GetLeafNodeTops(ge::ComputeGraphPtr &graph) {
auto netout = graph->FindFirstNodeMatchType(ge::parser::NETOUTPUT);
GE_CHECK_NOTNULL(netout);
for (const auto &in_anchor : netout->GetAllInDataAnchors()) {
auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_data_anchor);
auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(peer_out_data_node);
int idx = peer_out_data_anchor->GetIdx();
string node_name = peer_out_data_node->GetName();
auto layer_iter = layer_tops_map_.find(node_name);
if (layer_iter != layer_tops_map_.end()) {
ge::GetParserContext().out_top_names.push_back(layer_iter->second[idx]);
GELOGI("The top of out node [%s] is [%s]", node_name.c_str(), layer_iter->second[idx].c_str());
} else {
GELOGW("The out node [%s] can not find its top.", node_name.c_str());
}
}
return SUCCESS;
}

Status CaffeModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) {
return SUCCESS;
}


+ 4
- 6
parser/caffe/caffe_parser.h View File

@@ -279,12 +279,12 @@ class CaffeModelParser : public domi::ModelParser {

/**
* @ingroup domi_omg
* @brief Add edge information to graph
* @param [in|out] graph graph for saving model information
* @brief Add top name information to graph
* @param [in|out] proto_message
* @return SUCCESS add successfully
* @return FAILED add failed
*/
Status AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph);
Status AddOutputTop(const domi::caffe::NetParameter &proto_message);

/**
* @ingroup domi_omg
@@ -324,7 +324,7 @@ class CaffeModelParser : public domi::ModelParser {
Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer,
const string &op_type);

Status AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph);
Status AddUserOutNodesTop();

std::string RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name, int index);

@@ -335,8 +335,6 @@ class CaffeModelParser : public domi::ModelParser {
Status ParseOpParam(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op,
std::shared_ptr<ge::OpParser> &op_parser);

Status GetLeafNodeTops(ge::ComputeGraphPtr &graph);

void SaveOrigionLayerTops(domi::caffe::LayerParameter &layer);

Status ReorderInput(domi::caffe::NetParameter &net);


+ 190
- 190
parser/caffe/proto/ge_ir.proto View File

@@ -1,190 +1,190 @@
syntax = "proto3";
package ge.proto;
enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
}
message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType
ListValueType val_type = 20;
}
message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}
oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}
// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}
// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}
// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name
DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"
bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;
map<string, AttrDef> attr = 5; // Set of extra parameter fields
}
// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}
// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type
repeated string input = 5; // input original op name + outgoing index. op_name:index
map<string, AttrDef> attr = 10; // Set of operator parameter fields
bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}
// Graph definition
message GraphDef
{
string name = 1; // name
repeated string input = 4; // Graph input
repeated string output = 5; // Graph output
repeated OpDef op = 6; // List of operators
map<string, AttrDef> attr = 11; // Extended field
}
// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user
repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef
map<string, AttrDef> attr = 11; // Extended field
}
syntax = "proto3";
package ge.proto;
enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
}
message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType
ListValueType val_type = 20;
}
message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}
oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}
// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}
// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}
// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name
DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"
bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;
map<string, AttrDef> attr = 5; // Set of extra parameter fields
}
// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}
// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type
repeated string input = 5; // input original op name + outgoing index. op_name:index
map<string, AttrDef> attr = 10; // Set of operator parameter fields
bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}
// Graph definition
message GraphDef
{
string name = 1; // name
repeated string input = 4; // Graph input
repeated string output = 5; // Graph output
repeated OpDef op = 6; // List of operators
map<string, AttrDef> attr = 11; // Extended field
}
// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user
repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef
map<string, AttrDef> attr = 11; // Extended field
}

+ 3
- 0
parser/common/CMakeLists.txt View File

@@ -41,8 +41,11 @@ target_include_directories(parser_common PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PARSER_DIR}
${PARSER_DIR}/parser
${PARSER_DIR}/../ge
${PARSER_DIR}/../inc
${PARSER_DIR}/../inc/framework
${PARSER_DIR}/../inc/common/util
${PARSER_DIR}/../inc/external
${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/register


+ 22
- 21
parser/common/acl_graph_parser_util.h View File

@@ -139,8 +139,8 @@ bool ValidateStr(const std::string &filePath, const std::string &mode);
std::string CurrentTimeInStr();

template <typename T, typename... Args>
static inline std::shared_ptr<T> ComGraphMakeShared(Args &&... args) {
using T_nc = typename std::remove_const<T>::type;
static inline std::shared_ptr<T> MakeShared(Args &&... args) {
typedef typename std::remove_const<T>::type T_nc;
std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...));
return ret;
}
@@ -150,63 +150,64 @@ static inline std::shared_ptr<T> ComGraphMakeShared(Args &&... args) {
/// @param [in] a multiplicator
/// @param [in] b multiplicator
/// @return Status
inline Status Int64MulCheckOverflow(int64_t a, int64_t b) {
inline domi::Status Int64MulCheckOverflow(int64_t a, int64_t b) {
if (a > 0) {
if (b > 0) {
if (a > (INT64_MAX / b)) {
return FAILED;
return domi::FAILED;
}
} else {
if (b < (INT64_MIN / a)) {
return FAILED;
return domi::FAILED;
}
}
} else {
if (b > 0) {
if (a < (INT64_MIN / b)) {
return FAILED;
return domi::FAILED;
}
} else {
if ((a != 0) && (b < (INT64_MAX / a))) {
return FAILED;
return domi::FAILED;
}
}
}
return SUCCESS;
return domi::SUCCESS;
}

/// @ingroup math_util
/// @brief check whether int64 multiplication can result in overflow
/// @param [in] a multiplicator
/// @param [in] b multiplicator
/// @return Status
inline Status CheckInt64Uint32MulOverflow(int64_t a, uint32_t b) {
inline domi::Status CheckInt64Uint32MulOverflow(int64_t a, uint32_t b) {
if (a == 0 || b == 0) {
return SUCCESS;
return domi::SUCCESS;
}
if (a > 0) {
if (a > (INT64_MAX / b)) {
return FAILED;
return domi::FAILED;
}
} else {
if (a < (INT64_MIN / b)) {
return FAILED;
return domi::FAILED;
}
}
return SUCCESS;
return domi::SUCCESS;
}

#define PARSER_INT64_MULCHECK(a, b) \
if (ge::Int64MulCheckOverflow((a), (b)) != SUCCESS) { \
#define PARSER_INT64_MULCHECK(a, b) \
if (ge::parser::Int64MulCheckOverflow((a), (b)) != SUCCESS) { \
GELOGW("Int64 %ld and %ld multiplication can result in overflow!", static_cast<int64_t>(a), \
static_cast<int64_t>(b)); \
return INTERNAL_ERROR; \
static_cast<int64_t>(b)); \
return INTERNAL_ERROR; \
}

#define PARSER_INT64_UINT32_MULCHECK(a, b) \
if (ge::CheckInt64Uint32MulOverflow((a), (b)) != SUCCESS) { \
#define PARSER_INT64_UINT32_MULCHECK(a, b) \
if (ge::parser::CheckInt64Uint32MulOverflow((a), (b)) != SUCCESS) { \
GELOGW("Int64 %ld and UINT32 %u multiplication can result in overflow!", static_cast<uint32_t>(a), \
static_cast<uint32_t>(b)); \
return INTERNAL_ERROR; \
static_cast<uint32_t>(b)); \
return INTERNAL_ERROR; \
}
} // namespace parser
} // namespace ge


+ 190
- 190
parser/common/proto/ge_ir.proto View File

@@ -1,190 +1,190 @@
syntax = "proto3";
package ge.proto;
enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
}
message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType
ListValueType val_type = 20;
}
message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}
oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}
// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}
// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}
// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name
DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"
bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;
map<string, AttrDef> attr = 5; // Set of extra parameter fields
}
// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}
// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type
repeated string input = 5; // input original op name + outgoing index. op_name:index
map<string, AttrDef> attr = 10; // Set of operator parameter fields
bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}
// Graph definition
message GraphDef
{
string name = 1; // name
repeated string input = 4; // Graph input
repeated string output = 5; // Graph output
repeated OpDef op = 6; // List of operators
map<string, AttrDef> attr = 11; // Extended field
}
// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user
repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef
map<string, AttrDef> attr = 11; // Extended field
}
syntax = "proto3";
package ge.proto;
enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
}
message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType
ListValueType val_type = 20;
}
message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}
oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}
// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}
// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}
// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name
DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"
bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;
map<string, AttrDef> attr = 5; // Set of extra parameter fields
}
// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}
// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type
repeated string input = 5; // input original op name + outgoing index. op_name:index
map<string, AttrDef> attr = 10; // Set of operator parameter fields
bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}
// Graph definition
message GraphDef
{
string name = 1; // name
repeated string input = 4; // Graph input
repeated string output = 5; // Graph output
repeated OpDef op = 6; // List of operators
map<string, AttrDef> attr = 11; // Extended field
}
// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user
repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef
map<string, AttrDef> attr = 11; // Extended field
}

+ 136
- 136
parser/common/proto/insert_op.proto View File

@@ -1,136 +1,136 @@
syntax = "proto3";
package domi;
message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}
message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}
enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}
// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;
// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;
// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;
// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;
// [End] 动态AIPP参数
// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;
int32 src_image_size_w = 57;
int32 src_image_size_h = 58;
bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;
bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;
bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;
repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;
// [End] 静态AIPP参数
// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}
message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}
MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入
repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}
syntax = "proto3";
package domi;
message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}
message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}
enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}
// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;
// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;
// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;
// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;
// [End] 动态AIPP参数
// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;
int32 src_image_size_w = 57;
int32 src_image_size_h = 58;
bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;
bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;
bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;
repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;
// [End] 静态AIPP参数
// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}
message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}
MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入
repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 16
- 16
parser/func_to_graph/proto_python_rule.mk View File

@@ -1,17 +1,17 @@
include $(BUILD_SYSTEM)/base_rules.mk
FUNCTION_TO_GRAPH_OUT_TIMESTAMP := $(HOST_OUT_ROOT)/func_to_graph/.timestamp
PROTO_SRC_DIR = framework/domi/parser/func_to_graph/proto
PY_PROTO_BUILD_DIR = $(HOST_OUT_ROOT)/tmp/function_to_graph/proto
$(warning PRIVATE_PROTOC is $(PRIVATE_PROTOC))
$(warning protobuf_lib_dir is $(protobuf_lib_dir))
$(FUNCTION_TO_GRAPH_OUT_TIMESTAMP): $(PRIVATE_PROTOC)
mkdir -p $(PY_PROTO_BUILD_DIR)
LD_LIBRARY_PATH=$(protobuf_lib_dir):$$LD_LIBRARY_PATH $(PRIVATE_PROTOC) -I=$(PROTO_SRC_DIR) --python_out=$(PY_PROTO_BUILD_DIR) $(PROTO_SRC_DIR)/*.proto
$(LOCAL_BUILT_MODULE): $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP)
mkdir -p $@
include $(BUILD_SYSTEM)/base_rules.mk
FUNCTION_TO_GRAPH_OUT_TIMESTAMP := $(HOST_OUT_ROOT)/func_to_graph/.timestamp
PROTO_SRC_DIR = framework/domi/parser/func_to_graph/proto
PY_PROTO_BUILD_DIR = $(HOST_OUT_ROOT)/tmp/function_to_graph/proto
$(warning PRIVATE_PROTOC is $(PRIVATE_PROTOC))
$(warning protobuf_lib_dir is $(protobuf_lib_dir))
$(FUNCTION_TO_GRAPH_OUT_TIMESTAMP): $(PRIVATE_PROTOC)
mkdir -p $(PY_PROTO_BUILD_DIR)
LD_LIBRARY_PATH=$(protobuf_lib_dir):$$LD_LIBRARY_PATH $(PRIVATE_PROTOC) -I=$(PROTO_SRC_DIR) --python_out=$(PY_PROTO_BUILD_DIR) $(PROTO_SRC_DIR)/*.proto
$(LOCAL_BUILT_MODULE): $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP)
mkdir -p $@
cp -rf $(PY_PROTO_BUILD_DIR)/* $@

+ 3
- 0
parser/onnx/CMakeLists.txt View File

@@ -29,8 +29,11 @@ target_include_directories(fmk_onnx_parser PRIVATE
${PARSER_DIR}
${PARSER_DIR}/inc
${PARSER_DIR}/parser
${PARSER_DIR}/../ge
${PARSER_DIR}/../inc
${PARSER_DIR}/../inc/common/util
{PARSER_DIR}/../inc/framework
{PARSER_DIR}/../inc/external
${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/register


+ 0
- 1
parser/onnx/module.mk View File

@@ -43,7 +43,6 @@ LOCAL_SHARED_LIBRARIES := \
libparser_common \
libgraph \
libregister \
libge_common \

LOCAL_LDFLAGS := -lrt



+ 1
- 1
parser/proto/caffe/CMakeLists.txt View File

@@ -1,5 +1,5 @@
set(PROTO_LIST
"${TOP_DIR}/inc/register/proto/caffe/caffe.proto"
"${METADEF_DIR}/proto/caffe/caffe.proto"
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})


+ 190
- 190
parser/proto/ge_ir.proto View File

@@ -1,190 +1,190 @@
syntax = "proto3";
package ge.proto;
enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
}
message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType
ListValueType val_type = 20;
}
message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}
oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}
// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}
// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}
// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name
DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"
bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;
map<string, AttrDef> attr = 5; // Set of extra parameter fields
}
// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}
// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type
repeated string input = 5; // input original op name + outgoing index. op_name:index
map<string, AttrDef> attr = 10; // Set of operator parameter fields
bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}
// Graph definition
message GraphDef
{
string name = 1; // name
repeated string input = 4; // Graph input
repeated string output = 5; // Graph output
repeated OpDef op = 6; // List of operators
map<string, AttrDef> attr = 11; // Extended field
}
// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user
repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef
map<string, AttrDef> attr = 11; // Extended field
}
syntax = "proto3";
package ge.proto;
enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
}
message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType
ListValueType val_type = 20;
}
message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}
oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}
// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}
// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}
// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name
DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"
bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;
map<string, AttrDef> attr = 5; // Set of extra parameter fields
}
// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}
// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type
repeated string input = 5; // input original op name + outgoing index. op_name:index
map<string, AttrDef> attr = 10; // Set of operator parameter fields
bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}
// Graph definition
message GraphDef
{
string name = 1; // name
repeated string input = 4; // Graph input
repeated string output = 5; // Graph output
repeated OpDef op = 6; // List of operators
map<string, AttrDef> attr = 11; // Extended field
}
// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user
repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef
map<string, AttrDef> attr = 11; // Extended field
}

+ 136
- 136
parser/proto/insert_op.proto View File

@@ -1,136 +1,136 @@
syntax = "proto3";
package domi;
message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}
message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}
enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}
// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;
// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;
// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;
// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;
// [End] 动态AIPP参数
// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;
int32 src_image_size_w = 57;
int32 src_image_size_h = 58;
bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;
bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;
bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;
repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;
// [End] 静态AIPP参数
// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}
message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}
MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入
repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}
syntax = "proto3";
package domi;
message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}
message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}
enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}
// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;
// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;
// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;
// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;
// [End] 动态AIPP参数
// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;
int32 src_image_size_w = 57;
int32 src_image_size_h = 58;
bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;
bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;
bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;
repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;
// [End] 静态AIPP参数
// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}
message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}
MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入
repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 0
- 1
parser/tensorflow/graph_optimizer.cc View File

@@ -807,7 +807,6 @@ Status CreateNodeDefBytes(ge::NodePtr n, string originalType, map<string, PIOLis
for (uint32_t j = 0; j < ge_desc->GetShape().GetDimNum(); ++j) {
tmp_dim = ge_desc->GetShape().GetDim(j);
GE_CHECK_GE(tmp_dim, 0);
PARSER_INT64_MULCHECK(real_size, tmp_dim);
real_size *= tmp_dim;
}


+ 190
- 190
parser/tensorflow/proto/ge_ir.proto View File

@@ -1,190 +1,190 @@
syntax = "proto3";
package ge.proto;
enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
}
message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType
ListValueType val_type = 20;
}
message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}
oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}
// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}
// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}
// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name
DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"
bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;
map<string, AttrDef> attr = 5; // Set of extra parameter fields
}
// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}
// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type
repeated string input = 5; // input original op name + outgoing index. op_name:index
map<string, AttrDef> attr = 10; // Set of operator parameter fields
bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}
// Graph definition
message GraphDef
{
string name = 1; // name
repeated string input = 4; // Graph input
repeated string output = 5; // Graph output
repeated OpDef op = 6; // List of operators
map<string, AttrDef> attr = 11; // Extended field
}
// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user
repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef
map<string, AttrDef> attr = 11; // Extended field
}
syntax = "proto3";
package ge.proto;
enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
}
message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType
ListValueType val_type = 20;
}
message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}
oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}
// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}
// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}
// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name
DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"
bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;
map<string, AttrDef> attr = 5; // Set of extra parameter fields
}
// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}
// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type
repeated string input = 5; // input original op name + outgoing index. op_name:index
map<string, AttrDef> attr = 10; // Set of operator parameter fields
bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}
// Graph definition
message GraphDef
{
string name = 1; // name
repeated string input = 4; // Graph input
repeated string output = 5; // Graph output
repeated OpDef op = 6; // List of operators
map<string, AttrDef> attr = 11; // Extended field
}
// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user
repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef
map<string, AttrDef> attr = 11; // Extended field
}

+ 136
- 136
parser/tensorflow/proto/insert_op.proto View File

@@ -1,136 +1,136 @@
syntax = "proto3";
package domi;
message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}
message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}
enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}
// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;
// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;
// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;
// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;
// [End] 动态AIPP参数
// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;
int32 src_image_size_w = 57;
int32 src_image_size_h = 58;
bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;
bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;
bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;
repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;
// [End] 静态AIPP参数
// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}
message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}
MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入
repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}
syntax = "proto3";
package domi;
message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}
message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}
enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}
// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;
// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;
// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;
// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;
// [End] 动态AIPP参数
// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;
int32 src_image_size_w = 57;
int32 src_image_size_h = 58;
bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;
bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;
bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;
repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;
// [End] 静态AIPP参数
// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}
message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}
MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入
repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 6
- 0
stub/Makefile View File

@@ -0,0 +1,6 @@
inc_path := $(shell pwd)/inc/external/
out_path := $(shell pwd)/out/graph/lib64/stub/
stub_path := $(shell pwd)/common/graph/stub/

mkdir_stub := $(shell mkdir -p $(out_path))
graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path))

+ 577
- 0
stub/gen_stubapi.py View File

@@ -0,0 +1,577 @@
import os
import re
import sys
import logging

logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s',
level=logging.INFO)

"""
this attr is used for symbol table visible
"""
GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY'

"""
generate stub func body by return type
"""
RETURN_STATEMENTS = {
'graphStatus': ' std::cout << "[ERROR]: stub library libgraph or libge_compiler cannot be used for execution, please check your "\n '
' << "environment variables and compilation options to make sure you use the correct library."\n'
' << std::endl;\n'
' return ACL_ERROR_COMPILING_STUB_MODE;',
'Status': ' return SUCCESS;',
'Graph': ' return Graph();',
'Graph&': ' return *this;',
'Format': ' return Format();',
'Format&': ' return *this;',
'Shape': ' return Shape();',
'Shape&': ' return *this;',
'TensorDesc': ' return TensorDesc();',
'TensorDesc&': ' return *this;',
'Tensor': ' return Tensor();',
'Tensor&': ' return *this;',
'Operator': ' return Operator();',
'Operator&': ' return *this;',
'Ptr': ' return nullptr;',
'std::string': ' return "";',
'std::string&': ' return "";',
'string': ' return "";',
'int': ' return 0;',
'DataType': ' return DT_FLOAT;',
'InferenceContextPtr': ' return nullptr;',
'SubgraphBuilder': ' return nullptr;',
'OperatorImplPtr': ' return nullptr;',
'OutHandler': ' return nullptr;',
'std::vector<std::string>': ' return {};',
'std::vector<int64_t>': ' return {};',
'std::map': ' return {};',
'uint32_t': ' return 0;',
'int64_t': ' return 0;',
'uint64_t': ' return 0;',
'size_t': ' return 0;',
'float': ' return 0.0f;',
'bool': ' return false;',
}

"""
max code len per line in hua_wei software programming specifications
"""
max_code_len_per_line = 100

"""
white_list_for_debug, include_dir_key_words is to
determines which header files to generate cc files from
when DEBUG on
"""
white_list_for_debug = ["tensorflow_parser.h", "caffe_parser.h"]
include_dir_key_words = ["parser"]
DEBUG = True


def need_generate_func(func_line):
"""
:param func_line:
:return:
"""
if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \
or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"):
return False
return True


def file_endswith_white_list_suffix(file):
"""
:param file:
:return:
"""
if DEBUG:
for suffix in white_list_for_debug:
if file.endswith(suffix):
return True
return False
else:
return True


"""
belows are patterns used for analyse .h file
"""
# pattern function
pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after
([a-zA-Z~_] # void int likely
.*
[)] #we find )
(?!.*{) # we do not want the case int abc() const
.*)
(;.*) #we want to find ; and after for we will replace these later
\n$
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

# pattern comment
pattern_comment = re.compile(r'^\s*//')
pattern_comment_2_start = re.compile(r'^\s*/[*]')
pattern_comment_2_end = re.compile(r'[*]/\s*$')
# pattern define
pattern_define = re.compile(r'^\s*#define')
pattern_define_return = re.compile(r'\\\s*$')
# blank line
pattern_blank_line = re.compile(r'^\s*$')
# virtual,explicit,friend,static
pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)')
# lead space
pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]')
# functions will have patterns such as func ( or func(
# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist
# format like :"operator = ()"
pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]')
# template
pattern_template = re.compile(r'^\s*template')
pattern_template_end = re.compile(r'>\s*$')
# namespace
pattern_namespace = re.compile(r'namespace.*{')
# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with
pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR)
# {}
pattern_start = re.compile('{')
pattern_end = re.compile('}')

line_index = 0


class H2CC(object):
def __init__(self, input_file, output_file, shared_includes_content):
"""
:param input_file:
:param output_file:
:param shared_includes_content:
"""
self.input_file = input_file
self.output_file = output_file
self.shared_includes_content = shared_includes_content
self.line_index = 0
self.input_fd = open(self.input_file, 'r')
self.input_content = self.input_fd.readlines()
self.output_fd = open(self.output_file, 'w')

# The state may be normal_now(in the middle of {}),class_now,namespace_now
self.stack = []
self.stack_class = []
self.stack_template = []
# record funcs generated by h2cc func
self.func_list_exist = []

def __del__(self):
self.input_fd.close()
self.output_fd.close()
del self.stack
del self.stack_class
del self.stack_template
del self.func_list_exist

def just_skip(self):
# skip blank line or comment
if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search(
self.input_content[self.line_index]): # /n or comment using //
self.line_index += 1
if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /*
while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */
self.line_index += 1
self.line_index += 1
# skip define
if pattern_define.search(self.input_content[self.line_index]):
while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search(
self.input_content[self.line_index]):
self.line_index += 1
self.line_index += 1

def write_inc_content(self):
for shared_include_content in self.shared_includes_content:
self.output_fd.write(shared_include_content)

def h2cc(self):
"""
:return:
"""
logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file)
global pattern_comment
global pattern_comment_2_start
global pattern_comment_2_end
global pattern_blank_line
global pattern_func
global pattern_keyword
global pattern_leading_space
global pattern_func_name
global pattern_template
global pattern_template_end
global pattern_namespace
global pattern_class
global pattern_start
global pattern_end
global line_index
# write inc content
self.write_inc_content()
# core processing cycle, process the input .h file by line
while self.line_index < len(self.input_content):
# handle comment and blank line
self.just_skip()

# match namespace
self.handle_namespace()

# match template
template_string = self.handle_template()
# match class
line = self.input_content[self.line_index]
match_class = pattern_class.search(line)
match_start = pattern_start.search(line)
handle_class_result = self.handle_class(template_string, line, match_start, match_class)
if handle_class_result == "continue":
continue

# match "}"
handle_stack_result = self.handle_stack(match_start)
if handle_stack_result == "continue":
continue
# handle func
handle_func1_result, line, start_i = self.handle_func1(line)
if handle_func1_result == "continue":
continue

# here means func is found
# delete key word
line = pattern_keyword.sub('', line)
logging.info("line[%s]", line)

# Class member function
# if friend we will not add class name
friend_match = re.search('friend ', line)
if len(self.stack_class) > 0 and not friend_match:
line, func_name = self.handle_class_member_func(line, template_string)
# Normal functions
else:
line, func_name = self.handle_normal_func(line, template_string)

need_generate = need_generate_func(line)
# func body
line += self.implement_function(line)
# comment
line = self.gen_comment(start_i) + line
# write to out file
self.write_func_content(line, func_name, need_generate)
# next loop
self.line_index += 1

logging.info('Added %s functions', len(self.func_list_exist))
logging.info('Successfully converted,please see ' + self.output_file)

def handle_func1(self, line):
"""
:param line:
:return:
"""
find1 = re.search('[(]', line)
if not find1:
self.line_index += 1
return "continue", line, None
find2 = re.search('[)]', line)
start_i = self.line_index
space_match = pattern_leading_space.search(line)
# deal with
# int abc(int a,
# int b)
if find1 and (not find2):
self.line_index += 1
line2 = self.input_content[self.line_index]
if space_match:
line2 = re.sub('^' + space_match.group(1), '', line2)
line += line2
while self.line_index < len(self.input_content) and (not re.search('[)]', line2)):
self.line_index += 1
line2 = self.input_content[self.line_index]
line2 = re.sub('^' + space_match.group(1), '', line2)
line += line2

match_start = pattern_start.search(self.input_content[self.line_index])
match_end = pattern_end.search(self.input_content[self.line_index])
if match_start: # like ) { or ) {} int the last line
if not match_end:
self.stack.append('normal_now')
ii = start_i
while ii <= self.line_index:
ii += 1
self.line_index += 1
return "continue", line, start_i
logging.info("line[%s]", line)
# ' int abc();'->'int abc()'
(line, match) = pattern_func.subn(r'\2\n', line)
logging.info("line[%s]", line)
# deal with case:
# 'int \n abc(int a, int b)'
if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]):
line = self.input_content[start_i - 1] + line
line = line.lstrip()
if not match:
self.line_index += 1
return "continue", line, start_i
return "pass", line, start_i

def handle_stack(self, match_start):
"""
:param match_start:
:return:
"""
line = self.input_content[self.line_index]
match_end = pattern_end.search(line)
if match_start:
self.stack.append('normal_now')
if match_end:
top_status = self.stack.pop()
if top_status == 'namespace_now':
self.output_fd.write(line + '\n')
elif top_status == 'class_now':
self.stack_class.pop()
self.stack_template.pop()
if match_start or match_end:
self.line_index += 1
return "continue"

if len(self.stack) > 0 and self.stack[-1] == 'normal_now':
self.line_index += 1
return "continue"
return "pass"

def handle_class(self, template_string, line, match_start, match_class):
"""
:param template_string:
:param line:
:param match_start:
:param match_class:
:return:
"""
if match_class: # we face a class
self.stack_template.append(template_string)
self.stack.append('class_now')
class_name = match_class.group(3)

# class template specializations: class A<u,Node<u> >
if '<' in class_name:
k = line.index('<')
fit = 1
for ii in range(k + 1, len(line)):
if line[ii] == '<':
fit += 1
if line[ii] == '>':
fit -= 1
if fit == 0:
break
class_name += line[k + 1:ii + 1]
logging.info('class_name[%s]', class_name)
self.stack_class.append(class_name)
while not match_start:
self.line_index += 1
line = self.input_content[self.line_index]
match_start = pattern_start.search(line)
self.line_index += 1
return "continue"
return "pass"

def handle_template(self):
line = self.input_content[self.line_index]
match_template = pattern_template.search(line)
template_string = ''
if match_template:
match_template_end = pattern_template_end.search(line)
template_string = line
while not match_template_end:
self.line_index += 1
line = self.input_content[self.line_index]
template_string += line
match_template_end = pattern_template_end.search(line)
self.line_index += 1
return template_string

def handle_namespace(self):
line = self.input_content[self.line_index]
match_namespace = pattern_namespace.search(line)
if match_namespace: # we face namespace
self.output_fd.write(line + '\n')
self.stack.append('namespace_now')
self.line_index += 1

def handle_normal_func(self, line, template_string):
template_line = ''
self.stack_template.append(template_string)
if self.stack_template[-1] != '':
template_line = re.sub(r'\s*template', 'template', self.stack_template[-1])
# change '< class T = a, class U = A(3)>' to '<class T, class U>'
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
template_line = re.sub(r'\s*=.*,', ',', template_line)
template_line = re.sub(r'\s*=.*', '', template_line)
line = re.sub(r'\s*=.*,', ',', line)
line = re.sub(r'\s*=.*\)', ')', line)
line = template_line + line
self.stack_template.pop()
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
logging.info("line[%s]", line)
logging.info("func_name[%s]", func_name)
return line, func_name

def handle_class_member_func(self, line, template_string):
template_line = ''
x = ''
if template_string != '':
template_string = re.sub(r'\s*template', 'template', template_string)
template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string)
template_string = re.sub(r'\s*=.*,', ',', template_string)
template_string = re.sub(r'\s*=.*', '', template_string)
if self.stack_template[-1] != '':
if not (re.search(r'<\s*>', stack_template[-1])):
template_line = re.sub(r'^\s*template', 'template', stack_template[-1])
if not (re.search(r'<.*>', self.stack_class[-1])):
# for x we get like template<class T, typename U> -> <T,U>
x = re.sub(r'template\s*<', '<', template_line) # remove template -> <class T, typename U>
x = re.sub(r'\n', '', x)
x = re.sub(r'\s*=.*,', ',', x)
x = re.sub(r'\s*=.*\>', '>', x)
x = x.rstrip() # remove \n
x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '',
x) # remove class,typename -> <T, U>
x = re.sub(r'<\s+', '<', x)
x = re.sub(r'\s+>', '>', x)
x = re.sub(r'\s+,', ',', x)
x = re.sub(r',\s+', ', ', x)
line = re.sub(r'\s*=\s+0', '', line)
line = re.sub(r'\s*=\s+.*,', ',', line)
line = re.sub(r'\s*=\s+.*\)', ')', line)
logging.info("x[%s]\nline[%s]", x, line)
# if the function is long, void ABC::foo()
# breaks into two lines void ABC::\n foo()
temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1)
if len(temp_line) > max_code_len_per_line:
line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1)
else:
line = temp_line
logging.info("line[%s]", line)
# add template as the above if there is one
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
template_line = re.sub(r'\s*=.*,', ',', template_line)
template_line = re.sub(r'\s*=.*', '', template_line)
line = template_line + template_string + line
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
logging.info("line[%s]", line)
logging.info("func_name[%s]", func_name)
return line, func_name

def write_func_content(self, content, func_name, need_generate):
if not (func_name in self.func_list_exist) and need_generate:
self.output_fd.write(content)
self.func_list_exist.append(func_name)
logging.info('add func:[%s]', func_name)

def gen_comment(self, start_i):
comment_line = ''
# Function comments are on top of function declarations, copy them over
k = start_i - 1 # one line before this func start
if pattern_template.search(self.input_content[k]):
k -= 1
if pattern_comment_2_end.search(self.input_content[k]):
comment_line = self.input_content[k].lstrip()
while not pattern_comment_2_start.search(self.input_content[k]):
k -= 1
comment_line = self.input_content[k].lstrip() + comment_line
else:
for j in range(k, 0, -1):
c_line = self.input_content[j]
if pattern_comment.search(c_line):
c_line = re.sub(r'\s*//', '//', c_line)
comment_line = c_line + comment_line
else:
break
return comment_line

@staticmethod
def implement_function(func):
function_def = ''
function_def += '{\n'

all_items = func.split()
start = 0
return_type = all_items[start]
if return_type == "const":
start += 1
return_type = all_items[start]
if return_type.startswith(('std::map', 'std::set', 'std::vector')):
return_type = "std::map"
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')):
return_type = "Ptr"
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
return_type += "&"
if RETURN_STATEMENTS.__contains__(return_type):
function_def += RETURN_STATEMENTS[return_type]
else:
logging.warning("Unhandled return type[%s]", return_type)

function_def += '\n'
function_def += '}\n'
function_def += '\n'
return function_def


def collect_header_files(path):
"""
:param path:
:return:
"""
header_files = []
shared_includes_content = []
for root, dirs, files in os.walk(path):
files.sort()
for file in files:
if file.find("git") >= 0:
continue
if not file.endswith('.h'):
continue
file_path = os.path.join(root, file)
file_path = file_path.replace('\\', '/')
header_files.append(file_path)
include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:])
shared_includes_content.append(include_str)
# for acl error code
shared_includes_content.append('#include <iostream>\n')
shared_includes_content.append('const int ACL_ERROR_COMPILING_STUB_MODE = 100039;\n')
return header_files, shared_includes_content


def generate_stub_file(inc_dir, out_cc_dir):
"""
:param inc_dir:
:param out_cc_dir:
:return:
"""
target_header_files, shared_includes_content = collect_header_files(inc_dir)
for header_file in target_header_files:
if not file_endswith_white_list_suffix(header_file):
continue
cc_file = re.sub('.h*$', '.cc', header_file)
h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content)
h_2_cc.h2cc()


def gen_code(inc_dir, out_cc_dir):
"""
:param inc_dir:
:param out_cc_dir:
:return:
"""
if not inc_dir.endswith('/'):
inc_dir += '/'
if not out_cc_dir.endswith('/'):
out_cc_dir += '/'
for include_dir_key_word in include_dir_key_words:
generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir)


if __name__ == '__main__':
inc_dir = sys.argv[1]
out_cc_dir = sys.argv[2]
gen_code(inc_dir, out_cc_dir)

Loading…
Cancel
Save