diff --git a/.gitmodules b/.gitmodules index fe72c1e..03812e0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "metadef"] path = metadef url = https://gitee.com/ascend/metadef.git + branch = development diff --git a/build.sh b/build.sh index 655cd78..f3a54b3 100644 --- a/build.sh +++ b/build.sh @@ -86,6 +86,8 @@ checkopts() } checkopts "$@" +git submodule update --init metadef + mk_dir() { local create_dir="$1" # the target to make @@ -152,12 +154,41 @@ generate_package() cd "${BASEPATH}" PARSER_LIB_PATH="lib" + ACL_PATH="acllib/lib64" + FWK_PATH="fwkacllib/lib64" + ATC_PATH="atc/lib64" + + COMMON_LIB=("libgraph.so" "libregister.so") + PARSER_LIB=("lib_caffe_parser.so" "libfmk_onnx_parser.so" "libfmk_parser.so" "libparser_common.so") + + rm -rf ${OUTPUT_PATH:?}/${FWK_PATH}/ + rm -rf ${OUTPUT_PATH:?}/${ACL_PATH}/ + rm -rf ${OUTPUT_PATH:?}/${ATC_PATH}/ + + mk_dir "${OUTPUT_PATH}/${FWK_PATH}" + mk_dir "${OUTPUT_PATH}/${ATC_PATH}" + mk_dir "${OUTPUT_PATH}/${ACL_PATH}" find output/ -name parser_lib.tar -exec rm {} \; cd "${OUTPUT_PATH}" - tar -cf parser_lib.tar "${PARSER_LIB_PATH}" + for lib in "${PARSER_LIB[@]}"; + do + find ${OUTPUT_PATH}/${PARSER_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH} \; + find ${OUTPUT_PATH}/${PARSER_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; + done + + for lib in "${COMMON_LIB[@]}"; + do + find ${OUTPUT_PATH}/${PARSER_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH} \; + find ${OUTPUT_PATH}/${PARSER_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; + done + + find ${OUTPUT_PATH}/${PARSER_LIB_PATH} -maxdepth 1 -name "libc_sec.so" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; + find ${OUTPUT_PATH}/${PARSER_LIB_PATH} -maxdepth 1 -name "libregister.a" -exec cp -f {} ${OUTPUT_PATH}/${ACL_PATH} \; + + tar -cf parser_lib.tar fwkacllib acllib atc } if [[ "X$ENABLE_GE_UT" = "Xoff" ]]; then diff --git a/metadef b/metadef index 765d857..cc9de48 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 765d85777ec10fe819799cd2c1b60db49be7c749 +Subproject commit cc9de48a7779cf95cab90a23db608421a691fd12 diff --git a/parser/CMakeLists.txt b/parser/CMakeLists.txt index 3b17728..db446e1 100644 --- a/parser/CMakeLists.txt +++ b/parser/CMakeLists.txt @@ -59,24 +59,12 @@ target_include_directories(fmk_parser PRIVATE ${PARSER_DIR} ${PARSER_DIR}/inc ${PARSER_DIR}/parser - ${PARSER_DIR}/../ge - ${PARSER_DIR}/../inc - ${PARSER_DIR}/../inc/framework - ${PARSER_DIR}/../inc/common/util - ${PARSER_DIR}/../inc/external - ${PARSER_DIR}/../third_party/fwkacllib/inc ${METADEF_DIR}/inc ${METADEF_DIR}/inc/graph ${METADEF_DIR}/inc/register ${METADEF_DIR}/inc/external ${METADEF_DIR}/inc/external/graph ${METADEF_DIR}/inc/external/register - #### independent compile ##### - ${METADEF_DIR}/third_party/graphengine/ge - ${METADEF_DIR}/third_party/graphengine/inc - ${METADEF_DIR}/third_party/graphengine/inc/framework - ${METADEF_DIR}/third_party/graphengine/inc/external - ${METADEF_DIR}/third_party/fwkacllib/inc #### temp #### ${PARSER_DIR}/../graphengine/inc/common/util ${PARSER_DIR}/../graphengine/inc/external @@ -84,7 +72,20 @@ target_include_directories(fmk_parser PRIVATE ${PARSER_DIR}/../graphengine/inc ${PARSER_DIR}/../graphengine/ge ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge + ${CMAKE_BINARY_DIR}/proto/ge + #### blue zone compile ##### + ${PARSER_DIR}/../ge + ${PARSER_DIR}/../inc + ${PARSER_DIR}/../inc/framework + ${PARSER_DIR}/../inc/common/util + ${PARSER_DIR}/../inc/external + ${PARSER_DIR}/../third_party/fwkacllib/inc + #### blue independent compile ##### + ${METADEF_DIR}/third_party/graphengine/ge + ${METADEF_DIR}/third_party/graphengine/inc + ${METADEF_DIR}/third_party/graphengine/inc/framework + ${METADEF_DIR}/third_party/graphengine/inc/external + ${METADEF_DIR}/third_party/fwkacllib/inc ) target_link_libraries(fmk_parser diff --git a/parser/caffe/caffe_custom_parser_adapter.cc b/parser/caffe/caffe_custom_parser_adapter.cc index d250f38..8cc1c97 100644 --- a/parser/caffe/caffe_custom_parser_adapter.cc +++ b/parser/caffe/caffe_custom_parser_adapter.cc @@ -18,17 +18,16 @@ #include #include #include "common/debug/log.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/ge/ge_util.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/omg/omg_inner_types.h" -#include "framework/omg/parser/parser_types.h" #include "graph/utils/graph_utils.h" #include "parser/common/op_parser_factory.h" #include "register/op_registry.h" -using domi::ParseParamByOpFunc; using domi::ParseParamFunc; +using domi::ParseParamByOpFunc; using std::vector; namespace ge { @@ -55,8 +54,8 @@ Status CaffeCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPt } Status CaffeCustomParserAdapter::ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest) { - GELOGI("Caffe custom op begin to params: layer name = %s, layer type= %s ", op_src.GetName().c_str(), - op_src.GetOpType().c_str()); + GELOGI("Caffe custom op begin to params: layer name = %s, layer type= %s ", + op_src.GetName().c_str(), op_src.GetOpType().c_str()); GE_CHECK_NOTNULL(op_dest); ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType()); @@ -86,7 +85,7 @@ Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr bool update_in_turn = (static_cast(op->GetAllInputsSize()) == (layer->bottom_size() + layer->blobs_size())); int start_pos = layer->bottom_size(); for (int i = 0; i < layer->blobs_size(); ++i) { - ge::GeTensorPtr weight = ge::parser::MakeShared(); + ge::GeTensorPtr weight = ge::MakeShared(); GE_CHECK_NOTNULL(weight); GE_CHK_STATUS_RET(ConvertWeight(layer->blobs(i), layer->name(), weight), "Convert blobs(%d) for layer %s failed", i, layer->name().c_str()); @@ -98,14 +97,14 @@ Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr bias_en = fc_params_src.bias_term();); auto bias_shape = weight->MutableTensorDesc().GetShape(); // The num 0, 1, 2, 3 represet the dim index. - bool matched = bias_en && bias_shape.GetDimNum() == static_cast(ge::parser::DIM_DEFAULT_SIZE) && + bool matched = bias_en && bias_shape.GetDimNum() == static_cast(ge::DIM_DEFAULT_SIZE) && bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1 && bias_shape.GetDim(2) == 1; if (matched) { weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(3)})); } matched = layer->type() == kInnerProduct && i == 0 && - bias_shape.GetDimNum() == static_cast(ge::parser::DIM_DEFAULT_SIZE) && - bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1; + bias_shape.GetDimNum() == static_cast(ge::DIM_DEFAULT_SIZE) && bias_shape.GetDim(0) == 1 && + bias_shape.GetDim(1) == 1; if (matched) { weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(2), bias_shape.GetDim(3)})); } diff --git a/parser/caffe/caffe_data_parser.cc b/parser/caffe/caffe_data_parser.cc index a155e7a..948ff99 100644 --- a/parser/caffe/caffe_data_parser.cc +++ b/parser/caffe/caffe_data_parser.cc @@ -18,15 +18,13 @@ #include #include #include "common/debug/log.h" -#include "framework/omg/parser/parser_types.h" +#include "common/types.h" #include "common/util.h" #include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" -#include "framework/omg/parser/parser_inner_ctx.h" +#include "omg/omg_inner_types.h" #include "parser/common/op_parser_factory.h" -using namespace ge::parser; - namespace ge { Status CaffeDataParser::GetOutputDesc(const string &name, int dim_size, const std::vector &input_dims, ge::OpDescPtr &op) { @@ -50,10 +48,10 @@ Status CaffeDataParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { GE_CHECK_NOTNULL(layer); GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); - if (layer->type() == ge::parser::INPUT_TYPE) { + if (layer->type() == ge::INPUT_TYPE) { GE_CHK_STATUS_RET(ParseParamsForInput(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed", layer->name().c_str(), layer->type().c_str()); - } else if(layer->type() == ge::parser::DUMMY_DATA) { + } else if(layer->type() == ge::DUMMY_DATA) { GE_CHK_STATUS_RET(ParseParamsForDummyData(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed", layer->name().c_str(), layer->type().c_str()); } else { @@ -77,12 +75,14 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l } for (int i = 0; i < input_param.shape_size(); i++) { const domi::caffe::BlobShape &blob_shape = input_param.shape(i); + vector shape; - unordered_map> &shape_map = GetParserContext().input_dims; + unordered_map> &shape_map = domi::GetContext().input_dims; std::vector model_dims; for (auto &blob_shape_dim_temp : blob_shape.dim()) { model_dims.push_back(blob_shape_dim_temp); } + string name = layer->name(); GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), "Get output desc failed in layer %s", @@ -90,7 +90,7 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l } } else { // Get from external input - const ge::ParserContext &ctx = GetParserContext(); + const ge::OmgContext &ctx = domi::GetContext(); std::unordered_map> input_dims = ctx.input_dims; string name = layer->name(); auto search = input_dims.find(name); @@ -124,7 +124,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete const domi::caffe::BlobShape &blob_shape = dummy_data_param.shape(i); vector shape; - unordered_map> &shape_map = GetParserContext().input_dims; + unordered_map> &shape_map = domi::GetContext().input_dims; std::vector model_dims; for (auto &blob_shape_dim_temp : blob_shape.dim()) { model_dims.push_back(blob_shape_dim_temp); @@ -137,7 +137,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete } } else { // Get from external input - const ge::ParserContext &ctx = GetParserContext(); + const ge::OmgContext &ctx = domi::GetContext(); std::unordered_map> input_dims = ctx.input_dims; string name = layer->name(); auto search = input_dims.find(name); diff --git a/parser/caffe/caffe_op_parser.cc b/parser/caffe/caffe_op_parser.cc index 63d0b6e..b7df7da 100644 --- a/parser/caffe/caffe_op_parser.cc +++ b/parser/caffe/caffe_op_parser.cc @@ -18,9 +18,6 @@ #include #include "parser/common/op_parser_factory.h" #include "common/util/error_manager/error_manager.h" -#include "framework/omg/parser/parser_types.h" - -using namespace ge::parser; using domi::CAFFE; diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index 36d07b3..0e97efd 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -20,16 +20,16 @@ #include #include #include -#include "parser/common/convert/pb2json.h" +#include "common/convert/pb2json.h" #include "common/debug/log.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/ge/ge_util.h" +#include "common/model_saver.h" #include "common/op_map.h" +#include "common/util.h" #include "common/util/error_manager/error_manager.h" -#include "common/ge_types.h" #include "common/string_util.h" #include "external/graph/operator_factory.h" #include "external/parser/caffe_parser.h" -#include "external/ge/ge_api_types.h" #include "framework/common/debug/ge_log.h" #include "graph/optimize/common/params.h" #include "graph/utils/graph_utils.h" @@ -46,8 +46,6 @@ #include "parser/caffe/caffe_op_parser.h" #include "parser/common/op_parser_factory.h" #include "parser/common/pre_checker.h" -#include "framework/omg/parser/parser_types.h" -#include "parser/common/model_saver.h" #include "parser/common/acl_graph_parser_util.h" #include "parser/common/proto_file_parser.h" #include "register/op_registry.h" @@ -57,7 +55,7 @@ using domi::caffe::NetParameter; using domi::ParseParamByOpFunc; using ge::caffe_op_map; using ge::CaffeOpParser; -using ge::parser::ModelSaver; +using ge::ModelSaver; using ge::OpParser; using ge::OpParserFactory; using ge::Pb2Json; @@ -76,7 +74,7 @@ using std::ifstream; namespace ge { graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph) { GE_CHECK_NOTNULL(model_file); - GetParserContext().type = domi::CAFFE; + domi::GetContext().type = domi::CAFFE; std::map options; options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(ge::CAFFE))); @@ -85,7 +83,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, (void)acl_graph_parse_util.AclParserInitialize(options); // Create an empty computegraph - ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared("tmpGraph"); + ge::ComputeGraphPtr compute_graph = ge::MakeShared("tmpGraph"); GE_CHECK_NOTNULL(compute_graph); graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); @@ -107,10 +105,6 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, return ret; } GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); - if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) { - GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); - return ge::FAILED; - } return ge::SUCCESS; } } // namespace ge @@ -154,15 +148,14 @@ const std::string kRepeated = "repeated"; const std::string kRequired = "required"; const std::string kCustom = "custom"; const std::string kBuiltin = "built-in"; -std::vector kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT, - ge::parser::NETOUTPUT}; +std::vector kAddTensorIrSkipNodes = {ge::DATA, ge::YOLODETECTIONOUTPUT, ge::NETOUTPUT}; const std::set kCustomProtoLayerCommonField = {"name", "type"}; const std::set kCaffeProtoLayerCommonField = {"name", "type", "bottom", "top", "phase", "loss_weight", "param", "blobs", "propagate_down", "include", "exclude"}; Status CheckPathValid(const char *model_path, const string &custom_proto, string &custom_proto_path, string &custom_proto_name) { - string path_model = ge::parser::RealPath(model_path); + string path_model = ge::RealPath(model_path); if (path_model.empty()) { ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {model_path, strerror(errno)}); GELOGE(FAILED, "Invalid path of model: %s", model_path); @@ -218,7 +211,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo domi::caffe::LayerParameter *layer = proto_message.add_layer(); GE_CHECK_NOTNULL(layer); layer->set_name(proto_message.input(i)); - layer->set_type(ge::parser::INPUT_TYPE); + layer->set_type(ge::INPUT_TYPE); layer->add_top(proto_message.input(i)); domi::caffe::InputParameter *input_param = layer->mutable_input_param(); @@ -247,7 +240,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo domi::caffe::LayerParameter *layer = proto_message.add_layer(); GE_CHECK_NOTNULL(layer); layer->set_name(proto_message.input(i)); - layer->set_type(ge::parser::INPUT_TYPE); + layer->set_type(ge::INPUT_TYPE); layer->add_top(proto_message.input(i)); domi::caffe::InputParameter *input_param = layer->mutable_input_param(); @@ -262,7 +255,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo input_data_flag = true; } } else { - const ge::ParserContext &ctx = ge::GetParserContext(); + const ge::OmgContext &ctx = domi::GetContext(); std::unordered_map> input_dims = ctx.input_dims; for (int i = 0; i < proto_message.input_size(); i++) { string name = proto_message.input(i); @@ -277,7 +270,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo domi::caffe::LayerParameter *layer = proto_message.add_layer(); GE_CHECK_NOTNULL(layer); layer->set_name(name); - layer->set_type(ge::parser::INPUT_TYPE); + layer->set_type(ge::INPUT_TYPE); layer->add_top(proto_message.input(i)); domi::caffe::InputParameter *input_param = layer->mutable_input_param(); @@ -342,7 +335,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons Status CaffeModelParser::CustomProtoParse(const char *model_path, const string &custom_proto, const string &caffe_proto, vector &operators) { - string custom_proto_path = ge::parser::RealPath(custom_proto.c_str()); + string custom_proto_path = ge::RealPath(custom_proto.c_str()); if (custom_proto_path.empty()) { GELOGW("Valid custom proto: %s does not exist, skip parsing custom proto", custom_proto.c_str()); return SUCCESS; @@ -748,27 +741,27 @@ Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection * } void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) { - auto iter_node_name = ge::GetParserContext().out_nodes_map.find(layer_name); - if (iter_node_name != ge::GetParserContext().out_nodes_map.end()) { + auto iter_node_name = domi::GetContext().out_nodes_map.find(layer_name); + if (iter_node_name != domi::GetContext().out_nodes_map.end()) { iter_node_name->second.emplace_back(top_index); } else { std::vector index_v; index_v.emplace_back(top_index); - ge::GetParserContext().out_nodes_map.emplace(layer_name, index_v); + domi::GetContext().out_nodes_map.emplace(layer_name, index_v); } - ge::GetParserContext().user_out_nodes.push_back(std::make_pair(layer_name, top_index)); + domi::GetContext().user_out_nodes.push_back(std::make_pair(layer_name, top_index)); } Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) { - if (ge::GetParserContext().user_out_nodes_top_vec.empty()) { + if (domi::GetContext().user_out_nodes_top_vec.empty()) { return SUCCESS; } - ge::GetParserContext().out_nodes_map.clear(); - ge::GetParserContext().user_out_nodes.clear(); + domi::GetContext().out_nodes_map.clear(); + domi::GetContext().user_out_nodes.clear(); int32_t layer_count = proto_message.layer_size(); const std::vector &user_out_nodes_top_vec = - ge::GetParserContext().user_out_nodes_top_vec; + domi::GetContext().user_out_nodes_top_vec; for (const auto &top_name : user_out_nodes_top_vec) { bool find_node_falg = false; @@ -807,6 +800,10 @@ Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter Status CaffeModelParser::AddBlobsToMap(const domi::caffe::LayerParameter &layer, std::map &inplace_blob_name_remapping) { + if (layer.type() == ge::NETOUTPUT) { + return SUCCESS; + } + if (layer.top_size() <= 0) { ErrorManager::GetInstance().ATCReportErrMessage("E19011", {"opname"}, {layer.name()}); GELOGE(FAILED, "The output size of layer %s needs to be greater than zero.", layer.name().c_str()); @@ -965,9 +962,9 @@ Status CaffeModelParser::AddNode(const domi::caffe::LayerParameter &layer, ge::C } else { op_type = layer.type(); // User defined duplicate name operator processing - auto m_iter = ge::GetParserContext().op_conf_map.find(op_type); + auto m_iter = domi::GetContext().op_conf_map.find(op_type); // User specified configuration item found - if (m_iter != ge::GetParserContext().op_conf_map.end()) { + if (m_iter != domi::GetContext().op_conf_map.end()) { op_type = m_iter->second; } // General layer layer, search optype @@ -1056,7 +1053,7 @@ Status CaffeModelParser::AddNode(const domi::caffe::LayerParameter &layer, ge::C Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer) { GE_CHECK_NOTNULL(op_desc); // Data node input and output tensordesc added in parserparam - if (op_desc->GetType() == ge::parser::DATA) { + if (op_desc->GetType() == ge::DATA) { return SUCCESS; } @@ -1076,7 +1073,7 @@ Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const dom } // yolo v2 YoloDetectionOutput - if (op_desc->GetType() == ge::parser::YOLODETECTIONOUTPUT) { + if (op_desc->GetType() == ge::YOLODETECTIONOUTPUT) { ge::GeTensorDesc input_tensor; GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); @@ -1085,13 +1082,41 @@ Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const dom "while it's original input num is: %d", layer.bottom_size()); } + + // Netoutput node processing + if (op_desc->GetType() == ge::NETOUTPUT) { + size_t input_output_tensor_num = 0; + if (!domi::GetContext().user_out_nodes.empty()) { + // User specified output + input_output_tensor_num = domi::GetContext().user_out_nodes.size(); + } else { + for (auto t_iter = top_blobs_map_.begin(); t_iter != top_blobs_map_.end(); t_iter++) { + auto b_iter = bottom_blobs_map_.find(t_iter->first); + // Find the output node of the network + if (b_iter == bottom_blobs_map_.end()) { + input_output_tensor_num += top_blobs_map_[t_iter->first].size(); + } + } + } + // add tensordesc + GELOGD( + "Current op type is NETOUTPUT, add additional input&output num: %zu." + "while it's original input num is: %d, output num is: %d", + input_output_tensor_num, layer.bottom_size(), output_tensor_num); + for (size_t i = 0; i < input_output_tensor_num; i++) { + ge::GeTensorDesc input_tensor; + GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); + ge::GeTensorDesc output_tensor; + GE_RETURN_IF_ERROR(op_desc->AddOutputDesc(output_tensor)); + } + } return SUCCESS; } Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, const string &op_type) { if (std::find(kAddTensorIrSkipNodes.begin(), kAddTensorIrSkipNodes.end(), op_type) != kAddTensorIrSkipNodes.end()) { - op_desc = ge::parser::MakeShared(layer.name(), op_type); + op_desc = ge::MakeShared(layer.name(), op_type); GE_CHECK_NOTNULL(op_desc); Status ret = AddTensorDescToOpDesc(op_desc, layer); if (ret != SUCCESS) { @@ -1223,8 +1248,8 @@ Status CaffeModelParser::AddEdges(ge::ComputeGraphPtr &graph) { bool CaffeModelParser::IsOutputTop(const string &op_name, const int32_t index) { bool ret = false; - auto iter = ge::GetParserContext().out_nodes_map.find(op_name); - if (iter != ge::GetParserContext().out_nodes_map.end()) { + auto iter = domi::GetContext().out_nodes_map.find(op_name); + if (iter != domi::GetContext().out_nodes_map.end()) { std::vector tmp_index_v; for (int32_t id : iter->second) { if (index == id) { @@ -1235,40 +1260,53 @@ bool CaffeModelParser::IsOutputTop(const string &op_name, const int32_t index) { } // To prevent specifying network output again in the build phase, need to delete the output node in the map list. if (ret) { - ge::GetParserContext().out_nodes_map.erase(op_name); - ge::GetParserContext().out_nodes_map.emplace(op_name, tmp_index_v); + domi::GetContext().out_nodes_map.erase(op_name); + domi::GetContext().out_nodes_map.emplace(op_name, tmp_index_v); } } return ret; } -Status CaffeModelParser::AddUserOutNodesTop() { +Status CaffeModelParser::AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + ge::NodePtr net_output_node = graph->FindFirstNodeMatchType(ge::NETOUTPUT); + if (net_output_node == nullptr) { + GELOGE(INTERNAL_ERROR, "Can not find netoutput node."); + return INTERNAL_ERROR; + } + uint32_t net_output_num = net_output_node->GetAllInDataAnchorsSize(); int32_t index = 0; - const std::vector> &user_out_nodes = ge::GetParserContext().user_out_nodes; - int net_output_num = user_out_nodes.size(); - for (const auto &out_pair : user_out_nodes) { - auto layer_iter = layer_tops_map_.find(out_pair.first); + std::vector> &user_out_nodes = domi::GetContext().user_out_nodes; + for (auto &out_pair : user_out_nodes) { + auto node_iter = node_map.find(out_pair.first); GELOGI("Add to output, node name: %s", out_pair.first.c_str()); - if (layer_iter != layer_tops_map_.end()) { - if (static_cast(out_pair.second) >= (layer_iter->second).size()) { + if (node_iter != node_map.end()) { + if ((static_cast(out_pair.second) >= node_iter->second->GetAllOutDataAnchorsSize()) || + (static_cast(index) >= net_output_num)) { ErrorManager::GetInstance().ATCReportErrMessage( "E11016", {"opname", "outputindex", "totlaloutputindex", "inputindex", "totlalinputindex"}, {out_pair.first.c_str(), std::to_string(out_pair.second), - std::to_string((layer_iter->second).size()), std::to_string(index), + std::to_string(node_iter->second->GetAllOutDataAnchorsSize()), std::to_string(index), std::to_string(net_output_num)}); GELOGE(INTERNAL_ERROR, "Add op %s to NetOutput faild, current node output index:%d should < %u. NetOutput" "input_index:%d should < %u.", - out_pair.first.c_str(), out_pair.second, (layer_iter->second).size(), index, + out_pair.first.c_str(), out_pair.second, node_iter->second->GetAllOutDataAnchorsSize(), index, net_output_num); return INTERNAL_ERROR; } - - string top_name = layer_iter->second[out_pair.second]; - auto top_node_iter = node_map.find(out_pair.first); - if (top_node_iter != node_map.end()) { - ge::GetParserContext().out_top_names.push_back(top_name); - GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str()); + GELOGD("Start add edge for user out node: From %s:%d To %s:%d.", node_iter->second->GetName().c_str(), + out_pair.second, net_output_node->GetName().c_str(), index); + ge::OutDataAnchorPtr out_archor_ptr = node_iter->second->GetOutDataAnchor(out_pair.second); + GE_CHECK_NOTNULL(out_archor_ptr); + ge::InDataAnchorPtr in_archor_ptr = net_output_node->GetInDataAnchor(index); + GE_CHECK_NOTNULL(in_archor_ptr); + if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E11013", {"opname1", "opname2"}, + {node_iter->second->GetName(), net_output_node->GetName()}); + GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", node_iter->second->GetName().c_str(), + net_output_node->GetName().c_str()); + return INTERNAL_ERROR; } ++index; } else { @@ -1280,7 +1318,13 @@ Status CaffeModelParser::AddUserOutNodesTop() { return SUCCESS; } -Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_message) { +Status CaffeModelParser::AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + ge::NodePtr node = graph->FindFirstNodeMatchType(ge::NETOUTPUT); + + GE_RETURN_WITH_LOG_IF_FALSE(node != nullptr, "Net without output, some phase failed in front."); + + int32_t index = 0; for (int32_t i = 0; i < proto_message.layer_size(); i++) { const domi::caffe::LayerParameter &layer = proto_message.layer(i); @@ -1290,7 +1334,6 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes for (int i = 0; i < layer.top_size(); i++) { string top = layer.top(i); - string top_origin = top; // Handling 'inplace' scenarios if (IsInplaceTopBlob(layer, top)) { top = RemapTopNameByLayer(layer, top, i); @@ -1312,9 +1355,21 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes auto top_node_iter = node_map.find(layer.name()); GELOGI("output in top_blob: %s", layer.name().c_str()); if (top_node_iter != node_map.end()) { - ge::GetParserContext().out_top_names.push_back(top_origin); - ge::GetParserContext().default_out_nodes.push_back(std::make_pair(layer.name(), (int32_t)i)); - GELOGI("The top of out node [%s] is [%s]", layer.name().c_str(), top_origin.c_str()); + // add edge + // Output node, output index, input node, input index + GELOGD("Start add edge for out node: From %s:%d To %s:%d.", top_node_iter->second->GetName().c_str(), i, + node->GetName().c_str(), index); + ge::OutDataAnchorPtr out_archor_ptr = top_node_iter->second->GetOutDataAnchor(i); + GE_CHECK_NOTNULL(out_archor_ptr); + ge::InDataAnchorPtr in_archor_ptr = node->GetInDataAnchor(index); + GE_CHECK_NOTNULL(in_archor_ptr); + GE_IF_BOOL_EXEC(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, + ErrorManager::GetInstance().ATCReportErrMessage( + "E11013", {"opname1", "opname2"}, {top_node_iter->second->GetName(), node->GetName()}); + GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to to op[%s].", + top_node_iter->second->GetName().c_str(), node->GetName().c_str()); + return INTERNAL_ERROR;); + index++; } } } @@ -1369,7 +1424,7 @@ Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) { // validate opname string mode = "^[A-Za-z0-9./_-]+$"; - if (!ge::parser::ValidateStr(layer.name(), mode)) { + if (!ge::ValidateStr(layer.name(), mode)) { ErrorManager::GetInstance().ATCReportErrMessage("E11018", {"opname"}, {layer.name()}); GELOGE(ge::FAILED, "Parse caffe pbtxt validate op[%s] failed, opname can only contain " @@ -1398,7 +1453,7 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co domi::caffe::NetParameter proto_message; // Get Caffe network model information - if (!ge::parser::ReadProtoFromMem(data, static_cast(size), &proto_message)) { + if (!ge::ReadProtoFromMem(data, static_cast(size), &proto_message)) { GELOGE(FAILED, "read proto from text ret fail"); return FAILED; } @@ -1428,6 +1483,12 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true; GELOGE(FAILED, "ParseInput ret fail.")); + // build output layer + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(graph->GetName() + "_" + ge::NODE_NAME_NET_OUTPUT); + layer->set_type(ge::NETOUTPUT); + int32_t layer_count = proto_message.layer_size(); std::map inplace_blob_name_remapping; // Map of operator name and occurrence times @@ -1443,7 +1504,7 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, "layer phase is train, skip this layer, name:%s, type:%s.", layer.name().c_str(), layer.type().c_str()); - CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && (input_data_flag == true)), has_error = true; + CHECK_FALSE_EXEC(!((layer.type() == ge::DATA_TYPE) && (input_data_flag == true)), has_error = true; GELOGE(FAILED, "net %s has input and data layer simultaneously.", proto_message.name().c_str())); // All layer names cannot be duplicate @@ -1492,10 +1553,10 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail."); - if (!(ge::GetParserContext().user_out_nodes.empty())) { - GE_RETURN_WITH_LOG_IF_ERROR(AddUserOutNodesTop(), "Caffe parser add top_name for user out nodes failed."); + if (!(domi::GetContext().user_out_nodes.empty())) { + GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed."); } else { - GE_RETURN_WITH_LOG_IF_ERROR(AddOutputTop(proto_message), "Caffe parser add top_name for output fail."); + GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail."); } GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail."); @@ -1597,13 +1658,19 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true; GELOGE(FAILED, "ParseInput ret fail.")); + // build output layer + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(graph->GetName() + "_" + ge::NODE_NAME_NET_OUTPUT); + layer->set_type(ge::NETOUTPUT); + int32_t layer_count = proto_message.layer_size(); - if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) { + if (!domi::GetContext().user_out_nodes_top_vec.empty()) { GELOGW("The out_put info has top_name items."); GE_RETURN_WITH_LOG_IF_ERROR(ParseOutputNodeTopInfo(proto_message), "Caffe parser parse output node-top info failed."); - ge::GetParserContext().user_out_nodes_top_vec.clear(); + domi::GetContext().user_out_nodes_top_vec.clear(); } std::map inplace_blob_name_remapping; @@ -1619,7 +1686,7 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, "layer phase is train, skip this layer, name:%s, type:%s.", layer.name().c_str(), layer.type().c_str()); - CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && (input_data_flag == true)), has_error = true; + CHECK_FALSE_EXEC(!((layer.type() == ge::DATA_TYPE) && (input_data_flag == true)), has_error = true; GELOGE(FAILED, "net %s has input and data layer simultaneously.", proto_message.name().c_str())); // All layer names cannot be duplicate @@ -1657,6 +1724,7 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap GE_RETURN_WITH_LOG_IF_ERROR(AddBlobsToMap(layer, inplace_blob_name_remapping), "Caffe parser add blobs to map ret fail."); } + // Find a layer with the same param name and save it to graph GE_RETURN_WITH_LOG_IF_ERROR(FindShareParamLayers(layer_params_map), "Caffe parser find share param layers map ret fail."); @@ -1668,12 +1736,13 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail."); - if (!(ge::GetParserContext().user_out_nodes.empty())) { - GE_RETURN_WITH_LOG_IF_ERROR(AddUserOutNodesTop(), "Caffe parser add top_name for user out nodes failed."); + if (!(domi::GetContext().user_out_nodes.empty())) { + GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed."); } else { - GE_RETURN_WITH_LOG_IF_ERROR(AddOutputTop(proto_message), "Caffe parser add top_name for output fail."); + GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail."); } GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail."); + GE_RETURN_WITH_LOG_IF_ERROR(GetLeafNodeTops(graph), "Caffe parser get out nodes top names failed."); auto nodes = graph->GetDirectNode(); GELOGI("graph node size = %zu.", nodes.size()); @@ -1766,7 +1835,7 @@ Status CaffeWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge:: // Resolve proto file to netparameter NetParameter proto; - bool success = ge::parser::ReadProtoFromArray(reinterpret_cast(data), static_cast(size), &proto); + bool success = ge::ReadProtoFromArray(reinterpret_cast(data), static_cast(size), &proto); if (!success) { GELOGE(domi::PARSE_WEIGHTS_FAILED, "ReadProto from Memory fail"); return domi::PARSE_WEIGHTS_FAILED; @@ -1814,7 +1883,7 @@ Status CaffeWeightsParser::Parse(const char *file, ge::ComputeGraphPtr &graph) { GELOGD("caffe_proto_path:%s custom_proto_path:%s", caffe_proto_path.c_str(), custom_proto_path.c_str()); string fusion_proto_file; - string custom_proto_file = ge::parser::RealPath(custom_proto_path.c_str()); + string custom_proto_file = ge::RealPath(custom_proto_path.c_str()); if (custom_proto_file.empty()) { GELOGW("custom_proto_path:%s is not existed", custom_proto_path.c_str()); fusion_proto_file = caffe_proto_path; @@ -1826,7 +1895,7 @@ Status CaffeWeightsParser::Parse(const char *file, ge::ComputeGraphPtr &graph) { } } - string fusion_proto_path = ge::parser::RealPath(fusion_proto_file.c_str()); + string fusion_proto_path = ge::RealPath(fusion_proto_file.c_str()); GELOGI("Get fusion proto file[%s]-[%s].", fusion_proto_file.c_str(), fusion_proto_path.c_str()); if (fusion_proto_path.empty()) { GELOGE(FAILED, "Fusion proto file path [%s]-[%s] is not real existed.", fusion_proto_file.c_str(), @@ -1879,7 +1948,7 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con google::protobuf::Message *message = proto->New(); GE_CHECK_NOTNULL(message); - if (!ge::parser::ReadProtoFromBinaryFile(weight_path, message)) { + if (!ge::ReadProtoFromBinaryFile(weight_path, message)) { delete message; message = nullptr; ErrorManager::GetInstance().ATCReportErrMessage( @@ -2269,7 +2338,7 @@ Status CaffeWeightsParser::CheckNodes(ge::ComputeGraphPtr &graph) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); for (const auto &in_anchor_ptr : node->GetAllInDataAnchors()) { - if (op_desc->GetType() == ge::parser::DATA || op_desc->GetType() == ge::parser::CONSTANT) { + if (op_desc->GetType() == ge::DATA || op_desc->GetType() == ge::CONSTANT) { continue; } auto index = in_anchor_ptr->GetIdx(); @@ -2384,6 +2453,27 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter ¶m, ge::Co return SUCCESS; } +Status CaffeModelParser::GetLeafNodeTops(ge::ComputeGraphPtr &graph) { + auto netout = graph->FindFirstNodeMatchType(ge::NETOUTPUT); + GE_CHECK_NOTNULL(netout); + for (const auto &in_anchor : netout->GetAllInDataAnchors()) { + auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_data_anchor); + auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_out_data_node); + int idx = peer_out_data_anchor->GetIdx(); + string node_name = peer_out_data_node->GetName(); + auto layer_iter = layer_tops_map_.find(node_name); + if (layer_iter != layer_tops_map_.end()) { + domi::GetContext().out_top_names.push_back(layer_iter->second[idx]); + GELOGI("The top of out node [%s] is [%s]", node_name.c_str(), layer_iter->second[idx].c_str()); + } else { + GELOGW("The out node [%s] can not find its top.", node_name.c_str()); + } + } + return SUCCESS; +} + Status CaffeModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { return SUCCESS; } diff --git a/parser/caffe/caffe_parser.h b/parser/caffe/caffe_parser.h index bd1520e..ef3d1f1 100644 --- a/parser/caffe/caffe_parser.h +++ b/parser/caffe/caffe_parser.h @@ -279,12 +279,12 @@ class CaffeModelParser : public domi::ModelParser { /** * @ingroup domi_omg - * @brief Add top name information to graph - * @param [in|out] proto_message + * @brief Add edge information to graph + * @param [in|out] graph graph for saving model information * @return SUCCESS add successfully * @return FAILED add failed */ - Status AddOutputTop(const domi::caffe::NetParameter &proto_message); + Status AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph); /** * @ingroup domi_omg @@ -324,7 +324,7 @@ class CaffeModelParser : public domi::ModelParser { Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, const string &op_type); - Status AddUserOutNodesTop(); + Status AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph); std::string RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name, int index); @@ -335,6 +335,8 @@ class CaffeModelParser : public domi::ModelParser { Status ParseOpParam(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op, std::shared_ptr &op_parser); + Status GetLeafNodeTops(ge::ComputeGraphPtr &graph); + void SaveOrigionLayerTops(domi::caffe::LayerParameter &layer); Status ReorderInput(domi::caffe::NetParameter &net); diff --git a/parser/caffe/caffe_reshape_parser.cc b/parser/caffe/caffe_reshape_parser.cc index 2edabbf..010f764 100644 --- a/parser/caffe/caffe_reshape_parser.cc +++ b/parser/caffe/caffe_reshape_parser.cc @@ -17,16 +17,14 @@ #include "parser/caffe/caffe_reshape_parser.h" #include #include "common/debug/log.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/ge/ge_util.h" #include "common/op/op_parser_util.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "graph/utils/graph_utils.h" #include "parser/common/op_parser_factory.h" -#include "framework/omg/parser/parser_types.h" #include "proto/om.pb.h" -using namespace ge::parser; using domi::CAFFE; namespace ge { @@ -109,7 +107,7 @@ Status CaffeReshapeParser::AddConstInput(ge::NodePtr &node) { } // construct GeTensorPtr - ge::GeTensorPtr constTensor = ge::parser::MakeShared(); + ge::GeTensorPtr constTensor = ge::MakeShared(); GE_CHECK_NOTNULL(constTensor); constTensor->SetTensorDesc(const_desc); diff --git a/parser/common/CMakeLists.txt b/parser/common/CMakeLists.txt index 9c3be13..276d2c2 100644 --- a/parser/common/CMakeLists.txt +++ b/parser/common/CMakeLists.txt @@ -41,24 +41,12 @@ target_include_directories(parser_common PRIVATE ${CMAKE_CURRENT_LIST_DIR} ${PARSER_DIR} ${PARSER_DIR}/parser - ${PARSER_DIR}/../ge - ${PARSER_DIR}/../inc - ${PARSER_DIR}/../inc/framework - ${PARSER_DIR}/../inc/common/util - ${PARSER_DIR}/../inc/external - ${PARSER_DIR}/../third_party/fwkacllib/inc ${METADEF_DIR}/inc ${METADEF_DIR}/inc/graph ${METADEF_DIR}/inc/register ${METADEF_DIR}/inc/external ${METADEF_DIR}/inc/external/graph ${METADEF_DIR}/inc/external/register - #### independent compile ##### - ${METADEF_DIR}/third_party/graphengine/ge - ${METADEF_DIR}/third_party/graphengine/inc - ${METADEF_DIR}/third_party/graphengine/inc/framework - ${METADEF_DIR}/third_party/graphengine/inc/external - ${METADEF_DIR}/third_party/fwkacllib/inc #### temp #### ${PARSER_DIR}/../graphengine/inc/common/util ${PARSER_DIR}/../graphengine/inc/external @@ -67,6 +55,19 @@ target_include_directories(parser_common PRIVATE ${PARSER_DIR}/../graphengine/ge ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/ge + #### blue zone compile ##### + ${PARSER_DIR}/../ge + ${PARSER_DIR}/../inc + ${PARSER_DIR}/../inc/framework + ${PARSER_DIR}/../inc/common/util + ${PARSER_DIR}/../inc/external + ${PARSER_DIR}/../third_party/fwkacllib/inc + #### independent compile ##### + ${METADEF_DIR}/third_party/graphengine/ge + ${METADEF_DIR}/third_party/graphengine/inc + ${METADEF_DIR}/third_party/graphengine/inc/framework + ${METADEF_DIR}/third_party/graphengine/inc/external + ${METADEF_DIR}/third_party/fwkacllib/inc ) target_link_libraries(parser_common PRIVATE diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc index 0a16d38..a2b90da 100644 --- a/parser/common/acl_graph_parser_util.cc +++ b/parser/common/acl_graph_parser_util.cc @@ -18,37 +18,20 @@ #include #include -#include -#include -#include - #include "common/string_util.h" +#include "common/types.h" #include "common/debug/log.h" +#include "common/ge/tbe_plugin_manager.h" #include "common/op/ge_op_utils.h" +#include "common/util.h" + #include "ge/ge_api_types.h" #include "graph/opsproto_manager.h" #include "omg/parser/parser_inner_ctx.h" -#include "tbe_plugin_loader.h" #include "framework/common/debug/ge_log.h" #include "parser/common/register_tbe.h" -#include "framework/omg/parser/parser_types.h" -#include "common/util/error_manager/error_manager.h" -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" - -using google::protobuf::io::CodedInputStream; -using google::protobuf::io::FileInputStream; -using google::protobuf::io::ZeroCopyInputStream; -using namespace ge::parser; namespace { -/// The maximum length of the file. -/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 -const int kMaxFileSizeLimit = INT_MAX; -const int kMaxBuffSize = 256; -const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. -const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M - static string GetSoPath() { Dl_info dl_info; if (dladdr(reinterpret_cast(&GetSoPath), &dl_info) == 0) { @@ -77,7 +60,7 @@ static void GetOpsProtoPath(string &opsproto_path) { const char *path_env = std::getenv("ASCEND_OPP_PATH"); if (path_env != nullptr) { string path = path_env; - string file_path = ge::parser::RealPath(path.c_str()); + string file_path = ge::RealPath(path.c_str()); if (file_path.empty()) { GELOGE(ge::FAILED, "File path %s is invalid.", path.c_str()); return; @@ -125,7 +108,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, std::vector &output_nodes_name) { output_nodes_name.clear(); - if (ge::GetParserContext().out_top_names.empty()) { + if (domi::GetContext().out_top_names.empty()) { // tf process, no top name. for (const auto output_node_info : output_nodes_info) { std::string node_name = output_node_info.first->GetName(); @@ -159,7 +142,7 @@ domi::Status AclGrphParseUtil::SetDefaultOutputNode(ge::Graph &graph) { AclGrphParseUtil::GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); compute_graph->SetGraphOutNodesInfo(output_nodes_info); - ge::GetParserContext().net_out_nodes = output_nodes_name; + domi::GetContext().net_out_nodes = output_nodes_name; GELOGI("Set graph %s default output node success.", graph.GetName().c_str()); return SUCCESS; } @@ -211,7 +194,7 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map= PATH_MAX) { - ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(PATH_MAX)}); - GELOGE(ge::FAILED, "Path[%s] len is too long, it must be less than %d", path, PATH_MAX); - return ""; - } - // Nullptr is returned when the path does not exist or there is no permission - // Return absolute path when path is accessible - std::string res; - char resolved_path[PATH_MAX] = {0}; - if (realpath(path, resolved_path) != nullptr) { - res = resolved_path; - } - - return res; -} - -// Get file length -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(), return -1, "input_file path is null."); - - std::string real_path = RealPath(input_file.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); - unsigned long long file_length = 0; - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, - ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, - {input_file, strerror(errno)}); - return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno)); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), - ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); - return -1, "File[%s] size is 0, not valid.", input_file.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit, - ErrorManager::GetInstance().ATCReportErrMessage( - "E19016", {"filepath", "filesize", "maxlen"}, - {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); - return -1, "File[%s] size %lld is out of limit: %d.", - input_file.c_str(), file_length, kMaxFileSizeLimit); - return static_cast(file_length); -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { - struct timeval tv{}; - int ret = gettimeofday(&tv, nullptr); - GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret); - auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds - return static_cast(total_use_time); -} - -static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr, - return false, "incorrect parameter. nullptr == proto"); - - coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit, kWarningThreshold); - return proto->ParseFromCodedStream(&coded_stream); -} - -/** @ingroup domi_common - * @brief Read all data from binary file - * @param [in] file_name File path - * @param [out] buffer The address of the output memory, which needs to be released by the caller - * @param [out] length Output memory size - * @return false fail - * @return true success - */ -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, - int &length) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file is nullptr"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), return false, "incorrect parameter. buffer is nullptr"); - - std::string real_path = RealPath(file_name); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "file path '%s' not valid", file_name); - - std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate); - if (!file.is_open()) { - GELOGE(ge::FAILED, "Read file %s failed.", file_name); - return false; - } - - length = static_cast(file.tellg()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((length <= 0), file.close(); return false, "file length <= 0"); - - file.seekg(0, std::ios::beg); - - *buffer = new(std::nothrow) char[length](); - GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, file.close(), "new an object failed."); - - file.read(*buffer, length); - file.close(); - return true; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr), - return false, - "Input parameter file or proto is nullptr!"); - - std::string real_path = RealPath(file); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), - return false, "pb file path '%s' not valid", file); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); - - std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); - if (!fs.is_open()) { - ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"}); - GELOGE(ge::FAILED, "Open real path[%s] failed.", file); - return false; - } - - google::protobuf::io::IstreamInputStream istream(&fs); - google::protobuf::io::CodedInputStream coded_stream(&istream); - - bool ret = ReadProtoFromCodedInputStream(coded_stream, proto); - - fs.close(); - - if (!ret) { - ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file}); - GELOGE(ge::FAILED, "Parse file[%s] failed.", file); - return ret; - } - - return ret; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), return false, - "incorrect parameter. proto is nullptr || data is nullptr || size is 0"); - - google::protobuf::io::CodedInputStream coded_stream(reinterpret_cast(const_cast(data)), size); - return ReadProtoFromCodedInputStream(coded_stream, proto); -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file, - google::protobuf::Message *message) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), return false, - "incorrect parameter. nullptr == file || nullptr == message"); - - std::string real_path = RealPath(file); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), - ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, - {file, strerror(errno)}); - return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, - strerror(errno)); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); - - std::ifstream fs(real_path.c_str(), std::ifstream::in); - - if (!fs.is_open()) { - ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file}); - GELOGE(ge::FAILED, - "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), file); - return false; - } - - google::protobuf::io::IstreamInputStream input(&fs); - bool ret = google::protobuf::TextFormat::Parse(&input, message); - GE_IF_BOOL_EXEC(!ret, - ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file}); - GELOGE(ret, "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, " - "please check whether the file is a valid protobuf format file.", file)); - fs.close(); - - return ret; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size, - google::protobuf::Message *message) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), return false, - "incorrect parameter. data is nullptr || message is nullptr"); - std::string str(data, static_cast(size)); - std::istringstream fs(str); - - google::protobuf::io::IstreamInputStream input(&fs); - bool ret = google::protobuf::TextFormat::Parse(&input, message); - GE_IF_BOOL_EXEC( - !ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file.")); - - return ret; -} - -/// -/// @brief get the Original Type of FrameworkOp -/// @param [in] node -/// @param [out] type -/// @return Status -/// -Status GetOriginalType(const ge::NodePtr &node, string &type) { - GE_CHECK_NOTNULL(node); - type = node->GetType(); - GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS); - GE_CHECK_NOTNULL(node->GetOpDesc()); - bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); - if (!ret) { - GELOGE(INTERNAL_ERROR, "Get FrameWorkOp original type [%s]", type.c_str()); - return INTERNAL_ERROR; - } - GELOGD("Get FrameWorkOp original type [%s]", type.c_str()); - return SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) { - char ebuff[kMaxBuffSize]; - regex_t reg; - int cflags = REG_EXTENDED | REG_NOSUB; - int ret = regcomp(®, mode.c_str(), cflags); - if (ret) { - regerror(ret, ®, ebuff, kMaxBuffSize); - GELOGW("regcomp failed, reason: %s", ebuff); - regfree(®); - return true; - } - - ret = regexec(®, str.c_str(), 0, nullptr, 0); - if (ret) { - regerror(ret, ®, ebuff, kMaxBuffSize); - GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff); - regfree(®); - return false; - } - - regfree(®); - return true; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() { - std::time_t now = std::time(nullptr); - std::tm *ptm = std::localtime(&now); - if (ptm == nullptr) { - GELOGE(ge::FAILED, "Localtime failed."); - return ""; - } - - const int kTimeBufferLen = 32; - char buffer[kTimeBufferLen + 1] = {0}; - // format: 20171122042550 - std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm); - return std::string(buffer); -} -} // namespace parser } // namespace ge diff --git a/parser/common/acl_graph_parser_util.h b/parser/common/acl_graph_parser_util.h index 8c182e1..f9a5ae6 100644 --- a/parser/common/acl_graph_parser_util.h +++ b/parser/common/acl_graph_parser_util.h @@ -19,17 +19,10 @@ #include #include -#include -#include - -#include "framework/omg/parser/parser_types.h" -#include "register/register_error_codes.h" +#include "common/types.h" #include "graph/utils/graph_utils.h" namespace ge { - -using google::protobuf::Message; - class AclGrphParseUtil { public: AclGrphParseUtil() {} @@ -45,189 +38,6 @@ class AclGrphParseUtil { void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, std::vector &output_nodes_name); }; - -namespace parser { -/// -/// @ingroup: domi_common -/// @brief: get length of file -/// @param [in] input_file: path of file -/// @return long: File length. If the file length fails to be obtained, the value -1 is returned. -/// -extern long GetFileLength(const std::string &input_file); - -/// -/// @ingroup domi_common -/// @brief Absolute path for obtaining files. -/// @param [in] path of input file -/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned -/// -std::string RealPath(const char *path); - -/// -/// @ingroup domi_common -/// @brief Obtains the absolute time (timestamp) of the current system. -/// @return Timestamp, in microseconds (US) -/// -/// -uint64_t GetCurrentTimestamp(); - -/// -/// @ingroup domi_common -/// @brief Reads all data from a binary file. -/// @param [in] file_name path of file -/// @param [out] buffer Output memory address, which needs to be released by the caller. -/// @param [out] length Output memory size -/// @return false fail -/// @return true success -/// -bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length); - -/// -/// @ingroup domi_common -/// @brief proto file in bianary format -/// @param [in] file path of proto file -/// @param [out] proto memory for storing the proto file -/// @return true success -/// @return false fail -/// -bool ReadProtoFromBinaryFile(const char *file, Message *proto); - -/// -/// @ingroup domi_common -/// @brief Reads the proto structure from an array. -/// @param [in] data proto data to be read -/// @param [in] size proto data size -/// @param [out] proto Memory for storing the proto file -/// @return true success -/// @return false fail -/// -bool ReadProtoFromArray(const void *data, int size, Message *proto); - -/// -/// @ingroup domi_proto -/// @brief Reads the proto file in the text format. -/// @param [in] file path of proto file -/// @param [out] message Memory for storing the proto file -/// @return true success -/// @return false fail -/// -bool ReadProtoFromText(const char *file, google::protobuf::Message *message); - -bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message); - -/// -/// @brief get the Original Type of FrameworkOp -/// @param [in] node -/// @param [out] type -/// @return Status -/// -domi::Status GetOriginalType(const ge::NodePtr &node, string &type); - -/// -/// @ingroup domi_common -/// @brief Check whether the file path meets the whitelist verification requirements. -/// @param [in] filePath file path -/// @param [out] result -/// -bool ValidateStr(const std::string &filePath, const std::string &mode); - -/// -/// @ingroup domi_common -/// @brief Obtains the current time string. -/// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555 -/// -std::string CurrentTimeInStr(); - -template -static inline std::shared_ptr MakeShared(Args &&... args) { - typedef typename std::remove_const::type T_nc; - std::shared_ptr ret(new (std::nothrow) T_nc(std::forward(args)...)); - return ret; -} - -/// @ingroup math_util -/// @brief check whether int64 multiplication can result in overflow -/// @param [in] a multiplicator -/// @param [in] b multiplicator -/// @return Status -inline domi::Status Int64MulCheckOverflow(int64_t a, int64_t b) { - if (a > 0) { - if (b > 0) { - if (a > (INT64_MAX / b)) { - return domi::FAILED; - } - } else { - if (b < (INT64_MIN / a)) { - return domi::FAILED; - } - } - } else { - if (b > 0) { - if (a < (INT64_MIN / b)) { - return domi::FAILED; - } - } else { - if ((a != 0) && (b < (INT64_MAX / a))) { - return domi::FAILED; - } - } - } - return domi::SUCCESS; -} - -/// @ingroup math_util -/// @brief check whether int64 multiplication can result in overflow -/// @param [in] a multiplicator -/// @param [in] b multiplicator -/// @return Status -inline domi::Status CheckInt64Uint32MulOverflow(int64_t a, uint32_t b) { - if (a == 0 || b == 0) { - return domi::SUCCESS; - } - if (a > 0) { - if (a > (INT64_MAX / b)) { - return domi::FAILED; - } - } else { - if (a < (INT64_MIN / b)) { - return domi::FAILED; - } - } - return domi::SUCCESS; -} - -#define PARSER_INT64_MULCHECK(a, b) \ - if (ge::parser::Int64MulCheckOverflow((a), (b)) != SUCCESS) { \ - GELOGW("Int64 %ld and %ld multiplication can result in overflow!", static_cast(a), \ - static_cast(b)); \ - return INTERNAL_ERROR; \ - } - -#define PARSER_INT64_UINT32_MULCHECK(a, b) \ - if (ge::parser::CheckInt64Uint32MulOverflow((a), (b)) != SUCCESS) { \ - GELOGW("Int64 %ld and UINT32 %u multiplication can result in overflow!", static_cast(a), \ - static_cast(b)); \ - return INTERNAL_ERROR; \ - } -} // namespace parser } // namespace ge -/*lint --emacro((773),GE_TIMESTAMP_START)*/ -/*lint -esym(773,GE_TIMESTAMP_START)*/ -#define PARSER_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::parser::GetCurrentTimestamp() - -#define PARSER_TIMESTAMP_END(stage, stage_name) \ - do { \ - uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \ - GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ - (endUsec_##stage - startUsec_##stage)); \ - } while (0); - -#define PARSER_TIMESTAMP_EVENT_END(stage, stage_name) \ - do { \ - uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \ - GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ - (endUsec_##stage - startUsec_##stage)); \ - } while (0); - -#endif // ACL_GRAPH_PARSE_UTIL_ +#endif // ACL_GRAPH_PARSE_UTIL_ \ No newline at end of file diff --git a/parser/common/convert/pb2json.cc b/parser/common/convert/pb2json.cc deleted file mode 100644 index af13ed2..0000000 --- a/parser/common/convert/pb2json.cc +++ /dev/null @@ -1,248 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// File: pb2json.h -// Description: This imply file for protobuf message and json interconversion - -#include "common/convert/pb2json.h" -#include -#include -#include "securec.h" -#include "framework/common/fmk_types.h" -#include "framework/common/debug/ge_log.h" - -using std::set; -using std::string; - -namespace ge { -namespace { -const int kSignificantDigits = 10; -} -// JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message, - const set &black_fields, Json &json, - bool enum2str) { - auto descriptor = message.GetDescriptor(); - auto reflection = message.GetReflection(); - if (descriptor == nullptr || reflection == nullptr) { - return; - } - - auto count = descriptor->field_count(); - - for (auto i = 0; i < count; ++i) { - const auto field = descriptor->field(i); - if (field == nullptr) { - return; - } - - // Do not display weight data - if (black_fields.find(field->name()) != black_fields.end()) { - continue; - } - - if (field->is_repeated()) { - if (reflection->FieldSize(message, field) > 0) { - RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str); - } - continue; - } - - if (!reflection->HasField(message, field)) { - continue; - } - - OneField2Json(message, field, reflection, black_fields, json, enum2str); - } -} - -void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, - const ProtobufReflection *reflection, const set &black_fields, Json &json, - bool enum2str) { - switch (field->type()) { - case ProtobufFieldDescriptor::TYPE_MESSAGE: { - const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); - if (0 != tmp_message.ByteSize()) { - Message2Json(tmp_message, black_fields, json[field->name()], enum2str); - } - break; - } - - case ProtobufFieldDescriptor::TYPE_BOOL: - json[field->name()] = reflection->GetBool(message, field); - break; - - case ProtobufFieldDescriptor::TYPE_ENUM: { - auto *enum_value_desc = reflection->GetEnum(message, field); - Enum2Json(enum_value_desc, field, enum2str, json); - break; - } - - case ProtobufFieldDescriptor::TYPE_INT32: - case ProtobufFieldDescriptor::TYPE_SINT32: - case ProtobufFieldDescriptor::TYPE_SFIXED32: - json[field->name()] = reflection->GetInt32(message, field); - break; - - case ProtobufFieldDescriptor::TYPE_UINT32: - case ProtobufFieldDescriptor::TYPE_FIXED32: - json[field->name()] = reflection->GetUInt32(message, field); - break; - - case ProtobufFieldDescriptor::TYPE_INT64: - case ProtobufFieldDescriptor::TYPE_SINT64: - case ProtobufFieldDescriptor::TYPE_SFIXED64: - json[field->name()] = reflection->GetInt64(message, field); - break; - - case ProtobufFieldDescriptor::TYPE_UINT64: - case ProtobufFieldDescriptor::TYPE_FIXED64: - json[field->name()] = reflection->GetUInt64(message, field); - break; - - case ProtobufFieldDescriptor::TYPE_FLOAT: - char str[kSignificantDigits]; - if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1){ - json[field->name()] = str; - } else { - json[field->name()] = reflection->GetFloat(message, field); - } - - break; - - case ProtobufFieldDescriptor::TYPE_STRING: - json[field->name()] = reflection->GetString(message, field); - break; - - case ProtobufFieldDescriptor::TYPE_BYTES: { - string field_name = field->name(); - string type_bytes = reflection->GetString(message, field); - json[field_name] = TypeBytes2String(field_name, type_bytes); - break; - } - - default: - break; - } -} - -string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { - if (field_name != "offset") { - return type_bytes; - } - string result = ""; - for (char temp_value : type_bytes) { - uint8_t *value = 0; - value = reinterpret_cast(&temp_value); - char str[kSignificantDigits]; - if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1){ - GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); - continue; - } - result += str; - } - return result; -} - -void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, - const ProtobufReflection *reflection, const set &black_fields, Json &json, - bool enum2str) { - if ((field == nullptr) || (reflection == nullptr)) { - Message2Json(message, black_fields, json, enum2str); - return; - } - - for (auto i = 0; i < reflection->FieldSize(message, field); ++i) { - Json tmp_json; - switch (field->type()) { - case ProtobufFieldDescriptor::TYPE_MESSAGE: { - const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); - if (0 != tmp_message.ByteSize()) { - Message2Json(tmp_message, black_fields, tmp_json, enum2str); - } - } break; - - case ProtobufFieldDescriptor::TYPE_BOOL: - tmp_json = reflection->GetRepeatedBool(message, field, i); - break; - - case ProtobufFieldDescriptor::TYPE_ENUM: { - auto *enum_value_desc = reflection->GetRepeatedEnum(message, field, i); - RepeatedEnum2Json(enum_value_desc, enum2str, tmp_json); - } break; - - case ProtobufFieldDescriptor::TYPE_INT32: - case ProtobufFieldDescriptor::TYPE_SINT32: - case ProtobufFieldDescriptor::TYPE_SFIXED32: - tmp_json = reflection->GetRepeatedInt32(message, field, i); - break; - - case ProtobufFieldDescriptor::TYPE_UINT32: - case ProtobufFieldDescriptor::TYPE_FIXED32: - tmp_json = reflection->GetRepeatedUInt32(message, field, i); - break; - - case ProtobufFieldDescriptor::TYPE_INT64: - case ProtobufFieldDescriptor::TYPE_SINT64: - case ProtobufFieldDescriptor::TYPE_SFIXED64: - tmp_json = reflection->GetRepeatedInt64(message, field, i); - break; - - case ProtobufFieldDescriptor::TYPE_UINT64: - case ProtobufFieldDescriptor::TYPE_FIXED64: - tmp_json = reflection->GetRepeatedUInt64(message, field, i); - break; - - case ProtobufFieldDescriptor::TYPE_FLOAT: - tmp_json = reflection->GetRepeatedFloat(message, field, i); - break; - - case ProtobufFieldDescriptor::TYPE_STRING: - case ProtobufFieldDescriptor::TYPE_BYTES: - tmp_json = reflection->GetRepeatedString(message, field, i); - break; - - default: - break; - } - json += tmp_json; - } -} - -void Pb2Json::Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, - bool enum2str, Json &json) { - if (enum_value_desc != nullptr) { - if (field == nullptr) { - return; - } - if (enum2str) { - json[field->name()] = enum_value_desc->name(); - } else { - json[field->name()] = enum_value_desc->number(); - } - } -} - -void Pb2Json::RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json) { - if (enum_value_desc != nullptr) { - if (enum2str) { - json = enum_value_desc->name(); - } else { - json = enum_value_desc->number(); - } - } -} -} // namespace ge diff --git a/parser/common/convert/pb2json.h b/parser/common/convert/pb2json.h deleted file mode 100644 index 7bc55b1..0000000 --- a/parser/common/convert/pb2json.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// File: pb2json.h -// Description: This header file for protobuf message and json interconversion - -#ifndef PARSER_COMMON_CONVERT_PB2JSON_H_ -#define PARSER_COMMON_CONVERT_PB2JSON_H_ -#include -#include -#include -#include -#include "google/protobuf/descriptor.h" -#include "google/protobuf/message.h" -#include "nlohmann/json.hpp" - -namespace ge { -using Json = nlohmann::json; -using ProtobufMsg = ::google::protobuf::Message; -using ProtobufReflection = ::google::protobuf::Reflection; -using ProtobufFieldDescriptor = ::google::protobuf::FieldDescriptor; -using ProtobufDescriptor = ::google::protobuf::Descriptor; -using ProtobufEnumValueDescriptor = ::google::protobuf::EnumValueDescriptor; - -class Pb2Json { - public: - /** - * @ingroup domi_omg - * @brief Transfer protobuf object to JSON object - * @param [out] json Converted JSON object - * @return void success - * @author - */ - static void Message2Json(const ProtobufMsg &message, const std::set &black_fields, Json &json, - bool enum2str = false); - - protected: - static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, - const ProtobufReflection *reflection, const std::set &black_fields, - Json &json, bool enum2str); - - static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, - bool enum2str, Json &json); - - static void RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json); - - static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, - const ProtobufReflection *reflection, const std::set &black_fields, Json &json, - bool enum2str); - - static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes); -}; -} // namespace ge - -#endif // PARSER_COMMON_CONVERT_PB2JSON_H_ diff --git a/parser/common/data_op_parser.cc b/parser/common/data_op_parser.cc index 1078b0e..c79d8a7 100644 --- a/parser/common/data_op_parser.cc +++ b/parser/common/data_op_parser.cc @@ -18,7 +18,7 @@ #include #include "common/debug/log.h" #include "common/op/ge_op_utils.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/math/math_util.h" #include "common/util.h" #include "graph/utils/type_utils.h" #include "omg/omg.h" @@ -36,7 +36,7 @@ FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector & GE_RETURN_WITH_LOG_IF_FALSE(op != nullptr, "ParseShape failed for data_op, op is null"); const string &data_op_name = op->GetName(); - GetParserContext().input_dims.emplace(data_op_name, shape); + domi::GetContext().input_dims.emplace(data_op_name, shape); int64_t attr_type = 0; ge::DataType data_type; @@ -51,7 +51,7 @@ FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector & ge::GeTensorDesc i_tensor_desc; ge::GeTensorDesc o_tensor_desc; - const unordered_map &input_nodes_format_map = GetParserContext().input_nodes_format_map; + const unordered_map &input_nodes_format_map = domi::GetContext().input_nodes_format_map; auto map_iter = input_nodes_format_map.find(data_op_name); if (map_iter != input_nodes_format_map.end() && map_iter->second == domi::DOMI_TENSOR_NC1HWC0) { // Input 5D NC1HWC0 @@ -80,9 +80,9 @@ FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector & "Init ND Output Tensor failed"); } } - i_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); - i_tensor_desc.SetOriginFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); - o_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); + i_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(domi::GetContext().format)); + i_tensor_desc.SetOriginFormat(ge::TypeUtils::DomiFormatToFormat(domi::GetContext().format)); + o_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(domi::GetContext().format)); if (op->AddInputDesc(i_tensor_desc) != ge::GRAPH_SUCCESS) { GELOGE(domi::INTERNAL_ERROR, "AddInputDesc failed for op %s.", op->GetName().c_str()); return FAILED; @@ -128,10 +128,10 @@ Status DataOpParser::InitNDTensor(const vector &shape, ge::DataType dat } uint32_t type_size = 0; if (ge::TypeUtils::GetDataTypeLength(data_type, type_size)) { - PARSER_INT64_UINT32_MULCHECK(size, type_size); + FMK_INT64_UINT32_MULCHECK(size, type_size); size *= type_size; } else { - PARSER_INT64_UINT32_MULCHECK(size, static_cast(sizeof(float))); + FMK_INT64_UINT32_MULCHECK(size, static_cast(sizeof(float))); size *= sizeof(float); } ge::TensorUtils::SetSize(tensor_desc, size); @@ -169,7 +169,7 @@ Status DataOpParser::InitInputTensor(const vector &shape, ge::GeTensorD if (input.GetShape().GetDim(0) != -1) { size = input.GetShape().GetShapeSize(); } - PARSER_INT64_UINT32_MULCHECK(size, static_cast(sizeof(float))); + FMK_INT64_UINT32_MULCHECK(size, static_cast(sizeof(float))); ge::TensorUtils::SetSize(input, size * sizeof(float)); return SUCCESS; diff --git a/parser/common/data_op_parser.h b/parser/common/data_op_parser.h index 53bab18..9528853 100644 --- a/parser/common/data_op_parser.h +++ b/parser/common/data_op_parser.h @@ -21,7 +21,7 @@ #include #include "common/debug/log.h" #include "common/op/attr_value_util.h" -#include "framework/omg/parser/parser_types.h" +#include "common/types.h" #include "omg/omg_inner_types.h" #include "proto/om.pb.h" diff --git a/parser/common/model_saver.cc b/parser/common/model_saver.cc deleted file mode 100644 index fc810ad..0000000 --- a/parser/common/model_saver.cc +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "parser/common/model_saver.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" -#include "common/util/error_manager/error_manager.h" -#include "mmpa/mmpa_api.h" - -namespace { -const int kFileOpSuccess = 0; -} // namespace - -namespace ge { -namespace parser { -const uint32_t kInteval = 2; - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFile(const char *file_path, - const Json &model) { - Status ret = SUCCESS; - if (file_path == nullptr || SUCCESS != CheckPath(file_path)) { - GELOGE(FAILED, "Check output file failed."); - return FAILED; - } - std::string model_str; - try { - model_str = model.dump(kInteval, ' ', false, Json::error_handler_t::ignore); - } catch (std::exception &e) { - ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()}); - GELOGE(FAILED, "Failed to convert JSON to string, reason: %s.", e.what()); - return FAILED; - } catch (...) { - ErrorManager::GetInstance().ATCReportErrMessage("E19008"); - GELOGE(FAILED, "Failed to convert JSON to string."); - return FAILED; - } - - char real_path[PATH_MAX] = {0}; - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path) >= PATH_MAX, return FAILED, "file path is too long!"); - if (realpath(file_path, real_path) == nullptr) { - GELOGI("File %s does not exit, it will be created.", file_path); - } - - // Open file - mode_t mode = S_IRUSR | S_IWUSR; - int32_t fd = mmOpen2(real_path, O_RDWR | O_CREAT | O_TRUNC, mode); - if (fd == EN_ERROR || fd == EN_INVALID_PARAM) { - ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file_path, strerror(errno)}); - GELOGE(FAILED, "Open file[%s] failed. %s", file_path, strerror(errno)); - return FAILED; - } - const char *model_char = model_str.c_str(); - uint32_t len = static_cast(model_str.length()); - // Write data to file - mmSsize_t mmpa_ret = mmWrite(fd, const_cast((const void *)model_char), len); - if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19004", {"file", "errmsg"}, {file_path, strerror(errno)}); - // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose - GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno)); - ret = FAILED; - } - // Close file - if (mmClose(fd) != EN_OK) { - GELOGE(FAILED, "Close file failed."); - ret = FAILED; - } - return ret; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::CheckPath(const std::string &file_path) { - // Determine file path length - if (file_path.size() >= PATH_MAX) { - GELOGE(FAILED, "Path is too long:%zu", file_path.size()); - return FAILED; - } - - // Find the last separator - int path_split_pos = static_cast(file_path.size() - 1); - for (; path_split_pos >= 0; path_split_pos--) { - if (file_path[path_split_pos] == '\\' || file_path[path_split_pos] == '/') { - break; - } - } - - if (path_split_pos == 0) { - return SUCCESS; - } - - // If there is a path before the file name, create the path - if (path_split_pos != -1) { - if (CreateDirectory(std::string(file_path).substr(0, static_cast(path_split_pos))) != kFileOpSuccess) { - GELOGE(FAILED, "CreateDirectory failed, file path:%s.", file_path.c_str()); - return FAILED; - } - } - - return SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int ModelSaver::CreateDirectory(const std::string &directory_path) { - GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); - auto dir_path_len = directory_path.length(); - if (dir_path_len >= PATH_MAX) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19002", {"filepath", "size"}, {directory_path, std::to_string(PATH_MAX)}); - GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), PATH_MAX); - return -1; - } - char tmp_dir_path[PATH_MAX] = {0}; - for (size_t i = 0; i < dir_path_len; i++) { - tmp_dir_path[i] = directory_path[i]; - if ((tmp_dir_path[i] == '\\') || (tmp_dir_path[i] == '/')) { - if (access(tmp_dir_path, F_OK) != 0) { - int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 - if (ret != 0) { - if (errno != EEXIST) { - ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); - GELOGW("Can not create directory %s. Make sure the directory exists and writable.", - directory_path.c_str()); - return ret; - } - } - } - } - } - int32_t ret = mmMkdir(const_cast(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 - if (ret != 0) { - if (errno != EEXIST) { - ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); - GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); - return ret; - } - } - return 0; -} - -} // namespace parser -} // namespace ge \ No newline at end of file diff --git a/parser/common/model_saver.h b/parser/common/model_saver.h deleted file mode 100644 index bc31dba..0000000 --- a/parser/common/model_saver.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARSER_COMMON_FILE_SAVER_H_ -#define PARSER_COMMON_FILE_SAVER_H_ - -#include - -#include "ge/ge_api_error_codes.h" -#include "register/register_types.h" -#include "nlohmann/json.hpp" - -namespace ge { -namespace parser { -using Json = nlohmann::json; -using std::string; - -class ModelSaver { -public: - /** - * @ingroup domi_common - * @brief Save JSON object to file - * @param [in] file_path File output path - * @param [in] model json object - * @return Status result - */ - static Status SaveJsonToFile(const char *file_path, const Json &model); - -private: - /// - /// @ingroup domi_common - /// @brief Check validity of the file path - /// @return Status result - /// - static Status CheckPath(const string &file_path); - - static int CreateDirectory(const std::string &directory_path); -}; -} // namespace parser -} // namespace ge - -#endif //PARSER_COMMON_FILE_SAVER_H_ diff --git a/parser/common/module.mk b/parser/common/module.mk index 4eb3e74..5cb51ac 100644 --- a/parser/common/module.mk +++ b/parser/common/module.mk @@ -18,40 +18,34 @@ COMMON_LOCAL_SRC_FILES := \ register_tbe.cc \ parser_api.cc \ parser_inner_ctx.cc \ + acl_graph_parser_util.cc\ proto_file_parser.cc \ - acl_graph_parser_util.cc \ - tbe_plugin_loader.cc \ - model_saver.cc \ + ../../graph/passes/pass_manager.cc \ + ../../graph/common/omg_util.cc \ ../tensorflow/tensorflow_custom_parser_adapter.cc \ ../tensorflow/tensorflow_fusion_custom_parser_adapter.cc \ ../tensorflow/tensorflow_fusion_op_parser.cc \ ../tensorflow/tensorflow_util.cc \ - convert/pb2json.cc \ + ../../common/convert/pb2json.cc \ op_def/ir_pb_converter.cc \ op_def/defs.cc \ op_def/op_schema.cc \ op_def/operator.cc \ op_map.cc \ - parser_types.cc \ - pass_manager.cc \ - parser_fp16_t.cc \ - thread_pool.cc \ FMK_COMMON_SRC_FILES := \ -# ../../common/fmk_error_codes.cc \ - ../../common/auth/cipher.cc \ - ../../common/context/ctx.cc \ - ../../graph/passes/pass_manager.cc \ - ../../graph/common/omg_util.cc \ ../../common/types.cc \ - ../../common/auth/file_saver.cc \ ../../common/util.cc \ ../../common/model_saver.cc \ + ../../common/fmk_error_codes.cc \ ../../common/fp16_t.cc \ ../../common/thread_pool.cc \ + ../../common/auth/file_saver.cc \ + ../../common/auth/cipher.cc \ + ../../common/context/ctx.cc \ LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) -LOCAL_SRC_FILES += $(FMK_COMMON_SRC_FILES) +#LOCAL_SRC_FILES += $(FMK_COMMON_SRC_FILES) LOCAL_C_INCLUDES := \ proto/om.proto \ @@ -73,10 +67,9 @@ LOCAL_C_INCLUDES := \ $(TOPDIR)inc/external/graph \ $(TOPDIR)inc/framework \ $(TOPDIR)inc/common/util \ - $(TOPDIR)graphengine/ge \ - $(TOPDIR)graphengine/ge/common \ - $(TOPDIR)parser/parser \ - $(TOPDIR)parser \ + $(TOPDIR)framework/domi \ + $(TOPDIR)framework/domi/common \ + $(TOPDIR)framework/domi/parser \ $(TOPDIR)third_party/json/include \ $(TOPDIR)third_party/protobuf/include \ libc_sec/include \ @@ -90,6 +83,7 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ liberror_manager \ libregister \ + libge_common \ LOCAL_LDFLAGS := -lrt -ldl diff --git a/parser/common/op_def/constant_op.h b/parser/common/op_def/constant_op.h index 29549e5..f2d7fa9 100644 --- a/parser/common/op_def/constant_op.h +++ b/parser/common/op_def/constant_op.h @@ -18,7 +18,7 @@ #ifndef DOMI_OP_CONSTANT_OP_H_ #define DOMI_OP_CONSTANT_OP_H_ #include "parser/common/op_def/operator.h" -#include "framework/omg/parser/parser_types.h" +#include "framework/common/types.h" namespace ge { class ConstantOperator : public ParserOperator { diff --git a/parser/common/op_def/ir_pb_converter.cc b/parser/common/op_def/ir_pb_converter.cc index 8b5a8d4..17f5825 100644 --- a/parser/common/op_def/ir_pb_converter.cc +++ b/parser/common/op_def/ir_pb_converter.cc @@ -23,7 +23,7 @@ #include "graph/ge_tensor.h" #include "graph/buffer.h" #include "framework/common/debug/ge_log.h" -#include "framework/omg/parser/parser_types.h" +#include "framework/common/types.h" #include "framework/common/util.h" namespace ge { @@ -98,7 +98,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertToOpDesc(co GE_CHK_BOOL_RET_STATUS(op.GetSchema(), domi::PARAM_INVALID, "Op schema is null, op type: %s", op.GetType().c_str()); op_def->SetName(op.GetName()); op_def->SetType(op.GetType()); - GE_IF_BOOL_EXEC(op.GetType() == ge::parser::YOLO, op_def->SetType(ge::parser::REGION)); + GE_IF_BOOL_EXEC(op.GetType() == ge::YOLO, op_def->SetType(ge::REGION)); UpdateTensorForOpDesc(op, op_def); GELOGD("Convert to op desc: name:%s, input size: %zu, output size:%zu", op_def->GetName().c_str(), diff --git a/parser/common/op_def/no_op_op.h b/parser/common/op_def/no_op_op.h index 0208c90..75f84d0 100644 --- a/parser/common/op_def/no_op_op.h +++ b/parser/common/op_def/no_op_op.h @@ -18,7 +18,7 @@ #ifndef DOMI_OP_NO_OP_OP_H_ #define DOMI_OP_NO_OP_OP_H_ #include "parser/common/op_def/operator.h" -#include "framework/omg/parser/parser_types.h" +#include "framework/common/types.h" namespace ge { class NoOpOperator : public ParserOperator { diff --git a/parser/common/op_def/ref_switch_op.h b/parser/common/op_def/ref_switch_op.h index baf2167..ec8baca 100644 --- a/parser/common/op_def/ref_switch_op.h +++ b/parser/common/op_def/ref_switch_op.h @@ -18,7 +18,7 @@ #ifndef DOMI_OP_REF_SWITCH_H_ #define DOMI_OP_REF_SWITCH_H_ #include "parser/common/op_def/operator.h" -#include "framework/omg/parser/parser_types.h" +#include "framework/common/types.h" namespace ge { class RefSwitchOperator : public ParserOperator { diff --git a/parser/common/op_def/shape_n_op.cc b/parser/common/op_def/shape_n_op.cc index 0e6e14f..682ee00 100644 --- a/parser/common/op_def/shape_n_op.cc +++ b/parser/common/op_def/shape_n_op.cc @@ -17,7 +17,7 @@ // AUTO GEN PLEASE DO NOT MODIFY IT #include "common/op_def/shape_n_op.h" #include "graph/debug/ge_attr_define.h" -#include "framework/omg/parser/parser_types.h" +#include "framework/common/types.h" namespace ge { FMK_FUNC_HOST_VISIBILITY ShapeNOperator::ShapeNOperator() : ParserOperator("ShapeN") {} diff --git a/parser/common/op_def/shape_n_op.h b/parser/common/op_def/shape_n_op.h index bb69235..ce22e69 100644 --- a/parser/common/op_def/shape_n_op.h +++ b/parser/common/op_def/shape_n_op.h @@ -18,7 +18,7 @@ #ifndef DOMI_OP_SHAPE_N_OP_H_ #define DOMI_OP_SHAPE_N_OP_H_ #include "parser/common/op_def/operator.h" -#include "framework/omg/parser/parser_types.h" +#include "framework/common/types.h" namespace ge { class ShapeNOperator : public ParserOperator { diff --git a/parser/common/op_def/var_is_initialized_op_op.cc b/parser/common/op_def/var_is_initialized_op_op.cc index e0e3d62..d6fb259 100644 --- a/parser/common/op_def/var_is_initialized_op_op.cc +++ b/parser/common/op_def/var_is_initialized_op_op.cc @@ -20,7 +20,7 @@ #include namespace ge { -VarIsInitializedOpOperator::VarIsInitializedOpOperator() : ParserOperator(ge::parser::VARISINITIALIZEDOP) {} +VarIsInitializedOpOperator::VarIsInitializedOpOperator() : ParserOperator(ge::VARISINITIALIZEDOP) {} VarIsInitializedOpOperator::~VarIsInitializedOpOperator() {} diff --git a/parser/common/op_def/var_is_initialized_op_op.h b/parser/common/op_def/var_is_initialized_op_op.h index 88b649f..0749a20 100644 --- a/parser/common/op_def/var_is_initialized_op_op.h +++ b/parser/common/op_def/var_is_initialized_op_op.h @@ -18,7 +18,7 @@ #ifndef DOMI_OP_VARISINITIALIZEDOP_H_ #define DOMI_OP_VARISINITIALIZEDOP_H_ #include "parser/common/op_def/operator.h" -#include "framework/omg/parser/parser_types.h" +#include "framework/common/types.h" namespace ge { class VarIsInitializedOpOperator : public ParserOperator { diff --git a/parser/common/op_def/variable_op.cc b/parser/common/op_def/variable_op.cc index 2cf294e..66e1a13 100644 --- a/parser/common/op_def/variable_op.cc +++ b/parser/common/op_def/variable_op.cc @@ -19,7 +19,7 @@ #include "graph/debug/ge_attr_define.h" namespace ge { -VariableOperator::VariableOperator() : ParserOperator(ge::parser::VARIABLE) {} +VariableOperator::VariableOperator() : ParserOperator(ge::VARIABLE) {} VariableOperator::~VariableOperator() {} diff --git a/parser/common/op_def/variable_op.h b/parser/common/op_def/variable_op.h index c9b85d3..2f83e95 100644 --- a/parser/common/op_def/variable_op.h +++ b/parser/common/op_def/variable_op.h @@ -19,7 +19,7 @@ #define DOMI_OP_VARIABLE_H_ #include #include "parser/common/op_def/operator.h" -#include "framework/omg/parser/parser_types.h" +#include "framework/common/types.h" namespace ge { class VariableOperator : public ParserOperator { diff --git a/parser/common/op_map.cc b/parser/common/op_map.cc index 486b462..7bfc927 100644 --- a/parser/common/op_map.cc +++ b/parser/common/op_map.cc @@ -20,13 +20,12 @@ #include #include -#include "framework/omg/parser/parser_types.h" +#include "framework/common/types.h" #include "register/op_registry.h" using std::map; using std::string; using std::vector; -using namespace ge::parser; namespace ge { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map caffe_op_map = { @@ -98,7 +97,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY map tensorflow_ {"VarHandleOp", VARHANDLEOP}, {"VarIsInitializedOp", VARISINITIALIZEDOP}, {"IsVariableInitialized", ISVARIABLEINITIALIZED}, - {"ReadVariableOp", READVARIABLEOP}, {"Reshape", RESHAPE}, {"Squeeze", SQUEEZE}, {"NoOp", NOOP}, diff --git a/parser/common/op_parser_factory.h b/parser/common/op_parser_factory.h index 112345e..d124303 100644 --- a/parser/common/op_parser_factory.h +++ b/parser/common/op_parser_factory.h @@ -23,8 +23,8 @@ #include #include #include -#include "parser/common/acl_graph_parser_util.h" -#include "framework/omg/parser/parser_types.h" +#include "common/ge/ge_util.h" +#include "common/types.h" #include "framework/common/debug/ge_log.h" #include "omg/omg_inner_types.h" #include "external/register/register.h" @@ -162,7 +162,7 @@ class CustomParserAdapterRegistrar { */ #define REGISTER_OP_PARSER_CREATOR(framework, op_type, clazz) \ std::shared_ptr Creator_##framework##_##op_type##_Op_Parser() { \ - std::shared_ptr ptr = ge::parser::MakeShared(); \ + std::shared_ptr ptr = ge::MakeShared(); \ if (ptr == nullptr) { \ GELOGW("MakeShared failed, result is nullptr."); \ } \ @@ -173,7 +173,7 @@ class CustomParserAdapterRegistrar { #define REGISTER_FUSION_OP_PARSER_CREATOR(framework, op_type, clazz) \ std::shared_ptr Creator_##framework##_##op_type##_Fusion_Op_Parser() { \ - std::shared_ptr ptr = ge::parser::MakeShared(); \ + std::shared_ptr ptr = ge::MakeShared(); \ if (ptr == nullptr) { \ GELOGW("MakeShared failed, result is nullptr."); \ } \ @@ -187,7 +187,7 @@ class CustomParserAdapterRegistrar { /// @param [in] clazz CaffeCustomParserAdapter adaptation class #define REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(framework, clazz) \ std::shared_ptr Creator_##framework##_Op_Parser_Adapter() { \ - std::shared_ptr ptr = ge::parser::MakeShared(); \ + std::shared_ptr ptr = ge::MakeShared(); \ if (ptr == nullptr) { \ GELOGW("MakeShared failed, result is nullptr."); \ } \ diff --git a/parser/common/parser_api.cc b/parser/common/parser_api.cc index d582ed7..65f6061 100644 --- a/parser/common/parser_api.cc +++ b/parser/common/parser_api.cc @@ -17,7 +17,7 @@ #include "framework/omg/parser/parser_api.h" #include "common/debug/log.h" -#include "tbe_plugin_loader.h" +#include "common/ge/tbe_plugin_manager.h" #include "framework/common/debug/ge_log.h" #include "parser/common/register_tbe.h" #include "framework/omg/parser/parser_inner_ctx.h" @@ -36,7 +36,7 @@ Status ParserInitialize(const std::map &options) { } // load custom op plugin - TBEPluginLoader::Instance().LoadPluginSo(options); + TBEPluginManager::Instance().LoadPluginSo(options); std::vector registrationDatas = domi::OpRegistry::Instance()->registrationDatas; GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); @@ -67,7 +67,7 @@ Status ParserFinalize() { return SUCCESS; } - GE_CHK_STATUS(TBEPluginLoader::Instance().Finalize()); + GE_CHK_STATUS(TBEPluginManager::Instance().Finalize()); if (parser_initialized) { parser_initialized = false; } diff --git a/parser/common/parser_fp16_t.cc b/parser/common/parser_fp16_t.cc deleted file mode 100644 index 044eb5c..0000000 --- a/parser/common/parser_fp16_t.cc +++ /dev/null @@ -1,1270 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parser/common/parser_fp16_t.h" - -#include "external/register/register_types.h" - -namespace { -constexpr uint16_t kManBitLength = 11; -} -namespace ge { -namespace parser { -/// @ingroup fp16_t global filed -/// @brief round mode of last valid digital -enum TagFp16RoundMode g_round_mode = kRoundToNearest; - -void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m) { - // 1.Extract - s = static_cast(FP16_EXTRAC_SIGN(val)); - e = static_cast(FP16_EXTRAC_EXP(val)); - m = static_cast(FP16_EXTRAC_MAN(val)); - // Denormal - if (e == 0) { - e = 1; - } -} - -/// @ingroup fp16_t static method -/// @param [in] man truncated mantissa -/// @param [in] shift_out left shift bits based on ten bits -/// @brief judge whether to add one to the result while converting fp16_t to other datatype -/// @return Return true if add one, otherwise false -static bool IsRoundOne(uint64_t man, uint16_t trunc_len) { - uint64_t mask0 = 0x4; - uint64_t mask1 = 0x2; - uint64_t mask2; - uint16_t shift_out = static_cast(trunc_len - kDim2); - mask0 = mask0 << shift_out; - mask1 = mask1 << shift_out; - mask2 = mask1 - 1; - - bool last_bit = ((man & mask0) > 0); - bool trunc_high = false; - bool trunc_left = false; - if (g_round_mode == kRoundToNearest) { - trunc_high = ((man & mask1) > 0); - trunc_left = ((man & mask2) > 0); - } - return (trunc_high && (trunc_left || last_bit)); -} - -/// @ingroup fp16_t public method -/// @param [in] exp exponent of fp16_t value -/// @param [in] man exponent of fp16_t value -/// @brief normalize fp16_t value -/// @return -static void Fp16Normalize(int16_t &exp, uint16_t &man) { - // set to invalid data - if (exp >= kFp16MaxExp) { - exp = static_cast(kFp16MaxExp); - man = static_cast(kFp16MaxMan); - } else if (exp == 0 && man == kFp16ManHideBit) { - exp++; - man = 0; - } -} - -/// @ingroup fp16_t math conversion static method -/// @param [in] fp_val uint16_t value of fp16_t object -/// @brief Convert fp16_t to float/fp32 -/// @return Return float/fp32 value of fp_val which is the value of fp16_t object -static float Fp16ToFloat(const uint16_t &fp_val) { - uint16_t hf_sign; - uint16_t hf_man; - int16_t hf_exp; - ExtractFp16(fp_val, hf_sign, hf_exp, hf_man); - - while (hf_man && !(hf_man & kFp16ManHideBit)) { - hf_man <<= 1; - hf_exp--; - } - - uint32_t e_ret, m_ret; - uint32_t s_ret = hf_sign; - if (hf_man == 0) { - e_ret = 0; - m_ret = 0; - } else { - e_ret = hf_exp - kFp16ExpBias + kFp32ExpBias; - m_ret = hf_man & kFp16ManMask; - m_ret = m_ret << (kFp32ManLen - kFp16ManLen); - } - uint32_t f_val = FP32_CONSTRUCTOR(s_ret, e_ret, m_ret); - auto p_ret_v = reinterpret_cast(&f_val); - - return *p_ret_v; -} - -/// @ingroup fp16_t math conversion static method -/// @param [in] fp_val uint16_t value of fp16_t object -/// @brief Convert fp16_t to double/fp64 -/// @return Return double/fp64 value of fp_val which is the value of fp16_t object -static double Fp16ToDouble(const uint16_t &fp_val) { - uint16_t hf_sign; - uint16_t hf_man; - int16_t hf_exp; - ExtractFp16(fp_val, hf_sign, hf_exp, hf_man); - - while (hf_man && !(hf_man & kFp16ManHideBit)) { - hf_man <<= 1; - hf_exp--; - } - - uint64_t e_ret; - uint64_t m_ret; - uint64_t s_ret = hf_sign; - if (!hf_man) { - e_ret = 0; - m_ret = 0; - } else { - e_ret = hf_exp - kFp16ExpBias + kFp64ExpBias; - m_ret = hf_man & kFp16ManMask; - m_ret = m_ret << (kFp64ManLen - kFp16ManLen); - } - uint64_t f_val = (s_ret << kFp64SignIndex) | (e_ret << kFp64ManLen) | (m_ret); - auto p_ret_v = reinterpret_cast(&f_val); - - return *p_ret_v; -} - -/// @ingroup fp16_t static method -/// @param [in] s_ret sign of fp16_t value -/// @param [in] long_int_m man uint64_t value of fp16_t object -/// @param [in] shift_out shift offset -/// @brief calculate uint8 value by sign,man and shift offset -/// @return Return uint8 value of fp16_t object -static uint8_t GetUint8ValByMan(uint8_t s_ret, const uint64_t &long_int_m, const uint16_t &shift_out) { - bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); - auto m_ret = static_cast((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen8Max); - need_round = need_round && ((s_ret == 0 && m_ret < kInt8Max) || (s_ret == 1 && m_ret <= kInt8Max)); - if (need_round) { - m_ret++; - } - if (s_ret) { - m_ret = (~m_ret) + 1; - } - if (m_ret == 0) { - s_ret = 0; - } - return static_cast((s_ret << kBitShift7) | (m_ret)); -} - -/// @ingroup fp16_t math conversion static method -/// @param [in] fp_val uint16_t value of fp16_t object -/// @brief Convert fp16_t to int8_t -/// @return Return int8_t value of fp_val which is the value of fp16_t object -static int8_t Fp16ToInt8(const uint16_t &fp_val) { - int8_t ret; - uint8_t ret_v; - // 1.get s_ret and shift it to bit0. - uint8_t s_ret = FP16_EXTRAC_SIGN(fp_val); - // 2.get hf_e and hf_m - uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); - uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); - - if (FP16_IS_DENORM(fp_val)) { // Denormalized number - ret_v = 0; - ret = *(reinterpret_cast(&ret_v)); - return ret; - } - - uint64_t long_int_m = hf_m; - uint8_t overflow_flag = 0; - uint16_t shift_out = 0; - if (FP16_IS_INVALID(fp_val)) { // Inf or NaN - overflow_flag = 1; - } else { - while (hf_e != kFp16ExpBias) { - if (hf_e > kFp16ExpBias) { - hf_e--; - long_int_m = long_int_m << 1; - if (s_ret == 1 && long_int_m >= 0x20000u) { // sign=1,negative number(<0) - long_int_m = 0x20000u; // 10 0000 0000 0000 0000 10(fp16_t-man)+7(int8)=17bit - overflow_flag = 1; - break; - } else if (s_ret != 1 && long_int_m >= 0x1FFFFu) { // sign=0,positive number(>0) - long_int_m = 0x1FFFFu; // 01 1111 1111 1111 1111 10(fp16_t-man)+7(int8) - overflow_flag = 1; - break; - } - } else { - hf_e++; - shift_out++; - } - } - } - if (overflow_flag) { - ret_v = kInt8Max + s_ret; - } else { - // Generate final result - ret_v = GetUint8ValByMan(s_ret, long_int_m, shift_out); - } - - ret = *(reinterpret_cast(&ret_v)); - return ret; -} - -/// @ingroup fp16_t math conversion static method -/// @param [in] fp_val uint16_t value of fp16_t object -/// @brief Convert fp16_t to uint8_t -/// @return Return uint8_t value of fp_val which is the value of fp16_t object -static uint8_t Fp16ToUInt8(const uint16_t &fp_val) { - uint8_t m_ret = 0; - // 1.get s_ret and shift it to bit0. - uint8_t s_ret = FP16_EXTRAC_SIGN(fp_val); - // 2.get hf_e and hf_m - uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); - uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); - - if (FP16_IS_DENORM(fp_val)) { // Denormalized number - return 0; - } - - if (FP16_IS_INVALID(fp_val)) { // Inf or NaN - m_ret = ~0; - } else { - uint64_t long_int_m = hf_m; - uint8_t overflow_flag = 0; - uint16_t shift_out = 0; - while (hf_e != kFp16ExpBias) { - if (hf_e > kFp16ExpBias) { - hf_e--; - long_int_m = long_int_m << 1; - if (long_int_m >= 0x40000Lu) { // overflow 0100 0000 0000 0000 0000 - long_int_m = 0x3FFFFLu; // 11 1111 1111 1111 1111 10(fp16_t-man)+8(uint8)=18bit - overflow_flag = 1; - m_ret = ~0; - break; - } - } else { - hf_e++; - shift_out++; - } - } - if (!overflow_flag) { - bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); - m_ret = static_cast((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen8Max); - if (need_round && m_ret != kBitLen8Max) { - m_ret++; - } - } - } - - if (s_ret == 1) { // Negative number - m_ret = 0; - } - // m_ret equal to final result - return m_ret; -} - -/// @ingroup fp16_t static method -/// @param [in] s_ret sign of fp16_t value -/// @param [in] long_int_m man uint64_t value of fp16_t object -/// @param [in] shift_out shift offset -/// @brief calculate uint16 value by sign,man and shift offset -/// @return Return uint16 value of fp16_t object -static uint16_t GetUint16ValByMan(uint16_t s_ret, const uint64_t &long_int_m, const uint16_t &shift_out) { - bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); - auto m_ret = static_cast((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen16Max); - if (need_round && m_ret < kInt16Max) { - m_ret++; - } - if (s_ret) { - m_ret = (~m_ret) + 1; - } - if (m_ret == 0) { - s_ret = 0; - } - return static_cast((s_ret << kBitShift15) | (m_ret)); -} - -/// @ingroup fp16_t math conversion static method -/// @param [in] fp_val uint16_t value of fp16_t object -/// @brief Convert fp16_t to int16_t -/// @return Return int16_t value of fp_val which is the value of fp16_t object -static int16_t Fp16ToInt16(const uint16_t &fp_val) { - int16_t ret; - uint16_t ret_v; - // 1.get s_ret and shift it to bit0. - uint16_t s_ret = FP16_EXTRAC_SIGN(fp_val); - // 2.get hf_e and hf_m - uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); - uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); - - if (FP16_IS_DENORM(fp_val)) { // Denormalized number - ret_v = 0; - ret = *(reinterpret_cast(&ret_v)); - return ret; - } - - uint64_t long_int_m = hf_m; - uint8_t overflow_flag = 0; - uint16_t shift_out = 0; - if (FP16_IS_INVALID(fp_val)) { // Inf or NaN - overflow_flag = 1; - } else { - while (hf_e != kFp16ExpBias) { - if (hf_e > kFp16ExpBias) { - hf_e--; - long_int_m = long_int_m << 1; - if (s_ret == 1 && long_int_m > 0x2000000Lu) { // sign=1,negative number(<0) - long_int_m = 0x2000000Lu; // 10(fp16_t-man)+15(int16)=25bit - overflow_flag = 1; - break; - } else if (s_ret != 1 && long_int_m >= 0x1FFFFFFLu) { // sign=0,positive number(>0) Overflow - long_int_m = 0x1FFFFFFLu; // 10(fp16_t-man)+15(int16)=25bit - overflow_flag = 1; - break; - } - } else { - hf_e++; - shift_out++; - } - } - } - if (overflow_flag) { - ret_v = kInt16Max + s_ret; - } else { - // Generate final result - ret_v = GetUint16ValByMan(s_ret, long_int_m, shift_out); - } - ret = *(reinterpret_cast(&ret_v)); - return ret; -} - -/// @ingroup fp16_t math conversion static method -/// @param [in] fp_val uint16_t value of fp16_t object -/// @brief Convert fp16_t to uint16_t -/// @return Return uint16_t value of fp_val which is the value of fp16_t object -static uint16_t Fp16ToUInt16(const uint16_t &fp_val) { - uint16_t m_ret = 0; - // 1.get s_ret and shift it to bit0. - uint16_t s_ret = FP16_EXTRAC_SIGN(fp_val); - // 2.get hf_e and hf_m - uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); - uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); - - if (FP16_IS_DENORM(fp_val)) { // Denormalized number - return 0; - } - - if (FP16_IS_INVALID(fp_val)) { // Inf or NaN - m_ret = ~0; - } else { - uint64_t long_int_m = hf_m; - uint16_t shift_out = 0; - while (hf_e != kFp16ExpBias) { - if (hf_e > kFp16ExpBias) { - hf_e--; - long_int_m = long_int_m << 1; - } else { - hf_e++; - shift_out++; - } - } - bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); - m_ret = static_cast((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen16Max); - if (need_round && m_ret != kBitLen16Max) { - m_ret++; - } - } - - if (s_ret == 1) { // Negative number - m_ret = 0; - } - // m_ret equal to final result - return m_ret; -} - -/// @ingroup fp16_t math convertion static method -/// @param [in] fp_val uint16_t value of fp16_t object -/// @brief Convert fp16_t to int32_t -/// @return Return int32_t value of fp_val which is the value of fp16_t object -static int32_t Fp16ToInt32(const uint16_t &fp_val) { - uint32_t ret_v; - // 1.get s_ret and shift it to bit0. - uint32_t s_ret = FP16_EXTRAC_SIGN(fp_val); - // 2.get hf_e and hf_m - uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); - uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); - - if (FP16_IS_INVALID(fp_val)) { // Inf or NaN - ret_v = kInt32Max + s_ret; - } else { - uint64_t long_int_m = hf_m; - uint16_t shift_out = 0; - - while (hf_e != kFp16ExpBias) { - if (hf_e > kFp16ExpBias) { - hf_e--; - long_int_m = long_int_m << 1; - } else { - hf_e++; - shift_out++; - } - } - bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); - auto m_ret = static_cast((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen32Max); - if (need_round && m_ret < kInt32Max) { - m_ret++; - } - - if (s_ret == 1) { - m_ret = (~m_ret) + 1; - } - if (m_ret == 0) { - s_ret = 0; - } - // Generate final result - ret_v = (s_ret << kBitShift31) | (m_ret); - } - - return *(reinterpret_cast(&ret_v)); -} - -/// @ingroup fp16_t math conversion static method -/// @param [in] fp_val uint16_t value of fp16_t object -/// @brief Convert fp16_t to uint32_t -/// @return Return uint32_t value of fp_val which is the value of fp16_t object -static uint32_t Fp16ToUInt32(const uint16_t &fp_val) { - uint32_t m_ret; - // 1.get s_ret and shift it to bit0. - uint32_t s_ret = FP16_EXTRAC_SIGN(fp_val); - // 2.get hf_e and hf_m - uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); - uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); - - if (FP16_IS_DENORM(fp_val)) { // Denormalized number - return 0u; - } - - if (FP16_IS_INVALID(fp_val)) { // Inf or NaN - m_ret = ~0u; - } else { - uint64_t long_int_m = hf_m; - uint16_t shift_out = 0; - while (hf_e != kFp16ExpBias) { - if (hf_e > kFp16ExpBias) { - hf_e--; - long_int_m = long_int_m << 1; - } else { - hf_e++; - shift_out++; - } - } - bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); - m_ret = static_cast(long_int_m >> (kFp16ManLen + shift_out)) & kBitLen32Max; - if (need_round && m_ret != kBitLen32Max) { - m_ret++; - } - } - - if (s_ret == 1) { // Negative number - m_ret = 0; - } - // m_ret equal to final result - return m_ret; -} - -static uint16_t Fp16AddCalVal(uint16_t &s_ret, int16_t e_ret, uint16_t m_ret, uint32_t m_trunc, uint16_t shift_out) { - uint16_t m_min = kFp16ManHideBit << shift_out; - uint16_t m_max = m_min << 1; - // Denormal - while (m_ret < m_min && e_ret > 0) { // the value of m_ret should not be smaller than 2^23 - m_ret = m_ret << 1; - m_ret += (kFp32SignMask & m_trunc) >> kFp32SignIndex; - m_trunc = m_trunc << 1; - e_ret = e_ret - 1; - } - while (m_ret >= m_max) { // the value of m_ret should be smaller than 2^24 - m_trunc = m_trunc >> 1; - m_trunc = m_trunc | (kFp32SignMask * (m_ret & 1)); - m_ret = m_ret >> 1; - e_ret = e_ret + 1; - } - - bool b_last_bit = ((m_ret & 1) > 0); - bool b_trunc_high = 0; - bool b_trunc_left = 0; - b_trunc_high = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32SignMask) > 0); - b_trunc_left = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32AbsMax) > 0); - m_ret = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_ret, shift_out); - while (m_ret >= m_max) { - m_ret = m_ret >> 1; - e_ret = e_ret + 1; - } - - if (e_ret == 0 && m_ret <= m_max) { - m_ret = m_ret >> 1; - } - Fp16Normalize(e_ret, m_ret); - uint16_t ret = FP16_CONSTRUCTOR(s_ret, static_cast(e_ret), m_ret); - return ret; -} - -/// @ingroup fp16_t math operator -/// @param [in] v_1 left operator value of fp16_t object -/// @param [in] v_2 right operator value of fp16_t object -/// @brief Performing fp16_t addition -/// @return Return fp16_t result of adding this and fp -static uint16_t Fp16Add(uint16_t v_1, uint16_t v_2) { - uint16_t s_a; - uint16_t s_b; - int16_t e_a; - int16_t e_b; - uint32_t m_a; - uint32_t m_b; - uint16_t m_a_tmp; - uint16_t m_b_tmp; - uint16_t shift_out = 0; - // 1.Extract - ExtractFp16(v_1, s_a, e_a, m_a_tmp); - ExtractFp16(v_2, s_b, e_b, m_b_tmp); - m_a = m_a_tmp; - m_b = m_b_tmp; - - uint16_t sum; - uint16_t s_ret; - if (s_a != s_b) { - ReverseMan(s_a > 0, m_a); - ReverseMan(s_b > 0, m_b); - sum = static_cast(GetManSum(e_a, m_a, e_b, m_b)); - s_ret = (sum & kFp16SignMask) >> kFp16SignIndex; - ReverseMan(s_ret > 0, m_a); - ReverseMan(s_ret > 0, m_b); - } else { - sum = static_cast(GetManSum(e_a, m_a, e_b, m_b)); - s_ret = s_a; - } - - if (sum == 0) { - shift_out = 3; // shift to left 3 bits - m_a = m_a << shift_out; - m_b = m_b << shift_out; - } - - uint32_t m_trunc = 0; - int16_t e_ret = std::max(e_a, e_b); - int16_t e_tmp = std::abs(e_a - e_b); - if (e_a > e_b) { - m_trunc = (m_b << (kBitShift32 - static_cast(e_tmp))); - m_b = RightShift(m_b, e_tmp); - } else if (e_a < e_b) { - m_trunc = (m_a << (kBitShift32 - static_cast(e_tmp))); - m_a = RightShift(m_a, e_tmp); - } - // calculate mantissav - auto m_ret = static_cast(m_a + m_b); - return Fp16AddCalVal(s_ret, e_ret, m_ret, m_trunc, shift_out); -} - -/// @ingroup fp16_t math operator -/// @param [in] v_1 left operator value of fp16_t object -/// @param [in] v_2 right operator value of fp16_t object -/// @brief Performing fp16_t subtraction -/// @return Return fp16_t result of subtraction fp from this -static uint16_t Fp16Sub(uint16_t v_1, uint16_t v_2) { - // Reverse - uint16_t tmp = ((~(v_2)) & kFp16SignMask) | (v_2 & kFp16AbsMax); - return Fp16Add(v_1, tmp); -} - -/// @ingroup fp16_t math operator -/// @param [in] v_1 left operator value of fp16_t object -/// @param [in] v_2 right operator value of fp16_t object -/// @brief Performing fp16_t multiplication -/// @return Return fp16_t result of multiplying this and fp -static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) { - uint16_t s_a, s_b; - int16_t e_a, e_b; - uint32_t m_a, m_b; - uint16_t s_ret, m_ret; - int16_t e_ret; - uint32_t mul_m; - uint16_t m_a_tmp, m_b_tmp; - // 1.Extract - ExtractFp16(v_1, s_a, e_a, m_a_tmp); - ExtractFp16(v_2, s_b, e_b, m_b_tmp); - m_a = m_a_tmp; - m_b = m_b_tmp; - - e_ret = e_a + e_b - kFp16ExpBias - kDim10; - mul_m = m_a * m_b; - s_ret = s_a ^ s_b; - - uint32_t m_min = kFp16ManHideBit; - uint32_t m_max = m_min << 1; - uint32_t m_trunc = 0; - // the value of m_ret should not be smaller than 2^23 - while (mul_m < m_min && e_ret > 1) { - mul_m = mul_m << 1; - e_ret = e_ret - 1; - } - while (mul_m >= m_max || e_ret < 1) { - m_trunc = m_trunc >> 1; - m_trunc = m_trunc | (kFp32SignMask * (mul_m & 1)); - mul_m = mul_m >> 1; - e_ret = e_ret + 1; - } - bool b_last_bit = ((mul_m & 1) > 0); - bool b_trunc_high = 0; - bool b_trunc_left = 0; - b_trunc_high = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32SignMask) > 0); - b_trunc_left = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32AbsMax) > 0); - mul_m = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, mul_m); - - while (mul_m >= m_max || e_ret < 0) { - mul_m = mul_m >> 1; - e_ret = e_ret + 1; - } - - if (e_ret == 1 && mul_m < kFp16ManHideBit) { - e_ret = 0; - } - m_ret = static_cast(mul_m); - - Fp16Normalize(e_ret, m_ret); - - uint16_t ret = FP16_CONSTRUCTOR(s_ret, static_cast(e_ret), m_ret); - return ret; -} - -/// @ingroup fp16_t math operator divided -/// @param [in] v_1 left operator value of fp16_t object -/// @param [in] v_2 right operator value of fp16_t object -/// @brief Performing fp16_t division -/// @return Return fp16_t result of division this by fp -static uint16_t Fp16Div(uint16_t v_1, uint16_t v_2) { - uint16_t ret; - if (FP16_IS_ZERO(v_2)) { // result is inf - // throw "fp16_t division by zero."; - uint16_t s_a, s_b; - uint16_t s_ret; - s_a = FP16_EXTRAC_SIGN(v_1); - s_b = FP16_EXTRAC_SIGN(v_2); - s_ret = s_a ^ s_b; - ret = FP16_CONSTRUCTOR(s_ret, kFp16MaxExp, 0u); - } else if (FP16_IS_ZERO(v_1)) { - ret = 0u; - } else { - uint16_t s_a, s_b; - int16_t e_a, e_b; - uint64_t m_a, m_b; - float m_div; - uint16_t m_a_tmp, m_b_tmp; - // 1.Extract - ExtractFp16(v_1, s_a, e_a, m_a_tmp); - ExtractFp16(v_2, s_b, e_b, m_b_tmp); - m_a = m_a_tmp; - m_b = m_b_tmp; - - uint64_t m_tmp; - if (e_a > e_b) { - m_tmp = m_a; - uint16_t tmp; - tmp = e_a - e_b; - for (int i = 0; i < tmp; i++) { - m_tmp = m_tmp << 1; - } - m_a = m_tmp; - } else if (e_a < e_b) { - m_tmp = m_b; - uint16_t tmp = e_b - e_a; - for (int i = 0; i < tmp; i++) { - m_tmp = m_tmp << 1; - } - m_b = m_tmp; - } - m_div = static_cast(m_a * 1.0f / m_b); - fp16_t fp_div; - fp_div = m_div; - ret = fp_div.val; - if (s_a != s_b) { - ret |= kFp16SignMask; - } - } - return ret; -} - -// operate -fp16_t fp16_t::operator+(const fp16_t fp) { - uint16_t ret_val = Fp16Add(val, fp.val); - fp16_t ret(ret_val); - return ret; -} - -fp16_t fp16_t::operator-(const fp16_t fp) { - uint16_t ret_val = Fp16Sub(val, fp.val); - fp16_t ret(ret_val); - return ret; -} - -fp16_t fp16_t::operator*(const fp16_t fp) { - uint16_t ret_val = Fp16Mul(val, fp.val); - fp16_t ret(ret_val); - return ret; -} - -fp16_t fp16_t::operator/(const fp16_t fp) { - uint16_t ret_val = Fp16Div(val, fp.val); - fp16_t ret(ret_val); - return ret; -} - -fp16_t fp16_t::operator+=(const fp16_t fp) { - val = Fp16Add(val, fp.val); - return *this; -} - -fp16_t fp16_t::operator-=(const fp16_t fp) { - val = Fp16Sub(val, fp.val); - return *this; -} - -fp16_t fp16_t::operator*=(const fp16_t fp) { - val = Fp16Mul(val, fp.val); - return *this; -} - -fp16_t fp16_t::operator/=(const fp16_t fp) { - val = Fp16Div(val, fp.val); - return *this; -} - -// compare -bool fp16_t::operator==(const fp16_t &fp) const { - bool result = true; - if (FP16_IS_ZERO(val) && FP16_IS_ZERO(fp.val)) { - result = true; - } else { - result = ((val & kBitLen16Max) == (fp.val & kBitLen16Max)); // bit compare - } - return result; -} - -bool fp16_t::operator!=(const fp16_t &fp) const { - bool result = true; - if (FP16_IS_ZERO(val) && FP16_IS_ZERO(fp.val)) { - result = false; - } else { - result = ((val & kBitLen16Max) != (fp.val & kBitLen16Max)); // bit compare - } - return result; -} - -bool fp16_t::operator>(const fp16_t &fp) const { - uint16_t s_a, s_b; - uint16_t e_a, e_b; - uint16_t m_a, m_b; - bool result = true; - - // 1.Extract - s_a = FP16_EXTRAC_SIGN(val); - s_b = FP16_EXTRAC_SIGN(fp.val); - e_a = FP16_EXTRAC_EXP(val); - e_b = FP16_EXTRAC_EXP(fp.val); - m_a = FP16_EXTRAC_MAN(val); - m_b = FP16_EXTRAC_MAN(fp.val); - - // Compare - if ((s_a == 0) && (s_b > 0)) { // + - - // -0=0 - result = !(FP16_IS_ZERO(val) && FP16_IS_ZERO(fp.val)); - } else if ((s_a == 0) && (s_b == 0)) { // + + - if (e_a > e_b) { // e_a - e_b >= 1; Va always larger than Vb - result = true; - } else if (e_a == e_b) { - result = m_a > m_b; - } else { - result = false; - } - } else if ((s_a > 0) && (s_b > 0)) { // - - opposite to + + - if (e_a < e_b) { - result = true; - } else if (e_a == e_b) { - result = m_a < m_b; - } else { - result = false; - } - } else { // - + - result = false; - } - - return result; -} - -bool fp16_t::operator>=(const fp16_t &fp) const { - bool result = true; - if ((*this) > fp) { - result = true; - } else if ((*this) == fp) { - result = true; - } else { - result = false; - } - - return result; -} - -bool fp16_t::operator<(const fp16_t &fp) const { - bool result = true; - if ((*this) >= fp) { - result = false; - } else { - result = true; - } - - return result; -} - -bool fp16_t::operator<=(const fp16_t &fp) const { - bool result = true; - if ((*this) > fp) { - result = false; - } else { - result = true; - } - - return result; -} - -// evaluation -fp16_t &fp16_t::operator=(const fp16_t &fp) { - if (&fp == this) { - return *this; - } - val = fp.val; - return *this; -} - -fp16_t &fp16_t::operator=(const float &f_val) { - uint16_t s_ret, m_ret; - int16_t e_ret; - uint32_t e_f, m_f; - const uint32_t ui32_v = *(reinterpret_cast(&f_val)); // 1:8:23bit sign:exp:man - uint32_t m_len_delta; - - s_ret = static_cast((ui32_v & kFp32SignMask) >> kFp32SignIndex); // 4Byte->2Byte - e_f = (ui32_v & kFp32ExpMask) >> kFp32ManLen; // 8 bit exponent - m_f = (ui32_v & kFp32ManMask); // 23 bit mantissa dont't need to care about denormal - m_len_delta = kFp32ManLen - kFp16ManLen; - - bool need_round = false; - // Exponent overflow/NaN converts to signed inf/NaN - if (e_f > 0x8Fu) { // 0x8Fu:142=127+15 - e_ret = kFp16MaxExp - 1; - m_ret = kFp16MaxMan; - } else if (e_f <= 0x70u) { // 0x70u:112=127-15 Exponent underflow converts to denormalized half or signed zero - e_ret = 0; - if (e_f >= 0x67) { // 0x67:103=127-24 Denormal - m_f = (m_f | kFp32ManHideBit); - uint16_t shift_out = kFp32ManLen; - uint64_t m_tmp = (static_cast(m_f)) << (e_f - 0x67); - - need_round = IsRoundOne(m_tmp, shift_out); - m_ret = static_cast(m_tmp >> shift_out); - if (need_round) { - m_ret++; - } - } else if (e_f == 0x66 && m_f > 0) { // 0x66:102 Denormal 0(e_f - 0x70u); - - need_round = IsRoundOne(m_f, static_cast(m_len_delta)); - m_ret = static_cast(m_f >> m_len_delta); - if (need_round) { - m_ret++; - } - if (m_ret & kFp16ManHideBit) { - e_ret++; - } - } - - Fp16Normalize(e_ret, m_ret); - val = FP16_CONSTRUCTOR(s_ret, static_cast(e_ret), m_ret); - return *this; -} - -fp16_t &fp16_t::operator=(const int8_t &i_val) { - uint16_t s_ret, e_ret, m_ret; - - s_ret = static_cast(((static_cast(i_val)) & 0x80) >> kDim7); - m_ret = static_cast(((static_cast(i_val)) & kInt8Max)); - - if (m_ret == 0) { - e_ret = 0; - } else { - if (s_ret) { // negative number(<0) - m_ret = static_cast(std::abs(i_val)); // complement - } - - e_ret = kFp16ManLen; - while ((m_ret & kFp16ManHideBit) == 0) { - m_ret = m_ret << 1; - e_ret = e_ret - 1; - } - e_ret = e_ret + kFp16ExpBias; - } - - val = FP16_CONSTRUCTOR(s_ret, e_ret, m_ret); - return *this; -} - -fp16_t &fp16_t::operator=(const uint8_t &ui_val) { - uint16_t s_ret, e_ret, m_ret; - s_ret = 0; - e_ret = 0; - m_ret = ui_val; - if (m_ret) { - e_ret = kFp16ManLen; - while ((m_ret & kFp16ManHideBit) == 0) { - m_ret = m_ret << 1; - e_ret = e_ret - 1; - } - e_ret = e_ret + kFp16ExpBias; - } - - val = FP16_CONSTRUCTOR(s_ret, e_ret, m_ret); - return *this; -} - -static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, uint16_t &ret_val) { - uint32_t m_tmp = (input_val & kFp32AbsMax); - uint16_t m_min = kFp16ManHideBit; - uint16_t m_max = m_min << 1; - uint16_t len = static_cast(GetManBitLength(m_tmp)); - if (m_tmp) { - int16_t e_ret; - if (len > kDim11) { - e_ret = kFp16ExpBias + kFp16ManLen; - uint16_t e_tmp = len - kDim11; - uint32_t trunc_mask = 1; - for (int i = 1; i < e_tmp; i++) { - trunc_mask = (trunc_mask << 1) + 1; - } - uint32_t m_trunc = (m_tmp & trunc_mask) << (kBitShift32 - e_tmp); - for (int i = 0; i < e_tmp; i++) { - m_tmp = (m_tmp >> 1); - e_ret = e_ret + 1; - } - bool b_last_bit = ((m_tmp & 1) > 0); - bool b_trunc_high = 0; - bool b_trunc_left = 0; - if (kRoundToNearest == g_round_mode) { // trunc - b_trunc_high = ((m_trunc & kFp32SignMask) > 0); - b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); - } - m_tmp = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_tmp); - while (m_tmp >= m_max || e_ret < 0) { - m_tmp = m_tmp >> 1; - e_ret = e_ret + 1; - } - } else { - e_ret = kFp16ExpBias; - m_tmp = m_tmp << (kManBitLength - len); - e_ret = e_ret + (len - 1); - } - auto m_ret = static_cast(m_tmp); - ret_val = FP16_CONSTRUCTOR(sign, static_cast(e_ret), m_ret); - } -} - -fp16_t &fp16_t::operator=(const int16_t &i_val) { - if (i_val == 0) { - val = 0; - } else { - uint16_t ui_val = *(reinterpret_cast(&i_val)); - auto s_ret = static_cast(ui_val >> kBitShift15); - if (s_ret) { - int16_t iValM = -i_val; - ui_val = *(reinterpret_cast(&iValM)); - } - SetValByUint16Val(ui_val, s_ret, val); - } - return *this; -} - -fp16_t &fp16_t::operator=(const uint16_t &ui_val) { - if (ui_val == 0) { - val = 0; - } else { - int16_t e_ret; - uint16_t m_ret = ui_val; - uint16_t m_min = kFp16ManHideBit; - uint16_t m_max = m_min << 1; - uint16_t len = static_cast(GetManBitLength(m_ret)); - if (len > kManBitLength) { - e_ret = kFp16ExpBias + kFp16ManLen; - uint32_t m_trunc; - uint32_t trunc_mask = 1; - uint16_t e_tmp = len - kManBitLength; - for (int i = 1; i < e_tmp; i++) { - trunc_mask = (trunc_mask << 1) + 1; - } - m_trunc = (m_ret & trunc_mask) << (kBitShift32 - e_tmp); - for (int i = 0; i < e_tmp; i++) { - m_ret = (m_ret >> 1); - e_ret = e_ret + 1; - } - bool b_last_bit = ((m_ret & 1) > 0); - bool b_trunc_high = 0; - bool b_trunc_left = 0; - if (kRoundToNearest == g_round_mode) { // trunc - b_trunc_high = ((m_trunc & kFp32SignMask) > 0); - b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); - } - m_ret = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_ret); - while (m_ret >= m_max || e_ret < 0) { - m_ret = m_ret >> 1; - e_ret = e_ret + 1; - } - if (FP16_IS_INVALID(val)) { - val = kFp16Max; - } - } else { - e_ret = kFp16ExpBias; - m_ret = m_ret << (kDim11 - len); - e_ret = e_ret + (len - 1); - } - val = FP16_CONSTRUCTOR(0u, static_cast(e_ret), m_ret); - } - return *this; -} - -static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, uint16_t &ret_val) { - int16_t e_ret; - uint32_t m_tmp = (input_val & kFp32AbsMax); - uint32_t m_min = kFp16ManHideBit; - uint32_t m_max = m_min << 1; - uint16_t len = static_cast(GetManBitLength(m_tmp)); - if (len > kDim11) { - e_ret = kFp16ExpBias + kFp16ManLen; - uint32_t m_trunc = 0; - uint32_t trunc_mask = 1; - uint16_t e_tmp = len - kDim11; - for (int i = 1; i < e_tmp; i++) { - trunc_mask = (trunc_mask << 1) + 1; - } - m_trunc = (m_tmp & trunc_mask) << (kBitShift32 - e_tmp); - for (int i = 0; i < e_tmp; i++) { - m_tmp = (m_tmp >> 1); - e_ret = e_ret + 1; - } - bool b_last_bit = ((m_tmp & 1) > 0); - bool b_trunc_high = 0; - bool b_trunc_left = 0; - if (kRoundToNearest == g_round_mode) { // trunc - b_trunc_high = ((m_trunc & kFp32SignMask) > 0); - b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); - } - m_tmp = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_tmp); - while (m_tmp >= m_max || e_ret < 0) { - m_tmp = m_tmp >> 1; - e_ret = e_ret + 1; - } - if (e_ret >= kFp16MaxExp) { - e_ret = kFp16MaxExp - 1; - m_tmp = kFp16MaxMan; - } - } else { - e_ret = kFp16ExpBias; - m_tmp = m_tmp << (kDim11 - len); - e_ret = e_ret + (len - 1); - } - auto m_ret = static_cast(m_tmp); - ret_val = FP16_CONSTRUCTOR(sign, static_cast(e_ret), m_ret); -} - -fp16_t &fp16_t::operator=(const int32_t &i_val) { - if (i_val == 0) { - val = 0; - } else { - uint32_t ui_val = *(reinterpret_cast(&i_val)); - auto s_ret = static_cast(ui_val >> kBitShift31); - if (s_ret) { - int32_t iValM = -i_val; - ui_val = *(reinterpret_cast(&iValM)); - } - SetValByUint32Val(ui_val, s_ret, val); - } - return *this; -} - -fp16_t &fp16_t::operator=(const uint32_t &ui_val) { - if (ui_val == 0) { - val = 0; - } else { - int16_t e_ret; - uint32_t m_tmp = ui_val; - uint32_t m_min = kFp16ManHideBit; - uint32_t m_max = m_min << 1; - uint16_t len = static_cast(GetManBitLength(m_tmp)); - if (len > kDim11) { - e_ret = kFp16ExpBias + kFp16ManLen; - uint32_t m_trunc = 0; - uint32_t trunc_mask = 1; - uint16_t e_tmp = len - kDim11; - for (int i = 1; i < e_tmp; i++) { - trunc_mask = (trunc_mask << 1) + 1; - } - m_trunc = (m_tmp & trunc_mask) << static_cast(kBitShift32 - e_tmp); - for (uint16_t i = 0; i < e_tmp; i++) { - m_tmp = (m_tmp >> 1); - e_ret = e_ret + 1; - } - bool b_last_bit = ((m_tmp & 1) > 0); - bool b_trunc_high = false; - bool b_trunc_left = false; - if (g_round_mode == kRoundToNearest) { // trunc - b_trunc_high = ((m_trunc & kFp32SignMask) > 0); - b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); - } - m_tmp = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_tmp); - while (m_tmp >= m_max || e_ret < 0) { - m_tmp = m_tmp >> 1; - e_ret = e_ret + 1; - } - if (e_ret >= kFp16MaxExp) { - e_ret = kFp16MaxExp - 1; - m_tmp = kFp16MaxMan; - } - } else { - e_ret = kFp16ExpBias; - m_tmp = m_tmp << (kDim11 - len); - e_ret = e_ret + (len - 1); - } - auto m_ret = static_cast(m_tmp); - val = FP16_CONSTRUCTOR(0u, static_cast(e_ret), m_ret); - } - return *this; -} - -fp16_t &fp16_t::operator=(const double &d_val) { - uint16_t s_ret; - uint16_t m_ret; - int16_t e_ret; - uint64_t e_d; - uint64_t m_d; - uint64_t ui64_v = *(reinterpret_cast(&d_val)); // 1:11:52bit sign:exp:man - uint32_t m_len_delta; - - s_ret = static_cast((ui64_v & kFp64SignMask) >> kFp64SignIndex); // 4Byte - e_d = (ui64_v & kFp64ExpMask) >> kFp64ManLen; // 10 bit exponent - m_d = (ui64_v & kFp64ManMask); // 52 bit mantissa - m_len_delta = kFp64ManLen - kFp16ManLen; - - bool need_round = false; - // Exponent overflow/NaN converts to signed inf/NaN - if (e_d >= 0x410u) { // 0x410:1040=1023+16 - e_ret = kFp16MaxExp - 1; - m_ret = kFp16MaxMan; - val = FP16_CONSTRUCTOR(s_ret, static_cast(e_ret), m_ret); - } else if (e_d <= 0x3F0u) { // Exponent underflow converts to denormalized half or signed zero - // 0x3F0:1008=1023-15 - // Signed zeros, denormalized floats, and floats with small - // exponents all convert to signed zero half precision. - e_ret = 0; - if (e_d >= 0x3E7u) { // 0x3E7u:999=1023-24 Denormal - // Underflows to a denormalized value - m_d = (kFp64ManHideBit | m_d); - uint16_t shift_out = kFp64ManLen; - uint64_t m_tmp = (static_cast(m_d)) << (e_d - 0x3E7u); - - need_round = IsRoundOne(m_tmp, shift_out); - m_ret = static_cast(m_tmp >> shift_out); - if (need_round) { - m_ret++; - } - } else if (e_d == 0x3E6u && m_d > 0) { - m_ret = 1; - } else { - m_ret = 0; - } - } else { // Regular case with no overflow or underflow - e_ret = static_cast(e_d - 0x3F0u); - - need_round = IsRoundOne(m_d, m_len_delta); - m_ret = static_cast(m_d >> m_len_delta); - if (need_round) { - m_ret++; - } - if (m_ret & kFp16ManHideBit) { - e_ret++; - } - } - - Fp16Normalize(e_ret, m_ret); - val = FP16_CONSTRUCTOR(s_ret, static_cast(e_ret), m_ret); - return *this; -} - -// convert -fp16_t::operator float() const { return Fp16ToFloat(val); } - -fp16_t::operator double() const { return Fp16ToDouble(val); } - -fp16_t::operator int8_t() const { return Fp16ToInt8(val); } - -fp16_t::operator uint8_t() const { return Fp16ToUInt8(val); } - -fp16_t::operator int16_t() const { return Fp16ToInt16(val); } - -fp16_t::operator uint16_t() const { return Fp16ToUInt16(val); } - -fp16_t::operator int32_t() const { return Fp16ToInt32(val); } - -fp16_t::operator uint32_t() const { return Fp16ToUInt32(val); } - -// Cannot be used, just in order to solve the compile error -fp16_t::operator int64_t() const { return 0; } - -// Cannot be used, just in order to solve the compile error -fp16_t::operator uint64_t() const { return 0; } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() { - if ((val & kFp16AbsMax) == kFp16ExpMask) { - if (val & kFp16SignMask) { - return -1; - } else { - return 1; - } - } else { - return 0; - } -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY float fp16_t::ToFloat() const { return Fp16ToFloat(val); } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY double fp16_t::ToDouble() const { return Fp16ToDouble(val); } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int8_t fp16_t::ToInt8() const { return Fp16ToInt8(val); } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint8_t fp16_t::ToUInt8() const { return Fp16ToUInt8(val); } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int16_t fp16_t::ToInt16() const { return Fp16ToInt16(val); } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint16_t fp16_t::ToUInt16() const { return Fp16ToUInt16(val); } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int32_t fp16_t::ToInt32() const { return Fp16ToInt32(val); } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t fp16_t::ToUInt32() const { return Fp16ToUInt32(val); } -} // namespace parser -} // namespace ge diff --git a/parser/common/parser_fp16_t.h b/parser/common/parser_fp16_t.h deleted file mode 100644 index 6c361e8..0000000 --- a/parser/common/parser_fp16_t.h +++ /dev/null @@ -1,653 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARSER_COMMON_FP16_T_H_ -#define PARSER_COMMON_FP16_T_H_ - -#include -#include -#include - -namespace ge { -namespace parser { -using DimIndex = enum { - kDim0 = 0, - kDim1, - kDim2, - kDim3, - kDim4, - kDim5, - kDim6, - kDim7, - kDim8, - kDim9, - kDim10, - kDim11, - kDim12, - kDim13, - kDim14, - kDim15, - kDim16, -}; - -using BitShift = enum { - kBitShift2 = 2, - kBitShift3 = 3, - kBitShift4 = 4, - kBitShift5 = 5, - kBitShift6 = 6, - kBitShift7 = 7, - kBitShift8 = 8, - kBitShift9 = 9, - kBitShift10 = 10, - kBitShift11 = 11, - kBitShift12 = 12, - kBitShift13 = 13, - kBitShift14 = 14, - kBitShift15 = 15, - kBitShift16 = 16, - kBitShift20 = 20, - kBitShift24 = 24, - kBitShift27 = 27, - kBitShift28 = 28, - kBitShift31 = 31, - kBitShift32 = 32, - kBitShift36 = 36, - kBitShift40 = 40, - kBitShift44 = 44, - kBitShift48 = 48, - kBitShift52 = 52, - kBitShift56 = 56, - kBitShift59 = 59, - kBitShift60 = 60, - kBitShift63 = 63, - kBitShift64 = 64, - kBitShift128 = 128, - kBitShift255 = 255, - kBitShift256 = 256, - kBitShift512 = 512, - kBitShift768 = 768, - kBitShift784 = 784, - kBitShift1020 = 1020, - kBitShift1024 = 1024, - kBitShift3136 = 3136, - kBitShift4096 = 4096, - kBitShift6144 = 6144, - kBitShift10240 = 10240, - kBitShift65536 = 65536 -}; -/// @ingroup fp16 basic parameter -/// @brief fp16 exponent bias -constexpr uint16_t kFp16ExpBias = 15; -/// @ingroup fp16 basic parameter -/// @brief the exponent bit length of fp16 is 5 -constexpr uint16_t kFp16ExpLen = 5; -/// @ingroup fp16 basic parameter -/// @brief the mantissa bit length of fp16 is 10 -constexpr uint16_t kFp16ManLen = 10; -/// @ingroup fp16 basic parameter -/// @brief bit index of sign in fp16 -constexpr uint16_t kFp16SignIndex = 15; -/// @ingroup fp16 basic parameter -/// @brief sign mask of fp16 (1 00000 00000 00000) -constexpr uint16_t kFp16SignMask = 0x8000; -/// @ingroup fp16 basic parameter -/// @brief exponent mask of fp16 ( 11111 00000 00000) -constexpr uint16_t kFp16ExpMask = 0x7C00; -/// @ingroup fp16 basic parameter -/// @brief mantissa mask of fp16 ( 11111 11111) -constexpr uint16_t kFp16ManMask = 0x03FF; -/// @ingroup fp16 basic parameter -/// @brief hide bit of mantissa of fp16( 1 00000 00000) -constexpr uint16_t kFp16ManHideBit = 0x0400; -/// @ingroup fp16 basic parameter -/// @brief maximum value (0111 1011 1111 1111) -constexpr uint16_t kFp16Max = 0x7BFF; -/// @ingroup fp16 basic parameter -/// @brief minimum value (1111 1011 1111 1111) -constexpr uint16_t kFp16Min = 0xFBFF; -/// @ingroup fp16 basic parameter -/// @brief absolute maximum value (0111 1111 1111 1111) -constexpr uint16_t kFp16AbsMax = 0x7FFF; -/// @ingroup fp16 basic parameter -/// @brief maximum exponent value of fp16 is 15(11111) -constexpr uint16_t kFp16MaxExp = 0x001F; -/// @ingroup fp16 basic parameter -/// @brief maximum valid exponent value of fp16 is 14(11110) -constexpr uint16_t kFp16MaxValidExp = 0x001E; -/// @ingroup fp16 basic parameter -/// @brief maximum mantissa value of fp16(11111 11111) -constexpr uint16_t kFp16MaxMan = 0x03FF; -/// @ingroup fp16 basic parameter -/// @brief absolute minimum normal value of fp16 -/// (E=1,M=0 D=2^(-14)=0.00006103515625) -constexpr uint16_t kFp16MinNormal = 1.0f / (2 << 14); -/// @ingroup fp16 basic operator -/// @brief get sign of fp16 -#define FP16_EXTRAC_SIGN(x) (((x) >> 15) & 1) -/// @ingroup fp16 basic operator -/// @brief get exponent of fp16 -#define FP16_EXTRAC_EXP(x) (((x) >> 10) & kFp16MaxExp) -/// @ingroup fp16 basic operator -/// @brief get mantissa of fp16 -#define FP16_EXTRAC_MAN(x) ((((x) >> 0) & 0x3FF) | (((((x) >> 10) & 0x1F) > 0 ? 1 : 0) * 0x400)) -/// @ingroup fp16 basic operator -/// @brief constructor of fp16 from sign exponent and mantissa -#define FP16_CONSTRUCTOR(s, e, m) (((s) << kFp16SignIndex) | ((e) << kFp16ManLen) | ((m)&kFp16MaxMan)) -/// @ingroup fp16 special value judgment -/// @brief whether a fp16 is zero -#define FP16_IS_ZERO(x) (((x)&kFp16AbsMax) == 0) -/// @ingroup fp16 special value judgment -/// @brief whether a fp16 is a denormalized value -#define FP16_IS_DENORM(x) ((((x)&kFp16ExpMask) == 0)) -/// @ingroup fp16 special value judgment -/// @brief whether a fp16 is infinite -#define FP16_IS_INF(x) (((x)&kFp16AbsMax) == kFp16ExpMask) -/// @ingroup fp16 special value judgment -/// @brief whether a fp16 is NaN -#define FP16_IS_NAN(x) (((x & kFp16ExpMask) == kFp16ExpMask) && (x & kFp16ManMask)) -/// @ingroup fp16 special value judgment -/// @brief whether a fp16 is invalid -#define FP16_IS_INVALID(x) ((x & kFp16ExpMask) == kFp16ExpMask) -/// @ingroup fp32 basic parameter -/// @brief fp32 exponent bias -constexpr uint16_t kFp32ExpBias = 127; -/// @ingroup fp32 basic parameter -/// @brief the exponent bit length of float/fp32 is 8 -constexpr uint16_t kFp32ExpLen = 8; -/// @ingroup fp32 basic parameter -/// @brief the mantissa bit length of float/fp32 is 23 -constexpr uint16_t kFp32ManLen = 23; -/// @ingroup fp32 basic parameter -/// @brief bit index of sign in float/fp32 -constexpr uint16_t kFp32SignIndex = 31; -/// @ingroup fp32 basic parameter -/// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000) -constexpr uint32_t kFp32SignMask = 0x80000000u; -/// @ingroup fp32 basic parameter -/// @brief exponent mask of fp32 ( 1111 1111 0000 0000 0000 0000 000) -constexpr uint32_t kFp32ExpMask = 0x7F800000u; -/// @ingroup fp32 basic parameter -/// @brief mantissa mask of fp32 ( 1111 1111 1111 1111 111) -constexpr uint32_t kFp32ManMask = 0x007FFFFFu; -/// @ingroup fp32 basic parameter -/// @brief hide bit of mantissa of fp32 ( 1 0000 0000 0000 0000 000) -constexpr uint32_t kFp32ManHideBit = 0x00800000u; -/// @ingroup fp32 basic parameter -/// @brief absolute maximum value (0 1111 1111 1111 1111 1111 1111 111) -constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFu; -/// @ingroup fp32 basic parameter -/// @brief maximum exponent value of fp32 is 255(1111 1111) -constexpr uint32_t kFp32MaxExp = 0xFF; -/// @ingroup fp32 basic parameter -/// @brief maximum mantissa value of fp32 (1111 1111 1111 1111 1111 111) -constexpr uint32_t kFp32MaxMan = 0x7FFFFF; -/// @ingroup fp32 special value judgment -/// @brief whether a fp32 is NaN -#define FP32_IS_NAN(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (x & kFp32ManMask)) -/// @ingroup fp32 special value judgment -/// @brief whether a fp32 is infinite -#define FP32_IS_INF(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (!(x & kFp32ManMask))) -/// @ingroup fp32 special value judgment -/// @brief whether a fp32 is a denormalized value -#define FP32_IS_DENORM(x) ((((x)&kFp32ExpMask) == 0)) -/// @ingroup fp32 basic operator -/// @brief get sign of fp32 -#define FP32_EXTRAC_SIGN(x) (((x) >> kFp32SignIndex) & 1) -/// @ingroup fp32 basic operator -/// @brief get exponent of fp16 -#define FP32_EXTRAC_EXP(x) (((x)&kFp32ExpMask) >> kFp32ManLen) -/// @ingroup fp32 basic operator -/// @brief get mantissa of fp16 -#define FP32_EXTRAC_MAN(x) (((x)&kFp32ManMask) | (((((x) >> kFp32ManLen) & kFp32MaxExp) > 0 ? 1 : 0) * kFp32ManHideBit)) -/// @ingroup fp32 basic operator -/// @brief constructor of fp32 from sign exponent and mantissa -#define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m)&kFp32MaxMan)) -/// @ingroup fp64 basic parameter -/// @brief fp64 exponent bias -constexpr uint16_t kFp64ExpBias = 1023; -/// @ingroup fp64 basic parameter -/// @brief the exponent bit length of double/fp64 is 11 -constexpr uint16_t kFp64ExpLen = 11; -/// @ingroup fp64 basic parameter -/// @brief the mantissa bit length of double/fp64 is 52 -constexpr uint16_t kFp64ManLen = 52; -/// @ingroup fp64 basic parameter -/// @brief bit index of sign in double/fp64 is 63 -constexpr uint16_t kFp64SignIndex = 63; -/// @ingroup fp64 basic parameter -/// @brief sign mask of fp64 (1 000 (total 63bits 0)) -constexpr uint64_t kFp64SignMask = 0x8000000000000000LLu; -/// @ingroup fp64 basic parameter -/// @brief exponent mask of fp64 (0 1 11111 11111 0000?-?-(total 52bits 0)) -constexpr uint64_t kFp64ExpMask = 0x7FF0000000000000LLu; -/// @ingroup fp64 basic parameter -/// @brief mantissa mask of fp64 ( 1111?-?-(total 52bits 1)) -constexpr uint64_t kFp64ManMask = 0x000FFFFFFFFFFFFFLLu; -/// @ingroup fp64 basic parameter -/// @brief hide bit of mantissa of fp64 ( 1 0000?-?-(total 52bits 0)) -constexpr uint64_t kFp64ManHideBit = 0x0010000000000000LLu; -/// @ingroup fp64 basic parameter -/// @brief absolute maximum value (0 111?-?-(total 63bits 1)) -constexpr uint64_t kFp64AbsMax = 0x7FFFFFFFFFFFFFFFLLu; -/// @ingroup fp64 basic parameter -/// @brief maximum exponent value of fp64 is 2047(1 11111 11111) -constexpr uint64_t kFp64MaxExp = 0x07FF; -/// @ingroup fp64 basic parameter -/// @brief maximum mantissa value of fp64 (111?-?-(total 52bits 1)) -constexpr uint64_t kFp64MaxMan = 0xFFFFFFFFFFFLLu; -/// @ingroup fp64 special value judgment -/// @brief whether a fp64 is NaN -#define FP64_IS_NAN(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (x & kFp64ManMask)) -/// @ingroup fp64 special value judgment -/// @brief whether a fp64 is infinite -#define FP64_IS_INF(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (!(x & kFp64ManMask))) -/// @ingroup integer special value judgment -/// @brief maximum positive value of int8_t (0111 1111) -constexpr int8_t kInt8Max = 0x7F; -/// @ingroup integer special value judgment -/// @brief maximum value of a data with 8 bits length (1111 111) -constexpr uint8_t kBitLen8Max = 0xFF; -/// @ingroup integer special value judgment -/// @brief maximum positive value of int16_t (0111 1111 1111 1111) -constexpr int16_t kInt16Max = 0x7FFF; -/// @ingroup integer special value judgment -/// @brief maximum value of a data with 16 bits length (1111 1111 1111 1111) -constexpr uint16_t kBitLen16Max = 0xFFFF; -/// @ingroup integer special value judgment -/// @brief maximum positive value of int32_t (0111 1111 1111 1111 1111 1111 1111 1111) -constexpr int32_t kInt32Max = 0x7FFFFFFFu; -/// @ingroup integer special value judgment -/// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111) -constexpr uint32_t kBitLen32Max = 0xFFFFFFFFu; -/// @ingroup integer special value judgment -/// @brief maximum positive value of int64_t -/// (0111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) -constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFFu; -/// @ingroup integer special value judgment -/// @brief maximum value of a data with 64 bits length -/// (1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) -constexpr uint64_t kBitLen64Max = 0xFFFFFFFFFFFFFFFFu; - -/// @ingroup fp16_t enum -/// @brief round mode of last valid digital -enum TagFp16RoundMode { - kRoundToNearest = 0, // < round to nearest even - kRoundByTruncated, // < round by truncated - kRoundModeReserved, -}; - -/// @ingroup fp16_t -/// @brief Half precision float -/// bit15: 1 bit SIGN +---+-----+------------+ -/// bit14-10: 5 bit EXP | S |EEEEE|MM MMMM MMMM| -/// bit0-9: 10bit MAN +---+-----+------------+ -using fp16_t = struct TagFp16 { - uint16_t val; - -public: - /// @ingroup fp16_t constructor - /// @brief Constructor without any param(default constructor) - TagFp16(void) { val = 0x0u; } - - /// @ingroup fp16_t constructor - /// @brief Constructor with an uint16_t value - TagFp16(const uint16_t &ui_val) : val(ui_val) {} - - /// @ingroup fp16_t constructor - /// @brief Constructor with a fp16_t object(copy constructor) - TagFp16(const TagFp16 &fp) : val(fp.val) {} - - /// @ingroup fp16_t math operator - /// @param [in] fp fp16_t object to be added - /// @brief Override addition operator to performing fp16_t addition - /// @return Return fp16_t result of adding this and fp - TagFp16 operator+(const TagFp16 fp); - - /// @ingroup fp16_t math operator - /// @param [in] fp fp16_t object to be subtracted - /// @brief Override addition operator to performing fp16_t subtraction - /// @return Return fp16_t result of subtraction fp from this - TagFp16 operator-(const TagFp16 fp); - - /// @ingroup fp16_t math operator - /// @param [in] fp fp16_t object to be multiplied - /// @brief Override multiplication operator to performing fp16_t multiplication - /// @return Return fp16_t result of multiplying this and fp - TagFp16 operator*(const TagFp16 fp); - - /// @ingroup fp16_t math operator divided - /// @param [in] fp fp16_t object to be divided - /// @brief Override division operator to performing fp16_t division - /// @return Return fp16_t result of division this by fp - TagFp16 operator/(const TagFp16 fp); - - /// @ingroup fp16_t math operator - /// @param [in] fp fp16_t object to be added - /// @brief Override addition operator to performing fp16_t addition - /// @return Return fp16_t result of adding this and fp - TagFp16 operator+=(const TagFp16 fp); - - /// @ingroup fp16_t math operator - /// @param [in] fp fp16_t object to be subtracted - /// @brief Override addition operator to performing fp16_t subtraction - /// @return Return fp16_t result of subtraction fp from this - TagFp16 operator-=(const TagFp16 fp); - - /// @ingroup fp16_t math operator - /// @param [in] fp fp16_t object to be multiplied - /// @brief Override multiplication operator to performing fp16_t multiplication - /// @return Return fp16_t result of multiplying this and fp - TagFp16 operator*=(const TagFp16 fp); - - /// @ingroup fp16_t math operator divided - /// @param [in] fp fp16_t object to be divided - /// @brief Override division operator to performing fp16_t division - /// @return Return fp16_t result of division this by fp - TagFp16 operator/=(const TagFp16 fp); - - /// @ingroup fp16_t math compare operator - /// @param [in] fp fp16_t object to be compared - /// @brief Override basic comparison operator to performing fp16_t if-equal comparison - /// @return Return boolean result of if-equal comparison of this and fp. - bool operator==(const TagFp16 &fp) const; - - /// @ingroup fp16_t math compare operator - /// @param [in] fp fp16_t object to be compared - /// @brief Override basic comparison operator to performing fp16_t not-equal comparison - /// @return Return boolean result of not-equal comparison of this and fp. - bool operator!=(const TagFp16 &fp) const; - - /// @ingroup fp16_t math compare operator - /// @param [in] fp fp16_t object to be compared - /// @brief Override basic comparison operator to performing fp16_t greater-than comparison - /// @return Return boolean result of greater-than comparison of this and fp. - bool operator>(const TagFp16 &fp) const; - - /// @ingroup fp16_t math compare operator - /// @param [in] fp fp16_t object to be compared - /// @brief Override basic comparison operator to performing fp16_t greater-equal comparison - /// @return Return boolean result of greater-equal comparison of this and fp. - bool operator>=(const TagFp16 &fp) const; - - /// @ingroup fp16_t math compare operator - /// @param [in] fp fp16_t object to be compared - /// @brief Override basic comparison operator to performing fp16_t less-than comparison - /// @return Return boolean result of less-than comparison of this and fp. - bool operator<(const TagFp16 &fp) const; - - /// @ingroup fp16_t math compare operator - /// @param [in] fp fp16_t object to be compared - /// @brief Override basic comparison operator to performing fp16_t less-equal comparison - /// @return Return boolean result of less-equal comparison of this and fp. - bool operator<=(const TagFp16 &fp) const; - - /// @ingroup fp16_t math evaluation operator - /// @param [in] fp fp16_t object to be copy to fp16_t - /// @brief Override basic evaluation operator to copy fp16_t to a new fp16_t - /// @return Return fp16_t result from fp - TagFp16 &operator=(const TagFp16 &fp); - - /// @ingroup fp16_t math evaluation operator - /// @param [in] f_val float object to be converted to fp16_t - /// @brief Override basic evaluation operator to convert float to fp16_t - /// @return Return fp16_t result from f_val - TagFp16 &operator=(const float &f_val); - - /// @ingroup fp16_t math evaluation operator - /// @param [in] d_val double object to be converted to fp16_t - /// @brief Override basic evaluation operator to convert double to fp16_t - /// @return Return fp16_t result from d_val - TagFp16 &operator=(const double &d_val); - - /// @ingroup fp16_t math evaluation operator - /// @param [in] i_val float object to be converted to fp16_t - /// @brief Override basic evaluation operator to convert float to fp16_t - /// @return Return fp16_t result from i_val - TagFp16 &operator=(const int8_t &i_val); - - /// @ingroup fp16_t math evaluation operator - /// @param [in] ui_val uint8_t object to be converted to fp16_t - /// @brief Override basic evaluation operator to convert uint8_t to fp16_t - /// @return Return fp16_t result from ui_val - TagFp16 &operator=(const uint8_t &ui_val); - - /// @ingroup fp16_t math evaluation operator - /// @param [in] i_val int16_t object to be converted to fp16_t - /// @brief Override basic evaluation operator to convert int16_t to fp16_t - /// @return Return fp16_t result from i_val - TagFp16 &operator=(const int16_t &i_val); - - /// @ingroup fp16_t math evaluation operator - /// @param [in] ui_val uint16_t object to be converted to fp16_t - /// @brief Override basic evaluation operator to convert uint16_t to fp16_t - /// @return Return fp16_t result from ui_val - TagFp16 &operator=(const uint16_t &ui_val); - - /// @ingroup fp16_t math evaluation operator - /// @param [in] i_val int32_t object to be converted to fp16_t - /// @brief Override basic evaluation operator to convert int32_t to fp16_t - /// @return Return fp16_t result from i_val - TagFp16 &operator=(const int32_t &i_val); - - /// @ingroup fp16_t math evaluation operator - /// @param [in] ui_val uint32_t object to be converted to fp16_t - /// @brief Override basic evaluation operator to convert uint32_t to fp16_t - /// @return Return fp16_t result from ui_val - TagFp16 &operator=(const uint32_t &ui_val); - - /// @ingroup fp16_t math conversion - /// @brief Override convert operator to convert fp16_t to float/fp32 - /// @return Return float/fp32 value of fp16_t - operator float() const; - - /// @ingroup fp16_t math conversion - /// @brief Override convert operator to convert fp16_t to double/fp64 - /// @return Return double/fp64 value of fp16_t - operator double() const; - - /// @ingroup fp16_t math conversion - /// @brief Override convert operator to convert fp16_t to int8_t - /// @return Return int8_t value of fp16_t - operator int8_t() const; - - /// @ingroup fp16_t math conversion - /// @brief Override convert operator to convert fp16_t to uint8_t - /// @return Return uint8_t value of fp16_t - operator uint8_t() const; - - /// @ingroup fp16_t conversion - /// @brief Override convert operator to convert fp16_t to int16_t - /// @return Return int16_t value of fp16_t - operator int16_t() const; - - /// @ingroup fp16_t math conversion - /// @brief Override convert operator to convert fp16_t to uint16_t - /// @return Return uint16_t value of fp16_t - operator uint16_t() const; - - /// @ingroup fp16_t math conversion - /// @brief Override convert operator to convert fp16_t to int32_t - /// @return Return int32_t value of fp16_t - operator int32_t() const; - - /// @ingroup fp16_t math conversion - /// @brief Override convert operator to convert fp16_t to uint32_t - /// @return Return uint32_t value of fp16_t - operator uint32_t() const; - - /// @ingroup fp16_t math conversion - /// @brief Override convert operator to convert fp16_t to int64_t - /// @return Return int64_t value of fp16_t - operator int64_t() const; - - /// @ingroup fp16_t math conversion - /// @brief Override convert operator to convert fp16_t to uint64_t - /// @return Return uint64_t value of fp16_t - operator uint64_t() const; - - /// @ingroup fp16_t judgment method - /// @param [in] fp fp16_t object to be judgement - /// @brief whether a fp16_t is inifinite - /// @return Returns 1:+INF -1:-INF 0:not INF - int IsInf(); - - /// @ingroup fp16_t math conversion - /// @brief Convert fp16_t to float/fp32 - /// @return Return float/fp32 value of fp16_t - float ToFloat() const; - - /// @ingroup fp16_t math conversion - /// @brief Convert fp16_t to double/fp64 - /// @return Return double/fp64 value of fp16_t - double ToDouble() const; - - /// @ingroup fp16_t math conversion - /// @brief Convert fp16_t to int8_t - /// @return Return int8_t value of fp16_t - int8_t ToInt8() const; - - /// @ingroup fp16_t math conversion - /// @brief Convert fp16_t to uint8_t - /// @return Return uint8_t value of fp16_t - uint8_t ToUInt8() const; - - /// @ingroup fp16_t conversion - /// @brief Convert fp16_t to int16_t - /// @return Return int16_t value of fp16_t - int16_t ToInt16() const; - - /// @ingroup fp16_t math conversion - /// @brief Convert fp16_t to uint16_t - /// @return Return uint16_t value of fp16_t - uint16_t ToUInt16() const; - - /// @ingroup fp16_t math conversion - /// @brief Convert fp16_t to int32_t - /// @return Return int32_t value of fp16_t - int32_t ToInt32() const; - - /// @ingroup fp16_t math conversion - /// @brief Convert fp16_t to uint32_t - /// @return Return uint32_t value of fp16_t - uint32_t ToUInt32() const; -}; - -/// @ingroup fp16_t public method -/// @param [in] val signature is negative -/// @param [in|out] s sign of fp16_t object -/// @param [in|out] e exponent of fp16_t object -/// @param [in|out] m mantissa of fp16_t object -/// @brief Extract the sign, exponent and mantissa of a fp16_t object -void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m); - -/// @ingroup fp16_t public method -/// @param [in] negative sign is negative -/// @param [in|out] man mantissa to be reverse -/// @brief Calculate a mantissa's complement (add ont to it's radix-minus-one complement) -/// @return Return complement of man -template -void ReverseMan(bool negative, T &man) { - if (negative) { - man = (~(man)) + 1; - } -} - -/// @ingroup fp16_t public method -/// @param [in] e_a exponent of one fp16_t/float number -/// @param [in] m_a mantissa of one fp16_t/float number -/// @param [in] e_b exponent of another fp16_t/float number -/// @param [in] m_b mantissa of another fp16_t/float number -/// @brief choose mantissa to be shift right whoes exponent is less than another one -/// @return Return mantissawhoes exponent is less than another one -template -T MinMan(const int16_t &e_a, T &m_a, const int16_t &e_b, T &m_b) { - return (e_a > e_b) ? m_b : m_a; -} - -/// @ingroup fp16_t public method -/// @param [in] man mantissa to be operate -/// @param [in] shift right shift bits -/// @brief right shift a mantissa -/// @return Return right-shift mantissa -template -T RightShift(T man, int16_t shift) { - int bits = sizeof(T) * 8; // one byte have 8 bits - T mask = (((T) 1u) << ((unsigned int) (bits - 1))); - for (int i = 0; i < shift; i++) { - man = ((man & mask) | (man >> 1)); - } - return man; -} - -/// @ingroup fp16_t public method -/// @param [in] e_a exponent of one temp fp16_t number -/// @param [in] m_a mantissa of one temp fp16_t number -/// @param [in] e_b exponent of another temp fp16_t number -/// @param [in] m_b mantissa of another temp fp16_t number -/// @brief Get mantissa sum of two temp fp16_t numbers, T support types: uint16_t/uint32_t/uint64_t -/// @return Return mantissa sum -template -T GetManSum(int16_t e_a, const T &m_a, int16_t e_b, const T &m_b) { - T sum = 0; - if (e_a != e_b) { - T m_tmp = 0; - int16_t e_tmp = std::abs(e_a - e_b); - if (e_a > e_b) { - m_tmp = m_b; - m_tmp = RightShift(m_tmp, e_tmp); - sum = m_a + m_tmp; - } else { - m_tmp = m_a; - m_tmp = RightShift(m_tmp, e_tmp); - sum = m_tmp + m_b; - } - } else { - sum = m_a + m_b; - } - return sum; -} - -/// @ingroup fp16_t public method -/// @param [in] bit0 whether the last preserved bit is 1 before round -/// @param [in] bit1 whether the abbreviation's highest bit is 1 -/// @param [in] bitLeft whether the abbreviation's bits which not contain highest bit grater than 0 -/// @param [in] man mantissa of a fp16_t or float number, support types: uint16_t/uint32_t/uint64_t -/// @param [in] shift abbreviation bits -/// @brief Round fp16_t or float mantissa to nearest value -/// @return Returns true if round 1,otherwise false; -template -T ManRoundToNearest(bool bit0, bool bit1, bool bitLeft, T man, uint16_t shift = 0) { - man = (man >> shift) + ((bit1 && (bitLeft || bit0)) ? 1 : 0); - return man; -} - -/// @ingroup fp16_t public method -/// @param [in] man mantissa of a float number, support types: uint16_t/uint32_t/uint64_t -/// @brief Get bit length of a uint32_t number -/// @return Return bit length of man -template -int16_t GetManBitLength(T man) { - int16_t len = 0; - while (man) { - man >>= 1; - len++; - } - return len; -} -} // namespace parser -} // namespace ge -#endif // GE_PARSER_COMMON_FP16_T_H_ diff --git a/parser/common/parser_types.cc b/parser/common/parser_types.cc deleted file mode 100644 index 440e884..0000000 --- a/parser/common/parser_types.cc +++ /dev/null @@ -1,494 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "framework/omg/parser/parser_types.h" - - -namespace ge{ -namespace parser { -const char *DATA = "Data"; -const char *AIPPDATA = "AippData"; -const char *CONVOLUTION = "Convolution"; -const char *CORRELATION = "Correlation"; -const char *CORRELATIONV2 = "Correlation_V2"; -const char *DECONVOLUTION = "Deconvolution"; -const char *POOLING = "Pooling"; -const char *ELTWISE = "Eltwise"; -const char *RELU = "ReLU"; -const char *RELU6 = "ReLU6"; -const char *SIGMOID = "Sigmoid"; -const char *ABSVAL = "AbsVal"; -const char *TANH = "TanH"; -const char *PRELU = "PReLU"; -const char *BATCHNORM = "BatchNorm"; -const char *FUSIONBATCHNORM = "FusionBatchNorm"; -const char *SCALE = "Scale"; -const char *FULL_CONNECTION = "FullConnection"; -const char *SOFTMAX = "Softmax"; -const char *PLUS = "Plus"; -const char *ACTIVATION = "Activation"; -const char *FLATTEN = "Flatten"; -const char *ADD = "Add"; -const char *SUB = "Sub"; -const char *MUL = "Mul"; -const char *MATMUL = "MatMul"; -const char *RSQRT = "Rsqrt"; -const char *BIASADD = "BiasAdd"; -const char *RESHAPE = "Reshape"; -const char *REFORMAT = "ReFormat"; -const char *DEPCONVOLUTION = "ConvolutionDepthwise"; -const char *DROPOUT = "Dropout"; -const char *DROPOUTGENMASK = "DropOutGenMask"; -const char *DROPOUTDOMASK = "DropOutDoMask"; -const char *CONCAT = "Concat"; -const char *ROIPOOLING = "ROIPooling"; -const char *PROPOSAL = "Proposal"; -const char *FSRDETECTIONOUTPUT = "FSRDetectionOutput"; -const char *DETECTIONPOSTPROCESS = "Detectpostprocess"; -const char *LRN = "LRN"; -const char *TRANSDATA = "TransData"; -const char *PERMUTE = "Permute"; -const char *SSDNORMALIZE = "SSDNormalize"; -const char *SSDPRIORBOX = "SSDPriorBox"; -const char *NETOUTPUT = "NetOutput"; -const char *SSDDETECTIONOUTPUT = "SSDDetectionOutput"; -const char *REFINEDETDETECTIONOUTPUT = "RefinedetDetectionOutput"; -const char *CHANNELAXPY = "ChannelAxpy"; -const char *PSROIPOOLING = "PSROIPooling"; -const char *POWER = "Power"; -const char *POW = "Pow"; -const char *ROIALIGN = "ROIAlign"; -const char *PYTHON = "Python"; -const char *FREESPACEEXTRACT = "FreespaceExtract"; -const char *SPATIALTF = "SpatialTransform"; -const char *SHAPE = "Shape"; -const char *SHAPEN = "ShapeN"; -const char *ARGMAX = "ArgMax"; -const char *GATHERND = "GatherNd"; -const char *GATHER = "Gather"; -const char *REALDIV = "RealDiv"; -const char *PACK = "Pack"; -const char *SLICE = "Slice"; -const char *SLICED = "SliceD"; -const char *FLOORDIV = "FloorDiv"; -const char *SQUEEZE = "Squeeze"; -const char *UNSQUEEZE = "Unsqueeze"; -const char *STRIDEDSLICE = "StridedSlice"; -const char *RANGE = "Range"; -const char *RPNPROPOSALS = "RpnProposals"; -const char *DECODEBBOX = "DecodeBbox"; -const char *PAD = "Pad"; -const char *PADV2 = "PadV2"; -const char *MIRRORPAD = "MirrorPad"; -const char *TILE = "Tile"; -const char *SIZE = "Size"; -const char *CLIPBOXES = "ClipBoxes"; -const char *FASTRCNNPREDICTIONS = "FastrcnnPredictions"; -const char *SPLIT = "Split"; -const char *SPLITV = "SplitV"; -const char *EXPANDDIMS = "ExpandDims"; -const char *EMPTY = "Empty"; -const char *MEAN = "Mean"; -const char *GREATER = "Greater"; -const char *SWITCH = "Switch"; -const char *SWITCHN = "SwitchN"; -const char *MERGE = "Merge"; -const char *SYMBOLICGRADIENT = "SymbolicGradient"; -const char *REMOTECALL = "RemoteCall"; -const char *_IF = "_If"; -const char *STATELESSIF = "StatelessIf"; -const char *IF = "If"; -const char *CASE = "Case"; -const char *_WHILE = "_While"; -const char *WHILE = "While"; -const char *STATELESSWHILE = "StatelessWhile"; -const char *FOR = "For"; -const char *PARTITIONEDCALL = "PartitionedCall"; -const char *STATEFULPARTITIONEDCALL = "StatefulPartitionedCall"; -const char *FAKEPARAM = "FakeParam"; -const char *TRANSPOSE = "Transpose"; -const char *TRANSPOSED = "TransposeD"; -const char *CAST = "Cast"; -const char *REGION = "Region"; -const char *YOLO = "Yolo"; -const char *YOLODETECTIONOUTPUT = "YoloDetectionOutput"; -const char *FILL = "Fill"; -const char *REVERSE = "Reverse"; -const char *UNPACK = "Unpack"; -const char *YOLO2REORG = "Yolo2Reorg"; -const char *REDUCESUM = "ReduceSum"; -const char *SUM = "Sum"; -const char *CONSTANT = "Const"; -const char *RESIZEBILINEAR = "ResizeBilinear"; -const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad"; -const char *MAXIMUM = "Maximum"; -const char *FRAMEWORKOP = "FrameworkOp"; -const char *ARG = "_Arg"; -const char *FUSEDBATCHNORMGRAD = "FusedBatchNormGrad"; -const char *LSTM = "LSTM"; -const char *HIGHWAY = "HighWay"; -const char *RNN = "RNN"; -const char *ATTENTIONDECODER = "AttentionDecoder"; -const char *LOGICAL_NOT = "LogicalNot"; -const char *LOGICAL_AND = "LogicalAnd"; -const char *LOGICAL_OR = "LogicalOr"; -const char *EQUAL = "Equal"; -const char *NOTEQUAL = "NotEqual"; -const char *INTERP = "Interp"; -const char *SHUFFLECHANNEL = "ShuffleChannel"; -const char *AIPP = "Aipp"; -const char *MULTISHAPE = "MultiShape"; -const char *RECIPROCAL = "Reciprocal"; -const char *SELU = "Selu"; -const char *ELU = "Elu"; -const char *ACOSH = "Acosh"; -const char *ASINH = "Asinh"; -const char *MINIMUM = "Minimum"; -const char *CLIP = "Clip"; -const char *L2NORMALIZE = "L2Normalize"; -const char *CROPANDRESIZE = "CropAndResize"; -const char *UNUSEDCONST = "UnusedConst"; -const char *SPARSETODENSE = "SparseToDense"; -const char *NONMAXSUPPRESSION = "NonMaxSuppression"; -const char *TOPKV2 = "TopKV2"; -const char *INVERTPERMUTATION = "InvertPermutation"; -const char *MULTINOMIAL = "Multinomial"; -const char *REVERSESEQUENCE = "ReverseSequence"; -const char *REDUCEPROD = "ReduceProd"; -const char *REDUCEMAX = "ReduceMax"; -const char *REDUCEMIN = "ReduceMin"; -const char *EXTRACTIMAGEPATCHES = "ExtractImagePatches"; -const char *SQRT = "Sqrt"; -const char *REDUCEALL = "ReduceAll"; -const char *RESIZENEARESTNEIGHBOR = "ResizeNearestNeighbor"; -const char *SPACETOBATCHND = "SpaceToBatchND"; -const char *BATCHTOSPACEND = "BatchToSpaceND"; -const char *ASSERT = "Assert"; -const char *GREATEREQUAL = "GreaterEqual"; -const char *FLOOR = "Floor"; -const char *RANDOMUNIFORM = "RandomUniform"; -const char *BATCHMATMUL = "BatchMatMul"; -const char *SPACETODEPTH = "SpaceToDepth"; -const char *DEPTHTOSPACE = "DepthToSpace"; -const char *RINT = "Rint"; -const char *ATAN = "Atan"; -const char *ATAN2 = "Atan2"; -const char *ATANH = "Atanh"; -const char *ACOS = "Acos"; -const char *ASIN = "Asin"; -const char *NEG = "Neg"; -const char *LOG = "Log"; -const char *TAN = "Tan"; -const char *ROUND = "Round"; -const char *UPSAMPLE = "Upsample"; -const char *FLOORMOD = "FloorMod"; -const char *LESS = "Less"; -const char *LESSEQUAL = "LessEqual"; -const char *ONEHOT = "OneHot"; -const char *REFSWITCH = "RefSwitch"; -const char *REFMERGE = "RefMerge"; -const char *ENTER = "Enter"; -const char *REFENTER = "RefEnter"; -const char *LOOPCOND = "LoopCond"; -const char *NEXTITERATION = "NextIteration"; -const char *REFNEXTITERATION = "RefNextIteration"; -const char *EXIT = "Exit"; -const char *REFEXIT = "RefExit"; -const char *CONTROLTRIGGER = "ControlTrigger"; -const char *ZEROSLIKE = "ZerosLike"; -const char *EXP = "Exp"; -const char *WHERE = "Where"; -const char *FAKEQUANTWITHMINMAXVARS = "FakeQuantWithMinMaxVars"; -const char *SOFTPLUS = "Softplus"; -const char *SOFTSIGN = "Softsign"; -const char *COSH = "Cosh"; -const char *SINH = "Sinh"; -const char *SQUAREDDIFFERENCE = "SquaredDifference"; -const char *REQUIREDSPACETOBATCHPADDINGS = "RequiredSpaceToBatchPaddings"; // for retinanet scope fusion -const char *SSDPOSTPROCESSOR = "SSDPostProcessor"; -const char *RETINANETBOXES = "RetinanetBoxes"; -const char *RETINAMULTIANCHORS = "RetinaMultiAnchor"; -const char *RETINANETCLIPPEDBOXES = "RetinanetClippedBoxes"; -const char *RETINANETFILTEREDDETECTIONS = "RetinanetFilteredDetections"; -const char *RETINANETPOSTPROCESSOR = "RetinanetPostProcessor"; -const char *RETINANETANCHORS = "RetinanetAnchors"; -const char *FASTERRCNNMAP = "FasterRCNNMap"; -const char *FASTERRCNNMAP1 = "FasterRCNNMap1"; -const char *FASTERRCNNSECONDSTAGEPOSTPROCESSOR = "FasterRCNNSecondStagePostprocessor"; -const char *FASTERRCNNROIINTERPOOLING = "FasterRCNNROIInterPooling"; -const char *FASTERRCNNFIRSTSTAGEPOSTPROCESSOR = "FasterRCNNFirstStagePostprocessor"; -const char *FASTERRCNNGRIDANCHORGENERATOR = "FasterRCNNGridAnchorGenerator"; -const char *ROIINTERPOOLING = "ROIInterPooling"; -const char *FASTERRCNNCLIPTOWINDOW = "FasterRCNNClipToWindow"; -const char *EMBEDLOOKUP = "EmbedLookup"; -const char *HASHLOOKUP = "HashLookup"; -const char *LSH_PROJ = "LshProject"; -const char *SVDF = "SVDF"; -const char *SSDANCHORGENERATOR = "SSDAnchorGenerator"; -const char *IDENTITY = "Identity"; -const char *IDENTITYN = "IdentityN"; -const char *PLACEHOLDERWITHDEFAULT = "PlaceholderWithDefault"; -const char *SELECT = "Select"; -const char *GETSPAN = "GetSpan"; -const char *STOPGRADIENT = "StopGradient"; -const char *PREVENTGRADIENT = "PreventGradient"; -const char *GUARANTEECONST = "GuaranteeConst"; -const char *BROADCASTGRADIENTARGS = "BroadcastGradientArgs"; -const char *BROADCASTARGS = "BroadcastArgs"; -const char *CONFUSIONMATRIX = "ConfusionMatrix"; -const char *RANK = "Rank"; -const char *PLACEHOLDER = "PlaceHolder"; -const char *END = "End"; -const char *BASICLSTMCELL = "BasicLSTMCell"; -const char *GETNEXT = "GetNext"; -const char *INITDATA = "InitData"; -const char *REFIDENTITY = "RefIdentity"; -const char *BITCAST = "Bitcast"; - -/***************Ann special operator*************************/ -const char *ANN_MEAN = "AnnMean"; -const char *ANN_CONVOLUTION = "AnnConvolution"; -const char *ANN_DEPCONVOLUTION = "AnnDepthConv"; -const char *ANN_FULLCONNECTION = "AnnFullConnection"; -const char *ANN_NETOUTPUT = "AnnNetOutput"; -const char *ANN_DATA = "AnnData"; -const char *ANN_RESHAPE = "AnnReshape"; -const char *ANN_ADD = "AnnAdd"; -const char *ANN_MUL = "AnnMul"; -const char *ANN_SUB = "AnnSub"; -const char *ANN_DIV = "AnnDiv"; -const char *ANN_DEQUANTIZE = "AnnDequant"; -const char *ANN_QUANTIZE = "AnnQuant"; -const char *ANN_PAD = "AnnPad"; -const char *ANN_RESIZE_BILINEAR = "AnnResizeBilinear"; - -/***************************************************/ -/******************Training operator*************************/ -const char *GATHERV2 = "GatherV2"; -const char *CONVGRADFILTER = "Conv2DBackpropFilter"; -const char *CONV2D = "Conv2D"; -const char *CONV2DBACKPROPINPUT = "Conv2DBackpropInput"; -const char *FUSEDBATCHNORM = "FusedBatchNorm"; -const char *BIASADDGRAD = "BiasAddGrad"; -const char *ACTIVATIONGRAD = "ReluGrad"; -const char *MAXPOOLWITHARGMAX = "MaxPoolWithArgmax"; -const char *MAXPOOLGRADWITHARGMAX = "MaxPoolGradWithArgmax"; -const char *SPARSESOFTMAXCROSSENTROPYWITHLOGITS = "SparseSoftmaxCrossEntropyWithLogits"; -const char *SNAPSHOT = "Snapshot"; -const char *VAR = "Var"; -const char *MEANGRAD = "MeanGrad"; -const char *TRANSLATE = "Translate"; -const char *ADDN = "AddN"; -const char *L2LOSS = "L2Loss"; -const char *MULTIPLY = "Multiply"; -const char *HUBERLOSSGRAD = "HuberLossGrad"; -const char *HUBERLOSS = "HuberLoss"; -const char *NEGATIVE = "Negative"; -const char *SSDCAST = "SSDCast"; -const char *SPARSESOFTMAXCROSSENTROPY = "SsdSparseSoftmaxCrossEntropy"; -const char *SPARSESOFTMAXCROSSENTROPYGRAD = "SsdSparseSoftmaxCrossEntropyGrad"; -const char *SSDSQUEEZEFUSION = "SsdSqueezeFusion"; -const char *CONCATFOUR2FIVE = "ConcatFour2Five"; -const char *CONCATFIVE2FOUR = "ConcatFive2Four"; -const char *SSDREALDIVTILEMUL = "SSDRealdivTileMul"; -const char *SSDSUMMULREALDIVMEAN = "SSDSumMulRealdivMean"; - -const char *VARIABLEV2 = "VariableV2"; -const char *VARHANDLEOP = "VarHandleOp"; -const char *TEMPORARYVARIABLE = "TemporaryVariable"; -const char *DESTROYTEMPORARYVARIABLE = "DestroyTemporaryVariable"; -const char *VARIABLE = "Variable"; -const char *ASSIGN = "Assign"; -const char *ASSIGNVARIABLEOP = "AssignVariableOp"; -const char *ASSIGNADD = "AssignAdd"; -const char *ASSIGNADDVARIABLEOP = "AssignAddVariableOp"; -const char *ASSIGNSUB = "AssignSub"; -const char *ASSIGNSUBVARIABLEOP = "AssignSubVariableOp"; -const char *APPLYMOMENTUM = "ApplyMomentum"; -const char *RESOURCEAPPLYMOMENTUM = "ResourceApplyMomentum"; -const char *SGD = "SGD"; -const char *NOOP = "NoOp"; -const char *READVARIABLEOP = "ReadVariableOp"; -const char *PARALLELCONCATSTART = "_ParallelConcatStart"; -const char *CONSTANTOP = "Constant"; -const char *DEPTHWISECONV2DBACKPROPFILTER = "DepthwiseConv2dNativeBackpropFilter"; -const char *DEPTHWISECONV2DBACKPORPINPUT = "DepthwiseConv2dNativeBackpropInput"; -const char *DEPTHWISECONV2DFORWARDNATIVE = "DepthwiseConv2dNative"; -const char *DROPOUTGRAD = "DropOutGrad"; -const char *APPLYRMSPROPMIXEDPRECISION = "apply_rms_prop_mixed_precision"; -const char *APPLYRMSPROP = "ApplyRMSProp"; -const char *RELU6GRAD = "Relu6Grad"; -const char *AVGPOOLGRAD = "AvgPoolGrad"; -const char *CONCATV2 = "ConcatV2"; -const char *CONCATOFFSET = "ConcatOffset"; -const char *LAYERNORMGRAD = "LayerNormGrad"; -const char *LAYERNORM = "LayerNorm"; -const char *LARS = "Lars"; -const char *DYNAMICSTITCH = "DynamicStitch"; - -/***************************************************/ -const char *SQUARE = "Square"; -const char *HCOMBROADCAST = "HcomBroadcast"; -const char *HCOMALLGATHER = "HcomAllGather"; -const char *HCOMALLREDUCE = "HcomAllReduce"; -const char *HCOMREDUCESCATTER = "HcomReduceScatter"; -const char *HCOMSEND = "HcomSend"; -const char *HCOMRECEIVE = "HcomReceive"; -const char *HCOMREMOTEREAD = "HcomRemoteRead"; -const char *HCOMREMOTEWRITE = "HcomRemoteWrite"; - -const char *VARASSIGN = "VarAssign"; -const char *VARISINITIALIZEDOP = "VarIsInitializedOp"; -const char *LogTimeStamp = "LogTimeStamp"; -const char *ISVARIABLEINITIALIZED = "IsVariableInitialized"; -const char *STREAMSWITCH = "StreamSwitch"; -const char *STREAMSWITCHN = "StreamSwitchN"; -const char *STREAMACTIVE = "StreamActive"; -const char *MEMCPYASYNC = "MemcpyAsync"; -const char *MEMCPYADDRASYNC = "MemcpyAddrAsync"; -const char *STREAMMERGE = "StreamMerge"; -const char *ENDGRAPH = "EndGraph"; -const char *SEND = "Send"; -const char *RECV = "Recv"; -const char *ENDOFSEQUENCE = "EndOfSequence"; - -const char *LABELSET = "LabelSet"; -const char *LABELGOTO = "LabelGoto"; -const char *LABELGOTOEX = "LabelGotoEx"; -const char *LABELSWITCH = "LabelSwitch"; -const char *LABELSWITCHBYINDEX = "LabelSwitchByIndex"; - -const char *ATOMICADDRCLEAN = "AtomicAddrClean"; - -const char *ABS_GRAD = "AbsGrad"; -const char *ACCUMULATE_N_V2 = "AccumulateNV2"; -const char *ACOS_GRAD = "AcosGrad"; -const char *ACOSH_GRAD = "AcoshGrad"; -const char *ANY = "Any"; -const char *APPROXIMATE_EQUAL = "ApproximateEqual"; -const char *ASIN_GRAD = "AsinGrad"; -const char *ASINH_GRAD = "AsinhGrad"; -const char *ATAN_GRAD = "AtanGrad"; -const char *BROADCAST_TO = "BroadcastTo"; -const char *ELU_GRAD = "EluGrad"; -const char *ADD_V2 = "AddV2"; -const char *DATAFORMATDIMMAP = "DataFormatDimMap"; -const char *DATAFORMATVECPERMUTE = "DataFormatVecPermute"; -const char *BESSELI0E = "BesselI0e"; -const char *BESSELI1E = "BesselI1e"; -const char *APPLYADADELTA = "ApplyAdadelta"; -const char *APPLYADAGRAD = "ApplyAdagrad"; -const char *APPLYADAGRADDA = "ApplyAdagradDA"; -const char *APPLYADAM = "ApplyAdam"; -const char *APPLYADAMAX = "ApplyAdaMax"; -const char *APPLYADDSIGN = "ApplyAddSign"; -const char *APPLYCENTEREDRMSPROP = "ApplyCenteredRMSProp"; -const char *APPLYFTRL = "ApplyFtrl"; -const char *APPLYFTRLV2 = "ApplyFtrlV2"; -const char *APPLYGRADIENTDESCENT = "ApplyGradientDescent"; -const char *APPLYPOWERSIGN = "ApplyPowerSign"; -const char *APPLYPROXIMALADAGRAD = "ApplyProximalAdagrad"; -const char *APPLYPROXIMALGRADIENTDESCENT = "ApplyProximalGradientDescent"; -const char *DEQUANTIZE = "Dequantize"; - -const char *FOCAL_LOSS = "FocalLoss"; -const char *FOCAL_LOSS_GRAD = "FocalLossGrad"; -const char *SMOOTHL1_LOSS = "SmoothL1Loss"; -const char *SMOOTHL1_LOSS_grad = "SmoothL1LossGrad"; -const char *REDUCEMEAN = "ReduceMean"; -const char *CONCAT_V2 = "ConcatV2"; -const char *ONEHOT_V2 = "OneHotV2"; -const char *SLICE_V2 = "SliceV2"; -const char *TILE_V2 = "TileV2"; -const char *SUM_V2 = "SumV2"; -// Common type when the operator has the same name -const char *DETECTIONOUTPUT = "DetectionOutput"; -// Custom operator -const char *CUSTOMOP = "CustomOp"; -const char *CUSTOMOP_NCHW = "CustomOpNchw"; -const char *CUSTOMOP_NHWC = "CustomOpNhwc"; -const char *CUSTOMOP_NC1HWC0 = "CustomOpNc1hwc0"; - -// Depthwise 4d_2_6d,6d_2_4d -const char *DEPTHWISEWEIGHT4D26D = "depthwise_weight_4d_2_6d"; -const char *DEPTHWISEWEIGHT6D24D = "depthwise_weight_6d_2_4d"; - -const char *SQRTGRAD = "SqrtGrad"; -const char *SIGMOIDGRAD = "SigmoidGrad"; - -const char *TRANSSHAPE = "TransShape"; - -// Horovod operator -const char *HVDCALLBACKALLREDUCE = "HorovodAllreduce"; -const char *HVDCALLBACKALLGATHER = "HorovodAllgather"; -const char *HVDCALLBACKBROADCAST = "HorovodBroadcast"; -const char *HVDWAIT = "HorovodWait"; - -/// -/// @brief Magic number of model file -/// -const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number - -/// -/// @brief Model head length -/// -const uint32_t MODEL_FILE_HEAD_LEN = 256; - -const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0/// - -/// -/// @ingroup domi_omg -/// @brief alpha default value -/// -const float ALPHA_DEFAULT_VALUE = 1.0; - -/// -/// @ingroup domi_omg -/// @brief beta default value -/// -const float BETA_DEFAULT_VALUE = 0.0; - -/// -/// @ingroup domi_omg -/// @brief Input node type -/// -const std::string INPUT_TYPE = "Input"; -const std::string DUMMY_DATA = "DummyData"; - -// for fusion op plugin -const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; - -const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; -const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; - -/// -/// @ingroup domi_omg -/// @brief DATA node type -/// -const std::string DATA_TYPE = "Data"; - -/// -/// @ingroup domi_omg -/// @brief Frame operator type -/// -const std::string FRAMEWORK_OP_TYPE = "FrameworkOp"; - -/// -/// @ingroup domi_omg -/// @brief Convolution node type -/// -const std::string NODE_NAME_NET_OUTPUT = "Node_Output"; -} // namespace parser -} // namespace ge diff --git a/parser/common/pass_manager.cc b/parser/common/pass_manager.cc deleted file mode 100644 index 0c28572..0000000 --- a/parser/common/pass_manager.cc +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parser/common/pass_manager.h" -#include "framework/omg/parser/parser_types.h" -#include "parser/common/acl_graph_parser_util.h" -#include "common/debug/log.h" -#include "graph/utils/node_utils.h" -#include "omg/omg_inner_types.h" - -namespace ge { -namespace parser { -const vector> &PassManager::GraphPasses() const { return names_to_graph_passes_; } - -Status PassManager::AddPass(const string &pass_name, GraphPass *pass) { - GE_CHECK_NOTNULL(pass); - names_to_graph_passes_.emplace_back(pass_name, pass); - return SUCCESS; -} - -Status PassManager::Run(const ComputeGraphPtr &graph) { - GE_CHECK_NOTNULL(graph); - return Run(graph, names_to_graph_passes_); -} - -Status PassManager::Run(const ComputeGraphPtr &graph, vector> &names_to_passes) { - GE_CHECK_NOTNULL(graph); - bool not_changed = true; - - for (auto &pass_pair : names_to_passes) { - const auto &pass = pass_pair.second; - const auto &pass_name = pass_pair.first; - GE_CHECK_NOTNULL(pass); - - PARSER_TIMESTAMP_START(PassRun); - Status status = pass->Run(graph); - if (status == SUCCESS) { - not_changed = false; - } else if (status != NOT_CHANGED) { - GELOGE(status, "Pass Run failed on graph %s", graph->GetName().c_str()); - return status; - } - for (const auto &subgraph :graph->GetAllSubgraphs()) { - GE_CHECK_NOTNULL(subgraph); - GE_CHK_STATUS_RET(pass->ClearStatus(), "pass clear status failed for subgraph %s", subgraph->GetName().c_str()); - string subgraph_pass_name = pass_name + "::" + graph->GetName(); - PARSER_TIMESTAMP_START(PassRunSubgraph); - status = pass->Run(subgraph); - PARSER_TIMESTAMP_END(PassRunSubgraph, subgraph_pass_name.c_str()); - if (status == SUCCESS) { - not_changed = false; - } else if (status != NOT_CHANGED) { - GELOGE(status, "Pass Run failed on subgraph %s", subgraph->GetName().c_str()); - return status; - } - } - PARSER_TIMESTAMP_END(PassRun, pass_name.c_str()); - } - - return not_changed ? NOT_CHANGED : SUCCESS; -} - -PassManager::~PassManager() { - for (auto &pass_pair : names_to_graph_passes_) { - auto &pass = pass_pair.second; - GE_DELETE_NEW_SINGLE(pass); - } -} -} // namespace parser -} // namespace ge diff --git a/parser/common/pass_manager.h b/parser/common/pass_manager.h deleted file mode 100644 index b260248..0000000 --- a/parser/common/pass_manager.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARSER_COMMON_PASS_MANAGER_H_ -#define PARSER_COMMON_PASS_MANAGER_H_ - -#include - -#include "inc/graph_pass.h" - -using std::vector; - -namespace ge { -namespace parser { -/// -/// @ingroup domi_omg -/// @brief pass manager -/// @author -/// -class PassManager { -public: - /// - /// get graph passes - /// @author - /// - const vector> &GraphPasses() const; - - /// - /// Add graph pass - /// @param [in] pass Pass to be added, it will be destroyed when pass manager destroys. - /// @author - /// - Status AddPass(const string &pass_name, GraphPass *pass); - - /// - /// Optimize graph with added pass - /// @param [inout] graph graph to be optimized - /// @return SUCCESS optimize successfully - /// @return NOT_CHANGED not optimized - /// @return others optimize failed - /// @author - /// - Status Run(const ge::ComputeGraphPtr &graph); - - /// - /// Optimize graph with specified pass - /// @param [inout] graph graph to be optimized - /// @param [in] passes passes to be used - /// @return SUCCESS optimize successfully - /// @return NOT_CHANGED not optimized - /// @return others optimized failed - /// @author - /// - static Status Run(const ge::ComputeGraphPtr &graph, vector> &passes); - - ~PassManager(); - -private: - vector> names_to_graph_passes_; -}; -} // namespace parser -} // namespace ge -#endif // PARSER_COMMON_PASS_MANAGER_H_ diff --git a/parser/common/pre_checker.cc b/parser/common/pre_checker.cc index 91ea192..d4b4245 100644 --- a/parser/common/pre_checker.cc +++ b/parser/common/pre_checker.cc @@ -23,7 +23,6 @@ #include "framework/common/debug/ge_log.h" #include "omg/omg.h" #include "parser/common/op_parser_factory.h" -#include "parser/common/model_saver.h" #include "register/op_registry.h" namespace ge { @@ -56,7 +55,7 @@ void PreChecker::Init() { fmk_op_types_ = nullptr; // Currently only Caffe and tensorflow are supported - domi::FrameworkType fmk_type = GetParserContext().type; + domi::FrameworkType fmk_type = domi::GetContext().type; if (fmk_type == domi::CAFFE) fmk_op_types_ = &caffe_op_map; else if (fmk_type == domi::TENSORFLOW) @@ -119,8 +118,8 @@ FMK_FUNC_HOST_VISIBILITY Status PreChecker::CheckType(OpId id, bool is_tensorflo // If the user explicitly specifies the mapping relationship of the operator type through // the -- OP_name_map parameter, the type specified by the user is used. - auto op_map_iter = GetParserContext().op_conf_map.find(type); - if (op_map_iter != GetParserContext().op_conf_map.end()) { + auto op_map_iter = domi::GetContext().op_conf_map.find(type); + if (op_map_iter != domi::GetContext().op_conf_map.end()) { type = op_map_iter->second; } @@ -233,7 +232,7 @@ Status PreChecker::Save(string file) { } // Save JSON data to a file - GE_RETURN_WITH_LOG_IF_ERROR(ge::parser::ModelSaver::SaveJsonToFile(file.c_str(), model), "Save failed."); + GE_RETURN_WITH_LOG_IF_ERROR(ModelSaver::SaveJsonToFile(file.c_str(), model), "Save failed."); return SUCCESS; } diff --git a/parser/common/pre_checker.h b/parser/common/pre_checker.h index 12d3323..4b0c6fb 100644 --- a/parser/common/pre_checker.h +++ b/parser/common/pre_checker.h @@ -19,7 +19,7 @@ #include #include -#include "framework/omg/parser/parser_types.h" +#include "common/types.h" #include "omg/omg_inner_types.h" namespace ge { diff --git a/parser/common/proto_file_parser.cc b/parser/common/proto_file_parser.cc index 731ac8c..6eecf2b 100644 --- a/parser/common/proto_file_parser.cc +++ b/parser/common/proto_file_parser.cc @@ -27,7 +27,6 @@ #include "common/types.h" #include "common/util.h" #include "common/debug/log.h" -#include "parser/common/acl_graph_parser_util.h" #include "ge/ge_api_types.h" #include "framework/common/debug/ge_log.h" @@ -159,7 +158,7 @@ bool SaveIdentifierOpMapInfo(const string &line, std::map #include #include "common/debug/log.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/ge/ge_util.h" #include "common/op/ge_op_utils.h" #include "common/op_map.h" #include "common/util.h" @@ -38,6 +38,8 @@ FMK_FUNC_HOST_VISIBILITY OpRegistrationTbe *OpRegistrationTbe::Instance() { } bool OpRegistrationTbe::Finalize(const OpRegistrationData ®_data, bool is_train) { + ge::OpTypeContainer::Instance()->Register(reg_data.GetOmOptype()); + static std::map *> op_map = {{CAFFE, &caffe_op_map}}; if (is_train) { op_map[domi::TENSORFLOW] = &tensorflow_train_op_map; @@ -55,7 +57,8 @@ bool OpRegistrationTbe::Finalize(const OpRegistrationData ®_data, bool is_tra continue; } else { (*fmk_op_map)[tmp] = reg_data.GetOmOptype(); - GELOGD("First register in parser initialize, original type: %s, om_optype: %s, imply type: %s.", tmp.c_str(), + GELOGD("First register in parser initilize, original type: %s, om_optype: %s, imply type: %s.", + tmp.c_str(), reg_data.GetOmOptype().c_str(), TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str()); } } @@ -79,7 +82,7 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { return false; } std::shared_ptr tf_parser_adapter = - ge::parser::MakeShared(); + ge::MakeShared(); if (tf_parser_adapter == nullptr) { GELOGE(PARAM_INVALID, "Create tf parser adapter failed."); return false; @@ -94,20 +97,22 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { return false; } GELOGI("Register fusion custom op parser: %s", reg_data.GetOmOptype().c_str()); - std::shared_ptr tf_fusion_parser_adapter = - ge::parser::MakeShared(); + std::shared_ptr + tf_fusion_parser_adapter = ge::MakeShared(); if (tf_fusion_parser_adapter == nullptr) { GELOGE(PARAM_INVALID, "Create tf fusion parser adapter failed."); return false; } OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( domi::TENSORFLOW, reg_data.GetOmOptype(), - [=]() -> std::shared_ptr { return tf_fusion_parser_adapter; }, true); + [=]() -> std::shared_ptr { return tf_fusion_parser_adapter; }, + true); } } else { std::shared_ptr factory = OpParserFactory::Instance(reg_data.GetFrameworkType()); if (factory == nullptr) { - GELOGE(INTERNAL_ERROR, "Get op parser factory for %s failed.", + GELOGE(INTERNAL_ERROR, + "Get op parser factory for %s failed.", TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); return false; } @@ -119,12 +124,14 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { PARSER_CREATOR_FN func = CustomParserAdapterRegistry::Instance()->GetCreateFunc(reg_data.GetFrameworkType()); if (func == nullptr) { - GELOGE(INTERNAL_ERROR, "Get custom parser adapter failed for fmk type %s.", + GELOGE(INTERNAL_ERROR, + "Get custom parser adapter failed for fmk type %s.", TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); return false; } OpParserFactory::Instance(reg_data.GetFrameworkType())->RegisterCreator(reg_data.GetOmOptype(), func); - GELOGD("Register custom parser adapter for op %s of fmk type %s success.", reg_data.GetOmOptype().c_str(), + GELOGD("Register custom parser adapter for op %s of fmk type %s success.", + reg_data.GetOmOptype().c_str(), TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); } return true; diff --git a/parser/common/tbe_plugin_loader.cc b/parser/common/tbe_plugin_loader.cc deleted file mode 100644 index 82c06eb..0000000 --- a/parser/common/tbe_plugin_loader.cc +++ /dev/null @@ -1,212 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tbe_plugin_loader.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/util/error_manager/error_manager.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/string_util.h" -#include "framework/omg/parser/parser_inner_ctx.h" -#include "graph/utils/type_utils.h" -#include "parser/common/acl_graph_parser_util.h" - -namespace ge { -std::map TBEPluginLoader::options_ = {}; - -namespace { -const std::string FRAMEWORK_TYPE = "ge.frameworkType"; -} - -// Get Singleton Instance -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEPluginLoader &TBEPluginLoader::Instance() { - static TBEPluginLoader instance_ptr_; - return instance_ptr_; -} - -Status TBEPluginLoader::ClearHandles_() { - Status ret = SUCCESS; - for (const auto &handle : handles_vec_) { - if (dlclose(handle) != 0) { - ret = FAILED; - GELOGW("Failed to close handle: %s", dlerror()); - } - } - handles_vec_.clear(); - return ret; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status TBEPluginLoader::Finalize() { - Status ret = ClearHandles_(); - return ret; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginLoader::LoadPluginSo( - const std::map &options) { - vector file_list; - string caffe_parser_path; - std::string plugin_path; - - options_ = options; - GetCustomOpPath(plugin_path); - - // Whether there are files in the plugin so path - GetPluginSoFileList(plugin_path, file_list, caffe_parser_path); - - // No file - if (file_list.empty()) { - // Print log - GELOGW("Can not find any plugin file in plugin_path: %s", plugin_path.c_str()); - } - - GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted."); - - // Load other so files except lib_caffe_parser.so in the plugin so path - for (auto elem : file_list) { - StringUtils::Trim(elem); - - void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL | RTLD_NODELETE); - if (handle == nullptr) { - GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); - } else if (find(handles_vec_.begin(), handles_vec_.end(), handle) == handles_vec_.end()) { - // Close dl when the program exist, not close here - GELOGI("Plugin load %s success.", elem.c_str()); - handles_vec_.push_back(handle); - } else { - GELOGI("Plugin so has already been loaded, no need to load again."); - } - } -} - -void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) { - GELOGI("Enter get custom op path schedule"); - std::string fmk_type; - domi::FrameworkType type = domi::TENSORFLOW; - auto it = options_.find(FRAMEWORK_TYPE); - if (it != options_.end()) { - type = static_cast(std::strtol(it->second.c_str(), nullptr, 10)); - } - fmk_type = ge::TypeUtils::FmkTypeToSerialString(type); - GELOGI("Framework type is %s.", fmk_type.c_str()); - - const char *path_env = std::getenv("ASCEND_OPP_PATH"); - if (path_env != nullptr) { - std::string path = path_env; - customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type); - GELOGI("Get custom so path from env : %s", path_env); - return; - } - std::string path_base = GetPath(); - GELOGI("path_base is %s", path_base.c_str()); - path_base = path_base.substr(0, path_base.rfind('/')); - path_base = path_base.substr(0, path_base.rfind('/') + 1); - customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type); -} - -string TBEPluginLoader::GetPath() { - Dl_info dl_info; - if (dladdr(reinterpret_cast(&TBEPluginLoader::GetPath), &dl_info) == 0) { - GELOGW("Failed to read so path!"); - return string(); - } else { - string so_path = dl_info.dli_fname; - char path[PATH_MAX] = {0}; - if (so_path.length() >= PATH_MAX) { - GELOGW("File path is too long!"); - return string(); - } - if (realpath(so_path.c_str(), path) == nullptr) { - GELOGW("Failed to get realpath of %s", so_path.c_str()); - return string(); - } - - so_path = path; - so_path = so_path.substr(0, so_path.rfind('/') + 1); - return so_path; - } -} - -void TBEPluginLoader::GetPluginSoFileList(const string &path, vector &file_list, string &caffe_parser_path) { - // Support to split multiple so directories by ":" - vector v_path = StringUtils::Split(path, ':'); - for (size_t i = 0; i < v_path.size(); ++i) { - FindParserSo(v_path[i], file_list, caffe_parser_path); - GELOGI("CustomOpLib full name = %s", v_path[i].c_str()); - } -} - -void TBEPluginLoader::FindParserSo(const string &path, vector &file_list, string &caffe_parser_path) { - // Path, change to absolute path - string real_path = ge::parser::RealPath(path.c_str()); - // Plugin path does not exist - if (real_path.empty()) { - GELOGW("RealPath is empty."); - return; - } - struct stat stat_buf; - if ((stat(real_path.c_str(), &stat_buf) != 0) || (!S_ISDIR(stat_buf.st_mode))) { - GELOGW("%s is not a dir.", real_path.c_str()); - return; - } - struct dirent *dent(0); - DIR *dir = opendir(real_path.c_str()); - // Plugin path does not exist - if (dir == nullptr) { - GELOGW("Open directory %s failed.", real_path.c_str()); - return; - } - - while ((dent = readdir(dir)) != nullptr) { - if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) continue; - string name = dent->d_name; - string full_name = real_path + "/" + name; - const string so_suff = ".so"; - const string caffe_parser_so_suff = "lib_caffe_parser.so"; - const string aicpu_so_suff = "_aicpu.so"; - const string aicpu_host_so_suff = "_online.so"; - if (name.size() >= so_suff.size() && name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { - ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff, - aicpu_host_so_suff); - } else { - FindParserSo(full_name, file_list, caffe_parser_path); - } - } - closedir(dir); -} - -void TBEPluginLoader::ProcessSoFullName(vector &file_list, string &caffe_parser_path, string &full_name, - const string &caffe_parser_so_suff, const string &aicpu_so_suff, - const string &aicpu_host_so_suff) { - if (full_name.size() >= caffe_parser_so_suff.size() && - full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(), - caffe_parser_so_suff) == 0) { - caffe_parser_path = full_name; - } else { - // Save parser so path into file_list vector - file_list.push_back(full_name); - } -} -} // namespace ge diff --git a/parser/common/tbe_plugin_loader.h b/parser/common/tbe_plugin_loader.h deleted file mode 100644 index 1cd6f6b..0000000 --- a/parser/common/tbe_plugin_loader.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARSER_COMMON_TBE_PLUGIN_LOADER_H_ -#define PARSER_COMMON_TBE_PLUGIN_LOADER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "external/ge/ge_api_error_codes.h" -#include "external/register/register.h" - -namespace ge { -using SoHandlesVec = std::vector; -class TBEPluginLoader { -public: - Status Finalize(); - - // Get TBEPluginManager singleton instance - static TBEPluginLoader& Instance(); - - void LoadPluginSo(const std::map &options); - - static string GetPath(); - -private: - TBEPluginLoader() = default; - ~TBEPluginLoader() = default; - Status ClearHandles_(); - static void ProcessSoFullName(vector &file_list, string &caffe_parser_path, string &full_name, - const string &caffe_parser_so_suff, const string &aicpu_so_suff, - const string &aicpu_host_so_suff); - static void GetCustomOpPath(std::string &customop_path); - static void GetPluginSoFileList(const string &path, vector &file_list, string &caffe_parser_path); - static void FindParserSo(const string &path, vector &file_list, string &caffe_parser_path); - - SoHandlesVec handles_vec_; - static std::map options_; -}; -} // namespace ge - -#endif //PARSER_COMMON_TBE_PLUGIN_LOADER_H_ diff --git a/parser/common/thread_pool.cc b/parser/common/thread_pool.cc deleted file mode 100644 index dead012..0000000 --- a/parser/common/thread_pool.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/thread_pool.h" - -#include -#include -#include -#include -#include -#include - -#include "register/register_types.h" - -namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::ThreadPool(uint32_t size) : is_stoped_(false) { - idle_thrd_num_ = size < 1 ? 1 : size; - - for (uint32_t i = 0; i < idle_thrd_num_; ++i) { - pool_.emplace_back(ThreadFunc, this); - } -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::~ThreadPool() { - is_stoped_.store(true); - { - std::unique_lock lock{m_lock_}; - cond_var_.notify_all(); - } - - for (std::thread &thd : pool_) { - if (thd.joinable()) { - try { - thd.join(); - } catch (const std::system_error &) { - GELOGW("system_error"); - } catch (...) { - GELOGW("exception"); - } - } - } -} - -void ThreadPool::ThreadFunc(ThreadPool *thread_pool) { - if (thread_pool == nullptr) { - return; - } - while (!thread_pool->is_stoped_) { - std::function task; - { - std::unique_lock lock{thread_pool->m_lock_}; - thread_pool->cond_var_.wait( - lock, [thread_pool] { return thread_pool->is_stoped_.load() || !thread_pool->tasks_.empty(); }); - if (thread_pool->is_stoped_ && thread_pool->tasks_.empty()) { - return; - } - task = std::move(thread_pool->tasks_.front()); - thread_pool->tasks_.pop(); - } - --thread_pool->idle_thrd_num_; - task(); - ++thread_pool->idle_thrd_num_; - } -} -} // namespace ge diff --git a/parser/common/thread_pool.h b/parser/common/thread_pool.h deleted file mode 100644 index 2662a9f..0000000 --- a/parser/common/thread_pool.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARSER_COMMON_THREAD_POOL_H_ -#define PARSER_COMMON_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "framework/common/debug/ge_log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "external/ge/ge_api_error_codes.h" -#include "graph/types.h" -#include "parser/common/acl_graph_parser_util.h" - -namespace ge { -using ThreadTask = std::function; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ThreadPool { - public: - explicit ThreadPool(uint32_t size = 4); - ~ThreadPool(); - - template - auto commit(Func &&func, Args &&... args) -> std::future { - GELOGD("commit run task enter."); - using retType = decltype(func(args...)); - std::future fail_future; - if (is_stoped_.load()) { - GELOGE(ge::FAILED, "thread pool has been stopped."); - return fail_future; - } - - auto bindFunc = std::bind(std::forward(func), std::forward(args)...); - auto task = ge::parser::MakeShared>(bindFunc); - if (task == nullptr) { - GELOGE(ge::FAILED, "Make shared failed."); - return fail_future; - } - std::future future = task->get_future(); - { - std::lock_guard lock{m_lock_}; - tasks_.emplace([task]() { (*task)(); }); - } - cond_var_.notify_one(); - GELOGD("commit run task end"); - return future; - } - - static void ThreadFunc(ThreadPool *thread_pool); - - private: - std::vector pool_; - std::queue tasks_; - std::mutex m_lock_; - std::condition_variable cond_var_; - std::atomic is_stoped_; - std::atomic idle_thrd_num_; -}; -} // namespace ge - -#endif // PARSER_COMMON_THREAD_POOL_H_ diff --git a/parser/func_to_graph/proto_python_rule.mk b/parser/func_to_graph/proto_python_rule.mk index cf442d5..ece9732 100644 --- a/parser/func_to_graph/proto_python_rule.mk +++ b/parser/func_to_graph/proto_python_rule.mk @@ -2,7 +2,7 @@ include $(BUILD_SYSTEM)/base_rules.mk FUNCTION_TO_GRAPH_OUT_TIMESTAMP := $(HOST_OUT_ROOT)/func_to_graph/.timestamp -PROTO_SRC_DIR = parser/parser/func_to_graph/proto +PROTO_SRC_DIR = framework/domi/parser/func_to_graph/proto PY_PROTO_BUILD_DIR = $(HOST_OUT_ROOT)/tmp/function_to_graph/proto $(warning PRIVATE_PROTOC is $(PRIVATE_PROTOC)) @@ -14,4 +14,4 @@ $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP): $(PRIVATE_PROTOC) $(LOCAL_BUILT_MODULE): $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP) mkdir -p $@ - cp -rf $(PY_PROTO_BUILD_DIR)/* $@ + cp -rf $(PY_PROTO_BUILD_DIR)/* $@ \ No newline at end of file diff --git a/parser/module.mk b/parser/module.mk index 502d44b..3a6348a 100644 --- a/parser/module.mk +++ b/parser/module.mk @@ -1,6 +1,6 @@ LOCAL_PATH := $(call my-dir) -include $(LOCAL_PATH)/stub/Makefile +include $(LOCAL_PATH)/../stub/Makefile COMMON_LOCAL_C_INCLUDES := \ proto/om.proto \ proto/insert_op.proto \ @@ -39,9 +39,7 @@ COMMON_LOCAL_C_INCLUDES := \ $(TOPDIR)inc/external/graph \ $(TOPDIR)inc/external/parser \ $(TOPDIR)inc/framework \ - $(TOPDIR)parser/parser \ - $(TOPDIR)parser \ - $(TOPDIR)graphengine/ge \ + $(TOPDIR)framework/domi/parser \ libc_sec/include \ third_party/protobuf/include \ third_party/json/include \ @@ -115,6 +113,7 @@ LOCAL_SHARED_LIBRARIES := \ libparser_common \ libgraph \ libregister \ + libge_common \ lib_caffe_parser \ LOCAL_LDFLAGS := -lrt @@ -134,8 +133,8 @@ endif LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := ../../out/parser/lib64/stub/tensorflow_parser.cc -LOCAL_SRC_FILES += ../../out/parser/lib64/stub/caffe_parser.cc +LOCAL_SRC_FILES := ../../../out/ge/lib64/stub/tensorflow_parser.cc +LOCAL_SRC_FILES += ../../../out/ge/lib64/stub/caffe_parser.cc LOCAL_SHARED_LIBRARIES := diff --git a/parser/onnx/CMakeLists.txt b/parser/onnx/CMakeLists.txt index 26fcce4..84d1178 100644 --- a/parser/onnx/CMakeLists.txt +++ b/parser/onnx/CMakeLists.txt @@ -29,24 +29,12 @@ target_include_directories(fmk_onnx_parser PRIVATE ${PARSER_DIR} ${PARSER_DIR}/inc ${PARSER_DIR}/parser - ${PARSER_DIR}/../ge - ${PARSER_DIR}/../inc - ${PARSER_DIR}/../inc/common/util - ${PARSER_DIR}/../inc/framework - ${PARSER_DIR}/../inc/external - ${PARSER_DIR}/../third_party/fwkacllib/inc ${METADEF_DIR}/inc ${METADEF_DIR}/inc/graph ${METADEF_DIR}/inc/register ${METADEF_DIR}/inc/external ${METADEF_DIR}/inc/external/graph ${METADEF_DIR}/inc/external/register - #### independent compile ##### - ${METADEF_DIR}/third_party/graphengine/ge - ${METADEF_DIR}/third_party/graphengine/inc - ${METADEF_DIR}/third_party/graphengine/inc/framework - ${METADEF_DIR}/third_party/graphengine/inc/external - ${METADEF_DIR}/third_party/fwkacllib/inc #### temp #### ${PARSER_DIR}/../graphengine/inc/common/util ${PARSER_DIR}/../graphengine/inc/external @@ -55,6 +43,19 @@ target_include_directories(fmk_onnx_parser PRIVATE ${PARSER_DIR}/../graphengine/ge ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/ge + #### blue zone compile ##### + ${PARSER_DIR}/../ge + ${PARSER_DIR}/../inc + ${PARSER_DIR}/../inc/common/util + ${PARSER_DIR}/../inc/framework + ${PARSER_DIR}/../inc/external + ${PARSER_DIR}/../third_party/fwkacllib/inc + #### independent compile ##### + ${METADEF_DIR}/third_party/graphengine/ge + ${METADEF_DIR}/third_party/graphengine/inc + ${METADEF_DIR}/third_party/graphengine/inc/framework + ${METADEF_DIR}/third_party/graphengine/inc/external + ${METADEF_DIR}/third_party/fwkacllib/inc ) target_link_libraries(fmk_onnx_parser PRIVATE diff --git a/parser/onnx/module.mk b/parser/onnx/module.mk index 531be98..aee731f 100644 --- a/parser/onnx/module.mk +++ b/parser/onnx/module.mk @@ -29,9 +29,7 @@ LOCAL_C_INCLUDES := \ $(TOPDIR)inc/external \ $(TOPDIR)inc/external/graph \ $(TOPDIR)inc/framework \ - $(TOPDIR)parser \ - $(TOPDIR)parser/parser \ - $(TOPDIR)graphengine/ge \ + $(TOPDIR)framework/domi/parser \ libc_sec/include \ third_party/protobuf/include \ third_party/json/include \ @@ -45,6 +43,7 @@ LOCAL_SHARED_LIBRARIES := \ libparser_common \ libgraph \ libregister \ + libge_common \ LOCAL_LDFLAGS := -lrt diff --git a/parser/onnx/onnx_constant_parser.cc b/parser/onnx/onnx_constant_parser.cc index 479e52f..302aef9 100644 --- a/parser/onnx/onnx_constant_parser.cc +++ b/parser/onnx/onnx_constant_parser.cc @@ -17,7 +17,7 @@ #include "onnx_constant_parser.h" #include #include -#include "parser/common/acl_graph_parser_util.h" +#include "common/ge/ge_util.h" #include "common/util.h" #include "framework/omg/parser/parser_inner_ctx.h" #include "graph/ge_tensor.h" @@ -30,7 +30,6 @@ using ge::onnx::TensorProto; using domi::ONNX; using GeShape = ge::GeShape; using GeTensorDesc = ge::GeTensorDesc; -using namespace ge::parser; namespace ge { Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count) { diff --git a/parser/onnx/onnx_data_parser.cc b/parser/onnx/onnx_data_parser.cc index 7b396b7..47097fa 100644 --- a/parser/onnx/onnx_data_parser.cc +++ b/parser/onnx/onnx_data_parser.cc @@ -22,7 +22,6 @@ #include "parser/onnx/onnx_util.h" using domi::ONNX; -using namespace ge::parser; namespace ge { Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) { diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 25e6f8c..627e158 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -18,25 +18,23 @@ #include #include #include "common/convert/pb2json.h" +#include "common/model_saver.h" #include "common/util.h" #include "external/graph/operator_factory.h" #include "external/register/register_error_codes.h" #include "framework/omg/parser/parser_inner_ctx.h" -#include "framework/omg/parser/parser_types.h" #include "omg/parser/parser_factory.h" #include "onnx_op_parser.h" #include "onnx_util.h" #include "parser/common/op_parser_factory.h" #include "parser/common/pre_checker.h" -#include "parser/common/acl_graph_parser_util.h" -#include "parser/common/model_saver.h" #include "parser/onnx/onnx_util.h" #include "register/op_registry.h" namespace ge { namespace { std::map kOnnxOpMap = { - {ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT}, + {ge::kOpTypeInput, ge::DATA}, {ge::kOpTypeConstant, ge::CONSTANT}, }; } @@ -438,7 +436,7 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { // 1. Get graph from onnx model file. ge::onnx::ModelProto onnx_model; - if (!ge::parser::ReadProtoFromBinaryFile(file, &onnx_model)) { + if (!ge::ReadProtoFromBinaryFile(file, &onnx_model)) { GELOGE(PARAM_INVALID, "Read onnx model file failed."); return FAILED; } @@ -552,12 +550,12 @@ Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) { } ge::onnx::ModelProto onnx_model; - GE_RETURN_WITH_LOG_IF_FALSE(ge::parser::ReadProtoFromBinaryFile(model_file, &onnx_model), + GE_RETURN_WITH_LOG_IF_FALSE(ge::ReadProtoFromBinaryFile(model_file, &onnx_model), "ReadProtoFromBinaryFile failed, file:%s.", model_file); ge::onnx::GraphProto graph_proto = onnx_model.graph(); nlohmann::json j; ge::Pb2Json::Message2Json(graph_proto, std::set(), j, true); - return ge::parser::ModelSaver::SaveJsonToFile(json_file, j); + return ge::ModelSaver::SaveJsonToFile(json_file, j); } ge::DataType OnnxModelParser::ConvertToGeDataType(const uint32_t type) { diff --git a/parser/tensorflow/graph_functiondef.cc b/parser/tensorflow/graph_functiondef.cc index 0a242f8..3bba92b 100644 --- a/parser/tensorflow/graph_functiondef.cc +++ b/parser/tensorflow/graph_functiondef.cc @@ -18,8 +18,7 @@ #include #include "common/fmk_error_codes.h" #include "graph/debug/ge_attr_define.h" -#include "framework/omg/parser/parser_types.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/types.h" #include "common/types_map.h" #include "common/util.h" #include "graph/op_desc.h" @@ -219,7 +218,7 @@ domi::Status GraphToFunctionDef::RecordResult(ge::ComputeGraphPtr graph, string op_name = anchor->GetOwnerNode()->GetName() + "_" + to_string(anchor->GetIdx()) + "_retval"; ge::OpDescPtr op = nullptr; - GE_MAKE_SHARED(op = std::make_shared(op_name, ge::parser::NETOUTPUT), return FAILED); + GE_MAKE_SHARED(op = std::make_shared(op_name, ge::NETOUTPUT), return FAILED); graphStatus status = op->AddInputDesc(ge::GeTensorDesc()); if (status != GRAPH_SUCCESS) { GELOGE(FAILED, "Add input desc for op:%s failed.", op->GetName().c_str()); @@ -282,7 +281,7 @@ domi::Status GraphToFunctionDef::RecordArg(ge::ComputeGraphPtr graph, const vect string op_name = anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName() + "_" + to_string(anchor->GetPeerOutAnchor()->GetIdx()) + "_arg"; ge::OpDescPtr op = nullptr; - GE_MAKE_SHARED(op = std::make_shared(op_name, ge::parser::DATA), return FAILED); + GE_MAKE_SHARED(op = std::make_shared(op_name, ge::DATA), return FAILED); graphStatus status = op->AddOutputDesc(ge::GeTensorDesc()); if (status != GRAPH_SUCCESS) { GELOGE(FAILED, "Add output desc for op:%s failed.", op->GetName().c_str()); @@ -330,7 +329,7 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph for (const ge::NodePtr &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); - if (node->GetOpDesc()->GetType() == ge::parser::DATA) { + if (node->GetOpDesc()->GetType() == ge::DATA) { int64_t index = 0; int64_t type = 1; @@ -351,7 +350,7 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph continue; } - if (node->GetOpDesc()->GetType() == ge::parser::NETOUTPUT) { + if (node->GetOpDesc()->GetType() == ge::NETOUTPUT) { int64_t index = 0; int64_t type = 1; @@ -475,7 +474,7 @@ domi::Status GraphToFunctionDef::BuildFunctionDef(ge::ComputeGraphPtr &graph, co GE_CHECK_NOTNULL(library); GE_CHECK_NOTNULL(call_node_def); // Current date / time base on the current system - string now_time = ge::parser::CurrentTimeInStr(); + string now_time = ge::CurrentTimeInStr(); static int i = 0; const string name = name_in + now_time + to_string(i); i++; diff --git a/parser/tensorflow/graph_insert_trans_op.h b/parser/tensorflow/graph_insert_trans_op.h index abd6c2a..baf37df 100644 --- a/parser/tensorflow/graph_insert_trans_op.h +++ b/parser/tensorflow/graph_insert_trans_op.h @@ -21,7 +21,7 @@ #include #include "common/fmk_types.h" #include "common/op/ge_op_utils.h" -#include "framework/omg/parser/parser_types.h" +#include "common/types.h" #include "graph/compute_graph.h" #include "graph/node.h" #include "graph/types.h" diff --git a/parser/tensorflow/graph_optimizer.cc b/parser/tensorflow/graph_optimizer.cc index 373344f..86c041b 100644 --- a/parser/tensorflow/graph_optimizer.cc +++ b/parser/tensorflow/graph_optimizer.cc @@ -23,13 +23,13 @@ #include "cce/cce.h" #include "cce/dnn.h" #include "common/debug/log.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/math/math_util.h" #include "common/op/ge_op_utils.h" #include "common/op_map.h" +#include "common/types.h" #include "common/types_map.h" +#include "common/util.h" #include "framework/common/debug/ge_log.h" -#include "framework/omg/parser/parser_inner_ctx.h" -#include "framework/omg/parser/parser_types.h" #include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" @@ -39,7 +39,6 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "graph_functiondef.h" -#include "parser/common/acl_graph_parser_util.h" #include "proto/tensorflow/attr_value.pb.h" #include "register/op_registry.h" @@ -92,137 +91,117 @@ const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map g_OpSupportTranInfo = {}; -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::MUL, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::MUL, std::vector({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}), InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::L2LOSS, +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::L2LOSS, std::vector({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NC1HWC0, ge::FORMAT_NHWC, ge::FORMAT_HWCN}), // inputformats ge::DT_FLOAT, ge::FORMAT_NC1HWC0, ge::DT_FLOAT) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONVGRADFILTER, InFmtSupportUndefined, InDtSupportUndefined, - ge::FORMAT_FRACTAL_Z, ge::DT_FLOAT) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONV2DBACKPROPINPUT, InFmtSupportUndefined, InDtSupportUndefined, - ge::FORMAT_NC1HWC0, ge::DT_FLOAT16) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::BIASADDGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CONVGRADFILTER, InFmtSupportUndefined, InDtSupportUndefined, ge::FORMAT_FRACTAL_Z, ge::DT_FLOAT) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::BIASADD, ge::FORMAT_NCHW, ge::DT_FLOAT, ge::FORMAT_NCHW, ge::DT_FLOAT) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CONV2DBACKPROPINPUT, InFmtSupportUndefined, InDtSupportUndefined, + ge::FORMAT_NC1HWC0, ge::DT_FLOAT16) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::BIASADDGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::DT_FLOAT) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::BIASADD, ge::FORMAT_NCHW, ge::DT_FLOAT, ge::FORMAT_NCHW, ge::DT_FLOAT) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ACTIVATION, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, - ge::DT_FLOAT16) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ACTIVATIONGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, - ge::DT_FLOAT16) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::SOFTMAX, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ACTIVATION, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ACTIVATIONGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SOFTMAX, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::SOFTMAX, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SOFTMAX, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DBACKPROPFILTER, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DEPTHWISECONV2DBACKPROPFILTER, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_C1HWNCoC0, ge::DT_FLOAT) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DBACKPORPINPUT, InFmtSupportUndefined, InDtSupportUndefined, - OutFmtSupportAsInput, OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DFORWARDNATIVE, InFmtSupportUndefined, InDtSupportUndefined, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DEPTHWISECONV2DBACKPORPINPUT, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput, OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::FUSEDBATCHNORM, InFmtSupportUndefined, InDtSupportUndefined, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DEPTHWISECONV2DFORWARDNATIVE, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput, OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::FUSEDBATCHNORMGRAD, InFmtSupportUndefined, InDtSupportUndefined, +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::FUSEDBATCHNORM, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput, + OutDtSupportUndefined) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::FUSEDBATCHNORMGRAD, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput, OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONV2D, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CONV2D, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput, OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::RESHAPE, ge::FORMAT_NHWC, InDtSupportAll, ge::FORMAT_NHWC, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, InFmtSupport5D, ge::DT_FLOAT16, +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::RESHAPE, ge::FORMAT_NHWC, InDtSupportAll, ge::FORMAT_NHWC, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, InFmtSupport5D, ge::DT_FLOAT16, OutFmtSupportAsInput, OutDtSupportAsInput) TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MAXIMUM_GRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::APPLYRMSPROP, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::APPLYRMSPROP, std::vector({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}), ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DROPOUTDOMASK, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::LOG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQRTGRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SIGMOIDGRAD, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SIGMOID, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ARGMAX, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::AVGPOOLGRAD, InFmtSupport5D, ge::DT_FLOAT16, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::NEG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::RECIPROCAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DROPOUTDOMASK, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQUARE, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::LOG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SQRTGRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SIGMOIDGRAD, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SUB, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SIGMOID, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::ARGMAX, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::AVGPOOLGRAD, InFmtSupport5D, ge::DT_FLOAT16, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SUM, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::NEG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::RECIPROCAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SQUARE, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SUB, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SUM, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MATMUL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GATHERV2, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::GATHERV2, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::GREATEREQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GREATEREQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::REALDIV, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQRT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::STRIDEDSLICE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::TILE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::REALDIV, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SQRT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::STRIDEDSLICE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TILE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TFRELU6, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::RELU6GRAD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::EQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GREATER, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SELECT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::RELU6GRAD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::EQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::GREATER, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SELECT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_BATCH_MATMUL, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::TRANSPOSE, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput, +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TRANSPOSE, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::STREAMMERGE, +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::STREAMMERGE, std::vector({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}), InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::MEMCPYASYNC, +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::MEMCPYASYNC, std::vector({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}), InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) bool GetCceTbeTransInfo(string opType, OpSupportTranInfo &opSupportInfo) { static bool fmtInited = false; GE_IF_BOOL_EXEC( - !fmtInited, fmtInited = true; - if (domi::OpRegistry().Instance()->GetImplyType(ge::parser::DEPTHWISEWEIGHT4D26D) == domi::ImplyType::TVM) { - auto it = g_OpSupportTranInfo.find(string("TBE:") + ge::parser::MUL); - if (it != g_OpSupportTranInfo.end()) { - auto &fmts = it->second.inputFormats; - auto itFmt = std::find(fmts.begin(), fmts.end(), ge::FORMAT_NC1HWC0); - fmts.erase(itFmt); - } - }) + !fmtInited, fmtInited = true; + if (domi::OpRegistry().Instance()->GetImplyType(ge::DEPTHWISEWEIGHT4D26D) == domi::ImplyType::TVM) { + auto it = g_OpSupportTranInfo.find(string("TBE:") + ge::MUL); + if (it != g_OpSupportTranInfo.end()) { + auto &fmts = it->second.inputFormats; + auto itFmt = std::find(fmts.begin(), fmts.end(), ge::FORMAT_NC1HWC0); + fmts.erase(itFmt); + } + }) string cceTbeOpType = "TBE"; GE_IF_BOOL_EXEC(domi::OpRegistry().Instance()->GetImplyType(opType) == domi::ImplyType::BUILDIN, cceTbeOpType = "CCE";) @@ -807,7 +786,7 @@ Status CreateNodeDefBytes(ge::NodePtr n, string originalType, mapGetShape().GetDimNum(); ++j) { tmp_dim = ge_desc->GetShape().GetDim(j); GE_CHECK_GE(tmp_dim, 0); - PARSER_INT64_MULCHECK(real_size, tmp_dim); + FMK_INT64_MULCHECK(real_size, tmp_dim); real_size *= tmp_dim; } ge::TensorUtils::SetSize(*ge_desc, real_size * size_type); @@ -1198,7 +1177,7 @@ Status CreateFuncDefBytes(ge::NodePtr n, string original_type, string func_bin_p char *buf = nullptr; int32_t len = 0; - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::parser::ReadBytesFromBinaryFile(file.c_str(), &buf, len), return false, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::ReadBytesFromBinaryFile(file.c_str(), &buf, len), return false, "read bytes file error!"); GELOGI("len =%d\n", len); @@ -1229,7 +1208,7 @@ Status ParserGraphOptimizer::MakeTfProtoDef() { CreateIOListFuncMap(mOpIOListFuncMap); for (ge::NodePtr n : graph_->GetDirectNode()) { - if (n->GetType() != ge::parser::FRAMEWORKOP) continue; + if (n->GetType() != ge::FRAMEWORKOP) continue; std::string original_type; GE_LOGI_IF(ge::AttrUtils::GetStr(n->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type) != true, "get original type failed."); @@ -1290,9 +1269,9 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map bool hasGetNext = false; for (auto node : graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node); - GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); + GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::FRAMEWORK_OP_TYPE, continue); string type = ""; - GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); + GE_CHK_STATUS_RET(GetOriginalType(node, type)); if (type == "IteratorGetNext") { hasGetNext = true; break; @@ -1300,9 +1279,9 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map } for (auto node : graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node); - GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) + GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::FRAMEWORK_OP_TYPE, continue) string type = ""; - GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); + GE_CHK_STATUS_RET(GetOriginalType(node, type)); if (type == "IteratorGetNext") { vector temp_node_cluser; for (auto in_anchor : node->GetAllInDataAnchors()) { @@ -1338,9 +1317,9 @@ Status ParserGraphOptimizer::FindFmkNodeCluser(unordered_mapGetOpDesc(); GE_CHECK_NOTNULL(temp_node_desc_ptr); - GE_IF_BOOL_EXEC(temp_node_desc_ptr->GetType() == ge::parser::DATA_TYPE, continue); + GE_IF_BOOL_EXEC(temp_node_desc_ptr->GetType() == ge::DATA_TYPE, continue); - if (temp_node_desc_ptr->GetType() == ge::parser::FRAMEWORK_OP_TYPE && + if (temp_node_desc_ptr->GetType() == ge::FRAMEWORK_OP_TYPE && (temp_node_desc_ptr->GetName().find(RRTVAL_NODE_NAME_SUFFIX) == string::npos)) { temp_node_cluser.push_back(node); } else { @@ -1421,7 +1400,7 @@ Status ParserGraphOptimizer::UpdateGraph(vector &nodes) { return FAILED); std::string type = ""; - GE_CHK_STATUS_RET(ge::parser::GetOriginalType(nodes[0], type)); + GE_CHK_STATUS_RET(GetOriginalType(nodes[0], type)); (void)AttrUtils::SetStr(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); (void)AttrUtils::SetZeroCopyBytes( @@ -1431,7 +1410,7 @@ Status ParserGraphOptimizer::UpdateGraph(vector &nodes) { fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); - (void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, ge::GetParserContext().type); + (void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, domi::GetContext().type); // reconstruct fusion_node and edges GE_CHK_STATUS_RET(RebuildOutputAnchors(output_anchors, fusion_node_opdef), @@ -1481,19 +1460,17 @@ Status ParserGraphOptimizer::InsertNode(ge::ComputeGraphPtr sub_graph, vectorGetInControlAnchor(); - GE_IF_BOOL_EXEC( - node_in_control != nullptr, for (auto peer_out_anchor - : node_in_control->GetPeerOutControlAnchors()) { - vector::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode()); - GE_IF_BOOL_EXEC(iter == nodes.end(), input_control_anchors.emplace_back(node_in_control)); - }); + GE_IF_BOOL_EXEC(node_in_control != nullptr, for (auto peer_out_anchor + : node_in_control->GetPeerOutControlAnchors()) { + vector::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode()); + GE_IF_BOOL_EXEC(iter == nodes.end(), input_control_anchors.emplace_back(node_in_control)); + }); OutControlAnchorPtr node_out_control = node->GetOutControlAnchor(); - GE_IF_BOOL_EXEC( - node_out_control != nullptr, for (auto peer_in_control_anchor - : node_out_control->GetPeerInControlAnchors()) { - vector::iterator iter = find(nodes.begin(), nodes.end(), peer_in_control_anchor->GetOwnerNode()); - GE_IF_BOOL_EXEC(iter == nodes.end(), output_control_anchors.emplace_back(node_out_control)); - }); + GE_IF_BOOL_EXEC(node_out_control != nullptr, for (auto peer_in_control_anchor + : node_out_control->GetPeerInControlAnchors()) { + vector::iterator iter = find(nodes.begin(), nodes.end(), peer_in_control_anchor->GetOwnerNode()); + GE_IF_BOOL_EXEC(iter == nodes.end(), output_control_anchors.emplace_back(node_out_control)); + }); } return SUCCESS; } @@ -1518,19 +1495,18 @@ Status ParserGraphOptimizer::LinkInnerAnchor(unordered_map } InControlAnchorPtr node_in_control = node->GetInControlAnchor(); - GE_IF_BOOL_EXEC( - node_in_control != nullptr, for (auto peer_out_ctl_anchor - : node_in_control->GetPeerOutControlAnchors()) { - GE_IF_BOOL_EXEC(node_map.count(peer_out_ctl_anchor->GetOwnerNode()->GetName()) == 0, continue); - NodePtr src_ctrl = node_map[peer_out_ctl_anchor->GetOwnerNode()->GetName()]; - GE_IF_BOOL_EXEC( - ge::GraphUtils::AddEdge(src_ctrl->GetOutControlAnchor(), dst->GetInControlAnchor()) != GRAPH_SUCCESS, - GELOGE(FAILED, - "LinkInnerAnchor Link control anchor failed, src node: " - "%s, dst node: %s.", - src_ctrl->GetName().c_str(), dst->GetName().c_str()); - return FAILED); - }); + GE_IF_BOOL_EXEC(node_in_control != nullptr, for (auto peer_out_ctl_anchor + : node_in_control->GetPeerOutControlAnchors()) { + GE_IF_BOOL_EXEC(node_map.count(peer_out_ctl_anchor->GetOwnerNode()->GetName()) == 0, continue); + NodePtr src_ctrl = node_map[peer_out_ctl_anchor->GetOwnerNode()->GetName()]; + GE_IF_BOOL_EXEC( + ge::GraphUtils::AddEdge(src_ctrl->GetOutControlAnchor(), dst->GetInControlAnchor()) != GRAPH_SUCCESS, + GELOGE(FAILED, + "LinkInnerAnchor Link control anchor failed, src node: " + "%s, dst node: %s.", + src_ctrl->GetName().c_str(), dst->GetName().c_str()); + return FAILED); + }); } return SUCCESS; } @@ -1881,24 +1857,24 @@ OpDescPtr ParserGraphOptimizer::CreateTranslateOp(enum ge::Format inFormat, enum static uint32_t transop_count = 0; OpDescPtr op_def = nullptr; std::stringstream sstmp; - sstmp << "translate_" << ge::parser::TRANSDATA << "_" << transop_count++; - GE_MAKE_SHARED(op_def = std::make_shared(sstmp.str().c_str(), ge::parser::TRANSLATE), op_def = nullptr; + sstmp << "translate_" << ge::TRANSDATA << "_" << transop_count++; + GE_MAKE_SHARED(op_def = std::make_shared(sstmp.str().c_str(), ge::TRANSLATE), op_def = nullptr; return op_def); GELOGI( - "create translate op:%s, input format:%s, input datatype:%s, output " - "format:%s, output datatype:%s.", - op_def->GetName().c_str(), ge::TypeUtils::FormatToSerialString(inFormat).c_str(), - ge::TypeUtils::DataTypeToSerialString(inDatatype).c_str(), ge::TypeUtils::FormatToSerialString(outFormat).c_str(), - ge::TypeUtils::DataTypeToSerialString(outDatatype).c_str()); - - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_INPUT_FORMAT, inFormat), return nullptr, - "SetInt ATTR_NAME_INPUT_FORMAT failed."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_INPUT_DATATYPE, inDatatype), return nullptr, - "SetInt ATTR_NAME_INPUT_DATATYPE failed."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_OUTPUT_FORMAT, outFormat), return nullptr, - "SetInt ATTR_NAME_INPUT_DATATYPE failed."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_OUTPUT_DATATYPE, outDatatype), return nullptr, - "SetInt ATTR_NAME_INPUT_DATATYPE failed."); + "create translate op:%s, input format:%s, input datatype:%s, output " + "format:%s, output datatype:%s.", + op_def->GetName().c_str(), ge::TypeUtils::FormatToSerialString(inFormat).c_str(), + ge::TypeUtils::DataTypeToSerialString(inDatatype).c_str(), ge::TypeUtils::FormatToSerialString(outFormat).c_str(), + ge::TypeUtils::DataTypeToSerialString(outDatatype).c_str()); + + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_INPUT_FORMAT, inFormat), + return nullptr, "SetInt ATTR_NAME_INPUT_FORMAT failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_INPUT_DATATYPE, inDatatype), + return nullptr, "SetInt ATTR_NAME_INPUT_DATATYPE failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_OUTPUT_FORMAT, outFormat), + return nullptr, "SetInt ATTR_NAME_INPUT_DATATYPE failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_OUTPUT_DATATYPE, outDatatype), + return nullptr, "SetInt ATTR_NAME_INPUT_DATATYPE failed."); if (inDatatype != ge::DT_FLOAT16) { GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddInputDesc(GeTensorDesc(GeShape(), inFormat)), return nullptr, "create translate op:add input desc fail."); @@ -1920,17 +1896,17 @@ OpDescPtr ParserGraphOptimizer::CreatePermuteOp(enum ge::Format input_format, en static uint32_t transop_count = 0; std::stringstream sstmp; - sstmp << "transdata_" << ge::parser::PERMUTE << "_" << transop_count++; + sstmp << "transdata_" << ge::PERMUTE << "_" << transop_count++; OpDescPtr op_desc = nullptr; - GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::parser::PERMUTE), op_desc = nullptr; + GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::PERMUTE), op_desc = nullptr; return op_desc); GELOGI("create permute op:%s", op_desc->GetName().c_str()); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), return nullptr, - "SetInt ATTR_NAME_INPUT_FORMAT failed."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), return nullptr, - "SetInt ATTR_NAME_OUTPUT_FORMAT failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), + return nullptr, "SetInt ATTR_NAME_INPUT_FORMAT failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), + return nullptr, "SetInt ATTR_NAME_OUTPUT_FORMAT failed."); GE_IF_BOOL_EXEC(input_format == FORMAT_NCHW, (void)AttrUtils::SetInt(op_desc, "NCHW_to_NHWC", (int64_t)1)); GE_IF_BOOL_EXEC(input_format == FORMAT_NHWC, (void)AttrUtils::SetInt(op_desc, "NHWC_to_NCHW", (int64_t)1)); @@ -1947,11 +1923,10 @@ OpDescPtr ParserGraphOptimizer::CreateCastOp(enum ge::DataType input_data_type, enum ge::Format format) { static uint32_t transop_count = 0; std::stringstream sstmp; - sstmp << "transdata_" << ge::parser::CAST << "_" << transop_count++; + sstmp << "transdata_" << ge::CAST << "_" << transop_count++; OpDescPtr op_desc = nullptr; - GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::parser::CAST), op_desc = nullptr; - return op_desc); + GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::CAST), op_desc = nullptr; return op_desc); GELOGI("create cast op:%s, input datatype:%s, out datatype:%s", op_desc->GetName().c_str(), ge::TypeUtils::DataTypeToSerialString(input_data_type).c_str(), ge::TypeUtils::DataTypeToSerialString(output_data_type).c_str()); @@ -1975,10 +1950,10 @@ OpDescPtr ParserGraphOptimizer::CreateCastOp(enum ge::DataType input_data_type, OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format) { static uint32_t transop_count = 0; std::stringstream sstmp; - sstmp << "transdata_" << ge::parser::TRANSDATA << "_" << transop_count++; + sstmp << "transdata_" << ge::TRANSDATA << "_" << transop_count++; OpDescPtr op_desc = nullptr; - GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::parser::TRANSDATA), op_desc = nullptr; + GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::TRANSDATA), op_desc = nullptr; return op_desc); GELOGI("create transdata op:%s, input format:%s.", op_desc->GetName().c_str(), @@ -1989,10 +1964,10 @@ OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format) output_format = FORMAT_NCHW; } - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), return nullptr, - "SetInt of ATTR_NAME_INPUT_FORMAT failed."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), return nullptr, - "SetInt of ATTR_NAME_OUTPUT_FORMAT failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), + return nullptr, "SetInt of ATTR_NAME_INPUT_FORMAT failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), + return nullptr, "SetInt of ATTR_NAME_OUTPUT_FORMAT failed."); GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddInputDesc(GeTensorDesc(GeShape(), input_format)), return nullptr, "create transdata op:add input desc fail."); GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddOutputDesc(GeTensorDesc(GeShape(), output_format)), return nullptr, @@ -2000,4 +1975,4 @@ OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format) return op_desc; } -} // namespace ge +} // namespace domi diff --git a/parser/tensorflow/graph_optimizer.h b/parser/tensorflow/graph_optimizer.h index 9f73d69..6b79deb 100644 --- a/parser/tensorflow/graph_optimizer.h +++ b/parser/tensorflow/graph_optimizer.h @@ -20,7 +20,7 @@ #include #include #include -#include "framework/omg/parser/parser_types.h" +#include "common/types.h" #include "graph/anchor.h" #include "graph/compute_graph.h" #include "graph/node.h" @@ -46,9 +46,8 @@ class ParserGraphOptimizer { domi::Status FusionFmkop(); inline bool IsHCOMOp(const string &op_type) { - return (op_type == ge::parser::HCOMALLREDUCE) || (op_type == ge::parser::HCOMALLGATHER) || - (op_type == ge::parser::HCOMBROADCAST) || (op_type == ge::parser::HCOMSEND) || - (op_type == ge::parser::HCOMRECEIVE) || (op_type == "HcomReduceScatter"); + return (op_type == ge::HCOMALLREDUCE) || (op_type == ge::HCOMALLGATHER) || (op_type == ge::HCOMBROADCAST) || + (op_type == ge::HCOMSEND) || (op_type == ge::HCOMRECEIVE) || (op_type == "HcomReduceScatter"); } void SetLocalFmkopFlag(bool isLocalFmkopFlag) { local_fmk_op_flag_ = isLocalFmkopFlag; } @@ -104,11 +103,11 @@ class ParserGraphOptimizer { domi::Status UpdateGraph(vector &nodes); domi::Status InsertNode(ge::ComputeGraphPtr sub_graph, vector &nodes, - vector &input_anchors, vector &output_anchors, - map> &output_in_map, - vector &input_control_anchors, - vector &output_control_anchors, - unordered_map &node_map); + vector &input_anchors, vector &output_anchors, + map> &output_in_map, + vector &input_control_anchors, + vector &output_control_anchors, + unordered_map &node_map); domi::Status LinkInnerAnchor(unordered_map &node_map); @@ -124,5 +123,5 @@ class ParserGraphOptimizer { domi::Status MakeTfProtoDef(); }; -} // namespace ge +} // namespace domi #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ diff --git a/parser/tensorflow/iterator_fusion_pass.cc b/parser/tensorflow/iterator_fusion_pass.cc index 0324050..8e265d7 100644 --- a/parser/tensorflow/iterator_fusion_pass.cc +++ b/parser/tensorflow/iterator_fusion_pass.cc @@ -19,7 +19,7 @@ #include #include "common/debug/log.h" -#include "framework/omg/parser/parser_types.h" +#include "common/types.h" #include "common/util.h" #include "graph_optimizer.h" #include "framework/common/ge_inner_error_codes.h" diff --git a/parser/tensorflow/scope/scope_pass_manager.cc b/parser/tensorflow/scope/scope_pass_manager.cc index db2f1ca..84ecbe0 100644 --- a/parser/tensorflow/scope/scope_pass_manager.cc +++ b/parser/tensorflow/scope/scope_pass_manager.cc @@ -15,7 +15,7 @@ */ #include "parser/tensorflow/scope/scope_pass_manager.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/ge/ge_util.h" #include "common/util.h" #include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" @@ -25,7 +25,7 @@ namespace ge { shared_ptr ScopePassManager::BuildScopeGraph(domi::tensorflow::GraphDef *graph_def) { GE_CHK_BOOL_EXEC(graph_def != nullptr, return nullptr, "graph_def is nullptr"); - scope_graph_ = ge::parser::MakeShared(); + scope_graph_ = ge::MakeShared(); if (scope_graph_ == nullptr) { GELOGE(FAILED, "Scope graph make shared failed."); return nullptr; diff --git a/parser/tensorflow/tensorflow_arg_parser.cc b/parser/tensorflow/tensorflow_arg_parser.cc index 35a34c8..4088a1e 100644 --- a/parser/tensorflow/tensorflow_arg_parser.cc +++ b/parser/tensorflow/tensorflow_arg_parser.cc @@ -17,7 +17,6 @@ #include "common/debug/log.h" #include "parser/common/op_def/arg_op.h" #include "framework/common/debug/ge_log.h" -#include "framework/omg/parser/parser_inner_ctx.h" #include "graph/compute_graph.h" #include "graph/ge_tensor.h" #include "parser/common/op_parser_factory.h" @@ -45,7 +44,7 @@ Status ParseParams(const Message *op_src, ArgOpOperator *op) { "trans output_attr_value failed, op: %s", node->name().c_str()); domi::tensorflow::AttrValue_ListValue attr_list = output_attr_value.list(); - GetParserContext().format = + domi::GetContext().format = static_cast(attr_list.func(0).attr().at(kSerializeFormat).i()); } else { /// _Arg constructed from inference function do not has input_tensor_dec @@ -65,5 +64,5 @@ Status ParseParams(const Message *op_src, ArgOpOperator *op) { return SUCCESS; } -DOMI_REGISTER_TENSORFLOW_PARSER(ge::parser::ARG, ArgOpOperator).SetParseParamsFn(ParseParams); +DOMI_REGISTER_TENSORFLOW_PARSER(ge::ARG, ArgOpOperator).SetParseParamsFn(ParseParams); } // namespace ge diff --git a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc index e9fe078..9e11697 100644 --- a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc +++ b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc @@ -16,7 +16,6 @@ #include "tensorflow_auto_mapping_parser_adapter.h" -#include "framework/omg/parser/parser_types.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "parser/common/op_parser_factory.h" @@ -25,9 +24,6 @@ using domi::TENSORFLOW; -using namespace ge::parser; - -using ge::parser::PLACEHOLDERWITHDEFAULT; namespace ge { namespace { diff --git a/parser/tensorflow/tensorflow_constant_parser.cc b/parser/tensorflow/tensorflow_constant_parser.cc index 4e4244d..1c90cf1 100644 --- a/parser/tensorflow/tensorflow_constant_parser.cc +++ b/parser/tensorflow/tensorflow_constant_parser.cc @@ -19,7 +19,7 @@ #include #include #include "common/debug/log.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/ge/ge_util.h" #include "common/op/ge_op_utils.h" #include "parser/common/op_def/constant_op.h" #include "parser/common/op_def/ir_pb_converter.h" @@ -27,12 +27,10 @@ #include "graph/ge_tensor.h" #include "graph/utils/attr_utils.h" #include "parser/common/op_parser_factory.h" -#include "framework/omg/parser/parser_types.h" #include "register/tensor_assign.h" using domi::tensorflow::NodeDef; using domi::TENSORFLOW; -using ge::parser::CONSTANTOP; namespace ge { Status TensorFlowConstantParser::ParseDType(const domi::tensorflow::NodeDef *node, ConstantOperator *op) { @@ -68,7 +66,7 @@ Status TensorFlowConstantParser::ParseValue(const domi::tensorflow::NodeDef *nod const domi::tensorflow::TensorProto &tensor = attr_value.tensor(); - GeTensorPtr weight = ge::parser::MakeShared(); + GeTensorPtr weight = ge::MakeShared(); GE_CHECK_NOTNULL(weight); int64_t dataType = 0; GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(opDesc, TENSORFLOW_ATTR_DTYPE, dataType), INTERNAL_ERROR, diff --git a/parser/tensorflow/tensorflow_data_parser.cc b/parser/tensorflow/tensorflow_data_parser.cc index 8de0575..bf2f301 100644 --- a/parser/tensorflow/tensorflow_data_parser.cc +++ b/parser/tensorflow/tensorflow_data_parser.cc @@ -19,14 +19,11 @@ #include "common/debug/log.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" -#include "framework/omg/parser/parser_inner_ctx.h" #include "parser/common/op_parser_factory.h" -#include "framework/omg/parser/parser_types.h" using domi::tensorflow::AttrValue; using domi::tensorflow::NodeDef; using domi::TENSORFLOW; -using ge::parser::DATA; namespace ge { namespace { @@ -100,7 +97,7 @@ Status TensorFlowDataParser::ParseInputFromModel(const Message *op_src, ge::OpDe Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge::OpDescPtr &op_def) { GE_CHECK_NOTNULL(op_def); (void)op_src; - const ge::ParserContext &ctx = GetParserContext(); + const ge::OmgContext &ctx = domi::GetContext(); std::unordered_map> input_dims = ctx.input_dims; // User not designate the input_shape std::string name = op_def->GetName(); @@ -134,7 +131,7 @@ Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge: } Status TensorFlowDataParser::CheckInputShape(const std::string &name) { - const ge::ParserContext &ctx = GetParserContext(); + const ge::OmgContext &ctx = domi::GetContext(); if (!ctx.is_dynamic_input) { for (uint32_t i = 0; i < user_input_dims_v.size(); i++) { // if input_shape has some placeholders, user should designate them. diff --git a/parser/tensorflow/tensorflow_enter_parser.cc b/parser/tensorflow/tensorflow_enter_parser.cc index af065bd..51466da 100644 --- a/parser/tensorflow/tensorflow_enter_parser.cc +++ b/parser/tensorflow/tensorflow_enter_parser.cc @@ -19,11 +19,8 @@ #include "framework/common/debug/log.h" #include "graph/debug/ge_attr_define.h" #include "parser/common/op_parser_factory.h" -#include "framework/omg/parser/parser_types.h" using domi::TENSORFLOW; -using ge::parser::ENTER; -using ge::parser::REFENTER; namespace ge { Status TensorFlowEnterParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) { diff --git a/parser/tensorflow/tensorflow_fill_parser.cc b/parser/tensorflow/tensorflow_fill_parser.cc index d467ed3..2687906 100644 --- a/parser/tensorflow/tensorflow_fill_parser.cc +++ b/parser/tensorflow/tensorflow_fill_parser.cc @@ -20,11 +20,6 @@ #include "parser/common/op_def/fill_op.h" #include "common/util.h" #include "parser/tensorflow/tensorflow_parser_register.h" -#include "framework/omg/parser/parser_types.h" - -using ge::parser::ALPHA_DEFAULT_VALUE; -using ge::parser::BETA_DEFAULT_VALUE; -using ge::parser::FILL; namespace ge { /* @@ -58,8 +53,8 @@ domi::Status ParseParams(const NodeDef *node, FillOperator *op) { op->DataType(type); - op->Alpha(ge::parser::ALPHA_DEFAULT_VALUE); - op->Beta(ge::parser::BETA_DEFAULT_VALUE); + op->Alpha(ge::ALPHA_DEFAULT_VALUE); + op->Beta(ge::BETA_DEFAULT_VALUE); return domi::SUCCESS; } diff --git a/parser/tensorflow/tensorflow_frameworkop_parser.cc b/parser/tensorflow/tensorflow_frameworkop_parser.cc index 343c579..79688a8 100644 --- a/parser/tensorflow/tensorflow_frameworkop_parser.cc +++ b/parser/tensorflow/tensorflow_frameworkop_parser.cc @@ -18,15 +18,14 @@ #include "parser/common/op_def/frameworkop_op.h" #include "framework/common/debug/ge_log.h" #include "parser/common/op_parser_factory.h" -#include "framework/omg/parser/parser_types.h" #include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_parser_register.h" #include "proto/tensorflow/tensor_shape.pb.h" using domi::tensorflow::TensorShapeProto; using domi::tensorflow::AttrValue; +using ge::FRAMEWORKOP; using domi::TENSORFLOW; -using ge::parser::FRAMEWORKOP; namespace ge { Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { diff --git a/parser/tensorflow/tensorflow_fusion_op_parser.cc b/parser/tensorflow/tensorflow_fusion_op_parser.cc index ee50672..d5f3f2b 100644 --- a/parser/tensorflow/tensorflow_fusion_op_parser.cc +++ b/parser/tensorflow/tensorflow_fusion_op_parser.cc @@ -17,11 +17,11 @@ #include "parser/tensorflow/tensorflow_fusion_op_parser.h" #include #include "common/debug/log.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/fp16_t.h" +#include "common/ge/ge_util.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "omg/omg.h" -#include "parser/common/parser_fp16_t.h" #include "parser/tensorflow/tensorflow_op_parser.h" #include "register/tensor_assign.h" @@ -115,7 +115,7 @@ Status TensorFlowFusionOpParser::ParseHalfFromConst(const NodeDef *node_def, flo auto val_vec = tensor.half_val(); int32_t val_size = val_vec.size(); if (index < val_size) { - ge::parser::fp16_t fp16_value = static_cast(val_vec.Get(index)); + fp16_t fp16_value = static_cast(val_vec.Get(index)); param = fp16_value.ToFloat(); } else { GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index:%d, not supported.", index); @@ -132,7 +132,7 @@ Status TensorFlowFusionOpParser::ParseWeightFromConst(const NodeDef *node_def, g GE_CHECK_NOTNULL(node_def); TensorProto tensor; GE_CHK_STATUS_RET(GetTensorFromNode(node_def, tensor), "get tensor failed."); - weight = ge::parser::MakeShared(); + weight = ge::MakeShared(); GE_CHECK_NOTNULL(weight); domi::tensorflow::DataType data_type = tensor.dtype(); GE_CHK_STATUS_RET( diff --git a/parser/tensorflow/tensorflow_fusionop_util.cc b/parser/tensorflow/tensorflow_fusionop_util.cc index 404f7e4..82c8489 100644 --- a/parser/tensorflow/tensorflow_fusionop_util.cc +++ b/parser/tensorflow/tensorflow_fusionop_util.cc @@ -20,7 +20,6 @@ #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "parser/tensorflow/tensorflow_parser.h" -#include "framework/omg/parser/parser_types.h" #include #include @@ -114,21 +113,21 @@ static map tensorflow_fusionop_map = { // static map> tensorflow_fusionop_children_nums_map = { - {ge::parser::CLIPBOXES, {8}}, - {ge::parser::FASTRCNNPREDICTIONS, {118, 119, 120, 123, 125}}, - {ge::parser::RPNPROPOSALS, {75, 85, 97}}, - {ge::parser::DECODEBBOX, {24, 28}}, - {ge::parser::ROIALIGN, {82, 83, 84}}, - {ge::parser::FUSIONBATCHNORM, {8}}, - {ge::parser::GETSPAN, {81, 71, 91}}, // The pbtxt only has 62 nodes when test GetSpan sub net. However the - {ge::parser::HUBERLOSSGRAD, {8, 9, 10, 20, 21}}, + {CLIPBOXES, {8}}, + {FASTRCNNPREDICTIONS, {118, 119, 120, 123, 125}}, + {RPNPROPOSALS, {75, 85, 97}}, + {DECODEBBOX, {24, 28}}, + {ROIALIGN, {82, 83, 84}}, + {FUSIONBATCHNORM, {8}}, + {GETSPAN, {81, 71, 91}}, // The pbtxt only has 62 nodes when test GetSpan sub net. However the + {HUBERLOSSGRAD, {8, 9, 10, 20, 21}}, }; // static map> tensorflow_fusionop_children_names_map = { - {ge::parser::FUSIONBATCHNORM, {"add/y", "add", "Rsqrt", "mul", "mul_1", "mul_2", "sub", "add_1"}}, - {ge::parser::GETSPAN, {}}, - {ge::parser::HUBERLOSSGRAD, {}}, + {FUSIONBATCHNORM, {"add/y", "add", "Rsqrt", "mul", "mul_1", "mul_2", "sub", "add_1"}}, + {GETSPAN, {}}, + {HUBERLOSSGRAD, {}}, }; // ----------------------------Index table of input and output of fusion operator-------------- @@ -138,23 +137,23 @@ static map> tensorflow_fusionop_children_names_map = { // Generally, the old index is 0. If the new index value is kFusionDisableIndex, the edge can be ignored. // If it is control edge input, the index is graph::kControlSlot(-1). static map>>> tensorflow_fusionop_inputs_map = { - {ge::parser::FUSIONBATCHNORM, + {FUSIONBATCHNORM, {{"mul_1", {0, kFusionDisableIndex}}, {"mul", {1, 1}}, {"sub", {2, kFusionDisableIndex}}, {"mul_2", {3, kFusionDisableIndex}}, {"add", {4, kFusionDisableIndex}}}}, - {ge::parser::GETSPAN, {{"transpose", {0}}, {"TensorArray", {1}}, {"transpose_1", {2}}}}, - {ge::parser::HUBERLOSSGRAD, {{"Sub_1_grad/Neg", {1}}, {"Abs_grad/Sign", {0}}}}, + {GETSPAN, {{"transpose", {0}}, {"TensorArray", {1}}, {"transpose_1", {2}}}}, + {HUBERLOSSGRAD, {{"Sub_1_grad/Neg", {1}}, {"Abs_grad/Sign", {0}}}}, }; static map>>> tensorflow_fusionop_outputs_map = { - {ge::parser::FUSIONBATCHNORM, {{"add_1", {0}}}}, - {ge::parser::GETSPAN, {{"while/Exit_1", {0}}, {"while/Exit_2", {1}}}}, - {ge::parser::HUBERLOSSGRAD, {{"Abs_grad/mul", {0}}}}, + {FUSIONBATCHNORM, {{"add_1", {0}}}}, + {GETSPAN, {{"while/Exit_1", {0}}, {"while/Exit_2", {1}}}}, + {HUBERLOSSGRAD, {{"Abs_grad/mul", {0}}}}, }; map>> tensorflow_fusionop_input_const_weight_index_map = { - {ge::parser::FUSIONBATCHNORM, {{"mul", 0}, {"sub", 1}, {"mul_2", 2}, {"add", 3}}}, + {FUSIONBATCHNORM, {{"mul", 0}, {"sub", 1}, {"mul_2", 2}, {"add", 3}}}, }; // Can a string be converted to an integer diff --git a/parser/tensorflow/tensorflow_fusionop_util.h b/parser/tensorflow/tensorflow_fusionop_util.h index f08ccf9..0187932 100644 --- a/parser/tensorflow/tensorflow_fusionop_util.h +++ b/parser/tensorflow/tensorflow_fusionop_util.h @@ -22,7 +22,7 @@ #include #include "common/debug/log.h" #include "common/string_util.h" -#include "framework/omg/parser/parser_types.h" +#include "common/types.h" #include "common/util.h" #include "omg/omg_inner_types.h" #include "proto/tensorflow/graph.pb.h" diff --git a/parser/tensorflow/tensorflow_identity_parser.cc b/parser/tensorflow/tensorflow_identity_parser.cc index 50f6277..ebf369f 100644 --- a/parser/tensorflow/tensorflow_identity_parser.cc +++ b/parser/tensorflow/tensorflow_identity_parser.cc @@ -17,15 +17,11 @@ #include "common/op/ge_op_utils.h" #include "common/op_def/ir_pb_converter.h" #include "parser/common/op_parser_factory.h" -#include "framework/omg/parser/parser_types.h" #include "parser/tensorflow/tensorflow_identity_parser.h" using domi::TENSORFLOW; -using ge::parser::IDENTITY; -using ge::parser::READVARIABLEOP; namespace ge { REGISTER_OP_PARSER_CREATOR(TENSORFLOW, IDENTITY, TensorFlowIdentityParser); -REGISTER_OP_PARSER_CREATOR(TENSORFLOW, READVARIABLEOP, TensorFlowIdentityParser); } // namespace ge diff --git a/parser/tensorflow/tensorflow_merge_parser.cc b/parser/tensorflow/tensorflow_merge_parser.cc index 6f1dedb..dc8f07f 100644 --- a/parser/tensorflow/tensorflow_merge_parser.cc +++ b/parser/tensorflow/tensorflow_merge_parser.cc @@ -20,10 +20,8 @@ #include "framework/common/util.h" #include "graph/debug/ge_attr_define.h" #include "parser/common/op_parser_factory.h" -#include "framework/omg/parser/parser_types.h" using domi::TENSORFLOW; -using ge::parser::MERGE; namespace ge { Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) { diff --git a/parser/tensorflow/tensorflow_no_op_parser.cc b/parser/tensorflow/tensorflow_no_op_parser.cc index 633e921..b0c27d5 100644 --- a/parser/tensorflow/tensorflow_no_op_parser.cc +++ b/parser/tensorflow/tensorflow_no_op_parser.cc @@ -22,7 +22,6 @@ #include "parser/common/op_parser_factory.h" using domi::TENSORFLOW; -using namespace ge::parser; namespace ge { Status TensorFlowNoOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index a486462..aff5027 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -17,45 +17,49 @@ #include "parser/tensorflow/tensorflow_parser.h" #include #include -#include "parser/common/convert/pb2json.h" +#include "common/convert/pb2json.h" #include "common/debug/log.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/fp16_t.h" +#include "common/ge/ge_util.h" +#include "common/ge/ge_util.h" +#include "common/model_saver.h" +#include "common/thread_pool.h" +#include "common/util.h" #include "common/util/error_manager/error_manager.h" #include "external/graph/operator_factory.h" #include "external/parser/tensorflow_parser.h" -#include "external/register/scope/scope_fusion_pass_register.h" #include "framework/common/debug/ge_log.h" -#include "framework/omg/parser/parser_api.h" -#include "framework/omg/parser/parser_inner_ctx.h" #include "graph/debug/ge_attr_define.h" #include "graph/optimize/common/params.h" #include "graph/passes/variable_format_pass.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" +#include "graph/common/ge_call_wrapper.h" +#include "inc/pass_manager.h" #include "iterator_fusion_pass.h" #include "omg/omg.h" #include "omg/parser/op_parser.h" #include "omg/parser/parser_factory.h" -#include "parser/common/acl_graph_parser_util.h" -#include "parser/common/model_saver.h" +#include "framework/omg/parser/parser_inner_ctx.h" #include "parser/common/op_map.h" #include "parser/common/op_parser_factory.h" -#include "parser/common/parser_fp16_t.h" -#include "parser/common/pass_manager.h" #include "parser/common/pre_checker.h" -#include "parser/common/thread_pool.h" -#include "parser/tensorflow/tensorflow_custom_parser_adapter.h" -#include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" +#include "parser/common/acl_graph_parser_util.h" #include "parser/tensorflow/tensorflow_fusion_op_parser.h" #include "parser/tensorflow/tensorflow_fusionop_util.h" #include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_util.h" +#include "parser/tensorflow/tensorflow_custom_parser_adapter.h" +#include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" #include "register/op_registry.h" +#include "external/register/scope/scope_fusion_pass_register.h" #include "register/scope/scope_graph_impl.h" #include "register/scope/scope_pass_registry_impl.h" using ge::const_op_update_vec; +using ge::fp16_t; +using ge::ModelSaver; using ge::OpParserFactory; using ge::Pb2Json; using ge::PreChecker; @@ -80,17 +84,15 @@ using ge::TENSORFLOWF_NODE_OP_TRANSPOSE; using ge::TENSORFLOWF_TENSOR_NCHW; using ge::TENSORFLOWF_TENSOR_NHWC; using ge::TensorFlowFunsionOPUtil; -using ge::TensorFlowFusionCustomParserAdapter; using ge::TensorFlowFusionOpParser; using ge::TensorFlowOpParser; using ge::ThreadPool; -using ge::parser::fp16_t; -using ge::parser::ModelSaver; +using ge::TensorFlowFusionCustomParserAdapter; namespace ge { graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { GE_CHECK_NOTNULL(model_file); - GetParserContext().type = domi::TENSORFLOW; + domi::GetContext().type = domi::TENSORFLOW; std::map options; options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(ge::TENSORFLOW))); @@ -99,7 +101,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { (void)acl_graph_parse_util.AclParserInitialize(options); // Create an empty computegraph - ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared("tmpGraph"); + ge::ComputeGraphPtr compute_graph = ge::MakeShared("tmpGraph"); GE_CHECK_NOTNULL(compute_graph); graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); @@ -120,7 +122,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { GELOGI("Parser graph %s success.", graph.GetName().c_str()); return ge::SUCCESS; } -} // namespace ge +} namespace ge { namespace { @@ -131,21 +133,20 @@ const int kInputNumInt = 2; const int32_t kControlSlot = -1; const size_t kSoftmaxMultiple = 2; const set kTfBlackFields = {"tensor_content"}; -const std::vector kSkipCheckoutInputSizeNodes = {ge::parser::DATA, ge::parser::VARIABLE, - ge::parser::FRAMEWORKOP, ge::parser::LAYERNORM}; -const std::vector kMakeOperatorNotByIr = {ge::parser::ARG, ge::parser::VARIABLE, ge::parser::VARHANDLEOP, - ge::parser::FRAMEWORKOP, ge::parser::DATA}; -const std::map kNeedMarkFormatNodes = { - {"ExtractImagePatches", domi::DOMI_TENSOR_NHWC}, - {"ExtractVolumePatches", domi::DOMI_TENSOR_NHWC}, - {"LogSoftmax", domi::DOMI_TENSOR_NHWC}, - {"ResizeBilinear", domi::DOMI_TENSOR_NHWC}, - {"ResizeBilinearGrad", domi::DOMI_TENSOR_NHWC}, - {"ResizeNearestNeighbor", domi::DOMI_TENSOR_NHWC}, - {"Softmax", domi::DOMI_TENSOR_NHWC}, - {"SoftmaxCrossEntropyWithLogits", domi::DOMI_TENSOR_NHWC}, - {"SoftmaxGrad", domi::DOMI_TENSOR_NHWC}, - {"SpaceToBatch", domi::DOMI_TENSOR_NHWC}}; +const std::vector kSkipCheckoutInputSizeNodes = {ge::DATA, ge::VARIABLE, ge::FRAMEWORKOP, ge::LAYERNORM}; +const std::vector kMakeOperatorNotByIr = {ge::ARG, ge::VARIABLE, ge::VARHANDLEOP, ge::FRAMEWORKOP, + ge::DATA}; +const std::map kNeedMarkFormatNodes = {{"ExtractImagePatches", domi::DOMI_TENSOR_NHWC}, + {"ExtractVolumePatches", domi::DOMI_TENSOR_NHWC}, + {"LogSoftmax", domi::DOMI_TENSOR_NHWC}, + {"ResizeBilinear", domi::DOMI_TENSOR_NHWC}, + {"ResizeBilinearGrad", domi::DOMI_TENSOR_NHWC}, + {"ResizeNearestNeighbor", domi::DOMI_TENSOR_NHWC}, + {"Softmax", domi::DOMI_TENSOR_NHWC}, + {"SoftmaxCrossEntropyWithLogits", + domi::DOMI_TENSOR_NHWC}, + {"SoftmaxGrad", domi::DOMI_TENSOR_NHWC}, + {"SpaceToBatch", domi::DOMI_TENSOR_NHWC}}; const char *const kDpop = "DPOP"; const char *const kFuncDefLibraryFilePath = "graph_def_library.pbtxt"; const char *const kAttrNameIsScopeInnerNode = "_is_scope_inner_node"; @@ -172,7 +173,7 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque // A function may be referenced multiple times in TF, change the graph name to ensure it is unique in GE auto unique_name = node->GetName() + std::to_string(i) + subgraph_iname; - auto subgraph = ge::parser::MakeShared(unique_name); + auto subgraph = ge::MakeShared(unique_name); if (subgraph == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to alloc subgraph %s", subgraph_iname.c_str()); return OUT_OF_MEMORY; @@ -240,13 +241,13 @@ Status TensorFlowModelParser::DefunToPartitionedCall(const domi::tensorflow::Nod domi::tensorflow::AttrValue attr_call_inference; if (!ge::TensorFlowUtil::FindAttrValue(node_def, "_disable_call_shape_inference", attr_call_inference)) { ErrorManager::GetInstance().ATCReportErrMessage( - "E19014", {"opname", "value", "reason"}, - {node_def->name(), "attr [_disable_call_shape_inference]", "is not exist in nodedef"}); + "E19014", {"opname", "value", "reason"}, {node_def->name(), "attr [_disable_call_shape_inference]", + "is not exist in nodedef"}); GELOGE(FAILED, "In NodeDef %s attr [_disable_call_shape_inference] not exist.", op_name.c_str()); return FAILED; } - op = ge::parser::MakeShared(op_name, ge::parser::PARTITIONEDCALL); + op = ge::MakeShared(op_name, ge::PARTITIONEDCALL); GE_CHECK_NOTNULL(op); size_t input_tensor_num = 0; @@ -282,9 +283,9 @@ Status TensorFlowModelParser::TransNodeToOpDesc(const domi::tensorflow::NodeDef GE_CHECK_NOTNULL(node_def); string node_name = node_def->name(); ge::Operator op_factory = ge::OperatorFactory::CreateOperator(node_name, op_type); - if (op_factory.GetName() != node_name || op_type == ge::parser::DATA) { + if (op_factory.GetName() != node_name || op_type == ge::DATA) { if (std::find(kMakeOperatorNotByIr.begin(), kMakeOperatorNotByIr.end(), op_type) != kMakeOperatorNotByIr.end()) { - op = ge::parser::MakeShared(node_name, op_type); + op = ge::MakeShared(node_name, op_type); GE_CHECK_NOTNULL(op); } else if (node_name == op_type) { // Trans @tensorflow.python.framework.Defun(...) to PartitionedCall. @@ -309,8 +310,8 @@ Status TensorFlowModelParser::TransNodeToOpDesc(const domi::tensorflow::NodeDef return SUCCESS; } -Status TensorFlowModelParser::ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, - shared_ptr &op_parser) { +Status TensorFlowModelParser::ParseOpParams(const domi::tensorflow::NodeDef *node_def, + ge::OpDescPtr &op, shared_ptr &op_parser) { GE_CHECK_NOTNULL(node_def); GE_CHECK_NOTNULL(op); GE_CHECK_NOTNULL(op_parser); @@ -336,7 +337,7 @@ Status TensorFlowModelParser::ParseOpParams(const domi::tensorflow::NodeDef *nod return status; } std::shared_ptr tf_custom_op_parser = - std::dynamic_pointer_cast(op_parser); + std::dynamic_pointer_cast(op_parser); GE_CHECK_NOTNULL(tf_custom_op_parser); status = tf_custom_op_parser->ParseParams(op_src, op); if (status != SUCCESS) { @@ -501,10 +502,10 @@ Status TensorFlowModelParser::CheckoutInputNum(ge::OpDescPtr &op_desc, const dom // get input and output tensor number from op desc size_t factory_input_size = op_desc->GetInputsSize(); if (input_tensor_num != factory_input_size) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19014", {"opname", "value", "reason"}, - {op_desc->GetName(), "input number of tensorflow[" + std::to_string(input_tensor_num) + "]", - "should be equal to factory size[" + std::to_string(factory_input_size) + "]"}); + ErrorManager::GetInstance().ATCReportErrMessage( + "E19014", {"opname", "value", "reason"}, {op_desc->GetName(), + "input number of tensorflow[" + std::to_string(input_tensor_num) + "]", + "should be equal to factory size[" + std::to_string(factory_input_size) + "]"}); GELOGE(FAILED, "op [%s], type[%s], The input number of tensorflow[%zu] should be equal to factory size[%zu]", op_desc->GetName().c_str(), op_desc->GetType().c_str(), input_tensor_num, factory_input_size); return FAILED; @@ -570,7 +571,7 @@ Status TensorFlowModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, cons domi::tensorflow::AttrValue input_attr_value; domi::tensorflow::AttrValue output_attr_value; ParserOperator temp_op; - if (ge::TensorFlowUtil::FindAttrValue(node, ge::parser::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { + if (ge::TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { GE_CHK_STATUS_RET(ge::TensorFlowUtil::TransTensorDescriptor(input_attr_value, &temp_op, TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG, type), "trans input_attr_value failed, op: %s", node->name().c_str()); @@ -713,8 +714,7 @@ Status TensorFlowModelParser::CheckOpShapeDim(const domi::tensorflow::NodeDef *n bool &valid) { GE_CHECK_NOTNULL(node_def); domi::tensorflow::AttrValue input_attr_value; - bool is_attr_exist = - ge::TensorFlowUtil::FindAttrValue(node_def, ge::parser::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value); + bool is_attr_exist = ge::TensorFlowUtil::FindAttrValue(node_def, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value); GE_IF_BOOL_EXEC(!is_attr_exist, return SUCCESS); GE_CHK_BOOL_EXEC(input_attr_value.has_list(), return PARAM_INVALID, "output attr value vector is empty"); @@ -741,26 +741,24 @@ Status TensorFlowModelParser::CheckOpType(const domi::tensorflow::NodeDef *node_ string node_name = node_def->name(); std::map> check_dims = { - {ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, {10}}, + {ge::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, {10}}, }; GE_IF_BOOL_EXEC( - op_type == ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, + op_type == ge::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, GE_CHK_STATUS_RET(CheckOpShapeDim(node_def, check_dims[op_type], valid), "failed to check op shape"); - GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP; GELOGI("Set op %s to frameworkop", node_name.c_str()); + GE_IF_BOOL_EXEC(!valid, op_type = ge::FRAMEWORKOP; GELOGI("Set op %s to frameworkop", node_name.c_str()); framework_ops_[node_name] = node_def;);); - GE_IF_BOOL_EXEC( - op_type == ge::parser::ADD || op_type == ge::parser::MULTIPLY || op_type == ge::parser::MEAN, - for (const string &input_name - : node_def->input()) { - string tmp_input_name; - GE_RETURN_IF_ERROR(CheckInputNodeName(input_name, &tmp_input_name, nullptr, nullptr)); - GELOGD("Add or Mul op %s input name is %s", node_name.c_str(), input_name.c_str()); - GE_IF_BOOL_EXEC(framework_ops_.find(tmp_input_name) != framework_ops_.end(), - GELOGI("Set op %s to frameworkop", node_name.c_str()); - op_type = ge::parser::FRAMEWORKOP;); - }); + GE_IF_BOOL_EXEC(op_type == ge::ADD || op_type == ge::MULTIPLY || op_type == ge::MEAN, for (const string &input_name + : node_def->input()) { + string tmp_input_name; + GE_RETURN_IF_ERROR(CheckInputNodeName(input_name, &tmp_input_name, nullptr, nullptr)); + GELOGD("Add or Mul op %s input name is %s", node_name.c_str(), input_name.c_str()); + GE_IF_BOOL_EXEC(framework_ops_.find(tmp_input_name) != framework_ops_.end(), + GELOGI("Set op %s to frameworkop", node_name.c_str()); + op_type = ge::FRAMEWORKOP;); + }); return SUCCESS; } @@ -795,13 +793,13 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co string op_type = iterator->second; // Log printing for determining operator type domi::ImplyType implyType = domi::OpRegistry::Instance()->GetImplyType(op_type); - GE_IF_BOOL_EXEC((implyType == domi::ImplyType::TVM) && (op_type != ge::parser::FRAMEWORKOP), + GE_IF_BOOL_EXEC((implyType == domi::ImplyType::TVM) && (op_type != ge::FRAMEWORKOP), GELOGD("TBE %s parsering", node_op.c_str());); - GE_IF_BOOL_EXEC((implyType == domi::ImplyType::CCE) && (op_type != ge::parser::FRAMEWORKOP), + GE_IF_BOOL_EXEC((implyType == domi::ImplyType::CCE) && (op_type != ge::FRAMEWORKOP), GELOGD("CCE %s parsering", node_op.c_str());); - GE_IF_BOOL_EXEC((implyType == domi::ImplyType::HCCL) && (op_type != ge::parser::FRAMEWORKOP), + GE_IF_BOOL_EXEC((implyType == domi::ImplyType::HCCL) && (op_type != ge::FRAMEWORKOP), GELOGD("HCCL %s parsering", node_op.c_str());); - GE_IF_BOOL_EXEC(op_type == ge::parser::FRAMEWORKOP, GELOGD("FRAMEWORKOP %s parsering", node_op.c_str());); + GE_IF_BOOL_EXEC(op_type == ge::FRAMEWORKOP, GELOGD("FRAMEWORKOP %s parsering", node_op.c_str());); GELOGD("TF op node name = %s, op type= %s, trans to op type %s", node_name.c_str(), node_op.c_str(), op_type.c_str()); // Construct operator by IR @@ -809,7 +807,7 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co ge::Operator op_factory = ge::OperatorFactory::CreateOperator(node_name, op_type); if (op_factory.GetName() != node_name) { if (std::find(kMakeOperatorNotByIr.begin(), kMakeOperatorNotByIr.end(), op_type) != kMakeOperatorNotByIr.end()) { - op = ge::parser::MakeShared(node_name, op_type); + op = ge::MakeShared(node_name, op_type); GE_CHECK_NOTNULL(op); } else if (node_name == op_type) { GE_RETURN_IF_ERROR(parser->DefunToPartitionedCall(node_def, op)); @@ -906,7 +904,7 @@ Status TensorFlowModelParser::AdaptOpType(const domi::tensorflow::NodeDef *node_ op_type = tensorflow_train_op_map.at(node_op); GE_CHK_STATUS_RET(CheckOpType(node_def, op_type), "Failed to check op type"); } else { - op_type = ge::parser::FRAMEWORKOP; + op_type = ge::FRAMEWORKOP; domi::tensorflow::AttrValue attr_call_inference; if ((node_name == node_op) && ge::TensorFlowUtil::FindAttrValue(node_def, "_disable_call_shape_inference", attr_call_inference)) { @@ -914,7 +912,7 @@ Status TensorFlowModelParser::AdaptOpType(const domi::tensorflow::NodeDef *node_ } } - GE_IF_BOOL_EXEC(isDatasetInit, op_type = ge::parser::FRAMEWORKOP); + GE_IF_BOOL_EXEC(isDatasetInit, op_type = ge::FRAMEWORKOP); adaptedOpTypeMap_[node_name] = op_type; return SUCCESS; @@ -939,7 +937,7 @@ Status TensorFlowModelParser::AddFmkNode(ge::ComputeGraphPtr &graph, shared_ptr< ThreadPool executor(kThreadNum); std::mutex graphMutex; std::vector> vectorFuture(op_node_list_size); - ge::ComputeGraphPtr graph_tmp = ge::parser::MakeShared("tmpGraph"); + ge::ComputeGraphPtr graph_tmp = ge::MakeShared("tmpGraph"); GE_CHECK_NOTNULL(graph_tmp); for (size_t j = 0; j < op_node_list_size; j++) { const string op_node_name = op_node_name_list[j]; @@ -1007,11 +1005,11 @@ Status TensorFlowModelParser::ExcuteScopeFusionPasses(domi::tensorflow::GraphDef // Identifying scope fusion operators based on scope rules GE_CHECK_NOTNULL(graph_def); ScopePassManager passmanager; - PARSER_TIMESTAMP_START(BuildScopeGraph); + GE_TIMESTAMP_START(BuildScopeGraph); scope_graph = passmanager.BuildScopeGraph(graph_def); GE_CHECK_NOTNULL(scope_graph); - PARSER_TIMESTAMP_END(BuildScopeGraph, "TensorFlowModelParser::BuildScopeGraph"); - PARSER_TIMESTAMP_START(ScopeGraphPass); + GE_TIMESTAMP_END(BuildScopeGraph, "TensorFlowModelParser::BuildScopeGraph"); + GE_TIMESTAMP_START(ScopeGraphPass); // Validate the non-general scope fusion pass. // The parameter is set to the name of the fusion rule. // Multiple names can be set and separated by ",". @@ -1037,7 +1035,7 @@ Status TensorFlowModelParser::ExcuteScopeFusionPasses(domi::tensorflow::GraphDef GELOGE(ret, "Run scope fusion failed, ret:%u.", ret); return ret; } - PARSER_TIMESTAMP_END(ScopeGraphPass, "TensorFlowModelParser::ScopeGraphPass"); + GE_TIMESTAMP_END(ScopeGraphPass, "TensorFlowModelParser::ScopeGraphPass"); return SUCCESS; } @@ -1049,11 +1047,11 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g // Store objects parsed from pb files domi::tensorflow::GraphDef OriDef; - bool read = ge::parser::ReadProtoFromArray(data, static_cast(size), &OriDef); + bool read = ge::ReadProtoFromArray(data, static_cast(size), &OriDef); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, return INTERNAL_ERROR, "read_proto_from_binary failed."); domi::tensorflow::GraphDef graph_def; - if (ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) { + if (domi::GetContext().input_dims.empty() && domi::GetContext().out_nodes_map.empty()) { graph_def = OriDef; } else { GELOGI("Before Trim, the Graph Node size is:%d", OriDef.node_size()); @@ -1115,14 +1113,14 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g GELOGD("[TF ParseFromMemory] get op nodes context from graph success"); // Infer input formats - ge::GetParserContext().format = InferInputFormats(); + domi::GetContext().format = InferInputFormats(); GELOGD("[TF ParseFromMemory] infer input formats success"); // Building input-output relationship between fusionop and common op GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, graph_def, op_node_name_list)); ret = AddFusionNodeDef(scope_graph, op_node_name_list); - if (ret != SUCCESS) { + if(ret != SUCCESS) { GELOGE(ret, "Add fusion NodeDef failed."); DeleteFuisonNodeDef(); return ret; @@ -1170,7 +1168,7 @@ Status TensorFlowModelParser::GetFunctionProto(const string &file, string graph_def_path = (pos == -1) ? kFuncDefLibraryFilePath : file.substr(0, pos) + "/" + kFuncDefLibraryFilePath; GELOGI("Function def libraray path is %s.", graph_def_path.c_str()); - bool read = ge::parser::ReadProtoFromText(graph_def_path.c_str(), &graph_def_library); + bool read = ge::ReadProtoFromText(graph_def_path.c_str(), &graph_def_library); if (!read) { GELOGE(INTERNAL_ERROR, "Get subgraph library failed. " @@ -1206,12 +1204,12 @@ Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr GELOGI("Parse file %s", model_path); // Store objects parsed from pb files domi::tensorflow::GraphDef ori_def; - bool read = ge::parser::ReadProtoFromBinaryFile(model_path, &ori_def); + bool read = ge::ReadProtoFromBinaryFile(model_path, &ori_def); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, return INTERNAL_ERROR, "read_proto_from_binary failed."); // Trim graph by user input and output. domi::tensorflow::GraphDef graph_def; - if (ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) { + if (domi::GetContext().input_dims.empty() && domi::GetContext().out_nodes_map.empty()) { graph_def = ori_def; } else { GELOGI("Before Trim, the Graph Node size is:%d", ori_def.node_size()); @@ -1304,12 +1302,14 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()), "Add node_def to PreChecker failed, node name: %s.", node.name().c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckName(&node) != SUCCESS, return FAILED, - "Check op[%s] failed, name repeat in tensorflow pb file.", node.name().c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckName(&node) != SUCCESS, + return FAILED, "Check op[%s] failed, name repeat in tensorflow pb file.", + node.name().c_str()); GE_CHK_BOOL_EXEC_NOLOG( node.op() == TENSORFLOWF_NODE_OP_IDENTITY, - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckType(&node, true) != SUCCESS, return FAILED, - "Check op[%s]'s optype failed, type is not supported.", node.name().c_str());) + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + PreChecker::Instance().CheckType(&node, true) != SUCCESS, + return FAILED, "Check op[%s]'s optype failed, type is not supported.", node.name().c_str());) } bool has_error = false; @@ -1346,7 +1346,7 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro GELOGD("[TF Parse] get op nodes context from graph success"); // Infer input formats - ge::GetParserContext().format = InferInputFormats(); + domi::GetContext().format = InferInputFormats(); GELOGD("[TF Parse] infer input formats success"); // Building input-output relationship between fusionop and common op @@ -1355,13 +1355,13 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro // set user-designate-inputs-order std::vector user_inputs_order; - for (auto &input : ge::GetParserContext().user_input_dims) { + for (auto &input : domi::GetContext().user_input_dims) { user_inputs_order.push_back(input.first); } graph->SetInputsOrder(user_inputs_order); ret = AddFusionNodeDef(scope_graph, op_node_name_list); - if (ret != SUCCESS) { + if(ret != SUCCESS) { GELOGE(ret, "Add fusion NodeDef failed."); DeleteFuisonNodeDef(); return ret; @@ -1425,7 +1425,7 @@ Status TensorFlowModelParser::CheckGraphDefValid(const domi::tensorflow::GraphDe } } - if (node_def.op() == TENSORFLOWF_NODE_OP_PLACEHOLDER || node_def.op() == ge::parser::ARG) { + if (node_def.op() == TENSORFLOWF_NODE_OP_PLACEHOLDER || node_def.op() == ge::ARG) { data_node_count++; } } @@ -1501,20 +1501,20 @@ Status TensorFlowModelParser::GeStoi(const string &input_node_name, const string int32_t tmp_index = static_cast(std::stoi(index_str.c_str(), nullptr, 10)); *index = tmp_index; } catch (std::invalid_argument &) { - ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, - {"input_node_name(" + input_node_name + ")", index_str}); + ErrorManager::GetInstance().ATCReportErrMessage( + "E10014", {"parameter", "value"}, {"input_node_name(" + input_node_name + ")", index_str}); GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is invalid argument!", input_node_name.c_str(), index_str.c_str()); return INTERNAL_ERROR; } catch (std::out_of_range &) { - ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, - {"input_node_name(" + input_node_name + ")", index_str}); + ErrorManager::GetInstance().ATCReportErrMessage( + "E10013", {"parameter", "value"}, {"input_node_name(" + input_node_name + ")", index_str}); GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is out of range!", input_node_name.c_str(), index_str.c_str()); return INTERNAL_ERROR; } catch (...) { - ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"}, - {"input_node_name(" + input_node_name + ")", index_str}); + ErrorManager::GetInstance().ATCReportErrMessage( + "E10015", {"parameter", "value"}, {"input_node_name(" + input_node_name + ")", index_str}); GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is bad argument!", input_node_name.c_str(), index_str.c_str()); return INTERNAL_ERROR; @@ -1610,21 +1610,19 @@ bool TensorFlowModelParser::MaybeFusionOp(shared_ptr &scope_grap auto &impl = scope_graph->impl_; if (TensorFlowFunsionOPUtil::MaybeFusionOp(node_def->name(), &info) || impl->IsFusionOpChild(node_def->name(), info_list)) { - GE_IF_BOOL_EXEC( - info_list.size() > 0, for (size_t i = 0; i < info_list.size(); ++i) { - fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].fusion_op_type); - fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].description); - fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(node_def); - if (info_list[i].fusion_op_type == "Dropout" && - (node_def->op() == "Add" || node_def->op() == "RandomUniform")) { - fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(nodedef_map_[node_def->input(0)]); - } - if (info_list[i].fusion_op_type == "LayerNorm" && node_def->op() == "Mean") { - fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(nodedef_map_[node_def->input(1)]); - } - fusion_op_policy_[info_list[i].fusion_node_name] = info_list[i].scope_pass; - fusion_op_children_[node_def->name()] = info_list[i]; - }); + GE_IF_BOOL_EXEC(info_list.size() > 0, for (size_t i = 0; i < info_list.size(); ++i) { + fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].fusion_op_type); + fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].description); + fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(node_def); + if (info_list[i].fusion_op_type == "Dropout" && (node_def->op() == "Add" || node_def->op() == "RandomUniform")) { + fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(nodedef_map_[node_def->input(0)]); + } + if (info_list[i].fusion_op_type == "LayerNorm" && node_def->op() == "Mean") { + fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(nodedef_map_[node_def->input(1)]); + } + fusion_op_policy_[info_list[i].fusion_node_name] = info_list[i].scope_pass; + fusion_op_children_[node_def->name()] = info_list[i]; + }); GE_IF_BOOL_EXEC(info_list.size() == 0, fusion_op_type_map_[info.fusion_node_name].push_back(info.fusion_op_type); fusion_op_type_map_[info.fusion_node_name].push_back(info.description); fusion_op_nodedef_map_[info.fusion_node_name].push_back(node_def); @@ -1819,7 +1817,8 @@ Status TensorFlowModelParser::UpdateFusionOpContext(shared_ptr & } Status TensorFlowModelParser::UppdateInputMap(shared_ptr &scope_graph, - const ge::ScopeFusionOpInfo &info, OpNodeContext &fusion_op_node_context, + const ge::ScopeFusionOpInfo &info, + OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context) { GE_CHECK_NOTNULL(scope_graph); for (auto &iter : normal_op_node_context.input_map) { @@ -1863,7 +1862,8 @@ Status TensorFlowModelParser::UppdateInputMap(shared_ptr &scope_ return SUCCESS; } Status TensorFlowModelParser::UppdateOutputMap(shared_ptr &scope_graph, - const ge::ScopeFusionOpInfo &info, OpNodeContext &fusion_op_node_context, + const ge::ScopeFusionOpInfo &info, + OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context) { GE_CHECK_NOTNULL(scope_graph); for (auto &iter : normal_op_node_context.output_map) { @@ -1936,7 +1936,8 @@ Status TensorFlowModelParser::EraseNormalOpOutputIfChild(shared_ptr &scope_graph, const string &op_node_name, +Status TensorFlowModelParser::UpdateNormalOpContext(shared_ptr &scope_graph, + const string &op_node_name, OpNodeContext &normal_op_node_context) { GE_CHECK_NOTNULL(scope_graph); std::map>> tmp_input_map; @@ -1967,7 +1968,7 @@ Status TensorFlowModelParser::UpdateNormalOpContext(shared_ptr & } Status ret = EraseNormalOpOutputIfChild(scope_graph, op_node_name, normal_op_node_context); - if (ret != SUCCESS) { + if ( ret != SUCCESS) { return ret; } @@ -2093,7 +2094,7 @@ Status TensorFlowModelParser::ToJson(const char *model_file, const char *json_fi domi::tensorflow::GraphDef graph_def; nlohmann::json j; - GE_RETURN_WITH_LOG_IF_FALSE(ge::parser::ReadProtoFromBinaryFile(model_file, &graph_def), + GE_RETURN_WITH_LOG_IF_FALSE(ge::ReadProtoFromBinaryFile(model_file, &graph_def), "ReadProtoFromBinaryFile failed, file:%s.", model_file); Pb2Json::Message2Json(graph_def, kTfBlackFields, j, true); @@ -2107,10 +2108,10 @@ Status TensorFlowWeightsParser::ParseFromMemory(const char *data, uint32_t size, Status TensorFlowWeightsParser::Parse(const char *file, ge::Graph &graph) { return SUCCESS; } Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { - PARSER_TIMESTAMP_START(ParseProto); + GE_TIMESTAMP_START(ParseProto); GE_CHECK_NOTNULL(proto); GE_CHECK_NOTNULL(graph); - ge::GetParserContext().train_flag = true; + domi::GetContext().train_flag = true; const domi::tensorflow::GraphDef *graph_def_in = reinterpret_cast(proto); // Make a copy for operation without modifying the original graph def. @@ -2129,15 +2130,15 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, bool has_error = false; // Graphdef optimizes identity - PARSER_TIMESTAMP_START(GraphDefOptimize); + GE_TIMESTAMP_START(GraphDefOptimize); GE_RETURN_IF_ERROR(GraphDefOptimize(graph_def)); - PARSER_TIMESTAMP_END(GraphDefOptimize, "TensorFlowModelParser::GraphDefOptimize"); + GE_TIMESTAMP_END(GraphDefOptimize, "TensorFlowModelParser::GraphDefOptimize"); GELOGD("[TF Parser] graph def optimize success"); // Optimization for TVM operator - PARSER_TIMESTAMP_START(OptimizeConstNodes4CustomOp); + GE_TIMESTAMP_START(OptimizeConstNodes4CustomOp); GE_RETURN_IF_ERROR(OptimizeConstNodes4CustomOp(graph_def)); - PARSER_TIMESTAMP_END(OptimizeConstNodes4CustomOp, "TensorFlowModelParser::OptimizeConstNodes4CustomOp"); + GE_TIMESTAMP_END(OptimizeConstNodes4CustomOp, "TensorFlowModelParser::OptimizeConstNodes4CustomOp"); GELOGD("[TF Parser] optimize const nodes for custom op success"); GE_RETURN_IF_ERROR(GetTensorflowGraphInOutMap(graph_def)); @@ -2145,13 +2146,13 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, vector op_node_name_list; bool isDatasetInit = false; - PARSER_TIMESTAMP_START(AddFmkNodeDefToMap); + GE_TIMESTAMP_START(AddFmkNodeDefToMap); for (int i = 0; i < graph_def->node_size(); i++) { const domi::tensorflow::NodeDef *node_def = graph_def->mutable_node(i); - if (node_def->op() == ge::parser::IDENTITY && node_def->input_size() == 0) { + if (node_def->op() == ge::IDENTITY && node_def->input_size() == 0) { continue; } - if (node_def->op() == ge::parser::SNAPSHOT && node_def->input_size() == 0) { + if (node_def->op() == ge::SNAPSHOT && node_def->input_size() == 0) { continue; } GE_IF_BOOL_EXEC(node_def->op() == "MakeIterator", isDatasetInit = true); @@ -2165,26 +2166,26 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, Status ret = AddFmkNodeDefToMap(*graph_def, node_def, op_node_name_list); GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); } - PARSER_TIMESTAMP_END(AddFmkNodeDefToMap, "TensorFlowModelParser::AddFmkNodeDefToMap"); + GE_TIMESTAMP_END(AddFmkNodeDefToMap, "TensorFlowModelParser::AddFmkNodeDefToMap"); GELOGI("[TF Parser] TF subgraph isDatasetInit: %d.", isDatasetInit); // Verify the validity of fusionop GE_RETURN_IF_ERROR(CheckFusionOpValid()); // Build input and output relationships for all OP nodes - PARSER_TIMESTAMP_START(GetOpNodesContextFromGraph); + GE_TIMESTAMP_START(GetOpNodesContextFromGraph); GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def)); - PARSER_TIMESTAMP_END(GetOpNodesContextFromGraph, "TensorFlowModelParser::GetOpNodesContextFromGraph"); + GE_TIMESTAMP_END(GetOpNodesContextFromGraph, "TensorFlowModelParser::GetOpNodesContextFromGraph"); GELOGD("[TF Parser] Get op nodes context from graph success"); // Building input-output relationship between fusionop and common op GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, *graph_def, op_node_name_list)); GELOGI("[TF Parser] TF op node size = %zu.", op_node_name_list.size()); - PARSER_TIMESTAMP_START(AddFmkNode); + GE_TIMESTAMP_START(AddFmkNode); // Loop analysis of op_nodes and map them to nodes in graph ret = AddFmkNode(graph, scope_graph, op_node_name_list, isDatasetInit); - PARSER_TIMESTAMP_END(AddFmkNode, "TensorFlowModelParser::AddFmkNode"); + GE_TIMESTAMP_END(AddFmkNode, "TensorFlowModelParser::AddFmkNode"); GE_CHK_STATUS_EXEC(ret, DeleteFuisonNodeDef(); return ret, "AddFmkNode failed"); GELOGD("[TF Parser] Add framework node success"); @@ -2193,16 +2194,16 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, GE_CHK_STATUS_EXEC(ret, return ret, "AddEdges failed"); GELOGD("[TF Parser] Add edges success"); - PARSER_TIMESTAMP_START(RemoveIsolateNode); + GE_TIMESTAMP_START(RemoveIsolateNode); // Delete isolated nodes GE_RETURN_IF_ERROR(RemoveIsolateNode(graph)); - PARSER_TIMESTAMP_END(RemoveIsolateNode, "TensorFlowModelParser::RemoveIsolateNode"); - PARSER_TIMESTAMP_START(TopologicalSorting); + GE_TIMESTAMP_END(RemoveIsolateNode, "TensorFlowModelParser::RemoveIsolateNode"); + GE_TIMESTAMP_START(TopologicalSorting); GE_RETURN_IF_ERROR(graph->TopologicalSorting()); - PARSER_TIMESTAMP_END(TopologicalSorting, "TensorFlowModelParser::TopologicalSorting"); + GE_TIMESTAMP_END(TopologicalSorting, "TensorFlowModelParser::TopologicalSorting"); - ge::parser::PassManager iterator_fusion_pass; + ge::PassManager iterator_fusion_pass; try { (void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass", new ge::IteratorFusionPass(ge::TENSORFLOW, false)); @@ -2222,7 +2223,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, return PARAM_INVALID; } GELOGI("[TF Parser] Parse proto success."); - PARSER_TIMESTAMP_END(ParseProto, "TensorFlowModelParser::ParseProto"); + GE_TIMESTAMP_END(ParseProto, "TensorFlowModelParser::ParseProto"); return SUCCESS; } @@ -2232,7 +2233,7 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Mes GE_CHECK_NOTNULL(callback); GE_CHECK_NOTNULL(root_graph); - PARSER_TIMESTAMP_START(ParseProtoWithSubgraph); + GE_TIMESTAMP_START(ParseProtoWithSubgraph); std::vector> proto_holder; std::deque tasks; tasks.push_back({root_proto, "root", nullptr, "", root_graph}); @@ -2272,7 +2273,7 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Mes return ret; } } - PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); + GE_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); return SUCCESS; } @@ -2369,8 +2370,8 @@ Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_m string node_name; bool is_control = false; if (CheckInputNodeName(input, &node_name, nullptr, &is_control) != SUCCESS) { - GELOGE(FAILED, "parse node input info failed, node %s, input %s.", output_node_def->name().c_str(), - input.c_str()); + GELOGE(FAILED, "parse node input info failed, node %s, input %s.", + output_node_def->name().c_str(), input.c_str()); return FAILED; } if (node_name == curr_node_name) { @@ -2433,8 +2434,8 @@ Status TensorFlowModelParser::GraphDefOptimizeSnapShot(domi::tensorflow::GraphDe int input_index = 0; bool is_control = false; if (CheckInputNodeName(input, &node_name, &input_index, &is_control) != SUCCESS) { - GELOGE(FAILED, "parse SnapShot input info failed, node %s, input %s.", curr_node_def->name().c_str(), - input.c_str()); + GELOGE(FAILED, "parse SnapShot input info failed, node %s, input %s.", + curr_node_def->name().c_str(), input.c_str()); return FAILED; } if (is_control) { @@ -2509,7 +2510,7 @@ Status TensorFlowModelParser::GraphDefOptimizeDestroyTemporaryVariable(domi::ten GELOGE(FAILED, "input param is nullptr."); return FAILED; } - if (nodeCurrent->op() != ge::parser::DESTROYTEMPORARYVARIABLE) { + if (nodeCurrent->op() != ge::DESTROYTEMPORARYVARIABLE) { return SUCCESS; } @@ -2521,7 +2522,7 @@ Status TensorFlowModelParser::GraphDefOptimizeDestroyTemporaryVariable(domi::ten for (int j = 0; j < graph_def->node_size(); j++) { domi::tensorflow::NodeDef *nodeTmpVar = graph_def->mutable_node(j); - GE_IF_BOOL_EXEC(nodeTmpVar->op() != ge::parser::TEMPORARYVARIABLE, continue); + GE_IF_BOOL_EXEC(nodeTmpVar->op() != ge::TEMPORARYVARIABLE, continue); google::protobuf::Map *attr_map_tmp = nodeTmpVar->mutable_attr(); domi::tensorflow::AttrValue var_name_attr_tmp = (*attr_map_tmp)[ge::VAR_ATTR_NAME]; @@ -2557,11 +2558,11 @@ Status GetTransposeInfo(GraphDef *graph_def, std::map GE_CHECK_NOTNULL(graph_def); for (int i = 0; i < graph_def->node_size(); ++i) { auto node_def = graph_def->mutable_node(i); - if (node_def->op() == ge::parser::TRANSPOSE) { + if (node_def->op() == ge::TRANSPOSE) { DelTransposeInfo transpose; transpose.node_def = node_def; transposeInfo.insert(std::make_pair(node_def->name(), transpose)); - } else if (node_def->op() == ge::parser::SOFTMAX) { + } else if (node_def->op() == ge::SOFTMAX) { softmaxInfo.insert(std::make_pair(node_def->name(), node_def->input(0))); GELOGI("softmax name:%s, input name:%s", node_def->name().c_str(), node_def->input(0).c_str()); } @@ -2630,7 +2631,7 @@ void TensorFlowModelParser::SoftmaxAddAttr(GraphDef *graph_def) { // The caller guarantees that the pointer is not null for (int i = 0; i < graph_def->node_size(); ++i) { auto node_def = graph_def->mutable_node(i); - if (node_def->op() == ge::parser::SOFTMAX) { + if (node_def->op() == ge::SOFTMAX) { domi::tensorflow::AttrValue attr_value; attr_value.set_i(1); ge::TensorFlowUtil::AddNodeAttr("axis", attr_value, node_def); @@ -2654,9 +2655,9 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph const string &node_name = node_def->name(); Status ret = AddFmkNodeDefToMap(*graph_def, node_def, op_node_name_list); GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); - if (node_def->op() == ge::parser::IDENTITY || node_def->op() == ge::parser::READVARIABLEOP) { + if (node_def->op() == ge::IDENTITY || node_def->op() == ge::READVARIABLEOP) { identity_to_optimize.push_back(node_def); - } else if (node_def->op() == ge::parser::SNAPSHOT) { + } else if (node_def->op() == ge::SNAPSHOT) { snapshot_to_optimize.push_back(node_def); } nodedef_map[node_name] = node_def; @@ -2687,19 +2688,19 @@ Status TensorFlowModelParser::RemoveIsolateNode(ge::ComputeGraphPtr &graph) { if (n->GetName().substr(0, 4) == "dpop") { continue; } - if ((n->GetType() == ge::parser::DATA) || - (ge::GetParserContext().out_nodes_map.find(n->GetName()) != ge::GetParserContext().out_nodes_map.end())) { + if ((n->GetType() == DATA) || + (domi::GetContext().out_nodes_map.find(n->GetName()) != domi::GetContext().out_nodes_map.end())){ GELOGI("Can not remove op [%s] because it is data or out node.", n->GetName().c_str()); continue; } - GE_IF_BOOL_EXEC((((n->GetInAllNodes().size() == 0) && (n->GetOutDataNodes().size() == 0)) || - ((n->GetType() == ge::parser::CONSTANTOP || n->GetType() == ge::parser::CONSTANT) && - (n->GetOutDataNodes().size() == 0))), - GE_CHK_STATUS_RET(ge::GraphUtils::IsolateNode(n, {}), "Isolate removed node: %s, type: %s failed", - n->GetName().c_str(), n->GetType().c_str()); - GE_CHK_STATUS_RET(ge::GraphUtils::RemoveNodeWithoutRelink(graph, n), - "Remove node: %s, type: %s without relink failed", n->GetName().c_str(), - n->GetType().c_str());); + GE_IF_BOOL_EXEC( + (((n->GetInAllNodes().size() == 0) && (n->GetOutDataNodes().size() == 0)) || + ((n->GetType() == ge::CONSTANTOP || n->GetType() == ge::CONSTANT) && (n->GetOutDataNodes().size() == 0))), + GE_CHK_STATUS_RET(ge::GraphUtils::IsolateNode(n, {}), "Isolate removed node: %s, type: %s failed", + n->GetName().c_str(), n->GetType().c_str()); + GE_CHK_STATUS_RET(ge::GraphUtils::RemoveNodeWithoutRelink(graph, n), + "Remove node: %s, type: %s without relink failed", n->GetName().c_str(), + n->GetType().c_str());); } return SUCCESS; } @@ -2708,19 +2709,19 @@ Status TensorFlowModelParser::RemoveIsolateNode(ge::ComputeGraphPtr &graph) { // if not specified, use InferInputFormats to infer, // and if the inference fails, the default NHWC format is used. domiTensorFormat_t TensorFlowModelParser::InferInputFormats() { - GE_IF_BOOL_EXEC(ge::GetParserContext().format != DOMI_TENSOR_RESERVED, return ge::GetParserContext().format); + GE_IF_BOOL_EXEC(domi::GetContext().format != DOMI_TENSOR_RESERVED, return domi::GetContext().format); domiTensorFormat_t global_input_format = DOMI_TENSOR_RESERVED; set visited_node; for (auto &node_item : nodedef_map_) { - // Infer format for data node and save it to ge::GetParserContext().format. + // Infer format for data node and save it to domi::GetContext().format. domiTensorFormat_t format = DOMI_TENSOR_RESERVED; const NodeDef *node = node_item.second; if (node == nullptr) { return format; } auto it = tensorflow_op_map.find(node->op()); - if (it != tensorflow_op_map.end() && it->second == ge::parser::DATA) { + if (it != tensorflow_op_map.end() && it->second == ge::DATA) { GE_IF_BOOL_EXEC(GetNodeFormat(node, NO_TRANSPOSE, format, visited_node) != SUCCESS, GELOGW("Cannot infer input format, the NHWC format is used by default, and you can also " "specify format by command line arguments."); @@ -2746,11 +2747,12 @@ Status TensorFlowModelParser::GetNodeFormat(const NodeDef *node, TfTranspose pre GE_IF_BOOL_EXEC(visited_node.find(node) != visited_node.end(), return SUCCESS); visited_node.emplace(node); - GE_IF_BOOL_EXEC(node->op() == TENSORFLOWF_NODE_OP_SWITCH || node->op() == TENSORFLOWF_NODE_OP_MERGE, return SUCCESS); + GE_IF_BOOL_EXEC(node->op() == TENSORFLOWF_NODE_OP_SWITCH || node->op() == TENSORFLOWF_NODE_OP_MERGE, + return SUCCESS); // If node has a data_format attribute, format is set according to data_format. domi::tensorflow::AttrValue attr; - if (ge::TensorFlowUtil::FindAttrValue(node, TENSORFLOW_ATTR_DATA_FORMAT, attr) && node->op() != ge::parser::BIASADD) { + if (ge::TensorFlowUtil::FindAttrValue(node, TENSORFLOW_ATTR_DATA_FORMAT, attr) && node->op() != ge::BIASADD) { GE_RETURN_IF_ERROR(ge::TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_STRING)); format = (attr.s() == TENSORFLOWF_TENSOR_NCHW) ? domi::DOMI_TENSOR_NCHW : domi::DOMI_TENSOR_NHWC; @@ -2840,15 +2842,13 @@ Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node, vector perm_value; - GE_IF_BOOL_EXEC( - type == domi::tensorflow::DT_INT32, - const int32_t *data = reinterpret_cast(tensor.tensor_content().data()); - for (int i = 0; i < ge::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); }); + GE_IF_BOOL_EXEC(type == domi::tensorflow::DT_INT32, + const int32_t *data = reinterpret_cast(tensor.tensor_content().data()); + for (int i = 0; i < ge::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); }); - GE_IF_BOOL_EXEC( - type == domi::tensorflow::DT_INT64, - const int64_t *data = reinterpret_cast(tensor.tensor_content().data()); - for (int i = 0; i < ge::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); }); + GE_IF_BOOL_EXEC(type == domi::tensorflow::DT_INT64, + const int64_t *data = reinterpret_cast(tensor.tensor_content().data()); + for (int i = 0; i < ge::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); }); // 0, 1, 2, 3 present dim num. vector perm_to_nchw = {0, 3, 1, 2}; @@ -2862,7 +2862,7 @@ Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node, Status TensorFlowModelParser::TrimGraph(const domi::tensorflow::GraphDef &input_graph_def, domi::tensorflow::GraphDef *output_graph_def) { GE_CHECK_NOTNULL(output_graph_def); - if (!ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) { + if (!domi::GetContext().input_dims.empty() && domi::GetContext().out_nodes_map.empty()) { return TrimGraphByInput(input_graph_def, output_graph_def); } else { return TrimGraphByOutput(input_graph_def, output_graph_def); @@ -2873,7 +2873,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef // The caller guarantees that the pointer is not null std::set delete_nodes; std::set input_nodes; - for (auto &iter : ge::GetParserContext().input_dims) { + for (auto &iter : domi::GetContext().input_dims) { input_nodes.insert(iter.first); } std::map node_lookup; @@ -2881,7 +2881,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef node_lookup[node.name()] = &node; } std::vector current_inputs; - for (auto &iter : ge::GetParserContext().input_dims) { + for (auto &iter : domi::GetContext().input_dims) { current_inputs.push_back(iter.first); } while (!current_inputs.empty()) { @@ -2922,7 +2922,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef domi::tensorflow::AttrValue attr_value; TensorShapeProto *data_shape = attr_value.mutable_shape(); GE_CHECK_NOTNULL(data_shape); - const ge::ParserContext &ctx = ge::GetParserContext(); + const ge::OmgContext &ctx = domi::GetContext(); std::unordered_map> input_dims = ctx.input_dims; std::vector designated_dims = input_dims.at(node.name()); for (int32_t i = 0; i < (int32_t)designated_dims.size(); i++) { @@ -2944,11 +2944,11 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef // The caller guarantees that the pointer is not null std::set required_nodes; std::set input_nodes; - for (auto &iter : ge::GetParserContext().input_dims) { + for (auto &iter : domi::GetContext().input_dims) { required_nodes.insert(iter.first); input_nodes.insert(iter.first); } - for (auto &iter : ge::GetParserContext().out_nodes_map) { + for (auto &iter : domi::GetContext().out_nodes_map) { required_nodes.insert(iter.first); } std::map node_lookup; @@ -2956,7 +2956,7 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef node_lookup[node.name()] = &node; } std::vector current_inputs; - for (auto &iter : ge::GetParserContext().out_nodes_map) { + for (auto &iter : domi::GetContext().out_nodes_map) { current_inputs.push_back(iter.first); } while (!current_inputs.empty()) { @@ -2995,7 +2995,7 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef domi::tensorflow::AttrValue attr_value; TensorShapeProto *data_shape = attr_value.mutable_shape(); GE_CHECK_NOTNULL(data_shape); - const ge::ParserContext &ctx = ge::GetParserContext(); + const ge::OmgContext &ctx = domi::GetContext(); std::unordered_map> input_dims = ctx.input_dims; std::vector designated_dims = input_dims.at(node.name()); for (int32_t i = 0; i < (int32_t)designated_dims.size(); i++) { @@ -3075,7 +3075,7 @@ Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr &op_par op_src_vec.push_back(op_src); } shared_ptr tf_custom_fusion_op_paser = - std::dynamic_pointer_cast(tensorflow_fusion_op_parser); + std::dynamic_pointer_cast(tensorflow_fusion_op_parser); status = tf_custom_fusion_op_paser->ParseParams(op_src_vec, node); if (status != SUCCESS) { GELOGE(status, "Parse params for fusionop node %s failed", node_def->name().c_str()); @@ -3121,8 +3121,8 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap // 2.2 check whether the current op is a TVM op. GE_CHK_BOOL_EXEC_INFO( - domi::OpRegistry::Instance()->GetImplyTypeByOriOpType(current_op_name) == domi::ImplyType::TVM, continue, - "op %s is not TVM op", current_op_name.c_str()); + domi::OpRegistry::Instance()->GetImplyTypeByOriOpType(current_op_name) == domi::ImplyType::TVM, + continue, "op %s is not TVM op", current_op_name.c_str()); GELOGD("handle tvm op %s", current_op_name.c_str()); // 2.3 copy input to attr @@ -3310,7 +3310,7 @@ Status TensorFlowModelParser::RemoveIsolateNode(domi::tensorflow::GraphDef *grap } if ((node_inputs_outputs_map_[node_name].first.empty() && node_inputs_outputs_map_[node_name].second.empty() && node->op() != kDpop) || - (node->op() == ge::parser::CONSTANT && node_inputs_outputs_map_[node_name].second.empty())) { + (node->op() == ge::CONSTANT && node_inputs_outputs_map_[node_name].second.empty())) { GELOGI("%s will inset to node_to_delete", node_name.c_str()); node_to_delete.insert(node_name); } @@ -3402,10 +3402,10 @@ Status TensorFlowModelParser::SetOriginNodeContext(NodeDef *node_def, OpNodeCont return SUCCESS; } -void TensorFlowModelParser::GetFusionInputInfo( - const string &fusion_op_name, OpNodeContext &fusion_context, +void TensorFlowModelParser::GetFusionInputInfo(const string &fusion_op_name, OpNodeContext &fusion_context, std::map>> &remap_data_input, - std::map> &remap_ctrl_input, std::set &fusion_input_nodes) { + std::map> &remap_ctrl_input, + std::set &fusion_input_nodes) { for (const auto &fusion_input : fusion_context.input_map) { string fusion_src_name = fusion_input.first; for (const auto &fusion_idx_pair : fusion_input.second) { @@ -3420,10 +3420,10 @@ void TensorFlowModelParser::GetFusionInputInfo( } } -void TensorFlowModelParser::GetFusionOutputInfo( - const string &fusion_op_name, OpNodeContext &fusion_context, +void TensorFlowModelParser::GetFusionOutputInfo(const string &fusion_op_name, OpNodeContext &fusion_context, std::map>>> &remap_data_output, - std::map> &remap_ctrl_output, std::set &fusion_output_nodes) { + std::map> &remap_ctrl_output, + std::set &fusion_output_nodes) { for (const auto &fusion_output : fusion_context.output_map) { string fusion_dst_name = fusion_output.first; for (const auto &fusion_idx_pair : fusion_output.second) { @@ -3444,7 +3444,8 @@ void TensorFlowModelParser::UpdateInnerInputMap(const string &fusion_op_name, Op std::set &fusion_input_nodes) { std::map>> remap_data_input; std::map> remap_ctrl_input; - GetFusionInputInfo(fusion_op_name, fusion_context, remap_data_input, remap_ctrl_input, fusion_input_nodes); + GetFusionInputInfo(fusion_op_name, fusion_context, + remap_data_input, remap_ctrl_input, fusion_input_nodes); for (const auto &node_name : inner_nodes_name) { auto context_iter = op_node_context_map_.find(node_name); @@ -3487,12 +3488,14 @@ void TensorFlowModelParser::UpdateInnerInputMap(const string &fusion_op_name, Op } } + void TensorFlowModelParser::UpdateInnerOutputMap(const string &fusion_op_name, OpNodeContext &fusion_context, const std::vector &inner_nodes_name, std::set &fusion_output_nodes) { std::map>>> remap_data_output; std::map> remap_ctrl_output; - GetFusionOutputInfo(fusion_op_name, fusion_context, remap_data_output, remap_ctrl_output, fusion_output_nodes); + GetFusionOutputInfo(fusion_op_name, fusion_context, + remap_data_output, remap_ctrl_output, fusion_output_nodes); for (const auto &node_name : inner_nodes_name) { auto context_iter = op_node_context_map_.find(node_name); if (context_iter != op_node_context_map_.end()) { @@ -3508,8 +3511,8 @@ void TensorFlowModelParser::UpdateInnerOutputMap(const string &fusion_op_name, O auto data_outputs = remap_data_output[fusion_op_name + std::to_string(out_pair.second)]; for (const auto &data : data_outputs) { tmp_output_map[data.first].emplace_back(std::make_pair(out_pair.first, data.second.second)); - GELOGI("Update inner output, dst:%s, idx:%u->%u.", data.first.c_str(), out_pair.first, - data.second.second); + GELOGI("Update inner output, dst:%s, idx:%u->%u.", data.first.c_str(), + out_pair.first, data.second.second); } } } @@ -3567,7 +3570,8 @@ Status TensorFlowModelParser::UpdateInnerNodeContext(const string &fusion_op_nam } Status TensorFlowModelParser::AddFusionInnerNodeDef(shared_ptr &scope_graph, - const string &fusion_op_name, vector &node_name_list) { + const string &fusion_op_name, + vector &node_name_list) { auto &impl_scope_graph = scope_graph->impl_; GE_CHECK_NOTNULL(impl_scope_graph); ge::FusionScopesResult *fusion_result = impl_scope_graph->GetFusionScopesResults(fusion_op_name); @@ -3670,7 +3674,7 @@ Status TensorFlowModelParser::AddScopeInnerNode(TensorFlowModelParser *parser, g return PARAM_INVALID; } const ge::Operator *op = iter->second; - ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(*op); + ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(*op); GE_CHECK_NOTNULL(op_desc); ge::NodePtr node; { @@ -3678,16 +3682,16 @@ Status TensorFlowModelParser::AddScopeInnerNode(TensorFlowModelParser *parser, g node = graph->AddNode(op_desc); } if (node == nullptr) { - GELOGE(INTERNAL_ERROR, "Failed to Add scope inner node:%s, type:%s.", op_desc->GetName().c_str(), - op_desc->GetType().c_str()); + GELOGE(INTERNAL_ERROR, "Failed to Add scope inner node:%s, type:%s.", + op_desc->GetName().c_str(), op_desc->GetType().c_str()); return INTERNAL_ERROR; } { std::lock_guard lock(parser->nodeMapMutex_); parser->node_map_[node_name] = node; } - GELOGI("Add scope inner node successfully, node name:%s, type:%s.", op_desc->GetName().c_str(), - op_desc->GetType().c_str()); + GELOGI("Add scope inner node successfully, node name:%s, type:%s.", + op_desc->GetName().c_str(), op_desc->GetType().c_str()); return SUCCESS; } @@ -3695,12 +3699,14 @@ void TensorFlowModelParser::DumpNodeContext(const string &node_name, const OpNod GELOGD("phase:%s === Begin to dump context for node:%s ===", phase.c_str(), node_name.c_str()); for (const auto &input : ctx.input_map) { for (const auto &input_idx : input.second) { - GELOGD(" Input info: %s:%d --> in_idx %d.", input.first.c_str(), input_idx.first, input_idx.second); + GELOGD(" Input info: %s:%d --> in_idx %d.", + input.first.c_str(), input_idx.first, input_idx.second); } } for (const auto &output : ctx.output_map) { for (const auto &output_idx : output.second) { - GELOGD(" Output info: out_idx %d --> %s:%d.", output_idx.first, output.first.c_str(), output_idx.second); + GELOGD(" Output info: out_idx %d --> %s:%d.", + output_idx.first, output.first.c_str(), output_idx.second); } } GELOGD("phase:%s === End to dump context for node:%s ===", phase.c_str(), node_name.c_str()); @@ -3714,9 +3720,9 @@ void TensorFlowModelParser::DumpAllNodeContext(const string &phase) { DumpNodeContext(iter.first, iter.second, phase); } } -} // namespace ge +} // namespace domi namespace domi { -REGISTER_MODEL_PARSER_CREATOR(TENSORFLOW, ge::TensorFlowModelParser); -REGISTER_WEIGHTS_PARSER_CREATOR(TENSORFLOW, ge::TensorFlowWeightsParser); -} // namespace domi + REGISTER_MODEL_PARSER_CREATOR(TENSORFLOW, ge::TensorFlowModelParser); + REGISTER_WEIGHTS_PARSER_CREATOR(TENSORFLOW, ge::TensorFlowWeightsParser); +} diff --git a/parser/tensorflow/tensorflow_parser_register.h b/parser/tensorflow/tensorflow_parser_register.h index 17a2368..6ff0e2e 100644 --- a/parser/tensorflow/tensorflow_parser_register.h +++ b/parser/tensorflow/tensorflow_parser_register.h @@ -25,7 +25,7 @@ #include "framework/omg/parser/op_parser.h" #include "parser/common/op_def/ir_pb_converter.h" #include "parser/common/op_def/operator.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/ge/ge_util.h" #include "parser/common/op_parser_factory.h" #include "parser/tensorflow/tensorflow_op_parser.h" #include "proto/tensorflow/node_def.pb.h" @@ -72,7 +72,7 @@ class TensorflowParserBuilder : public TensorflowWeightParserBuilder { } bool Finalize() override { - auto op_parser_adapter = ge::parser::MakeShared>(*this); + auto op_parser_adapter = ge::MakeShared>(*this); if (op_parser_adapter == nullptr) { GELOGE(FAILED, "Op parser adapter is null."); } @@ -102,7 +102,7 @@ class TensorflowOpParserAdapter : public TensorFlowOpParser { Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override { const domi::tensorflow::NodeDef *node = static_cast(op_src); GE_CHECK_NOTNULL(node); - std::shared_ptr param = ge::parser::MakeShared(); + std::shared_ptr param = ge::MakeShared(); if (param == nullptr) { GELOGE(domi::FAILED, "Param is null"); return domi::FAILED; diff --git a/parser/tensorflow/tensorflow_ref_switch_parser.cc b/parser/tensorflow/tensorflow_ref_switch_parser.cc index 08b7a30..624dde0 100644 --- a/parser/tensorflow/tensorflow_ref_switch_parser.cc +++ b/parser/tensorflow/tensorflow_ref_switch_parser.cc @@ -26,7 +26,6 @@ using domi::tensorflow::DT_FLOAT; using domi::tensorflow::AttrValue; using domi::tensorflow::NodeDef; using domi::TENSORFLOW; -using namespace ge::parser; namespace ge { // AUTO GEN PLEASE DO NOT MODIFY IT diff --git a/parser/tensorflow/tensorflow_reshape_parser.cc b/parser/tensorflow/tensorflow_reshape_parser.cc index 612c2ed..5c82176 100644 --- a/parser/tensorflow/tensorflow_reshape_parser.cc +++ b/parser/tensorflow/tensorflow_reshape_parser.cc @@ -22,10 +22,9 @@ #include "graph/utils/type_utils.h" #include "parser/common/op_parser_factory.h" #include "parser/tensorflow/tensorflow_util.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/math/math_util.h" using domi::TENSORFLOW; -using namespace ge::parser; namespace ge { Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) { @@ -48,7 +47,7 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); real_size *= tmp_dim; } - PARSER_INT64_MULCHECK(real_size, size_type); + FMK_INT64_MULCHECK(real_size, size_type); ge::TensorUtils::SetSize(ge_desc, real_size * size_type); ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); GELOGI("after translate tf_desc, datatype: %s, format: %s, real size: %u, size_type: %u", @@ -68,7 +67,7 @@ Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr domi::tensorflow::AttrValue output_attr_value; GE_IF_BOOL_EXEC( - GetParserContext().train_flag == true, + domi::GetContext().train_flag == true, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc; diff --git a/parser/tensorflow/tensorflow_shape_n_parser.cc b/parser/tensorflow/tensorflow_shape_n_parser.cc index e8c0e9c..bd74f18 100644 --- a/parser/tensorflow/tensorflow_shape_n_parser.cc +++ b/parser/tensorflow/tensorflow_shape_n_parser.cc @@ -26,7 +26,6 @@ using domi::tensorflow::AttrValue; using domi::tensorflow::DataType; using domi::tensorflow::DT_FLOAT; using domi::tensorflow::DT_INT32; -using namespace ge::parser; namespace { const std::string kShapeAttrDtype = "out_type"; diff --git a/parser/tensorflow/tensorflow_squeeze_parser.cc b/parser/tensorflow/tensorflow_squeeze_parser.cc index a1ff012..75e2c2d 100644 --- a/parser/tensorflow/tensorflow_squeeze_parser.cc +++ b/parser/tensorflow/tensorflow_squeeze_parser.cc @@ -22,16 +22,14 @@ #include "framework/common/op/attr_value_util.h" #include "framework/common/op/op_parser_util.h" #include "framework/common/util.h" -#include "framework/omg/parser/parser_inner_ctx.h" #include "graph/utils/type_utils.h" #include "parser/common/op_parser_factory.h" -#include "parser/common/acl_graph_parser_util.h" +#include "common/math/math_util.h" using domi::tensorflow::AttrValue; using std::vector; using std::shared_ptr; using domi::TENSORFLOW; -using namespace ge::parser; namespace ge { Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) { @@ -52,10 +50,10 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { tmp_dim = ge_desc.GetShape().GetDim(j); GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); - PARSER_INT64_MULCHECK(real_size, tmp_dim); + FMK_INT64_MULCHECK(real_size, tmp_dim); real_size *= tmp_dim; } - PARSER_INT64_MULCHECK(real_size, size_type); + FMK_INT64_MULCHECK(real_size, size_type); ge::TensorUtils::SetSize(ge_desc, real_size * size_type); ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); GELOGD("after translate tf_desc, datatype: %s, format: %s, real size: %u, size_type: %u", @@ -112,7 +110,7 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr domi::tensorflow::AttrValue output_attr_value; GE_IF_BOOL_EXEC( - GetParserContext().train_flag == true, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc; + domi::GetContext().train_flag == true, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc; if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { GE_CHK_BOOL_RET_STATUS(ParseDesc(input_attr_value, input_desc) == SUCCESS, FAILED, "parse input desc failed"); diff --git a/parser/tensorflow/tensorflow_util.cc b/parser/tensorflow/tensorflow_util.cc index 3a05ffb..a618c9e 100644 --- a/parser/tensorflow/tensorflow_util.cc +++ b/parser/tensorflow/tensorflow_util.cc @@ -15,25 +15,25 @@ */ #include "parser/tensorflow/tensorflow_util.h" -#include #include +#include #include #include #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "framework/common/op/ge_op_utils.h" -#include "framework/omg/parser/parser_types.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" #include "graph/utils/type_utils.h" #include "parser/tensorflow/tensorflow_op_parser.h" +#include "common/math/math_util.h" using domi::tensorflow::DT_INVALID; namespace ge { using AttrValueMap = ::google::protobuf::Map; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrValue( - const domi::tensorflow::NodeDef *node_def, const string &attr_name, domi::tensorflow::AttrValue &attr_value) { + const domi::tensorflow::NodeDef *node_def, const string &attr_name, domi::tensorflow::AttrValue &attr_value) { GE_CHECK_NOTNULL(node_def); const google::protobuf::Map &attr = node_def->attr(); const google::protobuf::Map::const_iterator it = attr.find(attr_name); @@ -46,7 +46,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrVa } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::CheckAttrHasType( - const domi::tensorflow::AttrValue &attr_value, const string &type) { + const domi::tensorflow::AttrValue &attr_value, const string &type) { uint32_t num_set = 0; #define VALIDATE_FIELD(name, type_string, oneof_case) \ do { \ @@ -118,7 +118,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Ch } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType( - const NodeDef *node_src, const string &attr_src, domi::tensorflow::DataType &data_type) { + const NodeDef *node_src, const string &attr_src, domi::tensorflow::DataType &data_type) { GE_CHECK_NOTNULL(node_src); string node_name = node_src->name(); @@ -138,7 +138,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Pa } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromAttrValueList( - ge::GeTensorDesc &ge_desc, const domi::tensorflow::AttrValue_ListValue &a_list, int32_t i, int32_t &tf_datatype) { + ge::GeTensorDesc &ge_desc, const domi::tensorflow::AttrValue_ListValue &a_list, int32_t i, int32_t &tf_datatype) { const std::string SERIALIZE_FORMAT = "serialize_format"; const std::string SERIALIZE_DATATYPE = "serialize_datatype"; const std::string SERIALIZE_SHAPE = "serialize_shape"; @@ -162,7 +162,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromA } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::TransTensorDescriptor( - const domi::tensorflow::AttrValue &attr_value, ParserOperator *op, const uint32_t io, const string &type) { + const domi::tensorflow::AttrValue &attr_value, ParserOperator *op, const uint32_t io, const string &type) { GE_CHECK_NOTNULL(op); if (!attr_value.has_list()) { return PARAM_INVALID; @@ -191,9 +191,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr // The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate. // Here, special treatment is given to the two operators. // Adjust shape to fit resnet50 network only. - GE_IF_BOOL_EXEC((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape()); - break;); - GE_IF_BOOL_EXEC((type == ge::parser::MEAN) && (tmp_dim == 0), vector data_dim = {tmp_dim}; + GE_IF_BOOL_EXEC((type == ge::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape()); break;); + GE_IF_BOOL_EXEC((type == ge::MEAN) && (tmp_dim == 0), vector data_dim = {tmp_dim}; ge_desc.SetShape(ge::GeShape(data_dim)); break;); } ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); @@ -215,7 +214,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr return SUCCESS; } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr( - const string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *node_def) { + const string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *node_def) { GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null."); node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); } diff --git a/parser/tensorflow/tensorflow_util.h b/parser/tensorflow/tensorflow_util.h index 40e780f..dd53736 100644 --- a/parser/tensorflow/tensorflow_util.h +++ b/parser/tensorflow/tensorflow_util.h @@ -26,7 +26,7 @@ #include "external/graph/attr_value.h" #include "external/graph/graph.h" #include "external/graph/operator.h" -#include "framework/omg/parser/parser_types.h" +#include "framework/common/types.h" #include "framework/omg/omg_inner_types.h" #include "graph/compute_graph.h" #include "graph/ge_tensor.h" diff --git a/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc b/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc index 4ec74bd..9b17b5f 100644 --- a/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc +++ b/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc @@ -22,8 +22,6 @@ #include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_parser_register.h" -using namespace ge::parser; - namespace ge { Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *op) { GE_CHECK_NOTNULL(op_src); diff --git a/parser/tensorflow/tensorflow_variable_v2_parser.cc b/parser/tensorflow/tensorflow_variable_v2_parser.cc index 139dd0e..fac29ac 100644 --- a/parser/tensorflow/tensorflow_variable_v2_parser.cc +++ b/parser/tensorflow/tensorflow_variable_v2_parser.cc @@ -32,7 +32,6 @@ using domi::tensorflow::AttrValue; using domi::tensorflow::NodeDef; using domi::tensorflow::TensorShapeProto; -using namespace ge::parser; namespace ge { const std::string SERIALIZE_FORMAT = "serialize_format";