Merge pull request !3 from taoxiangdong/masterpull/3/MERGE
| @@ -29,9 +29,9 @@ if (ENABLE_OPEN_SRC) | |||
| find_module(ge_common libge_common.so ${ASCEND_RUNTIME_DIR}) | |||
| find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) | |||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
| #set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
| add_subdirectory(metadef) | |||
| #add_subdirectory(metadef) | |||
| #add_subdirectory(metadef/graph) | |||
| #add_subdirectory(metadef/register) | |||
| @@ -12,7 +12,7 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||
| endif() | |||
| ExternalProject_Add(gflags_build | |||
| #URL http://tfk.inhuawei.com/api/containers/container1/download/protobuf-3.8.0.tar.gz | |||
| URL https://github.com/gflags/gflags/archive/v2.2.2.tar.gz | |||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||
| SOURCE_DIR ${PARSER_DIR}/../third_party/gflags/src/gflags-2.2.2 | |||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR> | |||
| @@ -6,7 +6,7 @@ include(ExternalProject) | |||
| set(JSON_SRC_DIR ${PARSER_DIR}/../third_party/json/include) | |||
| ExternalProject_Add(json_build | |||
| #URL https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip | |||
| URL https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip | |||
| #URL /home/txd/workspace/cloud_code/pkg/include.zip | |||
| SOURCE_DIR ${JSON_SRC_DIR} | |||
| CONFIGURE_COMMAND "" | |||
| @@ -7,7 +7,7 @@ set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) | |||
| file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) | |||
| ExternalProject_Add(onnx | |||
| #URL https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz | |||
| URL https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz | |||
| URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz | |||
| #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 | |||
| #SOURCE_DIR ${ONNX_SRC_DIR} | |||
| @@ -14,7 +14,7 @@ endif() | |||
| set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 $<$<STREQUAL:${PRODUCT_SIDE},host>:-D_GLIBCXX_USE_CXX11_ABI=0> -O2") | |||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||
| ExternalProject_Add(protobuf_build | |||
| #URL http://tfk.inhuawei.com/api/containers/container1/download/protobuf-3.8.0.tar.gz | |||
| URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz | |||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||
| #SOURCE_DIR ${PARSER_DIR}/third_party/protobuf/src/protobuf-3.8.0 | |||
| DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E copy_directory ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0 <SOURCE_DIR> | |||
| @@ -12,7 +12,7 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst | |||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||
| set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) | |||
| ExternalProject_Add(protobuf_static_build | |||
| #URL http://tfk.inhuawei.com/api/containers/container1/download/protobuf-3.8.0.tar.gz | |||
| URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz | |||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||
| SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | |||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | |||
| @@ -15,7 +15,7 @@ endif() | |||
| set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") | |||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||
| ExternalProject_Add(protoc_build | |||
| #URL http://tfk.inhuawei.com/api/containers/container1/download/protobuf-3.8.0.tar.gz | |||
| URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz | |||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||
| SOURCE_DIR ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | |||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake | |||
| @@ -11,7 +11,7 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||
| endif() | |||
| ExternalProject_Add(c_sec_build | |||
| #URL http://tfk.inhuawei.com/api/containers/container1/download/protobuf-3.8.0.tar.gz | |||
| URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz | |||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||
| SOURCE_DIR ${PARSER_DIR}/../libc_sec | |||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | |||
| @@ -59,8 +59,11 @@ target_include_directories(fmk_parser PRIVATE | |||
| ${PARSER_DIR} | |||
| ${PARSER_DIR}/inc | |||
| ${PARSER_DIR}/parser | |||
| ${PARSER_DIR}/../ge | |||
| ${PARSER_DIR}/../inc | |||
| ${PARSER_DIR}/../inc/framework | |||
| ${PARSER_DIR}/../inc/common/util | |||
| ${PARSER_DIR}/../inc/external | |||
| ${METADEF_DIR}/inc | |||
| ${METADEF_DIR}/inc/graph | |||
| ${METADEF_DIR}/inc/register | |||
| @@ -1,27 +0,0 @@ | |||
| set(PROTO_LIST | |||
| "${METADEF_DIR}/proto/caffe/caffe.proto" | |||
| ) | |||
| protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
| ############ lib_caffe_parser.so ############ | |||
| add_library(_caffe_parser SHARED ${PROTO_SRCS}) | |||
| target_include_directories(_caffe_parser PRIVATE | |||
| ${CMAKE_CURRENT_LIST_DIR} | |||
| ) | |||
| target_link_libraries(_caffe_parser PRIVATE | |||
| $<BUILD_INTERFACE:intf_pub> | |||
| -Wl,--no-as-needed | |||
| protobuf | |||
| -Wl,--as-needed | |||
| ) | |||
| ############ install ############ | |||
| set(INSTALL_BASE_DIR "") | |||
| set(INSTALL_LIBRARY_DIR lib) | |||
| install(TARGETS _caffe_parser OPTIONAL | |||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||
| ) | |||
| @@ -83,8 +83,8 @@ Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr | |||
| } | |||
| bool bias_en = false; | |||
| int start_pos = layer->bottom_size(); | |||
| bool update_in_turn = (static_cast<int64_t>(op->GetAllInputsSize()) == (layer->bottom_size() + layer->blobs_size())); | |||
| int start_pos = layer->bottom_size(); | |||
| for (int i = 0; i < layer->blobs_size(); ++i) { | |||
| ge::GeTensorPtr weight = ge::parser::MakeShared<ge::GeTensor>(); | |||
| GE_CHECK_NOTNULL(weight); | |||
| @@ -107,6 +107,10 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||
| return ret; | |||
| } | |||
| GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); | |||
| if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) { | |||
| GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| return ge::SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -803,10 +807,6 @@ Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter | |||
| Status CaffeModelParser::AddBlobsToMap(const domi::caffe::LayerParameter &layer, | |||
| std::map<std::string, std::string> &inplace_blob_name_remapping) { | |||
| if (layer.type() == ge::parser::NETOUTPUT) { | |||
| return SUCCESS; | |||
| } | |||
| if (layer.top_size() <= 0) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19011", {"opname"}, {layer.name()}); | |||
| GELOGE(FAILED, "The output size of layer %s needs to be greater than zero.", layer.name().c_str()); | |||
| @@ -1085,34 +1085,6 @@ Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const dom | |||
| "while it's original input num is: %d", | |||
| layer.bottom_size()); | |||
| } | |||
| // Netoutput node processing | |||
| if (op_desc->GetType() == ge::parser::NETOUTPUT) { | |||
| size_t input_output_tensor_num = 0; | |||
| if (!ge::GetParserContext().user_out_nodes.empty()) { | |||
| // User specified output | |||
| input_output_tensor_num = ge::GetParserContext().user_out_nodes.size(); | |||
| } else { | |||
| for (auto t_iter = top_blobs_map_.begin(); t_iter != top_blobs_map_.end(); t_iter++) { | |||
| auto b_iter = bottom_blobs_map_.find(t_iter->first); | |||
| // Find the output node of the network | |||
| if (b_iter == bottom_blobs_map_.end()) { | |||
| input_output_tensor_num += top_blobs_map_[t_iter->first].size(); | |||
| } | |||
| } | |||
| } | |||
| // add tensordesc | |||
| GELOGD( | |||
| "Current op type is NETOUTPUT, add additional input&output num: %zu." | |||
| "while it's original input num is: %d, output num is: %d", | |||
| input_output_tensor_num, layer.bottom_size(), output_tensor_num); | |||
| for (size_t i = 0; i < input_output_tensor_num; i++) { | |||
| ge::GeTensorDesc input_tensor; | |||
| GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); | |||
| ge::GeTensorDesc output_tensor; | |||
| GE_RETURN_IF_ERROR(op_desc->AddOutputDesc(output_tensor)); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -1140,10 +1112,12 @@ Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| auto valid_input_size = layer.bottom_size(); | |||
| auto blob_size = layer.blobs_size(); | |||
| GELOGI("After GetOpDescFromOperator op[%s] type[%s] have all input size: %zu, caffe_input_size:%d output size: %zu", | |||
| GELOGI("After GetOpDescFromOperator op[%s] type[%s] have all input size: %zu, " | |||
| "caffe_input_size:%d blob_size %d output size: %zu", | |||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||
| op_desc->GetAllInputsSize(), valid_input_size, op_desc->GetOutputsSize()); | |||
| bool update_in_turn = (static_cast<int64_t>(op_desc->GetAllInputsSize()) == (valid_input_size + blob_size); | |||
| op_desc->GetAllInputsSize(), valid_input_size, | |||
| blob_size, op_desc->GetOutputsSize()); | |||
| bool update_in_turn = (static_cast<int64_t >(op_desc->GetAllInputsSize()) == (valid_input_size + blob_size)); | |||
| for (int i = 0; i < valid_input_size; i++) { | |||
| ge::GeTensorDesc input_tensor; | |||
| std::string input_name; | |||
| @@ -1151,8 +1125,8 @@ Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const | |||
| // Below cases are supported fow now when there are optional inputs | |||
| // x means optional, o means requierd input | |||
| // a. ooxxx, number of o and x>=layer.bottom_size+layer.blobs_size>=number of o | |||
| // b. oxoxoxox, layer.bottom_size+layer.blobs_size>=number of o | |||
| // c. oxoxoxox, layer.bottom_size+layer.blobs_size>=number of o and x | |||
| // b. oxoxoxox, layer.bottom_size+layer.blobs_size=number of o | |||
| // c. oxoxoxox, layer.bottom_size+layer.blobs_size=number of o and x | |||
| if (update_in_turn) { | |||
| ret = op_desc->UpdateInputDesc(op_desc->GetInputNameByIndex(static_cast<uint32_t>(i)), input_tensor); | |||
| } else { | |||
| @@ -1164,14 +1138,16 @@ Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const | |||
| } | |||
| } | |||
| GELOGI("op [%s], type[%s], update input(%d) with name %s %s", op_desc->GetName().c_str(), | |||
| op_desc->GetType().c_str(), i, input_name.c_str(), ret == ge::GRAPH_SUCCESS ? "success" : "failed"); | |||
| op_desc->GetType().c_str(), i, input_name.c_str(), ret == ge::GRAPH_SUCCESS ? "success" : "failed"); | |||
| } | |||
| for (int i = 0; i < layer.top_size(); i++) { | |||
| ge::GeTensorDesc output_tensor; | |||
| auto ret = op_desc->UpdateOutputDesc(op_desc->GetOutputNameByIndex(static_cast<uint32_t>(i)), output_tensor); | |||
| GELOGI("op [%s], type[%s], update output(%d) with name %s %s", op_desc->GetName().c_str(), | |||
| op_desc->GetType().c_str(), i, op_desc->GetOutputNameByIndex(i).c_str(), ret == ge::GRAPH_SUCCESS ? "success" : "failed"); | |||
| GELOGI("op [%s], type[%s], update output(%d) with name %s %s", | |||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||
| i, op_desc->GetOutputNameByIndex(i).c_str(), | |||
| ret == ge::GRAPH_SUCCESS ? "success" : "failed"); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| @@ -1266,46 +1242,33 @@ bool CaffeModelParser::IsOutputTop(const string &op_name, const int32_t index) { | |||
| return ret; | |||
| } | |||
| Status CaffeModelParser::AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| ge::NodePtr net_output_node = graph->FindFirstNodeMatchType(ge::parser::NETOUTPUT); | |||
| if (net_output_node == nullptr) { | |||
| GELOGE(INTERNAL_ERROR, "Can not find netoutput node."); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| uint32_t net_output_num = net_output_node->GetAllInDataAnchorsSize(); | |||
| Status CaffeModelParser::AddUserOutNodesTop() { | |||
| int32_t index = 0; | |||
| const std::vector<std::pair<std::string, int32_t>> &user_out_nodes = ge::GetParserContext().user_out_nodes; | |||
| int net_output_num = user_out_nodes.size(); | |||
| for (const auto &out_pair : user_out_nodes) { | |||
| auto node_iter = node_map.find(out_pair.first); | |||
| auto layer_iter = layer_tops_map_.find(out_pair.first); | |||
| GELOGI("Add to output, node name: %s", out_pair.first.c_str()); | |||
| if (node_iter != node_map.end()) { | |||
| if ((static_cast<uint32_t>(out_pair.second) >= node_iter->second->GetAllOutDataAnchorsSize()) || | |||
| (static_cast<uint32_t>(index) >= net_output_num)) { | |||
| if (layer_iter != layer_tops_map_.end()) { | |||
| if (static_cast<uint32_t>(out_pair.second) >= (layer_iter->second).size()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E11016", {"opname", "outputindex", "totlaloutputindex", "inputindex", "totlalinputindex"}, | |||
| {out_pair.first.c_str(), std::to_string(out_pair.second), | |||
| std::to_string(node_iter->second->GetAllOutDataAnchorsSize()), std::to_string(index), | |||
| std::to_string((layer_iter->second).size()), std::to_string(index), | |||
| std::to_string(net_output_num)}); | |||
| GELOGE(INTERNAL_ERROR, | |||
| "Add op %s to NetOutput faild, current node output index:%d should < %u. NetOutput" | |||
| "input_index:%d should < %u.", | |||
| out_pair.first.c_str(), out_pair.second, node_iter->second->GetAllOutDataAnchorsSize(), index, | |||
| out_pair.first.c_str(), out_pair.second, (layer_iter->second).size(), index, | |||
| net_output_num); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| GELOGD("Start add edge for user out node: From %s:%d To %s:%d.", node_iter->second->GetName().c_str(), | |||
| out_pair.second, net_output_node->GetName().c_str(), index); | |||
| ge::OutDataAnchorPtr out_archor_ptr = node_iter->second->GetOutDataAnchor(out_pair.second); | |||
| GE_CHECK_NOTNULL(out_archor_ptr); | |||
| ge::InDataAnchorPtr in_archor_ptr = net_output_node->GetInDataAnchor(index); | |||
| GE_CHECK_NOTNULL(in_archor_ptr); | |||
| if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E11013", {"opname1", "opname2"}, | |||
| {node_iter->second->GetName(), net_output_node->GetName()}); | |||
| GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", node_iter->second->GetName().c_str(), | |||
| net_output_node->GetName().c_str()); | |||
| return INTERNAL_ERROR; | |||
| string top_name = layer_iter->second[out_pair.second]; | |||
| auto top_node_iter = node_map.find(out_pair.first); | |||
| if (top_node_iter != node_map.end()) { | |||
| ge::GetParserContext().out_top_names.push_back(top_name); | |||
| GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str()); | |||
| } | |||
| ++index; | |||
| } else { | |||
| @@ -1317,13 +1280,7 @@ Status CaffeModelParser::AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph) { | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeModelParser::AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| ge::NodePtr node = graph->FindFirstNodeMatchType(ge::parser::NETOUTPUT); | |||
| GE_RETURN_WITH_LOG_IF_FALSE(node != nullptr, "Net without output, some phase failed in front."); | |||
| int32_t index = 0; | |||
| Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_message) { | |||
| for (int32_t i = 0; i < proto_message.layer_size(); i++) { | |||
| const domi::caffe::LayerParameter &layer = proto_message.layer(i); | |||
| @@ -1333,6 +1290,7 @@ Status CaffeModelParser::AddEdge4Output(const domi::caffe::NetParameter &proto_m | |||
| for (int i = 0; i < layer.top_size(); i++) { | |||
| string top = layer.top(i); | |||
| string top_origin = top; | |||
| // Handling 'inplace' scenarios | |||
| if (IsInplaceTopBlob(layer, top)) { | |||
| top = RemapTopNameByLayer(layer, top, i); | |||
| @@ -1354,21 +1312,9 @@ Status CaffeModelParser::AddEdge4Output(const domi::caffe::NetParameter &proto_m | |||
| auto top_node_iter = node_map.find(layer.name()); | |||
| GELOGI("output in top_blob: %s", layer.name().c_str()); | |||
| if (top_node_iter != node_map.end()) { | |||
| // add edge | |||
| // Output node, output index, input node, input index | |||
| GELOGD("Start add edge for out node: From %s:%d To %s:%d.", top_node_iter->second->GetName().c_str(), i, | |||
| node->GetName().c_str(), index); | |||
| ge::OutDataAnchorPtr out_archor_ptr = top_node_iter->second->GetOutDataAnchor(i); | |||
| GE_CHECK_NOTNULL(out_archor_ptr); | |||
| ge::InDataAnchorPtr in_archor_ptr = node->GetInDataAnchor(index); | |||
| GE_CHECK_NOTNULL(in_archor_ptr); | |||
| GE_IF_BOOL_EXEC(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E11013", {"opname1", "opname2"}, {top_node_iter->second->GetName(), node->GetName()}); | |||
| GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to to op[%s].", | |||
| top_node_iter->second->GetName().c_str(), node->GetName().c_str()); | |||
| return INTERNAL_ERROR;); | |||
| index++; | |||
| ge::GetParserContext().out_top_names.push_back(top_origin); | |||
| ge::GetParserContext().default_out_nodes.push_back(std::make_pair(layer.name(), (int32_t)i)); | |||
| GELOGI("The top of out node [%s] is [%s]", layer.name().c_str(), top_origin.c_str()); | |||
| } | |||
| } | |||
| } | |||
| @@ -1482,12 +1428,6 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||
| CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true; | |||
| GELOGE(FAILED, "ParseInput ret fail.")); | |||
| // build output layer | |||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
| GE_CHECK_NOTNULL(layer); | |||
| layer->set_name(graph->GetName() + "_" + ge::parser::NODE_NAME_NET_OUTPUT); | |||
| layer->set_type(ge::parser::NETOUTPUT); | |||
| int32_t layer_count = proto_message.layer_size(); | |||
| std::map<std::string, std::string> inplace_blob_name_remapping; | |||
| // Map of operator name and occurrence times | |||
| @@ -1553,9 +1493,9 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail."); | |||
| if (!(ge::GetParserContext().user_out_nodes.empty())) { | |||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed."); | |||
| GE_RETURN_WITH_LOG_IF_ERROR(AddUserOutNodesTop(), "Caffe parser add top_name for user out nodes failed."); | |||
| } else { | |||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail."); | |||
| GE_RETURN_WITH_LOG_IF_ERROR(AddOutputTop(proto_message), "Caffe parser add top_name for output fail."); | |||
| } | |||
| GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail."); | |||
| @@ -1657,12 +1597,6 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||
| CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true; | |||
| GELOGE(FAILED, "ParseInput ret fail.")); | |||
| // build output layer | |||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
| GE_CHECK_NOTNULL(layer); | |||
| layer->set_name(graph->GetName() + "_" + ge::parser::NODE_NAME_NET_OUTPUT); | |||
| layer->set_type(ge::parser::NETOUTPUT); | |||
| int32_t layer_count = proto_message.layer_size(); | |||
| if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) { | |||
| @@ -1735,12 +1669,11 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail."); | |||
| if (!(ge::GetParserContext().user_out_nodes.empty())) { | |||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed."); | |||
| GE_RETURN_WITH_LOG_IF_ERROR(AddUserOutNodesTop(), "Caffe parser add top_name for user out nodes failed."); | |||
| } else { | |||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail."); | |||
| GE_RETURN_WITH_LOG_IF_ERROR(AddOutputTop(proto_message), "Caffe parser add top_name for output fail."); | |||
| } | |||
| GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail."); | |||
| GE_RETURN_WITH_LOG_IF_ERROR(GetLeafNodeTops(graph), "Caffe parser get out nodes top names failed."); | |||
| auto nodes = graph->GetDirectNode(); | |||
| GELOGI("graph node size = %zu.", nodes.size()); | |||
| @@ -2451,27 +2384,6 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter ¶m, ge::Co | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeModelParser::GetLeafNodeTops(ge::ComputeGraphPtr &graph) { | |||
| auto netout = graph->FindFirstNodeMatchType(ge::parser::NETOUTPUT); | |||
| GE_CHECK_NOTNULL(netout); | |||
| for (const auto &in_anchor : netout->GetAllInDataAnchors()) { | |||
| auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL(peer_out_data_anchor); | |||
| auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); | |||
| GE_CHECK_NOTNULL(peer_out_data_node); | |||
| int idx = peer_out_data_anchor->GetIdx(); | |||
| string node_name = peer_out_data_node->GetName(); | |||
| auto layer_iter = layer_tops_map_.find(node_name); | |||
| if (layer_iter != layer_tops_map_.end()) { | |||
| ge::GetParserContext().out_top_names.push_back(layer_iter->second[idx]); | |||
| GELOGI("The top of out node [%s] is [%s]", node_name.c_str(), layer_iter->second[idx].c_str()); | |||
| } else { | |||
| GELOGW("The out node [%s] can not find its top.", node_name.c_str()); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { | |||
| return SUCCESS; | |||
| } | |||
| @@ -279,12 +279,12 @@ class CaffeModelParser : public domi::ModelParser { | |||
| /** | |||
| * @ingroup domi_omg | |||
| * @brief Add edge information to graph | |||
| * @param [in|out] graph graph for saving model information | |||
| * @brief Add top name information to graph | |||
| * @param [in|out] proto_message | |||
| * @return SUCCESS add successfully | |||
| * @return FAILED add failed | |||
| */ | |||
| Status AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph); | |||
| Status AddOutputTop(const domi::caffe::NetParameter &proto_message); | |||
| /** | |||
| * @ingroup domi_omg | |||
| @@ -324,7 +324,7 @@ class CaffeModelParser : public domi::ModelParser { | |||
| Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, | |||
| const string &op_type); | |||
| Status AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph); | |||
| Status AddUserOutNodesTop(); | |||
| std::string RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name, int index); | |||
| @@ -335,8 +335,6 @@ class CaffeModelParser : public domi::ModelParser { | |||
| Status ParseOpParam(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op, | |||
| std::shared_ptr<ge::OpParser> &op_parser); | |||
| Status GetLeafNodeTops(ge::ComputeGraphPtr &graph); | |||
| void SaveOrigionLayerTops(domi::caffe::LayerParameter &layer); | |||
| Status ReorderInput(domi::caffe::NetParameter &net); | |||
| @@ -1,190 +1,190 @@ | |||
| syntax = "proto3"; | |||
| package ge.proto; | |||
| enum DataType | |||
| { | |||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||
| DT_FLOAT = 1; // float type | |||
| DT_FLOAT16 = 2; // fp16 type | |||
| DT_INT8 = 3; // int8 type | |||
| DT_UINT8 = 4; // uint8 type | |||
| DT_INT16 = 5; // int16 type | |||
| DT_UINT16 = 6; // uint16 type | |||
| DT_INT32 = 7; // | |||
| DT_INT64 = 8; // int64 type | |||
| DT_UINT32 = 9; // unsigned int32 | |||
| DT_UINT64 = 10; // unsigned int64 | |||
| DT_BOOL = 11; // bool type | |||
| DT_DOUBLE = 12; // double type | |||
| DT_STRING = 13; // string type | |||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||
| DT_COMPLEX64 = 16; // complex64 type | |||
| DT_COMPLEX128 = 17; // complex128 type | |||
| DT_QINT8 = 18; // qint8 type | |||
| DT_QINT16 = 19; // qint16 type | |||
| DT_QINT32 = 20; // qint32 type | |||
| DT_QUINT8 = 21; // quint8 type | |||
| DT_QUINT16 = 22; // quint16 type | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| } | |||
| message AttrDef | |||
| { | |||
| message ListValue | |||
| { | |||
| enum ListValueType{ | |||
| VT_LIST_NONE = 0; | |||
| VT_LIST_STRING = 1; | |||
| VT_LIST_INT = 2; | |||
| VT_LIST_FLOAT = 3; | |||
| VT_LIST_BOOL = 4; | |||
| VT_LIST_BYTES = 5; | |||
| VT_LIST_TENSOR_DESC = 6; | |||
| VT_LIST_TENSOR = 7; | |||
| VT_LIST_GRAPH = 8; | |||
| VT_LIST_NAMED_ATTRS = 9; | |||
| VT_LIST_DATA_TYPE = 10; | |||
| } | |||
| repeated bytes s = 2; // "list(string)" | |||
| repeated int64 i = 3; // "list(int)" | |||
| repeated float f = 4; // "list(float)" | |||
| repeated bool b = 5; // "list(bool)" | |||
| repeated bytes bt = 7; | |||
| repeated TensorDescriptor td = 8; | |||
| repeated TensorDef t = 9; | |||
| repeated GraphDef g = 10; | |||
| repeated NamedAttrs na = 11; | |||
| repeated int64 dt = 12; // list ge::DataType | |||
| ListValueType val_type = 20; | |||
| } | |||
| message ListListInt{ | |||
| message ListInt{ | |||
| repeated int64 list_i = 1; // list int | |||
| } | |||
| repeated ListInt list_list_i = 1; // list list int | |||
| } | |||
| oneof value | |||
| { | |||
| bytes s = 2; // "string" | |||
| int64 i = 3; // "int" | |||
| float f = 4; // "float" | |||
| bool b = 5; // "bool" | |||
| bytes bt = 7; | |||
| ListValue list = 1; // any "list(...)" | |||
| NamedAttrs func = 10; // Used to support attr nesting | |||
| TensorDescriptor td = 11; // GeTensorDesc type | |||
| TensorDef t = 12; // GeTensor type | |||
| GraphDef g = 13; // Graph type | |||
| ListListInt list_list_int = 14; // List List Int type | |||
| int64 dt = 15; // ge::DataType | |||
| } | |||
| } | |||
| // A list of attr names and their values. The whole list is attached | |||
| // with a string name. E.g., MatMul[T=float]. | |||
| message NamedAttrs | |||
| { | |||
| string name = 1; | |||
| map<string, AttrDef> attr = 2; | |||
| } | |||
| // Shape / dimension description, using row-major order | |||
| message ShapeDef | |||
| { | |||
| repeated int64 dim = 1; // Size of each dimension | |||
| } | |||
| // Multidimensional data description | |||
| message TensorDescriptor | |||
| { | |||
| string name = 1; // Optional parameter, tensor name | |||
| DataType dtype = 2; // tensor datatype | |||
| ShapeDef shape = 3; // Shape / dimension | |||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||
| bool has_out_attr = 9; | |||
| int64 size = 10; | |||
| int64 weight_size = 11; | |||
| bool reuse_input = 12; | |||
| bool output_tensor = 13; | |||
| string device_type = 14; | |||
| bool input_tensor =15; | |||
| int64 real_dim_cnt = 16; | |||
| int64 reuse_input_index = 17; | |||
| int64 data_offset = 18; | |||
| int64 cmps_size = 19; | |||
| string cmps_tab = 20; | |||
| int64 cmps_tab_offset = 21; | |||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||
| } | |||
| // GeTensor definition | |||
| message TensorDef | |||
| { | |||
| TensorDescriptor desc = 1; // Tensor description | |||
| bytes data = 2; // Tensor data | |||
| } | |||
| // Operator description | |||
| message OpDef | |||
| { | |||
| string name = 1; // name | |||
| string type = 2; // type | |||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||
| bool has_out_attr = 20; | |||
| int64 id = 21; | |||
| int64 stream_id =22; | |||
| repeated string input_name = 23; | |||
| repeated string src_name = 24; | |||
| repeated int64 src_index = 25; | |||
| repeated string dst_name = 26; | |||
| repeated int64 dst_index = 27; | |||
| repeated int64 input_i = 28; | |||
| repeated int64 output_i = 29; | |||
| repeated int64 workspace = 30; | |||
| repeated int64 workspace_bytes = 31; | |||
| repeated bool is_input_const = 32; | |||
| repeated TensorDescriptor input_desc = 33; | |||
| repeated TensorDescriptor output_desc = 34; | |||
| repeated string subgraph_name = 35; | |||
| } | |||
| // Graph definition | |||
| message GraphDef | |||
| { | |||
| string name = 1; // name | |||
| repeated string input = 4; // Graph input | |||
| repeated string output = 5; // Graph output | |||
| repeated OpDef op = 6; // List of operators | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| // model definition | |||
| message ModelDef | |||
| { | |||
| string name = 1; // name | |||
| uint32 version = 2; // IR Proto verion | |||
| string custom_version = 3; // User model version number, passed in by user | |||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| syntax = "proto3"; | |||
| package ge.proto; | |||
| enum DataType | |||
| { | |||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||
| DT_FLOAT = 1; // float type | |||
| DT_FLOAT16 = 2; // fp16 type | |||
| DT_INT8 = 3; // int8 type | |||
| DT_UINT8 = 4; // uint8 type | |||
| DT_INT16 = 5; // int16 type | |||
| DT_UINT16 = 6; // uint16 type | |||
| DT_INT32 = 7; // | |||
| DT_INT64 = 8; // int64 type | |||
| DT_UINT32 = 9; // unsigned int32 | |||
| DT_UINT64 = 10; // unsigned int64 | |||
| DT_BOOL = 11; // bool type | |||
| DT_DOUBLE = 12; // double type | |||
| DT_STRING = 13; // string type | |||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||
| DT_COMPLEX64 = 16; // complex64 type | |||
| DT_COMPLEX128 = 17; // complex128 type | |||
| DT_QINT8 = 18; // qint8 type | |||
| DT_QINT16 = 19; // qint16 type | |||
| DT_QINT32 = 20; // qint32 type | |||
| DT_QUINT8 = 21; // quint8 type | |||
| DT_QUINT16 = 22; // quint16 type | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| } | |||
| message AttrDef | |||
| { | |||
| message ListValue | |||
| { | |||
| enum ListValueType{ | |||
| VT_LIST_NONE = 0; | |||
| VT_LIST_STRING = 1; | |||
| VT_LIST_INT = 2; | |||
| VT_LIST_FLOAT = 3; | |||
| VT_LIST_BOOL = 4; | |||
| VT_LIST_BYTES = 5; | |||
| VT_LIST_TENSOR_DESC = 6; | |||
| VT_LIST_TENSOR = 7; | |||
| VT_LIST_GRAPH = 8; | |||
| VT_LIST_NAMED_ATTRS = 9; | |||
| VT_LIST_DATA_TYPE = 10; | |||
| } | |||
| repeated bytes s = 2; // "list(string)" | |||
| repeated int64 i = 3; // "list(int)" | |||
| repeated float f = 4; // "list(float)" | |||
| repeated bool b = 5; // "list(bool)" | |||
| repeated bytes bt = 7; | |||
| repeated TensorDescriptor td = 8; | |||
| repeated TensorDef t = 9; | |||
| repeated GraphDef g = 10; | |||
| repeated NamedAttrs na = 11; | |||
| repeated int64 dt = 12; // list ge::DataType | |||
| ListValueType val_type = 20; | |||
| } | |||
| message ListListInt{ | |||
| message ListInt{ | |||
| repeated int64 list_i = 1; // list int | |||
| } | |||
| repeated ListInt list_list_i = 1; // list list int | |||
| } | |||
| oneof value | |||
| { | |||
| bytes s = 2; // "string" | |||
| int64 i = 3; // "int" | |||
| float f = 4; // "float" | |||
| bool b = 5; // "bool" | |||
| bytes bt = 7; | |||
| ListValue list = 1; // any "list(...)" | |||
| NamedAttrs func = 10; // Used to support attr nesting | |||
| TensorDescriptor td = 11; // GeTensorDesc type | |||
| TensorDef t = 12; // GeTensor type | |||
| GraphDef g = 13; // Graph type | |||
| ListListInt list_list_int = 14; // List List Int type | |||
| int64 dt = 15; // ge::DataType | |||
| } | |||
| } | |||
| // A list of attr names and their values. The whole list is attached | |||
| // with a string name. E.g., MatMul[T=float]. | |||
| message NamedAttrs | |||
| { | |||
| string name = 1; | |||
| map<string, AttrDef> attr = 2; | |||
| } | |||
| // Shape / dimension description, using row-major order | |||
| message ShapeDef | |||
| { | |||
| repeated int64 dim = 1; // Size of each dimension | |||
| } | |||
| // Multidimensional data description | |||
| message TensorDescriptor | |||
| { | |||
| string name = 1; // Optional parameter, tensor name | |||
| DataType dtype = 2; // tensor datatype | |||
| ShapeDef shape = 3; // Shape / dimension | |||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||
| bool has_out_attr = 9; | |||
| int64 size = 10; | |||
| int64 weight_size = 11; | |||
| bool reuse_input = 12; | |||
| bool output_tensor = 13; | |||
| string device_type = 14; | |||
| bool input_tensor =15; | |||
| int64 real_dim_cnt = 16; | |||
| int64 reuse_input_index = 17; | |||
| int64 data_offset = 18; | |||
| int64 cmps_size = 19; | |||
| string cmps_tab = 20; | |||
| int64 cmps_tab_offset = 21; | |||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||
| } | |||
| // GeTensor definition | |||
| message TensorDef | |||
| { | |||
| TensorDescriptor desc = 1; // Tensor description | |||
| bytes data = 2; // Tensor data | |||
| } | |||
| // Operator description | |||
| message OpDef | |||
| { | |||
| string name = 1; // name | |||
| string type = 2; // type | |||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||
| bool has_out_attr = 20; | |||
| int64 id = 21; | |||
| int64 stream_id =22; | |||
| repeated string input_name = 23; | |||
| repeated string src_name = 24; | |||
| repeated int64 src_index = 25; | |||
| repeated string dst_name = 26; | |||
| repeated int64 dst_index = 27; | |||
| repeated int64 input_i = 28; | |||
| repeated int64 output_i = 29; | |||
| repeated int64 workspace = 30; | |||
| repeated int64 workspace_bytes = 31; | |||
| repeated bool is_input_const = 32; | |||
| repeated TensorDescriptor input_desc = 33; | |||
| repeated TensorDescriptor output_desc = 34; | |||
| repeated string subgraph_name = 35; | |||
| } | |||
| // Graph definition | |||
| message GraphDef | |||
| { | |||
| string name = 1; // name | |||
| repeated string input = 4; // Graph input | |||
| repeated string output = 5; // Graph output | |||
| repeated OpDef op = 6; // List of operators | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| // model definition | |||
| message ModelDef | |||
| { | |||
| string name = 1; // name | |||
| uint32 version = 2; // IR Proto verion | |||
| string custom_version = 3; // User model version number, passed in by user | |||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| @@ -41,8 +41,11 @@ target_include_directories(parser_common PRIVATE | |||
| ${CMAKE_CURRENT_LIST_DIR} | |||
| ${PARSER_DIR} | |||
| ${PARSER_DIR}/parser | |||
| ${PARSER_DIR}/../ge | |||
| ${PARSER_DIR}/../inc | |||
| ${PARSER_DIR}/../inc/framework | |||
| ${PARSER_DIR}/../inc/common/util | |||
| ${PARSER_DIR}/../inc/external | |||
| ${METADEF_DIR}/inc | |||
| ${METADEF_DIR}/inc/graph | |||
| ${METADEF_DIR}/inc/register | |||
| @@ -139,8 +139,8 @@ bool ValidateStr(const std::string &filePath, const std::string &mode); | |||
| std::string CurrentTimeInStr(); | |||
| template <typename T, typename... Args> | |||
| static inline std::shared_ptr<T> ComGraphMakeShared(Args &&... args) { | |||
| using T_nc = typename std::remove_const<T>::type; | |||
| static inline std::shared_ptr<T> MakeShared(Args &&... args) { | |||
| typedef typename std::remove_const<T>::type T_nc; | |||
| std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...)); | |||
| return ret; | |||
| } | |||
| @@ -150,63 +150,64 @@ static inline std::shared_ptr<T> ComGraphMakeShared(Args &&... args) { | |||
| /// @param [in] a multiplicator | |||
| /// @param [in] b multiplicator | |||
| /// @return Status | |||
| inline Status Int64MulCheckOverflow(int64_t a, int64_t b) { | |||
| inline domi::Status Int64MulCheckOverflow(int64_t a, int64_t b) { | |||
| if (a > 0) { | |||
| if (b > 0) { | |||
| if (a > (INT64_MAX / b)) { | |||
| return FAILED; | |||
| return domi::FAILED; | |||
| } | |||
| } else { | |||
| if (b < (INT64_MIN / a)) { | |||
| return FAILED; | |||
| return domi::FAILED; | |||
| } | |||
| } | |||
| } else { | |||
| if (b > 0) { | |||
| if (a < (INT64_MIN / b)) { | |||
| return FAILED; | |||
| return domi::FAILED; | |||
| } | |||
| } else { | |||
| if ((a != 0) && (b < (INT64_MAX / a))) { | |||
| return FAILED; | |||
| return domi::FAILED; | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| return domi::SUCCESS; | |||
| } | |||
| /// @ingroup math_util | |||
| /// @brief check whether int64 multiplication can result in overflow | |||
| /// @param [in] a multiplicator | |||
| /// @param [in] b multiplicator | |||
| /// @return Status | |||
| inline Status CheckInt64Uint32MulOverflow(int64_t a, uint32_t b) { | |||
| inline domi::Status CheckInt64Uint32MulOverflow(int64_t a, uint32_t b) { | |||
| if (a == 0 || b == 0) { | |||
| return SUCCESS; | |||
| return domi::SUCCESS; | |||
| } | |||
| if (a > 0) { | |||
| if (a > (INT64_MAX / b)) { | |||
| return FAILED; | |||
| return domi::FAILED; | |||
| } | |||
| } else { | |||
| if (a < (INT64_MIN / b)) { | |||
| return FAILED; | |||
| return domi::FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| return domi::SUCCESS; | |||
| } | |||
| #define PARSER_INT64_MULCHECK(a, b) \ | |||
| if (ge::Int64MulCheckOverflow((a), (b)) != SUCCESS) { \ | |||
| #define PARSER_INT64_MULCHECK(a, b) \ | |||
| if (ge::parser::Int64MulCheckOverflow((a), (b)) != SUCCESS) { \ | |||
| GELOGW("Int64 %ld and %ld multiplication can result in overflow!", static_cast<int64_t>(a), \ | |||
| static_cast<int64_t>(b)); \ | |||
| return INTERNAL_ERROR; \ | |||
| static_cast<int64_t>(b)); \ | |||
| return INTERNAL_ERROR; \ | |||
| } | |||
| #define PARSER_INT64_UINT32_MULCHECK(a, b) \ | |||
| if (ge::CheckInt64Uint32MulOverflow((a), (b)) != SUCCESS) { \ | |||
| #define PARSER_INT64_UINT32_MULCHECK(a, b) \ | |||
| if (ge::parser::CheckInt64Uint32MulOverflow((a), (b)) != SUCCESS) { \ | |||
| GELOGW("Int64 %ld and UINT32 %u multiplication can result in overflow!", static_cast<uint32_t>(a), \ | |||
| static_cast<uint32_t>(b)); \ | |||
| return INTERNAL_ERROR; \ | |||
| static_cast<uint32_t>(b)); \ | |||
| return INTERNAL_ERROR; \ | |||
| } | |||
| } // namespace parser | |||
| } // namespace ge | |||
| @@ -1,190 +1,190 @@ | |||
| syntax = "proto3"; | |||
| package ge.proto; | |||
| enum DataType | |||
| { | |||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||
| DT_FLOAT = 1; // float type | |||
| DT_FLOAT16 = 2; // fp16 type | |||
| DT_INT8 = 3; // int8 type | |||
| DT_UINT8 = 4; // uint8 type | |||
| DT_INT16 = 5; // int16 type | |||
| DT_UINT16 = 6; // uint16 type | |||
| DT_INT32 = 7; // | |||
| DT_INT64 = 8; // int64 type | |||
| DT_UINT32 = 9; // unsigned int32 | |||
| DT_UINT64 = 10; // unsigned int64 | |||
| DT_BOOL = 11; // bool type | |||
| DT_DOUBLE = 12; // double type | |||
| DT_STRING = 13; // string type | |||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||
| DT_COMPLEX64 = 16; // complex64 type | |||
| DT_COMPLEX128 = 17; // complex128 type | |||
| DT_QINT8 = 18; // qint8 type | |||
| DT_QINT16 = 19; // qint16 type | |||
| DT_QINT32 = 20; // qint32 type | |||
| DT_QUINT8 = 21; // quint8 type | |||
| DT_QUINT16 = 22; // quint16 type | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| } | |||
| message AttrDef | |||
| { | |||
| message ListValue | |||
| { | |||
| enum ListValueType{ | |||
| VT_LIST_NONE = 0; | |||
| VT_LIST_STRING = 1; | |||
| VT_LIST_INT = 2; | |||
| VT_LIST_FLOAT = 3; | |||
| VT_LIST_BOOL = 4; | |||
| VT_LIST_BYTES = 5; | |||
| VT_LIST_TENSOR_DESC = 6; | |||
| VT_LIST_TENSOR = 7; | |||
| VT_LIST_GRAPH = 8; | |||
| VT_LIST_NAMED_ATTRS = 9; | |||
| VT_LIST_DATA_TYPE = 10; | |||
| } | |||
| repeated bytes s = 2; // "list(string)" | |||
| repeated int64 i = 3; // "list(int)" | |||
| repeated float f = 4; // "list(float)" | |||
| repeated bool b = 5; // "list(bool)" | |||
| repeated bytes bt = 7; | |||
| repeated TensorDescriptor td = 8; | |||
| repeated TensorDef t = 9; | |||
| repeated GraphDef g = 10; | |||
| repeated NamedAttrs na = 11; | |||
| repeated int64 dt = 12; // list ge::DataType | |||
| ListValueType val_type = 20; | |||
| } | |||
| message ListListInt{ | |||
| message ListInt{ | |||
| repeated int64 list_i = 1; // list int | |||
| } | |||
| repeated ListInt list_list_i = 1; // list list int | |||
| } | |||
| oneof value | |||
| { | |||
| bytes s = 2; // "string" | |||
| int64 i = 3; // "int" | |||
| float f = 4; // "float" | |||
| bool b = 5; // "bool" | |||
| bytes bt = 7; | |||
| ListValue list = 1; // any "list(...)" | |||
| NamedAttrs func = 10; // Used to support attr nesting | |||
| TensorDescriptor td = 11; // GeTensorDesc type | |||
| TensorDef t = 12; // GeTensor type | |||
| GraphDef g = 13; // Graph type | |||
| ListListInt list_list_int = 14; // List List Int type | |||
| int64 dt = 15; // ge::DataType | |||
| } | |||
| } | |||
| // A list of attr names and their values. The whole list is attached | |||
| // with a string name. E.g., MatMul[T=float]. | |||
| message NamedAttrs | |||
| { | |||
| string name = 1; | |||
| map<string, AttrDef> attr = 2; | |||
| } | |||
| // Shape / dimension description, using row-major order | |||
| message ShapeDef | |||
| { | |||
| repeated int64 dim = 1; // Size of each dimension | |||
| } | |||
| // Multidimensional data description | |||
| message TensorDescriptor | |||
| { | |||
| string name = 1; // Optional parameter, tensor name | |||
| DataType dtype = 2; // tensor datatype | |||
| ShapeDef shape = 3; // Shape / dimension | |||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||
| bool has_out_attr = 9; | |||
| int64 size = 10; | |||
| int64 weight_size = 11; | |||
| bool reuse_input = 12; | |||
| bool output_tensor = 13; | |||
| string device_type = 14; | |||
| bool input_tensor =15; | |||
| int64 real_dim_cnt = 16; | |||
| int64 reuse_input_index = 17; | |||
| int64 data_offset = 18; | |||
| int64 cmps_size = 19; | |||
| string cmps_tab = 20; | |||
| int64 cmps_tab_offset = 21; | |||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||
| } | |||
| // GeTensor definition | |||
| message TensorDef | |||
| { | |||
| TensorDescriptor desc = 1; // Tensor description | |||
| bytes data = 2; // Tensor data | |||
| } | |||
| // Operator description | |||
| message OpDef | |||
| { | |||
| string name = 1; // name | |||
| string type = 2; // type | |||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||
| bool has_out_attr = 20; | |||
| int64 id = 21; | |||
| int64 stream_id =22; | |||
| repeated string input_name = 23; | |||
| repeated string src_name = 24; | |||
| repeated int64 src_index = 25; | |||
| repeated string dst_name = 26; | |||
| repeated int64 dst_index = 27; | |||
| repeated int64 input_i = 28; | |||
| repeated int64 output_i = 29; | |||
| repeated int64 workspace = 30; | |||
| repeated int64 workspace_bytes = 31; | |||
| repeated bool is_input_const = 32; | |||
| repeated TensorDescriptor input_desc = 33; | |||
| repeated TensorDescriptor output_desc = 34; | |||
| repeated string subgraph_name = 35; | |||
| } | |||
| // Graph definition | |||
| message GraphDef | |||
| { | |||
| string name = 1; // name | |||
| repeated string input = 4; // Graph input | |||
| repeated string output = 5; // Graph output | |||
| repeated OpDef op = 6; // List of operators | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| // model definition | |||
| message ModelDef | |||
| { | |||
| string name = 1; // name | |||
| uint32 version = 2; // IR Proto verion | |||
| string custom_version = 3; // User model version number, passed in by user | |||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| syntax = "proto3"; | |||
| package ge.proto; | |||
| enum DataType | |||
| { | |||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||
| DT_FLOAT = 1; // float type | |||
| DT_FLOAT16 = 2; // fp16 type | |||
| DT_INT8 = 3; // int8 type | |||
| DT_UINT8 = 4; // uint8 type | |||
| DT_INT16 = 5; // int16 type | |||
| DT_UINT16 = 6; // uint16 type | |||
| DT_INT32 = 7; // | |||
| DT_INT64 = 8; // int64 type | |||
| DT_UINT32 = 9; // unsigned int32 | |||
| DT_UINT64 = 10; // unsigned int64 | |||
| DT_BOOL = 11; // bool type | |||
| DT_DOUBLE = 12; // double type | |||
| DT_STRING = 13; // string type | |||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||
| DT_COMPLEX64 = 16; // complex64 type | |||
| DT_COMPLEX128 = 17; // complex128 type | |||
| DT_QINT8 = 18; // qint8 type | |||
| DT_QINT16 = 19; // qint16 type | |||
| DT_QINT32 = 20; // qint32 type | |||
| DT_QUINT8 = 21; // quint8 type | |||
| DT_QUINT16 = 22; // quint16 type | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| } | |||
| message AttrDef | |||
| { | |||
| message ListValue | |||
| { | |||
| enum ListValueType{ | |||
| VT_LIST_NONE = 0; | |||
| VT_LIST_STRING = 1; | |||
| VT_LIST_INT = 2; | |||
| VT_LIST_FLOAT = 3; | |||
| VT_LIST_BOOL = 4; | |||
| VT_LIST_BYTES = 5; | |||
| VT_LIST_TENSOR_DESC = 6; | |||
| VT_LIST_TENSOR = 7; | |||
| VT_LIST_GRAPH = 8; | |||
| VT_LIST_NAMED_ATTRS = 9; | |||
| VT_LIST_DATA_TYPE = 10; | |||
| } | |||
| repeated bytes s = 2; // "list(string)" | |||
| repeated int64 i = 3; // "list(int)" | |||
| repeated float f = 4; // "list(float)" | |||
| repeated bool b = 5; // "list(bool)" | |||
| repeated bytes bt = 7; | |||
| repeated TensorDescriptor td = 8; | |||
| repeated TensorDef t = 9; | |||
| repeated GraphDef g = 10; | |||
| repeated NamedAttrs na = 11; | |||
| repeated int64 dt = 12; // list ge::DataType | |||
| ListValueType val_type = 20; | |||
| } | |||
| message ListListInt{ | |||
| message ListInt{ | |||
| repeated int64 list_i = 1; // list int | |||
| } | |||
| repeated ListInt list_list_i = 1; // list list int | |||
| } | |||
| oneof value | |||
| { | |||
| bytes s = 2; // "string" | |||
| int64 i = 3; // "int" | |||
| float f = 4; // "float" | |||
| bool b = 5; // "bool" | |||
| bytes bt = 7; | |||
| ListValue list = 1; // any "list(...)" | |||
| NamedAttrs func = 10; // Used to support attr nesting | |||
| TensorDescriptor td = 11; // GeTensorDesc type | |||
| TensorDef t = 12; // GeTensor type | |||
| GraphDef g = 13; // Graph type | |||
| ListListInt list_list_int = 14; // List List Int type | |||
| int64 dt = 15; // ge::DataType | |||
| } | |||
| } | |||
| // A list of attr names and their values. The whole list is attached | |||
| // with a string name. E.g., MatMul[T=float]. | |||
| message NamedAttrs | |||
| { | |||
| string name = 1; | |||
| map<string, AttrDef> attr = 2; | |||
| } | |||
| // Shape / dimension description, using row-major order | |||
| message ShapeDef | |||
| { | |||
| repeated int64 dim = 1; // Size of each dimension | |||
| } | |||
| // Multidimensional data description | |||
| message TensorDescriptor | |||
| { | |||
| string name = 1; // Optional parameter, tensor name | |||
| DataType dtype = 2; // tensor datatype | |||
| ShapeDef shape = 3; // Shape / dimension | |||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||
| bool has_out_attr = 9; | |||
| int64 size = 10; | |||
| int64 weight_size = 11; | |||
| bool reuse_input = 12; | |||
| bool output_tensor = 13; | |||
| string device_type = 14; | |||
| bool input_tensor =15; | |||
| int64 real_dim_cnt = 16; | |||
| int64 reuse_input_index = 17; | |||
| int64 data_offset = 18; | |||
| int64 cmps_size = 19; | |||
| string cmps_tab = 20; | |||
| int64 cmps_tab_offset = 21; | |||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||
| } | |||
| // GeTensor definition | |||
| message TensorDef | |||
| { | |||
| TensorDescriptor desc = 1; // Tensor description | |||
| bytes data = 2; // Tensor data | |||
| } | |||
| // Operator description | |||
| message OpDef | |||
| { | |||
| string name = 1; // name | |||
| string type = 2; // type | |||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||
| bool has_out_attr = 20; | |||
| int64 id = 21; | |||
| int64 stream_id =22; | |||
| repeated string input_name = 23; | |||
| repeated string src_name = 24; | |||
| repeated int64 src_index = 25; | |||
| repeated string dst_name = 26; | |||
| repeated int64 dst_index = 27; | |||
| repeated int64 input_i = 28; | |||
| repeated int64 output_i = 29; | |||
| repeated int64 workspace = 30; | |||
| repeated int64 workspace_bytes = 31; | |||
| repeated bool is_input_const = 32; | |||
| repeated TensorDescriptor input_desc = 33; | |||
| repeated TensorDescriptor output_desc = 34; | |||
| repeated string subgraph_name = 35; | |||
| } | |||
| // Graph definition | |||
| message GraphDef | |||
| { | |||
| string name = 1; // name | |||
| repeated string input = 4; // Graph input | |||
| repeated string output = 5; // Graph output | |||
| repeated OpDef op = 6; // List of operators | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| // model definition | |||
| message ModelDef | |||
| { | |||
| string name = 1; // name | |||
| uint32 version = 2; // IR Proto verion | |||
| string custom_version = 3; // User model version number, passed in by user | |||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| @@ -1,136 +1,136 @@ | |||
| syntax = "proto3"; | |||
| package domi; | |||
| message InsertNewOps { | |||
| repeated AippOpParams aipp_op = 1; | |||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||
| } | |||
| message AippOpParams { | |||
| enum InputFormat { | |||
| UNDEFINED = 0; | |||
| YUV420SP_U8 = 1; | |||
| XRGB8888_U8 = 2; | |||
| RGB888_U8 = 3; | |||
| YUV400_U8 = 4; | |||
| NC1HWC0DI_FP16 = 5; | |||
| NC1HWC0DI_S8 = 6; | |||
| ARGB8888_U8 = 7; | |||
| YUYV_U8 = 8; | |||
| YUV422SP_U8 = 9; | |||
| AYUV444_U8 = 10; | |||
| RAW10 = 11; | |||
| RAW12 = 12; | |||
| RAW16 = 13; | |||
| RAW24 = 14; | |||
| RGB16 = 15; | |||
| RGB20 = 16; | |||
| RGB24 = 17; | |||
| RGB8_IR = 18; | |||
| RGB16_IR = 19; | |||
| RGB24_IR = 20; | |||
| } | |||
| enum AippMode { | |||
| undefined = 0; | |||
| static = 1; | |||
| dynamic = 2; | |||
| } | |||
| // AIPP模式,区分静态AIPP和动态AIPP | |||
| AippMode aipp_mode = 1; | |||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||
| uint32 related_input_rank = 2; | |||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||
| // 配置值 <= Data算子输出边的个数。 | |||
| repeated uint32 input_edge_idx = 3; | |||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||
| uint32 max_src_image_size = 4; | |||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||
| bool support_rotation = 5; | |||
| // [End] 动态AIPP参数 | |||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||
| InputFormat input_format = 51; | |||
| bool csc_switch = 52; | |||
| float cpadding_value = 53; | |||
| bool rbuv_swap_switch = 54; | |||
| bool ax_swap_switch = 55; | |||
| bool single_line_mode = 56; | |||
| int32 src_image_size_w = 57; | |||
| int32 src_image_size_h = 58; | |||
| bool crop = 59; | |||
| int32 load_start_pos_w = 60; | |||
| int32 load_start_pos_h = 61; | |||
| int32 crop_size_w = 62; | |||
| int32 crop_size_h = 63; | |||
| bool resize = 64; | |||
| int32 resize_output_w = 65; | |||
| int32 resize_output_h = 66; | |||
| bool padding = 67; | |||
| int32 left_padding_size = 68; | |||
| int32 right_padding_size = 69; | |||
| int32 top_padding_size = 70; | |||
| int32 bottom_padding_size = 71; | |||
| int32 mean_chn_0 = 10; | |||
| int32 mean_chn_1 = 11; | |||
| int32 mean_chn_2 = 12; | |||
| int32 mean_chn_3 = 19; | |||
| float min_chn_0 = 13; | |||
| float min_chn_1 = 14; | |||
| float min_chn_2 = 15; | |||
| float min_chn_3 = 20; | |||
| repeated float var_reci_chn_0 = 16; | |||
| repeated float var_reci_chn_1 = 17; | |||
| repeated float var_reci_chn_2 = 18; | |||
| repeated float var_reci_chn_3 = 21; | |||
| repeated int32 matrix_r0c0 = 30; | |||
| repeated int32 matrix_r0c1 = 31; | |||
| repeated int32 matrix_r0c2 = 32; | |||
| repeated int32 matrix_r1c0 = 33; | |||
| repeated int32 matrix_r1c1 = 34; | |||
| repeated int32 matrix_r1c2 = 35; | |||
| repeated int32 matrix_r2c0 = 36; | |||
| repeated int32 matrix_r2c1 = 37; | |||
| repeated int32 matrix_r2c2 = 38; | |||
| repeated int32 output_bias_0 = 39; | |||
| repeated int32 output_bias_1 = 40; | |||
| repeated int32 output_bias_2 = 41; | |||
| repeated int32 input_bias_0 = 42; | |||
| repeated int32 input_bias_1 = 43; | |||
| repeated int32 input_bias_2 = 44; | |||
| // [End] 静态AIPP参数 | |||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||
| uint32 raw_rgbir_to_f16_n = 45; | |||
| } | |||
| message MultiShapeOpParams { | |||
| enum MultiShapeMode { | |||
| batch = 0; //动态batch | |||
| resolution = 1; //动态分辨率,扩展用 | |||
| } | |||
| MultiShapeMode mode = 1; //算子模式 | |||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||
| } | |||
| syntax = "proto3"; | |||
| package domi; | |||
| message InsertNewOps { | |||
| repeated AippOpParams aipp_op = 1; | |||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||
| } | |||
| message AippOpParams { | |||
| enum InputFormat { | |||
| UNDEFINED = 0; | |||
| YUV420SP_U8 = 1; | |||
| XRGB8888_U8 = 2; | |||
| RGB888_U8 = 3; | |||
| YUV400_U8 = 4; | |||
| NC1HWC0DI_FP16 = 5; | |||
| NC1HWC0DI_S8 = 6; | |||
| ARGB8888_U8 = 7; | |||
| YUYV_U8 = 8; | |||
| YUV422SP_U8 = 9; | |||
| AYUV444_U8 = 10; | |||
| RAW10 = 11; | |||
| RAW12 = 12; | |||
| RAW16 = 13; | |||
| RAW24 = 14; | |||
| RGB16 = 15; | |||
| RGB20 = 16; | |||
| RGB24 = 17; | |||
| RGB8_IR = 18; | |||
| RGB16_IR = 19; | |||
| RGB24_IR = 20; | |||
| } | |||
| enum AippMode { | |||
| undefined = 0; | |||
| static = 1; | |||
| dynamic = 2; | |||
| } | |||
| // AIPP模式,区分静态AIPP和动态AIPP | |||
| AippMode aipp_mode = 1; | |||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||
| uint32 related_input_rank = 2; | |||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||
| // 配置值 <= Data算子输出边的个数。 | |||
| repeated uint32 input_edge_idx = 3; | |||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||
| uint32 max_src_image_size = 4; | |||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||
| bool support_rotation = 5; | |||
| // [End] 动态AIPP参数 | |||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||
| InputFormat input_format = 51; | |||
| bool csc_switch = 52; | |||
| float cpadding_value = 53; | |||
| bool rbuv_swap_switch = 54; | |||
| bool ax_swap_switch = 55; | |||
| bool single_line_mode = 56; | |||
| int32 src_image_size_w = 57; | |||
| int32 src_image_size_h = 58; | |||
| bool crop = 59; | |||
| int32 load_start_pos_w = 60; | |||
| int32 load_start_pos_h = 61; | |||
| int32 crop_size_w = 62; | |||
| int32 crop_size_h = 63; | |||
| bool resize = 64; | |||
| int32 resize_output_w = 65; | |||
| int32 resize_output_h = 66; | |||
| bool padding = 67; | |||
| int32 left_padding_size = 68; | |||
| int32 right_padding_size = 69; | |||
| int32 top_padding_size = 70; | |||
| int32 bottom_padding_size = 71; | |||
| int32 mean_chn_0 = 10; | |||
| int32 mean_chn_1 = 11; | |||
| int32 mean_chn_2 = 12; | |||
| int32 mean_chn_3 = 19; | |||
| float min_chn_0 = 13; | |||
| float min_chn_1 = 14; | |||
| float min_chn_2 = 15; | |||
| float min_chn_3 = 20; | |||
| repeated float var_reci_chn_0 = 16; | |||
| repeated float var_reci_chn_1 = 17; | |||
| repeated float var_reci_chn_2 = 18; | |||
| repeated float var_reci_chn_3 = 21; | |||
| repeated int32 matrix_r0c0 = 30; | |||
| repeated int32 matrix_r0c1 = 31; | |||
| repeated int32 matrix_r0c2 = 32; | |||
| repeated int32 matrix_r1c0 = 33; | |||
| repeated int32 matrix_r1c1 = 34; | |||
| repeated int32 matrix_r1c2 = 35; | |||
| repeated int32 matrix_r2c0 = 36; | |||
| repeated int32 matrix_r2c1 = 37; | |||
| repeated int32 matrix_r2c2 = 38; | |||
| repeated int32 output_bias_0 = 39; | |||
| repeated int32 output_bias_1 = 40; | |||
| repeated int32 output_bias_2 = 41; | |||
| repeated int32 input_bias_0 = 42; | |||
| repeated int32 input_bias_1 = 43; | |||
| repeated int32 input_bias_2 = 44; | |||
| // [End] 静态AIPP参数 | |||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||
| uint32 raw_rgbir_to_f16_n = 45; | |||
| } | |||
| message MultiShapeOpParams { | |||
| enum MultiShapeMode { | |||
| batch = 0; //动态batch | |||
| resolution = 1; //动态分辨率,扩展用 | |||
| } | |||
| MultiShapeMode mode = 1; //算子模式 | |||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||
| } | |||
| @@ -1,17 +1,17 @@ | |||
| include $(BUILD_SYSTEM)/base_rules.mk | |||
| FUNCTION_TO_GRAPH_OUT_TIMESTAMP := $(HOST_OUT_ROOT)/func_to_graph/.timestamp | |||
| PROTO_SRC_DIR = framework/domi/parser/func_to_graph/proto | |||
| PY_PROTO_BUILD_DIR = $(HOST_OUT_ROOT)/tmp/function_to_graph/proto | |||
| $(warning PRIVATE_PROTOC is $(PRIVATE_PROTOC)) | |||
| $(warning protobuf_lib_dir is $(protobuf_lib_dir)) | |||
| $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP): $(PRIVATE_PROTOC) | |||
| mkdir -p $(PY_PROTO_BUILD_DIR) | |||
| LD_LIBRARY_PATH=$(protobuf_lib_dir):$$LD_LIBRARY_PATH $(PRIVATE_PROTOC) -I=$(PROTO_SRC_DIR) --python_out=$(PY_PROTO_BUILD_DIR) $(PROTO_SRC_DIR)/*.proto | |||
| $(LOCAL_BUILT_MODULE): $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP) | |||
| mkdir -p $@ | |||
| include $(BUILD_SYSTEM)/base_rules.mk | |||
| FUNCTION_TO_GRAPH_OUT_TIMESTAMP := $(HOST_OUT_ROOT)/func_to_graph/.timestamp | |||
| PROTO_SRC_DIR = framework/domi/parser/func_to_graph/proto | |||
| PY_PROTO_BUILD_DIR = $(HOST_OUT_ROOT)/tmp/function_to_graph/proto | |||
| $(warning PRIVATE_PROTOC is $(PRIVATE_PROTOC)) | |||
| $(warning protobuf_lib_dir is $(protobuf_lib_dir)) | |||
| $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP): $(PRIVATE_PROTOC) | |||
| mkdir -p $(PY_PROTO_BUILD_DIR) | |||
| LD_LIBRARY_PATH=$(protobuf_lib_dir):$$LD_LIBRARY_PATH $(PRIVATE_PROTOC) -I=$(PROTO_SRC_DIR) --python_out=$(PY_PROTO_BUILD_DIR) $(PROTO_SRC_DIR)/*.proto | |||
| $(LOCAL_BUILT_MODULE): $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP) | |||
| mkdir -p $@ | |||
| cp -rf $(PY_PROTO_BUILD_DIR)/* $@ | |||
| @@ -29,8 +29,11 @@ target_include_directories(fmk_onnx_parser PRIVATE | |||
| ${PARSER_DIR} | |||
| ${PARSER_DIR}/inc | |||
| ${PARSER_DIR}/parser | |||
| ${PARSER_DIR}/../ge | |||
| ${PARSER_DIR}/../inc | |||
| ${PARSER_DIR}/../inc/common/util | |||
| {PARSER_DIR}/../inc/framework | |||
| {PARSER_DIR}/../inc/external | |||
| ${METADEF_DIR}/inc | |||
| ${METADEF_DIR}/inc/graph | |||
| ${METADEF_DIR}/inc/register | |||
| @@ -43,7 +43,6 @@ LOCAL_SHARED_LIBRARIES := \ | |||
| libparser_common \ | |||
| libgraph \ | |||
| libregister \ | |||
| libge_common \ | |||
| LOCAL_LDFLAGS := -lrt | |||
| @@ -1,5 +1,5 @@ | |||
| set(PROTO_LIST | |||
| "${TOP_DIR}/inc/register/proto/caffe/caffe.proto" | |||
| "${METADEF_DIR}/proto/caffe/caffe.proto" | |||
| ) | |||
| protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
| @@ -1,190 +1,190 @@ | |||
| syntax = "proto3"; | |||
| package ge.proto; | |||
| enum DataType | |||
| { | |||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||
| DT_FLOAT = 1; // float type | |||
| DT_FLOAT16 = 2; // fp16 type | |||
| DT_INT8 = 3; // int8 type | |||
| DT_UINT8 = 4; // uint8 type | |||
| DT_INT16 = 5; // int16 type | |||
| DT_UINT16 = 6; // uint16 type | |||
| DT_INT32 = 7; // | |||
| DT_INT64 = 8; // int64 type | |||
| DT_UINT32 = 9; // unsigned int32 | |||
| DT_UINT64 = 10; // unsigned int64 | |||
| DT_BOOL = 11; // bool type | |||
| DT_DOUBLE = 12; // double type | |||
| DT_STRING = 13; // string type | |||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||
| DT_COMPLEX64 = 16; // complex64 type | |||
| DT_COMPLEX128 = 17; // complex128 type | |||
| DT_QINT8 = 18; // qint8 type | |||
| DT_QINT16 = 19; // qint16 type | |||
| DT_QINT32 = 20; // qint32 type | |||
| DT_QUINT8 = 21; // quint8 type | |||
| DT_QUINT16 = 22; // quint16 type | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| } | |||
| message AttrDef | |||
| { | |||
| message ListValue | |||
| { | |||
| enum ListValueType{ | |||
| VT_LIST_NONE = 0; | |||
| VT_LIST_STRING = 1; | |||
| VT_LIST_INT = 2; | |||
| VT_LIST_FLOAT = 3; | |||
| VT_LIST_BOOL = 4; | |||
| VT_LIST_BYTES = 5; | |||
| VT_LIST_TENSOR_DESC = 6; | |||
| VT_LIST_TENSOR = 7; | |||
| VT_LIST_GRAPH = 8; | |||
| VT_LIST_NAMED_ATTRS = 9; | |||
| VT_LIST_DATA_TYPE = 10; | |||
| } | |||
| repeated bytes s = 2; // "list(string)" | |||
| repeated int64 i = 3; // "list(int)" | |||
| repeated float f = 4; // "list(float)" | |||
| repeated bool b = 5; // "list(bool)" | |||
| repeated bytes bt = 7; | |||
| repeated TensorDescriptor td = 8; | |||
| repeated TensorDef t = 9; | |||
| repeated GraphDef g = 10; | |||
| repeated NamedAttrs na = 11; | |||
| repeated int64 dt = 12; // list ge::DataType | |||
| ListValueType val_type = 20; | |||
| } | |||
| message ListListInt{ | |||
| message ListInt{ | |||
| repeated int64 list_i = 1; // list int | |||
| } | |||
| repeated ListInt list_list_i = 1; // list list int | |||
| } | |||
| oneof value | |||
| { | |||
| bytes s = 2; // "string" | |||
| int64 i = 3; // "int" | |||
| float f = 4; // "float" | |||
| bool b = 5; // "bool" | |||
| bytes bt = 7; | |||
| ListValue list = 1; // any "list(...)" | |||
| NamedAttrs func = 10; // Used to support attr nesting | |||
| TensorDescriptor td = 11; // GeTensorDesc type | |||
| TensorDef t = 12; // GeTensor type | |||
| GraphDef g = 13; // Graph type | |||
| ListListInt list_list_int = 14; // List List Int type | |||
| int64 dt = 15; // ge::DataType | |||
| } | |||
| } | |||
| // A list of attr names and their values. The whole list is attached | |||
| // with a string name. E.g., MatMul[T=float]. | |||
| message NamedAttrs | |||
| { | |||
| string name = 1; | |||
| map<string, AttrDef> attr = 2; | |||
| } | |||
| // Shape / dimension description, using row-major order | |||
| message ShapeDef | |||
| { | |||
| repeated int64 dim = 1; // Size of each dimension | |||
| } | |||
| // Multidimensional data description | |||
| message TensorDescriptor | |||
| { | |||
| string name = 1; // Optional parameter, tensor name | |||
| DataType dtype = 2; // tensor datatype | |||
| ShapeDef shape = 3; // Shape / dimension | |||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||
| bool has_out_attr = 9; | |||
| int64 size = 10; | |||
| int64 weight_size = 11; | |||
| bool reuse_input = 12; | |||
| bool output_tensor = 13; | |||
| string device_type = 14; | |||
| bool input_tensor =15; | |||
| int64 real_dim_cnt = 16; | |||
| int64 reuse_input_index = 17; | |||
| int64 data_offset = 18; | |||
| int64 cmps_size = 19; | |||
| string cmps_tab = 20; | |||
| int64 cmps_tab_offset = 21; | |||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||
| } | |||
| // GeTensor definition | |||
| message TensorDef | |||
| { | |||
| TensorDescriptor desc = 1; // Tensor description | |||
| bytes data = 2; // Tensor data | |||
| } | |||
| // Operator description | |||
| message OpDef | |||
| { | |||
| string name = 1; // name | |||
| string type = 2; // type | |||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||
| bool has_out_attr = 20; | |||
| int64 id = 21; | |||
| int64 stream_id =22; | |||
| repeated string input_name = 23; | |||
| repeated string src_name = 24; | |||
| repeated int64 src_index = 25; | |||
| repeated string dst_name = 26; | |||
| repeated int64 dst_index = 27; | |||
| repeated int64 input_i = 28; | |||
| repeated int64 output_i = 29; | |||
| repeated int64 workspace = 30; | |||
| repeated int64 workspace_bytes = 31; | |||
| repeated bool is_input_const = 32; | |||
| repeated TensorDescriptor input_desc = 33; | |||
| repeated TensorDescriptor output_desc = 34; | |||
| repeated string subgraph_name = 35; | |||
| } | |||
| // Graph definition | |||
| message GraphDef | |||
| { | |||
| string name = 1; // name | |||
| repeated string input = 4; // Graph input | |||
| repeated string output = 5; // Graph output | |||
| repeated OpDef op = 6; // List of operators | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| // model definition | |||
| message ModelDef | |||
| { | |||
| string name = 1; // name | |||
| uint32 version = 2; // IR Proto verion | |||
| string custom_version = 3; // User model version number, passed in by user | |||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| syntax = "proto3"; | |||
| package ge.proto; | |||
| enum DataType | |||
| { | |||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||
| DT_FLOAT = 1; // float type | |||
| DT_FLOAT16 = 2; // fp16 type | |||
| DT_INT8 = 3; // int8 type | |||
| DT_UINT8 = 4; // uint8 type | |||
| DT_INT16 = 5; // int16 type | |||
| DT_UINT16 = 6; // uint16 type | |||
| DT_INT32 = 7; // | |||
| DT_INT64 = 8; // int64 type | |||
| DT_UINT32 = 9; // unsigned int32 | |||
| DT_UINT64 = 10; // unsigned int64 | |||
| DT_BOOL = 11; // bool type | |||
| DT_DOUBLE = 12; // double type | |||
| DT_STRING = 13; // string type | |||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||
| DT_COMPLEX64 = 16; // complex64 type | |||
| DT_COMPLEX128 = 17; // complex128 type | |||
| DT_QINT8 = 18; // qint8 type | |||
| DT_QINT16 = 19; // qint16 type | |||
| DT_QINT32 = 20; // qint32 type | |||
| DT_QUINT8 = 21; // quint8 type | |||
| DT_QUINT16 = 22; // quint16 type | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| } | |||
| message AttrDef | |||
| { | |||
| message ListValue | |||
| { | |||
| enum ListValueType{ | |||
| VT_LIST_NONE = 0; | |||
| VT_LIST_STRING = 1; | |||
| VT_LIST_INT = 2; | |||
| VT_LIST_FLOAT = 3; | |||
| VT_LIST_BOOL = 4; | |||
| VT_LIST_BYTES = 5; | |||
| VT_LIST_TENSOR_DESC = 6; | |||
| VT_LIST_TENSOR = 7; | |||
| VT_LIST_GRAPH = 8; | |||
| VT_LIST_NAMED_ATTRS = 9; | |||
| VT_LIST_DATA_TYPE = 10; | |||
| } | |||
| repeated bytes s = 2; // "list(string)" | |||
| repeated int64 i = 3; // "list(int)" | |||
| repeated float f = 4; // "list(float)" | |||
| repeated bool b = 5; // "list(bool)" | |||
| repeated bytes bt = 7; | |||
| repeated TensorDescriptor td = 8; | |||
| repeated TensorDef t = 9; | |||
| repeated GraphDef g = 10; | |||
| repeated NamedAttrs na = 11; | |||
| repeated int64 dt = 12; // list ge::DataType | |||
| ListValueType val_type = 20; | |||
| } | |||
| message ListListInt{ | |||
| message ListInt{ | |||
| repeated int64 list_i = 1; // list int | |||
| } | |||
| repeated ListInt list_list_i = 1; // list list int | |||
| } | |||
| oneof value | |||
| { | |||
| bytes s = 2; // "string" | |||
| int64 i = 3; // "int" | |||
| float f = 4; // "float" | |||
| bool b = 5; // "bool" | |||
| bytes bt = 7; | |||
| ListValue list = 1; // any "list(...)" | |||
| NamedAttrs func = 10; // Used to support attr nesting | |||
| TensorDescriptor td = 11; // GeTensorDesc type | |||
| TensorDef t = 12; // GeTensor type | |||
| GraphDef g = 13; // Graph type | |||
| ListListInt list_list_int = 14; // List List Int type | |||
| int64 dt = 15; // ge::DataType | |||
| } | |||
| } | |||
| // A list of attr names and their values. The whole list is attached | |||
| // with a string name. E.g., MatMul[T=float]. | |||
| message NamedAttrs | |||
| { | |||
| string name = 1; | |||
| map<string, AttrDef> attr = 2; | |||
| } | |||
| // Shape / dimension description, using row-major order | |||
| message ShapeDef | |||
| { | |||
| repeated int64 dim = 1; // Size of each dimension | |||
| } | |||
| // Multidimensional data description | |||
| message TensorDescriptor | |||
| { | |||
| string name = 1; // Optional parameter, tensor name | |||
| DataType dtype = 2; // tensor datatype | |||
| ShapeDef shape = 3; // Shape / dimension | |||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||
| bool has_out_attr = 9; | |||
| int64 size = 10; | |||
| int64 weight_size = 11; | |||
| bool reuse_input = 12; | |||
| bool output_tensor = 13; | |||
| string device_type = 14; | |||
| bool input_tensor =15; | |||
| int64 real_dim_cnt = 16; | |||
| int64 reuse_input_index = 17; | |||
| int64 data_offset = 18; | |||
| int64 cmps_size = 19; | |||
| string cmps_tab = 20; | |||
| int64 cmps_tab_offset = 21; | |||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||
| } | |||
| // GeTensor definition | |||
| message TensorDef | |||
| { | |||
| TensorDescriptor desc = 1; // Tensor description | |||
| bytes data = 2; // Tensor data | |||
| } | |||
| // Operator description | |||
| message OpDef | |||
| { | |||
| string name = 1; // name | |||
| string type = 2; // type | |||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||
| bool has_out_attr = 20; | |||
| int64 id = 21; | |||
| int64 stream_id =22; | |||
| repeated string input_name = 23; | |||
| repeated string src_name = 24; | |||
| repeated int64 src_index = 25; | |||
| repeated string dst_name = 26; | |||
| repeated int64 dst_index = 27; | |||
| repeated int64 input_i = 28; | |||
| repeated int64 output_i = 29; | |||
| repeated int64 workspace = 30; | |||
| repeated int64 workspace_bytes = 31; | |||
| repeated bool is_input_const = 32; | |||
| repeated TensorDescriptor input_desc = 33; | |||
| repeated TensorDescriptor output_desc = 34; | |||
| repeated string subgraph_name = 35; | |||
| } | |||
| // Graph definition | |||
| message GraphDef | |||
| { | |||
| string name = 1; // name | |||
| repeated string input = 4; // Graph input | |||
| repeated string output = 5; // Graph output | |||
| repeated OpDef op = 6; // List of operators | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| // model definition | |||
| message ModelDef | |||
| { | |||
| string name = 1; // name | |||
| uint32 version = 2; // IR Proto verion | |||
| string custom_version = 3; // User model version number, passed in by user | |||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| @@ -1,136 +1,136 @@ | |||
| syntax = "proto3"; | |||
| package domi; | |||
| message InsertNewOps { | |||
| repeated AippOpParams aipp_op = 1; | |||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||
| } | |||
| message AippOpParams { | |||
| enum InputFormat { | |||
| UNDEFINED = 0; | |||
| YUV420SP_U8 = 1; | |||
| XRGB8888_U8 = 2; | |||
| RGB888_U8 = 3; | |||
| YUV400_U8 = 4; | |||
| NC1HWC0DI_FP16 = 5; | |||
| NC1HWC0DI_S8 = 6; | |||
| ARGB8888_U8 = 7; | |||
| YUYV_U8 = 8; | |||
| YUV422SP_U8 = 9; | |||
| AYUV444_U8 = 10; | |||
| RAW10 = 11; | |||
| RAW12 = 12; | |||
| RAW16 = 13; | |||
| RAW24 = 14; | |||
| RGB16 = 15; | |||
| RGB20 = 16; | |||
| RGB24 = 17; | |||
| RGB8_IR = 18; | |||
| RGB16_IR = 19; | |||
| RGB24_IR = 20; | |||
| } | |||
| enum AippMode { | |||
| undefined = 0; | |||
| static = 1; | |||
| dynamic = 2; | |||
| } | |||
| // AIPP模式,区分静态AIPP和动态AIPP | |||
| AippMode aipp_mode = 1; | |||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||
| uint32 related_input_rank = 2; | |||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||
| // 配置值 <= Data算子输出边的个数。 | |||
| repeated uint32 input_edge_idx = 3; | |||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||
| uint32 max_src_image_size = 4; | |||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||
| bool support_rotation = 5; | |||
| // [End] 动态AIPP参数 | |||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||
| InputFormat input_format = 51; | |||
| bool csc_switch = 52; | |||
| float cpadding_value = 53; | |||
| bool rbuv_swap_switch = 54; | |||
| bool ax_swap_switch = 55; | |||
| bool single_line_mode = 56; | |||
| int32 src_image_size_w = 57; | |||
| int32 src_image_size_h = 58; | |||
| bool crop = 59; | |||
| int32 load_start_pos_w = 60; | |||
| int32 load_start_pos_h = 61; | |||
| int32 crop_size_w = 62; | |||
| int32 crop_size_h = 63; | |||
| bool resize = 64; | |||
| int32 resize_output_w = 65; | |||
| int32 resize_output_h = 66; | |||
| bool padding = 67; | |||
| int32 left_padding_size = 68; | |||
| int32 right_padding_size = 69; | |||
| int32 top_padding_size = 70; | |||
| int32 bottom_padding_size = 71; | |||
| int32 mean_chn_0 = 10; | |||
| int32 mean_chn_1 = 11; | |||
| int32 mean_chn_2 = 12; | |||
| int32 mean_chn_3 = 19; | |||
| float min_chn_0 = 13; | |||
| float min_chn_1 = 14; | |||
| float min_chn_2 = 15; | |||
| float min_chn_3 = 20; | |||
| repeated float var_reci_chn_0 = 16; | |||
| repeated float var_reci_chn_1 = 17; | |||
| repeated float var_reci_chn_2 = 18; | |||
| repeated float var_reci_chn_3 = 21; | |||
| repeated int32 matrix_r0c0 = 30; | |||
| repeated int32 matrix_r0c1 = 31; | |||
| repeated int32 matrix_r0c2 = 32; | |||
| repeated int32 matrix_r1c0 = 33; | |||
| repeated int32 matrix_r1c1 = 34; | |||
| repeated int32 matrix_r1c2 = 35; | |||
| repeated int32 matrix_r2c0 = 36; | |||
| repeated int32 matrix_r2c1 = 37; | |||
| repeated int32 matrix_r2c2 = 38; | |||
| repeated int32 output_bias_0 = 39; | |||
| repeated int32 output_bias_1 = 40; | |||
| repeated int32 output_bias_2 = 41; | |||
| repeated int32 input_bias_0 = 42; | |||
| repeated int32 input_bias_1 = 43; | |||
| repeated int32 input_bias_2 = 44; | |||
| // [End] 静态AIPP参数 | |||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||
| uint32 raw_rgbir_to_f16_n = 45; | |||
| } | |||
| message MultiShapeOpParams { | |||
| enum MultiShapeMode { | |||
| batch = 0; //动态batch | |||
| resolution = 1; //动态分辨率,扩展用 | |||
| } | |||
| MultiShapeMode mode = 1; //算子模式 | |||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||
| } | |||
| syntax = "proto3"; | |||
| package domi; | |||
| message InsertNewOps { | |||
| repeated AippOpParams aipp_op = 1; | |||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||
| } | |||
| message AippOpParams { | |||
| enum InputFormat { | |||
| UNDEFINED = 0; | |||
| YUV420SP_U8 = 1; | |||
| XRGB8888_U8 = 2; | |||
| RGB888_U8 = 3; | |||
| YUV400_U8 = 4; | |||
| NC1HWC0DI_FP16 = 5; | |||
| NC1HWC0DI_S8 = 6; | |||
| ARGB8888_U8 = 7; | |||
| YUYV_U8 = 8; | |||
| YUV422SP_U8 = 9; | |||
| AYUV444_U8 = 10; | |||
| RAW10 = 11; | |||
| RAW12 = 12; | |||
| RAW16 = 13; | |||
| RAW24 = 14; | |||
| RGB16 = 15; | |||
| RGB20 = 16; | |||
| RGB24 = 17; | |||
| RGB8_IR = 18; | |||
| RGB16_IR = 19; | |||
| RGB24_IR = 20; | |||
| } | |||
| enum AippMode { | |||
| undefined = 0; | |||
| static = 1; | |||
| dynamic = 2; | |||
| } | |||
| // AIPP模式,区分静态AIPP和动态AIPP | |||
| AippMode aipp_mode = 1; | |||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||
| uint32 related_input_rank = 2; | |||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||
| // 配置值 <= Data算子输出边的个数。 | |||
| repeated uint32 input_edge_idx = 3; | |||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||
| uint32 max_src_image_size = 4; | |||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||
| bool support_rotation = 5; | |||
| // [End] 动态AIPP参数 | |||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||
| InputFormat input_format = 51; | |||
| bool csc_switch = 52; | |||
| float cpadding_value = 53; | |||
| bool rbuv_swap_switch = 54; | |||
| bool ax_swap_switch = 55; | |||
| bool single_line_mode = 56; | |||
| int32 src_image_size_w = 57; | |||
| int32 src_image_size_h = 58; | |||
| bool crop = 59; | |||
| int32 load_start_pos_w = 60; | |||
| int32 load_start_pos_h = 61; | |||
| int32 crop_size_w = 62; | |||
| int32 crop_size_h = 63; | |||
| bool resize = 64; | |||
| int32 resize_output_w = 65; | |||
| int32 resize_output_h = 66; | |||
| bool padding = 67; | |||
| int32 left_padding_size = 68; | |||
| int32 right_padding_size = 69; | |||
| int32 top_padding_size = 70; | |||
| int32 bottom_padding_size = 71; | |||
| int32 mean_chn_0 = 10; | |||
| int32 mean_chn_1 = 11; | |||
| int32 mean_chn_2 = 12; | |||
| int32 mean_chn_3 = 19; | |||
| float min_chn_0 = 13; | |||
| float min_chn_1 = 14; | |||
| float min_chn_2 = 15; | |||
| float min_chn_3 = 20; | |||
| repeated float var_reci_chn_0 = 16; | |||
| repeated float var_reci_chn_1 = 17; | |||
| repeated float var_reci_chn_2 = 18; | |||
| repeated float var_reci_chn_3 = 21; | |||
| repeated int32 matrix_r0c0 = 30; | |||
| repeated int32 matrix_r0c1 = 31; | |||
| repeated int32 matrix_r0c2 = 32; | |||
| repeated int32 matrix_r1c0 = 33; | |||
| repeated int32 matrix_r1c1 = 34; | |||
| repeated int32 matrix_r1c2 = 35; | |||
| repeated int32 matrix_r2c0 = 36; | |||
| repeated int32 matrix_r2c1 = 37; | |||
| repeated int32 matrix_r2c2 = 38; | |||
| repeated int32 output_bias_0 = 39; | |||
| repeated int32 output_bias_1 = 40; | |||
| repeated int32 output_bias_2 = 41; | |||
| repeated int32 input_bias_0 = 42; | |||
| repeated int32 input_bias_1 = 43; | |||
| repeated int32 input_bias_2 = 44; | |||
| // [End] 静态AIPP参数 | |||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||
| uint32 raw_rgbir_to_f16_n = 45; | |||
| } | |||
| message MultiShapeOpParams { | |||
| enum MultiShapeMode { | |||
| batch = 0; //动态batch | |||
| resolution = 1; //动态分辨率,扩展用 | |||
| } | |||
| MultiShapeMode mode = 1; //算子模式 | |||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||
| } | |||
| @@ -807,7 +807,6 @@ Status CreateNodeDefBytes(ge::NodePtr n, string originalType, map<string, PIOLis | |||
| for (uint32_t j = 0; j < ge_desc->GetShape().GetDimNum(); ++j) { | |||
| tmp_dim = ge_desc->GetShape().GetDim(j); | |||
| GE_CHECK_GE(tmp_dim, 0); | |||
| PARSER_INT64_MULCHECK(real_size, tmp_dim); | |||
| real_size *= tmp_dim; | |||
| } | |||
| @@ -1,190 +1,190 @@ | |||
| syntax = "proto3"; | |||
| package ge.proto; | |||
| enum DataType | |||
| { | |||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||
| DT_FLOAT = 1; // float type | |||
| DT_FLOAT16 = 2; // fp16 type | |||
| DT_INT8 = 3; // int8 type | |||
| DT_UINT8 = 4; // uint8 type | |||
| DT_INT16 = 5; // int16 type | |||
| DT_UINT16 = 6; // uint16 type | |||
| DT_INT32 = 7; // | |||
| DT_INT64 = 8; // int64 type | |||
| DT_UINT32 = 9; // unsigned int32 | |||
| DT_UINT64 = 10; // unsigned int64 | |||
| DT_BOOL = 11; // bool type | |||
| DT_DOUBLE = 12; // double type | |||
| DT_STRING = 13; // string type | |||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||
| DT_COMPLEX64 = 16; // complex64 type | |||
| DT_COMPLEX128 = 17; // complex128 type | |||
| DT_QINT8 = 18; // qint8 type | |||
| DT_QINT16 = 19; // qint16 type | |||
| DT_QINT32 = 20; // qint32 type | |||
| DT_QUINT8 = 21; // quint8 type | |||
| DT_QUINT16 = 22; // quint16 type | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| } | |||
| message AttrDef | |||
| { | |||
| message ListValue | |||
| { | |||
| enum ListValueType{ | |||
| VT_LIST_NONE = 0; | |||
| VT_LIST_STRING = 1; | |||
| VT_LIST_INT = 2; | |||
| VT_LIST_FLOAT = 3; | |||
| VT_LIST_BOOL = 4; | |||
| VT_LIST_BYTES = 5; | |||
| VT_LIST_TENSOR_DESC = 6; | |||
| VT_LIST_TENSOR = 7; | |||
| VT_LIST_GRAPH = 8; | |||
| VT_LIST_NAMED_ATTRS = 9; | |||
| VT_LIST_DATA_TYPE = 10; | |||
| } | |||
| repeated bytes s = 2; // "list(string)" | |||
| repeated int64 i = 3; // "list(int)" | |||
| repeated float f = 4; // "list(float)" | |||
| repeated bool b = 5; // "list(bool)" | |||
| repeated bytes bt = 7; | |||
| repeated TensorDescriptor td = 8; | |||
| repeated TensorDef t = 9; | |||
| repeated GraphDef g = 10; | |||
| repeated NamedAttrs na = 11; | |||
| repeated int64 dt = 12; // list ge::DataType | |||
| ListValueType val_type = 20; | |||
| } | |||
| message ListListInt{ | |||
| message ListInt{ | |||
| repeated int64 list_i = 1; // list int | |||
| } | |||
| repeated ListInt list_list_i = 1; // list list int | |||
| } | |||
| oneof value | |||
| { | |||
| bytes s = 2; // "string" | |||
| int64 i = 3; // "int" | |||
| float f = 4; // "float" | |||
| bool b = 5; // "bool" | |||
| bytes bt = 7; | |||
| ListValue list = 1; // any "list(...)" | |||
| NamedAttrs func = 10; // Used to support attr nesting | |||
| TensorDescriptor td = 11; // GeTensorDesc type | |||
| TensorDef t = 12; // GeTensor type | |||
| GraphDef g = 13; // Graph type | |||
| ListListInt list_list_int = 14; // List List Int type | |||
| int64 dt = 15; // ge::DataType | |||
| } | |||
| } | |||
| // A list of attr names and their values. The whole list is attached | |||
| // with a string name. E.g., MatMul[T=float]. | |||
| message NamedAttrs | |||
| { | |||
| string name = 1; | |||
| map<string, AttrDef> attr = 2; | |||
| } | |||
| // Shape / dimension description, using row-major order | |||
| message ShapeDef | |||
| { | |||
| repeated int64 dim = 1; // Size of each dimension | |||
| } | |||
| // Multidimensional data description | |||
| message TensorDescriptor | |||
| { | |||
| string name = 1; // Optional parameter, tensor name | |||
| DataType dtype = 2; // tensor datatype | |||
| ShapeDef shape = 3; // Shape / dimension | |||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||
| bool has_out_attr = 9; | |||
| int64 size = 10; | |||
| int64 weight_size = 11; | |||
| bool reuse_input = 12; | |||
| bool output_tensor = 13; | |||
| string device_type = 14; | |||
| bool input_tensor =15; | |||
| int64 real_dim_cnt = 16; | |||
| int64 reuse_input_index = 17; | |||
| int64 data_offset = 18; | |||
| int64 cmps_size = 19; | |||
| string cmps_tab = 20; | |||
| int64 cmps_tab_offset = 21; | |||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||
| } | |||
| // GeTensor definition | |||
| message TensorDef | |||
| { | |||
| TensorDescriptor desc = 1; // Tensor description | |||
| bytes data = 2; // Tensor data | |||
| } | |||
| // Operator description | |||
| message OpDef | |||
| { | |||
| string name = 1; // name | |||
| string type = 2; // type | |||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||
| bool has_out_attr = 20; | |||
| int64 id = 21; | |||
| int64 stream_id =22; | |||
| repeated string input_name = 23; | |||
| repeated string src_name = 24; | |||
| repeated int64 src_index = 25; | |||
| repeated string dst_name = 26; | |||
| repeated int64 dst_index = 27; | |||
| repeated int64 input_i = 28; | |||
| repeated int64 output_i = 29; | |||
| repeated int64 workspace = 30; | |||
| repeated int64 workspace_bytes = 31; | |||
| repeated bool is_input_const = 32; | |||
| repeated TensorDescriptor input_desc = 33; | |||
| repeated TensorDescriptor output_desc = 34; | |||
| repeated string subgraph_name = 35; | |||
| } | |||
| // Graph definition | |||
| message GraphDef | |||
| { | |||
| string name = 1; // name | |||
| repeated string input = 4; // Graph input | |||
| repeated string output = 5; // Graph output | |||
| repeated OpDef op = 6; // List of operators | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| // model definition | |||
| message ModelDef | |||
| { | |||
| string name = 1; // name | |||
| uint32 version = 2; // IR Proto verion | |||
| string custom_version = 3; // User model version number, passed in by user | |||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| syntax = "proto3"; | |||
| package ge.proto; | |||
| enum DataType | |||
| { | |||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||
| DT_FLOAT = 1; // float type | |||
| DT_FLOAT16 = 2; // fp16 type | |||
| DT_INT8 = 3; // int8 type | |||
| DT_UINT8 = 4; // uint8 type | |||
| DT_INT16 = 5; // int16 type | |||
| DT_UINT16 = 6; // uint16 type | |||
| DT_INT32 = 7; // | |||
| DT_INT64 = 8; // int64 type | |||
| DT_UINT32 = 9; // unsigned int32 | |||
| DT_UINT64 = 10; // unsigned int64 | |||
| DT_BOOL = 11; // bool type | |||
| DT_DOUBLE = 12; // double type | |||
| DT_STRING = 13; // string type | |||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||
| DT_COMPLEX64 = 16; // complex64 type | |||
| DT_COMPLEX128 = 17; // complex128 type | |||
| DT_QINT8 = 18; // qint8 type | |||
| DT_QINT16 = 19; // qint16 type | |||
| DT_QINT32 = 20; // qint32 type | |||
| DT_QUINT8 = 21; // quint8 type | |||
| DT_QUINT16 = 22; // quint16 type | |||
| DT_RESOURCE = 23; // resource type | |||
| DT_STRING_REF = 24; // string_ref type | |||
| DT_DUAL = 25; /**< dual output type */ | |||
| } | |||
| message AttrDef | |||
| { | |||
| message ListValue | |||
| { | |||
| enum ListValueType{ | |||
| VT_LIST_NONE = 0; | |||
| VT_LIST_STRING = 1; | |||
| VT_LIST_INT = 2; | |||
| VT_LIST_FLOAT = 3; | |||
| VT_LIST_BOOL = 4; | |||
| VT_LIST_BYTES = 5; | |||
| VT_LIST_TENSOR_DESC = 6; | |||
| VT_LIST_TENSOR = 7; | |||
| VT_LIST_GRAPH = 8; | |||
| VT_LIST_NAMED_ATTRS = 9; | |||
| VT_LIST_DATA_TYPE = 10; | |||
| } | |||
| repeated bytes s = 2; // "list(string)" | |||
| repeated int64 i = 3; // "list(int)" | |||
| repeated float f = 4; // "list(float)" | |||
| repeated bool b = 5; // "list(bool)" | |||
| repeated bytes bt = 7; | |||
| repeated TensorDescriptor td = 8; | |||
| repeated TensorDef t = 9; | |||
| repeated GraphDef g = 10; | |||
| repeated NamedAttrs na = 11; | |||
| repeated int64 dt = 12; // list ge::DataType | |||
| ListValueType val_type = 20; | |||
| } | |||
| message ListListInt{ | |||
| message ListInt{ | |||
| repeated int64 list_i = 1; // list int | |||
| } | |||
| repeated ListInt list_list_i = 1; // list list int | |||
| } | |||
| oneof value | |||
| { | |||
| bytes s = 2; // "string" | |||
| int64 i = 3; // "int" | |||
| float f = 4; // "float" | |||
| bool b = 5; // "bool" | |||
| bytes bt = 7; | |||
| ListValue list = 1; // any "list(...)" | |||
| NamedAttrs func = 10; // Used to support attr nesting | |||
| TensorDescriptor td = 11; // GeTensorDesc type | |||
| TensorDef t = 12; // GeTensor type | |||
| GraphDef g = 13; // Graph type | |||
| ListListInt list_list_int = 14; // List List Int type | |||
| int64 dt = 15; // ge::DataType | |||
| } | |||
| } | |||
| // A list of attr names and their values. The whole list is attached | |||
| // with a string name. E.g., MatMul[T=float]. | |||
| message NamedAttrs | |||
| { | |||
| string name = 1; | |||
| map<string, AttrDef> attr = 2; | |||
| } | |||
| // Shape / dimension description, using row-major order | |||
| message ShapeDef | |||
| { | |||
| repeated int64 dim = 1; // Size of each dimension | |||
| } | |||
| // Multidimensional data description | |||
| message TensorDescriptor | |||
| { | |||
| string name = 1; // Optional parameter, tensor name | |||
| DataType dtype = 2; // tensor datatype | |||
| ShapeDef shape = 3; // Shape / dimension | |||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||
| bool has_out_attr = 9; | |||
| int64 size = 10; | |||
| int64 weight_size = 11; | |||
| bool reuse_input = 12; | |||
| bool output_tensor = 13; | |||
| string device_type = 14; | |||
| bool input_tensor =15; | |||
| int64 real_dim_cnt = 16; | |||
| int64 reuse_input_index = 17; | |||
| int64 data_offset = 18; | |||
| int64 cmps_size = 19; | |||
| string cmps_tab = 20; | |||
| int64 cmps_tab_offset = 21; | |||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||
| } | |||
| // GeTensor definition | |||
| message TensorDef | |||
| { | |||
| TensorDescriptor desc = 1; // Tensor description | |||
| bytes data = 2; // Tensor data | |||
| } | |||
| // Operator description | |||
| message OpDef | |||
| { | |||
| string name = 1; // name | |||
| string type = 2; // type | |||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||
| bool has_out_attr = 20; | |||
| int64 id = 21; | |||
| int64 stream_id =22; | |||
| repeated string input_name = 23; | |||
| repeated string src_name = 24; | |||
| repeated int64 src_index = 25; | |||
| repeated string dst_name = 26; | |||
| repeated int64 dst_index = 27; | |||
| repeated int64 input_i = 28; | |||
| repeated int64 output_i = 29; | |||
| repeated int64 workspace = 30; | |||
| repeated int64 workspace_bytes = 31; | |||
| repeated bool is_input_const = 32; | |||
| repeated TensorDescriptor input_desc = 33; | |||
| repeated TensorDescriptor output_desc = 34; | |||
| repeated string subgraph_name = 35; | |||
| } | |||
| // Graph definition | |||
| message GraphDef | |||
| { | |||
| string name = 1; // name | |||
| repeated string input = 4; // Graph input | |||
| repeated string output = 5; // Graph output | |||
| repeated OpDef op = 6; // List of operators | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| // model definition | |||
| message ModelDef | |||
| { | |||
| string name = 1; // name | |||
| uint32 version = 2; // IR Proto verion | |||
| string custom_version = 3; // User model version number, passed in by user | |||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||
| map<string, AttrDef> attr = 11; // Extended field | |||
| } | |||
| @@ -1,136 +1,136 @@ | |||
| syntax = "proto3"; | |||
| package domi; | |||
| message InsertNewOps { | |||
| repeated AippOpParams aipp_op = 1; | |||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||
| } | |||
| message AippOpParams { | |||
| enum InputFormat { | |||
| UNDEFINED = 0; | |||
| YUV420SP_U8 = 1; | |||
| XRGB8888_U8 = 2; | |||
| RGB888_U8 = 3; | |||
| YUV400_U8 = 4; | |||
| NC1HWC0DI_FP16 = 5; | |||
| NC1HWC0DI_S8 = 6; | |||
| ARGB8888_U8 = 7; | |||
| YUYV_U8 = 8; | |||
| YUV422SP_U8 = 9; | |||
| AYUV444_U8 = 10; | |||
| RAW10 = 11; | |||
| RAW12 = 12; | |||
| RAW16 = 13; | |||
| RAW24 = 14; | |||
| RGB16 = 15; | |||
| RGB20 = 16; | |||
| RGB24 = 17; | |||
| RGB8_IR = 18; | |||
| RGB16_IR = 19; | |||
| RGB24_IR = 20; | |||
| } | |||
| enum AippMode { | |||
| undefined = 0; | |||
| static = 1; | |||
| dynamic = 2; | |||
| } | |||
| // AIPP模式,区分静态AIPP和动态AIPP | |||
| AippMode aipp_mode = 1; | |||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||
| uint32 related_input_rank = 2; | |||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||
| // 配置值 <= Data算子输出边的个数。 | |||
| repeated uint32 input_edge_idx = 3; | |||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||
| uint32 max_src_image_size = 4; | |||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||
| bool support_rotation = 5; | |||
| // [End] 动态AIPP参数 | |||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||
| InputFormat input_format = 51; | |||
| bool csc_switch = 52; | |||
| float cpadding_value = 53; | |||
| bool rbuv_swap_switch = 54; | |||
| bool ax_swap_switch = 55; | |||
| bool single_line_mode = 56; | |||
| int32 src_image_size_w = 57; | |||
| int32 src_image_size_h = 58; | |||
| bool crop = 59; | |||
| int32 load_start_pos_w = 60; | |||
| int32 load_start_pos_h = 61; | |||
| int32 crop_size_w = 62; | |||
| int32 crop_size_h = 63; | |||
| bool resize = 64; | |||
| int32 resize_output_w = 65; | |||
| int32 resize_output_h = 66; | |||
| bool padding = 67; | |||
| int32 left_padding_size = 68; | |||
| int32 right_padding_size = 69; | |||
| int32 top_padding_size = 70; | |||
| int32 bottom_padding_size = 71; | |||
| int32 mean_chn_0 = 10; | |||
| int32 mean_chn_1 = 11; | |||
| int32 mean_chn_2 = 12; | |||
| int32 mean_chn_3 = 19; | |||
| float min_chn_0 = 13; | |||
| float min_chn_1 = 14; | |||
| float min_chn_2 = 15; | |||
| float min_chn_3 = 20; | |||
| repeated float var_reci_chn_0 = 16; | |||
| repeated float var_reci_chn_1 = 17; | |||
| repeated float var_reci_chn_2 = 18; | |||
| repeated float var_reci_chn_3 = 21; | |||
| repeated int32 matrix_r0c0 = 30; | |||
| repeated int32 matrix_r0c1 = 31; | |||
| repeated int32 matrix_r0c2 = 32; | |||
| repeated int32 matrix_r1c0 = 33; | |||
| repeated int32 matrix_r1c1 = 34; | |||
| repeated int32 matrix_r1c2 = 35; | |||
| repeated int32 matrix_r2c0 = 36; | |||
| repeated int32 matrix_r2c1 = 37; | |||
| repeated int32 matrix_r2c2 = 38; | |||
| repeated int32 output_bias_0 = 39; | |||
| repeated int32 output_bias_1 = 40; | |||
| repeated int32 output_bias_2 = 41; | |||
| repeated int32 input_bias_0 = 42; | |||
| repeated int32 input_bias_1 = 43; | |||
| repeated int32 input_bias_2 = 44; | |||
| // [End] 静态AIPP参数 | |||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||
| uint32 raw_rgbir_to_f16_n = 45; | |||
| } | |||
| message MultiShapeOpParams { | |||
| enum MultiShapeMode { | |||
| batch = 0; //动态batch | |||
| resolution = 1; //动态分辨率,扩展用 | |||
| } | |||
| MultiShapeMode mode = 1; //算子模式 | |||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||
| } | |||
| syntax = "proto3"; | |||
| package domi; | |||
| message InsertNewOps { | |||
| repeated AippOpParams aipp_op = 1; | |||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||
| } | |||
| message AippOpParams { | |||
| enum InputFormat { | |||
| UNDEFINED = 0; | |||
| YUV420SP_U8 = 1; | |||
| XRGB8888_U8 = 2; | |||
| RGB888_U8 = 3; | |||
| YUV400_U8 = 4; | |||
| NC1HWC0DI_FP16 = 5; | |||
| NC1HWC0DI_S8 = 6; | |||
| ARGB8888_U8 = 7; | |||
| YUYV_U8 = 8; | |||
| YUV422SP_U8 = 9; | |||
| AYUV444_U8 = 10; | |||
| RAW10 = 11; | |||
| RAW12 = 12; | |||
| RAW16 = 13; | |||
| RAW24 = 14; | |||
| RGB16 = 15; | |||
| RGB20 = 16; | |||
| RGB24 = 17; | |||
| RGB8_IR = 18; | |||
| RGB16_IR = 19; | |||
| RGB24_IR = 20; | |||
| } | |||
| enum AippMode { | |||
| undefined = 0; | |||
| static = 1; | |||
| dynamic = 2; | |||
| } | |||
| // AIPP模式,区分静态AIPP和动态AIPP | |||
| AippMode aipp_mode = 1; | |||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||
| uint32 related_input_rank = 2; | |||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||
| // 配置值 <= Data算子输出边的个数。 | |||
| repeated uint32 input_edge_idx = 3; | |||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||
| uint32 max_src_image_size = 4; | |||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||
| bool support_rotation = 5; | |||
| // [End] 动态AIPP参数 | |||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||
| InputFormat input_format = 51; | |||
| bool csc_switch = 52; | |||
| float cpadding_value = 53; | |||
| bool rbuv_swap_switch = 54; | |||
| bool ax_swap_switch = 55; | |||
| bool single_line_mode = 56; | |||
| int32 src_image_size_w = 57; | |||
| int32 src_image_size_h = 58; | |||
| bool crop = 59; | |||
| int32 load_start_pos_w = 60; | |||
| int32 load_start_pos_h = 61; | |||
| int32 crop_size_w = 62; | |||
| int32 crop_size_h = 63; | |||
| bool resize = 64; | |||
| int32 resize_output_w = 65; | |||
| int32 resize_output_h = 66; | |||
| bool padding = 67; | |||
| int32 left_padding_size = 68; | |||
| int32 right_padding_size = 69; | |||
| int32 top_padding_size = 70; | |||
| int32 bottom_padding_size = 71; | |||
| int32 mean_chn_0 = 10; | |||
| int32 mean_chn_1 = 11; | |||
| int32 mean_chn_2 = 12; | |||
| int32 mean_chn_3 = 19; | |||
| float min_chn_0 = 13; | |||
| float min_chn_1 = 14; | |||
| float min_chn_2 = 15; | |||
| float min_chn_3 = 20; | |||
| repeated float var_reci_chn_0 = 16; | |||
| repeated float var_reci_chn_1 = 17; | |||
| repeated float var_reci_chn_2 = 18; | |||
| repeated float var_reci_chn_3 = 21; | |||
| repeated int32 matrix_r0c0 = 30; | |||
| repeated int32 matrix_r0c1 = 31; | |||
| repeated int32 matrix_r0c2 = 32; | |||
| repeated int32 matrix_r1c0 = 33; | |||
| repeated int32 matrix_r1c1 = 34; | |||
| repeated int32 matrix_r1c2 = 35; | |||
| repeated int32 matrix_r2c0 = 36; | |||
| repeated int32 matrix_r2c1 = 37; | |||
| repeated int32 matrix_r2c2 = 38; | |||
| repeated int32 output_bias_0 = 39; | |||
| repeated int32 output_bias_1 = 40; | |||
| repeated int32 output_bias_2 = 41; | |||
| repeated int32 input_bias_0 = 42; | |||
| repeated int32 input_bias_1 = 43; | |||
| repeated int32 input_bias_2 = 44; | |||
| // [End] 静态AIPP参数 | |||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||
| uint32 raw_rgbir_to_f16_n = 45; | |||
| } | |||
| message MultiShapeOpParams { | |||
| enum MultiShapeMode { | |||
| batch = 0; //动态batch | |||
| resolution = 1; //动态分辨率,扩展用 | |||
| } | |||
| MultiShapeMode mode = 1; //算子模式 | |||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||
| } | |||
| @@ -0,0 +1,6 @@ | |||
| inc_path := $(shell pwd)/inc/external/ | |||
| out_path := $(shell pwd)/out/graph/lib64/stub/ | |||
| stub_path := $(shell pwd)/common/graph/stub/ | |||
| mkdir_stub := $(shell mkdir -p $(out_path)) | |||
| graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) | |||
| @@ -0,0 +1,577 @@ | |||
| import os | |||
| import re | |||
| import sys | |||
| import logging | |||
| logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', | |||
| level=logging.INFO) | |||
| """ | |||
| this attr is used for symbol table visible | |||
| """ | |||
| GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' | |||
| """ | |||
| generate stub func body by return type | |||
| """ | |||
| RETURN_STATEMENTS = { | |||
| 'graphStatus': ' std::cout << "[ERROR]: stub library libgraph or libge_compiler cannot be used for execution, please check your "\n ' | |||
| ' << "environment variables and compilation options to make sure you use the correct library."\n' | |||
| ' << std::endl;\n' | |||
| ' return ACL_ERROR_COMPILING_STUB_MODE;', | |||
| 'Status': ' return SUCCESS;', | |||
| 'Graph': ' return Graph();', | |||
| 'Graph&': ' return *this;', | |||
| 'Format': ' return Format();', | |||
| 'Format&': ' return *this;', | |||
| 'Shape': ' return Shape();', | |||
| 'Shape&': ' return *this;', | |||
| 'TensorDesc': ' return TensorDesc();', | |||
| 'TensorDesc&': ' return *this;', | |||
| 'Tensor': ' return Tensor();', | |||
| 'Tensor&': ' return *this;', | |||
| 'Operator': ' return Operator();', | |||
| 'Operator&': ' return *this;', | |||
| 'Ptr': ' return nullptr;', | |||
| 'std::string': ' return "";', | |||
| 'std::string&': ' return "";', | |||
| 'string': ' return "";', | |||
| 'int': ' return 0;', | |||
| 'DataType': ' return DT_FLOAT;', | |||
| 'InferenceContextPtr': ' return nullptr;', | |||
| 'SubgraphBuilder': ' return nullptr;', | |||
| 'OperatorImplPtr': ' return nullptr;', | |||
| 'OutHandler': ' return nullptr;', | |||
| 'std::vector<std::string>': ' return {};', | |||
| 'std::vector<int64_t>': ' return {};', | |||
| 'std::map': ' return {};', | |||
| 'uint32_t': ' return 0;', | |||
| 'int64_t': ' return 0;', | |||
| 'uint64_t': ' return 0;', | |||
| 'size_t': ' return 0;', | |||
| 'float': ' return 0.0f;', | |||
| 'bool': ' return false;', | |||
| } | |||
| """ | |||
| max code len per line in hua_wei software programming specifications | |||
| """ | |||
| max_code_len_per_line = 100 | |||
| """ | |||
| white_list_for_debug, include_dir_key_words is to | |||
| determines which header files to generate cc files from | |||
| when DEBUG on | |||
| """ | |||
| white_list_for_debug = ["tensorflow_parser.h", "caffe_parser.h"] | |||
| include_dir_key_words = ["parser"] | |||
| DEBUG = True | |||
| def need_generate_func(func_line): | |||
| """ | |||
| :param func_line: | |||
| :return: | |||
| """ | |||
| if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ | |||
| or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): | |||
| return False | |||
| return True | |||
| def file_endswith_white_list_suffix(file): | |||
| """ | |||
| :param file: | |||
| :return: | |||
| """ | |||
| if DEBUG: | |||
| for suffix in white_list_for_debug: | |||
| if file.endswith(suffix): | |||
| return True | |||
| return False | |||
| else: | |||
| return True | |||
| """ | |||
| belows are patterns used for analyse .h file | |||
| """ | |||
| # pattern function | |||
| pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after | |||
| ([a-zA-Z~_] # void int likely | |||
| .* | |||
| [)] #we find ) | |||
| (?!.*{) # we do not want the case int abc() const | |||
| .*) | |||
| (;.*) #we want to find ; and after for we will replace these later | |||
| \n$ | |||
| """, re.VERBOSE | re.MULTILINE | re.DOTALL) | |||
| # pattern comment | |||
| pattern_comment = re.compile(r'^\s*//') | |||
| pattern_comment_2_start = re.compile(r'^\s*/[*]') | |||
| pattern_comment_2_end = re.compile(r'[*]/\s*$') | |||
| # pattern define | |||
| pattern_define = re.compile(r'^\s*#define') | |||
| pattern_define_return = re.compile(r'\\\s*$') | |||
| # blank line | |||
| pattern_blank_line = re.compile(r'^\s*$') | |||
| # virtual,explicit,friend,static | |||
| pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') | |||
| # lead space | |||
| pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') | |||
| # functions will have patterns such as func ( or func( | |||
| # but operator is an exception; the class name is preceded by an operator, and the above mode does not exist | |||
| # format like :"operator = ()" | |||
| pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') | |||
| # template | |||
| pattern_template = re.compile(r'^\s*template') | |||
| pattern_template_end = re.compile(r'>\s*$') | |||
| # namespace | |||
| pattern_namespace = re.compile(r'namespace.*{') | |||
| # class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with | |||
| pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR) | |||
| # {} | |||
| pattern_start = re.compile('{') | |||
| pattern_end = re.compile('}') | |||
| line_index = 0 | |||
| class H2CC(object): | |||
| def __init__(self, input_file, output_file, shared_includes_content): | |||
| """ | |||
| :param input_file: | |||
| :param output_file: | |||
| :param shared_includes_content: | |||
| """ | |||
| self.input_file = input_file | |||
| self.output_file = output_file | |||
| self.shared_includes_content = shared_includes_content | |||
| self.line_index = 0 | |||
| self.input_fd = open(self.input_file, 'r') | |||
| self.input_content = self.input_fd.readlines() | |||
| self.output_fd = open(self.output_file, 'w') | |||
| # The state may be normal_now(in the middle of {}),class_now,namespace_now | |||
| self.stack = [] | |||
| self.stack_class = [] | |||
| self.stack_template = [] | |||
| # record funcs generated by h2cc func | |||
| self.func_list_exist = [] | |||
| def __del__(self): | |||
| self.input_fd.close() | |||
| self.output_fd.close() | |||
| del self.stack | |||
| del self.stack_class | |||
| del self.stack_template | |||
| del self.func_list_exist | |||
| def just_skip(self): | |||
| # skip blank line or comment | |||
| if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( | |||
| self.input_content[self.line_index]): # /n or comment using // | |||
| self.line_index += 1 | |||
| if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /* | |||
| while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */ | |||
| self.line_index += 1 | |||
| self.line_index += 1 | |||
| # skip define | |||
| if pattern_define.search(self.input_content[self.line_index]): | |||
| while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search( | |||
| self.input_content[self.line_index]): | |||
| self.line_index += 1 | |||
| self.line_index += 1 | |||
| def write_inc_content(self): | |||
| for shared_include_content in self.shared_includes_content: | |||
| self.output_fd.write(shared_include_content) | |||
| def h2cc(self): | |||
| """ | |||
| :return: | |||
| """ | |||
| logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file) | |||
| global pattern_comment | |||
| global pattern_comment_2_start | |||
| global pattern_comment_2_end | |||
| global pattern_blank_line | |||
| global pattern_func | |||
| global pattern_keyword | |||
| global pattern_leading_space | |||
| global pattern_func_name | |||
| global pattern_template | |||
| global pattern_template_end | |||
| global pattern_namespace | |||
| global pattern_class | |||
| global pattern_start | |||
| global pattern_end | |||
| global line_index | |||
| # write inc content | |||
| self.write_inc_content() | |||
| # core processing cycle, process the input .h file by line | |||
| while self.line_index < len(self.input_content): | |||
| # handle comment and blank line | |||
| self.just_skip() | |||
| # match namespace | |||
| self.handle_namespace() | |||
| # match template | |||
| template_string = self.handle_template() | |||
| # match class | |||
| line = self.input_content[self.line_index] | |||
| match_class = pattern_class.search(line) | |||
| match_start = pattern_start.search(line) | |||
| handle_class_result = self.handle_class(template_string, line, match_start, match_class) | |||
| if handle_class_result == "continue": | |||
| continue | |||
| # match "}" | |||
| handle_stack_result = self.handle_stack(match_start) | |||
| if handle_stack_result == "continue": | |||
| continue | |||
| # handle func | |||
| handle_func1_result, line, start_i = self.handle_func1(line) | |||
| if handle_func1_result == "continue": | |||
| continue | |||
| # here means func is found | |||
| # delete key word | |||
| line = pattern_keyword.sub('', line) | |||
| logging.info("line[%s]", line) | |||
| # Class member function | |||
| # if friend we will not add class name | |||
| friend_match = re.search('friend ', line) | |||
| if len(self.stack_class) > 0 and not friend_match: | |||
| line, func_name = self.handle_class_member_func(line, template_string) | |||
| # Normal functions | |||
| else: | |||
| line, func_name = self.handle_normal_func(line, template_string) | |||
| need_generate = need_generate_func(line) | |||
| # func body | |||
| line += self.implement_function(line) | |||
| # comment | |||
| line = self.gen_comment(start_i) + line | |||
| # write to out file | |||
| self.write_func_content(line, func_name, need_generate) | |||
| # next loop | |||
| self.line_index += 1 | |||
| logging.info('Added %s functions', len(self.func_list_exist)) | |||
| logging.info('Successfully converted,please see ' + self.output_file) | |||
| def handle_func1(self, line): | |||
| """ | |||
| :param line: | |||
| :return: | |||
| """ | |||
| find1 = re.search('[(]', line) | |||
| if not find1: | |||
| self.line_index += 1 | |||
| return "continue", line, None | |||
| find2 = re.search('[)]', line) | |||
| start_i = self.line_index | |||
| space_match = pattern_leading_space.search(line) | |||
| # deal with | |||
| # int abc(int a, | |||
| # int b) | |||
| if find1 and (not find2): | |||
| self.line_index += 1 | |||
| line2 = self.input_content[self.line_index] | |||
| if space_match: | |||
| line2 = re.sub('^' + space_match.group(1), '', line2) | |||
| line += line2 | |||
| while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): | |||
| self.line_index += 1 | |||
| line2 = self.input_content[self.line_index] | |||
| line2 = re.sub('^' + space_match.group(1), '', line2) | |||
| line += line2 | |||
| match_start = pattern_start.search(self.input_content[self.line_index]) | |||
| match_end = pattern_end.search(self.input_content[self.line_index]) | |||
| if match_start: # like ) { or ) {} int the last line | |||
| if not match_end: | |||
| self.stack.append('normal_now') | |||
| ii = start_i | |||
| while ii <= self.line_index: | |||
| ii += 1 | |||
| self.line_index += 1 | |||
| return "continue", line, start_i | |||
| logging.info("line[%s]", line) | |||
| # ' int abc();'->'int abc()' | |||
| (line, match) = pattern_func.subn(r'\2\n', line) | |||
| logging.info("line[%s]", line) | |||
| # deal with case: | |||
| # 'int \n abc(int a, int b)' | |||
| if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): | |||
| line = self.input_content[start_i - 1] + line | |||
| line = line.lstrip() | |||
| if not match: | |||
| self.line_index += 1 | |||
| return "continue", line, start_i | |||
| return "pass", line, start_i | |||
| def handle_stack(self, match_start): | |||
| """ | |||
| :param match_start: | |||
| :return: | |||
| """ | |||
| line = self.input_content[self.line_index] | |||
| match_end = pattern_end.search(line) | |||
| if match_start: | |||
| self.stack.append('normal_now') | |||
| if match_end: | |||
| top_status = self.stack.pop() | |||
| if top_status == 'namespace_now': | |||
| self.output_fd.write(line + '\n') | |||
| elif top_status == 'class_now': | |||
| self.stack_class.pop() | |||
| self.stack_template.pop() | |||
| if match_start or match_end: | |||
| self.line_index += 1 | |||
| return "continue" | |||
| if len(self.stack) > 0 and self.stack[-1] == 'normal_now': | |||
| self.line_index += 1 | |||
| return "continue" | |||
| return "pass" | |||
| def handle_class(self, template_string, line, match_start, match_class): | |||
| """ | |||
| :param template_string: | |||
| :param line: | |||
| :param match_start: | |||
| :param match_class: | |||
| :return: | |||
| """ | |||
| if match_class: # we face a class | |||
| self.stack_template.append(template_string) | |||
| self.stack.append('class_now') | |||
| class_name = match_class.group(3) | |||
| # class template specializations: class A<u,Node<u> > | |||
| if '<' in class_name: | |||
| k = line.index('<') | |||
| fit = 1 | |||
| for ii in range(k + 1, len(line)): | |||
| if line[ii] == '<': | |||
| fit += 1 | |||
| if line[ii] == '>': | |||
| fit -= 1 | |||
| if fit == 0: | |||
| break | |||
| class_name += line[k + 1:ii + 1] | |||
| logging.info('class_name[%s]', class_name) | |||
| self.stack_class.append(class_name) | |||
| while not match_start: | |||
| self.line_index += 1 | |||
| line = self.input_content[self.line_index] | |||
| match_start = pattern_start.search(line) | |||
| self.line_index += 1 | |||
| return "continue" | |||
| return "pass" | |||
| def handle_template(self): | |||
| line = self.input_content[self.line_index] | |||
| match_template = pattern_template.search(line) | |||
| template_string = '' | |||
| if match_template: | |||
| match_template_end = pattern_template_end.search(line) | |||
| template_string = line | |||
| while not match_template_end: | |||
| self.line_index += 1 | |||
| line = self.input_content[self.line_index] | |||
| template_string += line | |||
| match_template_end = pattern_template_end.search(line) | |||
| self.line_index += 1 | |||
| return template_string | |||
| def handle_namespace(self): | |||
| line = self.input_content[self.line_index] | |||
| match_namespace = pattern_namespace.search(line) | |||
| if match_namespace: # we face namespace | |||
| self.output_fd.write(line + '\n') | |||
| self.stack.append('namespace_now') | |||
| self.line_index += 1 | |||
| def handle_normal_func(self, line, template_string): | |||
| template_line = '' | |||
| self.stack_template.append(template_string) | |||
| if self.stack_template[-1] != '': | |||
| template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) | |||
| # change '< class T = a, class U = A(3)>' to '<class T, class U>' | |||
| template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) | |||
| template_line = re.sub(r'\s*=.*,', ',', template_line) | |||
| template_line = re.sub(r'\s*=.*', '', template_line) | |||
| line = re.sub(r'\s*=.*,', ',', line) | |||
| line = re.sub(r'\s*=.*\)', ')', line) | |||
| line = template_line + line | |||
| self.stack_template.pop() | |||
| func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() | |||
| logging.info("line[%s]", line) | |||
| logging.info("func_name[%s]", func_name) | |||
| return line, func_name | |||
| def handle_class_member_func(self, line, template_string): | |||
| template_line = '' | |||
| x = '' | |||
| if template_string != '': | |||
| template_string = re.sub(r'\s*template', 'template', template_string) | |||
| template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) | |||
| template_string = re.sub(r'\s*=.*,', ',', template_string) | |||
| template_string = re.sub(r'\s*=.*', '', template_string) | |||
| if self.stack_template[-1] != '': | |||
| if not (re.search(r'<\s*>', stack_template[-1])): | |||
| template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) | |||
| if not (re.search(r'<.*>', self.stack_class[-1])): | |||
| # for x we get like template<class T, typename U> -> <T,U> | |||
| x = re.sub(r'template\s*<', '<', template_line) # remove template -> <class T, typename U> | |||
| x = re.sub(r'\n', '', x) | |||
| x = re.sub(r'\s*=.*,', ',', x) | |||
| x = re.sub(r'\s*=.*\>', '>', x) | |||
| x = x.rstrip() # remove \n | |||
| x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '', | |||
| x) # remove class,typename -> <T, U> | |||
| x = re.sub(r'<\s+', '<', x) | |||
| x = re.sub(r'\s+>', '>', x) | |||
| x = re.sub(r'\s+,', ',', x) | |||
| x = re.sub(r',\s+', ', ', x) | |||
| line = re.sub(r'\s*=\s+0', '', line) | |||
| line = re.sub(r'\s*=\s+.*,', ',', line) | |||
| line = re.sub(r'\s*=\s+.*\)', ')', line) | |||
| logging.info("x[%s]\nline[%s]", x, line) | |||
| # if the function is long, void ABC::foo() | |||
| # breaks into two lines void ABC::\n foo() | |||
| temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1) | |||
| if len(temp_line) > max_code_len_per_line: | |||
| line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1) | |||
| else: | |||
| line = temp_line | |||
| logging.info("line[%s]", line) | |||
| # add template as the above if there is one | |||
| template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) | |||
| template_line = re.sub(r'\s*=.*,', ',', template_line) | |||
| template_line = re.sub(r'\s*=.*', '', template_line) | |||
| line = template_line + template_string + line | |||
| func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() | |||
| logging.info("line[%s]", line) | |||
| logging.info("func_name[%s]", func_name) | |||
| return line, func_name | |||
| def write_func_content(self, content, func_name, need_generate): | |||
| if not (func_name in self.func_list_exist) and need_generate: | |||
| self.output_fd.write(content) | |||
| self.func_list_exist.append(func_name) | |||
| logging.info('add func:[%s]', func_name) | |||
| def gen_comment(self, start_i): | |||
| comment_line = '' | |||
| # Function comments are on top of function declarations, copy them over | |||
| k = start_i - 1 # one line before this func start | |||
| if pattern_template.search(self.input_content[k]): | |||
| k -= 1 | |||
| if pattern_comment_2_end.search(self.input_content[k]): | |||
| comment_line = self.input_content[k].lstrip() | |||
| while not pattern_comment_2_start.search(self.input_content[k]): | |||
| k -= 1 | |||
| comment_line = self.input_content[k].lstrip() + comment_line | |||
| else: | |||
| for j in range(k, 0, -1): | |||
| c_line = self.input_content[j] | |||
| if pattern_comment.search(c_line): | |||
| c_line = re.sub(r'\s*//', '//', c_line) | |||
| comment_line = c_line + comment_line | |||
| else: | |||
| break | |||
| return comment_line | |||
| @staticmethod | |||
| def implement_function(func): | |||
| function_def = '' | |||
| function_def += '{\n' | |||
| all_items = func.split() | |||
| start = 0 | |||
| return_type = all_items[start] | |||
| if return_type == "const": | |||
| start += 1 | |||
| return_type = all_items[start] | |||
| if return_type.startswith(('std::map', 'std::set', 'std::vector')): | |||
| return_type = "std::map" | |||
| if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): | |||
| return_type = "Ptr" | |||
| if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): | |||
| return_type += "&" | |||
| if RETURN_STATEMENTS.__contains__(return_type): | |||
| function_def += RETURN_STATEMENTS[return_type] | |||
| else: | |||
| logging.warning("Unhandled return type[%s]", return_type) | |||
| function_def += '\n' | |||
| function_def += '}\n' | |||
| function_def += '\n' | |||
| return function_def | |||
| def collect_header_files(path): | |||
| """ | |||
| :param path: | |||
| :return: | |||
| """ | |||
| header_files = [] | |||
| shared_includes_content = [] | |||
| for root, dirs, files in os.walk(path): | |||
| files.sort() | |||
| for file in files: | |||
| if file.find("git") >= 0: | |||
| continue | |||
| if not file.endswith('.h'): | |||
| continue | |||
| file_path = os.path.join(root, file) | |||
| file_path = file_path.replace('\\', '/') | |||
| header_files.append(file_path) | |||
| include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) | |||
| shared_includes_content.append(include_str) | |||
| # for acl error code | |||
| shared_includes_content.append('#include <iostream>\n') | |||
| shared_includes_content.append('const int ACL_ERROR_COMPILING_STUB_MODE = 100039;\n') | |||
| return header_files, shared_includes_content | |||
| def generate_stub_file(inc_dir, out_cc_dir): | |||
| """ | |||
| :param inc_dir: | |||
| :param out_cc_dir: | |||
| :return: | |||
| """ | |||
| target_header_files, shared_includes_content = collect_header_files(inc_dir) | |||
| for header_file in target_header_files: | |||
| if not file_endswith_white_list_suffix(header_file): | |||
| continue | |||
| cc_file = re.sub('.h*$', '.cc', header_file) | |||
| h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) | |||
| h_2_cc.h2cc() | |||
| def gen_code(inc_dir, out_cc_dir): | |||
| """ | |||
| :param inc_dir: | |||
| :param out_cc_dir: | |||
| :return: | |||
| """ | |||
| if not inc_dir.endswith('/'): | |||
| inc_dir += '/' | |||
| if not out_cc_dir.endswith('/'): | |||
| out_cc_dir += '/' | |||
| for include_dir_key_word in include_dir_key_words: | |||
| generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) | |||
| if __name__ == '__main__': | |||
| inc_dir = sys.argv[1] | |||
| out_cc_dir = sys.argv[2] | |||
| gen_code(inc_dir, out_cc_dir) | |||