Merge pull request !653 from 王笑天/ge_devpull/656/head
| @@ -1 +1 @@ | |||
| Subproject commit 5d062a35640733026457c91966a558769570b0f8 | |||
| Subproject commit f5c1b6d1b6b6e97d0cfcf2efd52ec8da12d32c86 | |||
| @@ -22,18 +22,18 @@ set(SRC_LIST | |||
| "caffe/caffe_custom_parser_adapter.cc" | |||
| "caffe/caffe_op_parser.cc" | |||
| "tensorflow/scope/scope_pass_manager.cc" | |||
| "tensorflow/graph_functiondef.cc" | |||
| "tensorflow/graph_optimizer.cc" | |||
| "tensorflow/graph_to_function_def.cc" | |||
| "tensorflow/parser_graph_optimizer.cc" | |||
| "tensorflow/iterator_fusion_pass.cc" | |||
| "common/op_def/arg_op.cc" | |||
| "common/op_def/constant_op.cc" | |||
| "common/op_def/fill_op.cc" | |||
| "common/op_def/frameworkop_op.cc" | |||
| "common/op_def/no_op_op.cc" | |||
| "common/op_def/ref_switch_op.cc" | |||
| "common/op_def/shape_n_op.cc" | |||
| "common/op_def/var_is_initialized_op_op.cc" | |||
| "common/op_def/variable_op.cc" | |||
| "common/op_def/arg_op_operator.cc" | |||
| "common/op_def/constant_operator.cc" | |||
| "common/op_def/fill_operator.cc" | |||
| "common/op_def/framework_op_operator.cc" | |||
| "common/op_def/no_op_operator.cc" | |||
| "common/op_def/ref_switch_operator.cc" | |||
| "common/op_def/shape_n_operator.cc" | |||
| "common/op_def/var_is_initialized_op_operator.cc" | |||
| "common/op_def/variable_operator.cc" | |||
| ) | |||
| ############ libfmk_parser.so ############ | |||
| @@ -236,14 +236,8 @@ const char *const kFieldInnerPro = "inner_product_param"; | |||
| const char *const kFieldDim = "dim"; | |||
| const char *const kFieldBiasTerm = "bias_term"; | |||
| const char *const kDevNull = "/dev/null"; | |||
| const std::string kMessage = "message"; | |||
| const std::string kLayerParameter = "LayerParameter"; | |||
| const std::string kCloseBrace = "}"; | |||
| const std::string kOptional = "optional"; | |||
| const std::string kRepeated = "repeated"; | |||
| const std::string kRequired = "required"; | |||
| const std::string kCustom = "custom"; | |||
| const std::string kBuiltin = "built-in"; | |||
| const char *const kCustom = "custom"; | |||
| const char *const kBuiltin = "built-in"; | |||
| std::vector<std::string> kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT, | |||
| ge::parser::NETOUTPUT}; | |||
| const std::set<std::string> kCustomProtoLayerCommonField = {"name", "type"}; | |||
| @@ -284,104 +278,104 @@ const set<string> CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxW | |||
| "Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"}; | |||
| Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const { | |||
| if (proto_message.input_size() > 0) { | |||
| GELOGI("This net exsit input."); | |||
| if (proto_message.input_size() <= 0) { | |||
| return SUCCESS; | |||
| } | |||
| GELOGI("This net exsit input."); | |||
| if (proto_message.input_dim_size() > 0) { | |||
| if (proto_message.input_shape_size() > 0) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E11001"); | |||
| GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); | |||
| return FAILED; | |||
| } | |||
| if (proto_message.input_dim_size() > 0) { | |||
| if (proto_message.input_shape_size() > 0) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E11001"); | |||
| GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); | |||
| return FAILED; | |||
| } | |||
| const int32_t input_dim_size = proto_message.input_dim_size(); | |||
| const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || | |||
| ((input_dim_size % proto_message.input_size()) != 0)); | |||
| if (is_input_invalid) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, | |||
| {std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); | |||
| GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", | |||
| input_dim_size, proto_message.input_size()); | |||
| return FAILED; | |||
| } | |||
| const int32_t input_dim_size = proto_message.input_dim_size(); | |||
| const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || | |||
| ((input_dim_size % proto_message.input_size()) != 0)); | |||
| if (is_input_invalid) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, | |||
| {std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); | |||
| GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", | |||
| input_dim_size, proto_message.input_size()); | |||
| return FAILED; | |||
| } | |||
| for (int i = 0; i < proto_message.input_size(); i++) { | |||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
| GE_CHECK_NOTNULL(layer); | |||
| layer->set_name(proto_message.input(i)); | |||
| layer->set_type(ge::parser::INPUT_TYPE); | |||
| layer->add_top(proto_message.input(i)); | |||
| for (int i = 0; i < proto_message.input_size(); i++) { | |||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
| GE_CHECK_NOTNULL(layer); | |||
| layer->set_name(proto_message.input(i)); | |||
| layer->set_type(ge::parser::INPUT_TYPE); | |||
| layer->add_top(proto_message.input(i)); | |||
| domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
| GE_CHECK_NOTNULL(input_param); | |||
| domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
| GE_CHECK_NOTNULL(shape); | |||
| for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { | |||
| // Can guarantee that it will not cross the border | |||
| shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); | |||
| } | |||
| input_data_flag = true; | |||
| } | |||
| } else if (proto_message.input_shape_size() > 0) { | |||
| if (proto_message.input_shape_size() != proto_message.input_size()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, | |||
| {std::to_string(proto_message.input_shape_size()), | |||
| std::to_string(proto_message.input_size())}); | |||
| GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", | |||
| proto_message.input_shape_size(), proto_message.input_size()); | |||
| return FAILED; | |||
| domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
| GE_CHECK_NOTNULL(input_param); | |||
| domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
| GE_CHECK_NOTNULL(shape); | |||
| for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { | |||
| // Can guarantee that it will not cross the border | |||
| shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); | |||
| } | |||
| input_data_flag = true; | |||
| } | |||
| } else if (proto_message.input_shape_size() > 0) { | |||
| if (proto_message.input_shape_size() != proto_message.input_size()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, | |||
| {std::to_string(proto_message.input_shape_size()), | |||
| std::to_string(proto_message.input_size())}); | |||
| GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", | |||
| proto_message.input_shape_size(), proto_message.input_size()); | |||
| return FAILED; | |||
| } | |||
| for (int i = 0; i < proto_message.input_size(); i++) { | |||
| int dim_size = proto_message.input_shape(i).dim_size(); | |||
| for (int i = 0; i < proto_message.input_size(); i++) { | |||
| int dim_size = proto_message.input_shape(i).dim_size(); | |||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
| GE_CHECK_NOTNULL(layer); | |||
| layer->set_name(proto_message.input(i)); | |||
| layer->set_type(ge::parser::INPUT_TYPE); | |||
| layer->add_top(proto_message.input(i)); | |||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
| GE_CHECK_NOTNULL(layer); | |||
| layer->set_name(proto_message.input(i)); | |||
| layer->set_type(ge::parser::INPUT_TYPE); | |||
| layer->add_top(proto_message.input(i)); | |||
| domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
| GE_CHECK_NOTNULL(input_param); | |||
| domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
| GE_CHECK_NOTNULL(shape); | |||
| domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
| GE_CHECK_NOTNULL(input_param); | |||
| domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
| GE_CHECK_NOTNULL(shape); | |||
| for (int j = 0; j < dim_size; j++) { | |||
| // Can guarantee that it will not cross the border | |||
| shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j))); | |||
| } | |||
| input_data_flag = true; | |||
| for (int j = 0; j < dim_size; j++) { | |||
| // Can guarantee that it will not cross the border | |||
| shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j))); | |||
| } | |||
| } else { | |||
| const ge::ParserContext &ctx = ge::GetParserContext(); | |||
| std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||
| for (int i = 0; i < proto_message.input_size(); i++) { | |||
| string name = proto_message.input(i); | |||
| if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input | |||
| REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name})); | |||
| GELOGE(FAILED, "[Find][Dim]Model has no input shape."); | |||
| return FAILED; | |||
| } | |||
| std::vector<int64_t> dims = input_dims.at(name); | |||
| size_t dim_size = dims.size(); | |||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
| GE_CHECK_NOTNULL(layer); | |||
| layer->set_name(name); | |||
| layer->set_type(ge::parser::INPUT_TYPE); | |||
| layer->add_top(proto_message.input(i)); | |||
| domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
| GE_CHECK_NOTNULL(input_param); | |||
| domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
| GE_CHECK_NOTNULL(shape); | |||
| for (size_t j = 0; j < dim_size; j++) { | |||
| shape->add_dim(dims.at(j)); | |||
| } | |||
| input_data_flag = true; | |||
| input_data_flag = true; | |||
| } | |||
| } else { | |||
| const ge::ParserContext &ctx = ge::GetParserContext(); | |||
| std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||
| for (int i = 0; i < proto_message.input_size(); i++) { | |||
| string name = proto_message.input(i); | |||
| if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input | |||
| REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name})); | |||
| GELOGE(FAILED, "[Find][Dim]Model has no input shape."); | |||
| return FAILED; | |||
| } | |||
| std::vector<int64_t> dims = input_dims.at(name); | |||
| size_t dim_size = dims.size(); | |||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
| GE_CHECK_NOTNULL(layer); | |||
| layer->set_name(name); | |||
| layer->set_type(ge::parser::INPUT_TYPE); | |||
| layer->add_top(proto_message.input(i)); | |||
| domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
| GE_CHECK_NOTNULL(input_param); | |||
| domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
| GE_CHECK_NOTNULL(shape); | |||
| for (size_t j = 0; j < dim_size; j++) { | |||
| shape->add_dim(dims.at(j)); | |||
| } | |||
| input_data_flag = true; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| @@ -423,7 +417,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons | |||
| return FAILED; | |||
| } | |||
| if (ParseLayerParameter(layer_descriptor, message, operators) != SUCCESS) { | |||
| if (ParseLayerParameter(*layer_descriptor, *message, operators) != SUCCESS) { | |||
| delete message; | |||
| GELOGE(FAILED, "[Parse][LayerParameter] failed, model path:%s.", model_path); | |||
| return FAILED; | |||
| @@ -536,18 +530,18 @@ Status CaffeModelParser::ReadCaffeModelFromText(const char *model_path, google:: | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||
| const google::protobuf::Message *message, | |||
| Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||
| const google::protobuf::Message &message, | |||
| vector<ge::Operator> &operators) const { | |||
| auto field_name = layer_descriptor->FindFieldByName(kFieldName); | |||
| auto field_name = layer_descriptor.FindFieldByName(kFieldName); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); | |||
| auto field_type = layer_descriptor->FindFieldByName(kFieldType); | |||
| auto field_type = layer_descriptor.FindFieldByName(kFieldType); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); | |||
| const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
| const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
| vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
| reflection->ListFields(*message, &field_desc); | |||
| reflection->ListFields(message, &field_desc); | |||
| for (auto &field : field_desc) { | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message"); | |||
| // Only care about layers | |||
| @@ -561,10 +555,10 @@ Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor | |||
| return FAILED; | |||
| } | |||
| int field_size = reflection->FieldSize(*message, field); | |||
| int field_size = reflection->FieldSize(message, field); | |||
| GELOGI("Total Layer num of model file is %d", field_size); | |||
| for (int i = 0; i < field_size; ++i) { | |||
| const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); | |||
| const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); | |||
| const google::protobuf::Reflection *layer_reflection = layer_message.GetReflection(); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); | |||
| GE_CHECK_NOTNULL(layer_reflection); | |||
| @@ -1316,7 +1310,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||
| layer_name_map[layer.name()]++; | |||
| // Set the name in proto and layer | |||
| domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); | |||
| duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) | |||
| duplicate_name_layer->set_name(new_name); | |||
| layer.set_name(new_name);) | |||
| // Insert the new operator name, the number of times of duplicate name is recorded as 1 | |||
| layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); | |||
| @@ -1539,7 +1534,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||
| layer_name_map[layer.name()]++; | |||
| // Set the name in proto and layer | |||
| domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); | |||
| duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) | |||
| duplicate_name_layer->set_name(new_name); | |||
| layer.set_name(new_name);) | |||
| // Insert the new operator name, the number of times of duplicate name is recorded as 1 | |||
| layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); | |||
| @@ -1832,13 +1828,13 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con | |||
| return FAILED; | |||
| } | |||
| if (CheckLayersSize(message) != SUCCESS) { | |||
| if (CheckLayersSize(*message) != SUCCESS) { | |||
| delete message; | |||
| message = nullptr; | |||
| return FAILED; | |||
| } | |||
| if (ParseLayerParameter(layer_descriptor, message, graph) != SUCCESS) { | |||
| if (ParseLayerParameter(*layer_descriptor, *message, graph) != SUCCESS) { | |||
| delete message; | |||
| message = nullptr; | |||
| REPORT_CALL_ERROR("E19999", "ParseLayerParameter failed failed from weight file:%s.", weight_path); | |||
| @@ -1852,18 +1848,18 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||
| const google::protobuf::Message *message, | |||
| Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||
| const google::protobuf::Message &message, | |||
| ge::ComputeGraphPtr &graph) { | |||
| auto field_name = layer_descriptor->FindFieldByName(kFieldName); | |||
| auto field_name = layer_descriptor.FindFieldByName(kFieldName); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); | |||
| auto field_type = layer_descriptor->FindFieldByName(kFieldType); | |||
| auto field_type = layer_descriptor.FindFieldByName(kFieldType); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); | |||
| const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
| const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
| vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
| reflection->ListFields(*message, &field_desc); | |||
| reflection->ListFields(message, &field_desc); | |||
| NetParameter tmp_net; | |||
| for (auto &field : field_desc) { | |||
| @@ -1880,13 +1876,13 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto | |||
| return FAILED; | |||
| } | |||
| int field_size = reflection->FieldSize(*message, field); | |||
| int field_size = reflection->FieldSize(message, field); | |||
| GELOGI("Total Layer num of model file is %d", field_size); | |||
| for (int i = 0; i < field_size; ++i) { | |||
| const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); | |||
| const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); | |||
| LayerParameter *layer = tmp_net.add_layer(); | |||
| if (ConvertLayerProto(&layer_message, layer) != SUCCESS) { | |||
| if (ConvertLayerProto(layer_message, layer) != SUCCESS) { | |||
| GELOGE(FAILED, "[Invoke][ConvertLayerProto] Convert message to layer proto failed."); | |||
| return FAILED; | |||
| } | |||
| @@ -1907,16 +1903,16 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *message, | |||
| Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message &message, | |||
| google::protobuf::Message *layer) { | |||
| const google::protobuf::Reflection *layer_reflection = message->GetReflection(); | |||
| const google::protobuf::Reflection *layer_reflection = message.GetReflection(); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); | |||
| vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
| layer_reflection->ListFields(*message, &field_desc); | |||
| layer_reflection->ListFields(message, &field_desc); | |||
| for (auto &field : field_desc) { | |||
| GE_CHECK_NOTNULL(field); | |||
| if (ParseLayerField(layer_reflection, message, field, layer) != SUCCESS) { | |||
| if (ParseLayerField(*layer_reflection, message, *field, layer) != SUCCESS) { | |||
| GELOGE(FAILED, "[Invoke][ParseLayerField] Parse field %s failed.", field->name().c_str()); | |||
| return FAILED; | |||
| } | |||
| @@ -1924,114 +1920,114 @@ Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *me | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *reflection, | |||
| const google::protobuf::Message *message, | |||
| const google::protobuf::FieldDescriptor *field, | |||
| Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection &reflection, | |||
| const google::protobuf::Message &message, | |||
| const google::protobuf::FieldDescriptor &field, | |||
| google::protobuf::Message *layer) const { | |||
| GELOGD("Start to parse field: %s.", field->name().c_str()); | |||
| GELOGD("Start to parse field: %s.", field.name().c_str()); | |||
| domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer); | |||
| string filed_name = field->name(); | |||
| #define CASE_FIELD_NAME(kName, method) \ | |||
| string filed_name = field.name(); | |||
| #define CASE_FIELD_NAME(kName, method, inner_message, field_ptr) \ | |||
| if (filed_name == kField##kName) { \ | |||
| string value = reflection->GetString(*message, field); \ | |||
| string value = reflection.GetString(inner_message, field_ptr); \ | |||
| GELOGD("Parse res: (%s : %s)", filed_name.c_str(), value.c_str()); \ | |||
| layer_proto->set_##method(value); \ | |||
| return SUCCESS; \ | |||
| } | |||
| CASE_FIELD_NAME(Name, name); | |||
| CASE_FIELD_NAME(Type, type); | |||
| CASE_FIELD_NAME(Name, name, message, &field); | |||
| CASE_FIELD_NAME(Type, type, message, &field); | |||
| #undef CASE_FIELD_NAME | |||
| #define CASE_FIELD_NAME_REPEATED(kName, method) \ | |||
| if (filed_name == kField##kName) { \ | |||
| int field_size = reflection->FieldSize(*message, field); \ | |||
| for (int i = 0; i < field_size; ++i) { \ | |||
| auto value = reflection->GetRepeatedString(*message, field, i); \ | |||
| layer_proto->add_##method(value); \ | |||
| } \ | |||
| return SUCCESS; \ | |||
| } | |||
| CASE_FIELD_NAME_REPEATED(Bottom, bottom); | |||
| CASE_FIELD_NAME_REPEATED(Top, top); | |||
| #define CASE_FIELD_NAME_REPEATED(kName, method, inner_message, field_ptr) \ | |||
| if (filed_name == kField##kName) { \ | |||
| int field_size = reflection.FieldSize(inner_message, field_ptr); \ | |||
| for (int i = 0; i < field_size; ++i) { \ | |||
| auto value = reflection.GetRepeatedString(inner_message, field_ptr, i); \ | |||
| layer_proto->add_##method(value); \ | |||
| } \ | |||
| return SUCCESS; \ | |||
| } | |||
| CASE_FIELD_NAME_REPEATED(Bottom, bottom, message, &field); | |||
| CASE_FIELD_NAME_REPEATED(Top, top, message, &field); | |||
| #undef CASE_FIELD_NAME_REPEATED | |||
| if (filed_name == kFieldBlobs) { | |||
| int field_size = reflection->FieldSize(*message, field); | |||
| int field_size = reflection.FieldSize(message, &field); | |||
| for (int i = 0; i < field_size; ++i) { | |||
| domi::caffe::BlobProto *item_message = layer_proto->add_blobs(); | |||
| const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); | |||
| if (ConvertBlobsProto(&sub_message, item_message) != SUCCESS) { | |||
| GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field->name().c_str()); | |||
| const google::protobuf::Message &sub_message = reflection.GetRepeatedMessage(message, &field, i); | |||
| if (ConvertBlobsProto(sub_message, item_message) != SUCCESS) { | |||
| GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field.name().c_str()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| if (filed_name == kFieldConvParam) { | |||
| const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); | |||
| const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); | |||
| ConvolutionParameter *conv_param = layer_proto->mutable_convolution_param(); | |||
| ConvertConvParamProto(&sub_message, conv_param); | |||
| ConvertConvParamProto(sub_message, conv_param); | |||
| } | |||
| if (filed_name == kFieldInnerPro) { | |||
| const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); | |||
| const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); | |||
| InnerProductParameter *inner_product = layer_proto->mutable_inner_product_param(); | |||
| ConvertInnerProdcutProto(&sub_message, inner_product); | |||
| ConvertInnerProdcutProto(sub_message, inner_product); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message, | |||
| Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message &message, | |||
| google::protobuf::Message *blobs) const { | |||
| const google::protobuf::Reflection *blobs_reflection = message->GetReflection(); | |||
| const google::protobuf::Reflection *blobs_reflection = message.GetReflection(); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); | |||
| vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
| blobs_reflection->ListFields(*message, &field_desc); | |||
| blobs_reflection->ListFields(message, &field_desc); | |||
| domi::caffe::BlobProto *blobs_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobProto>(blobs); | |||
| for (auto &field : field_desc) { | |||
| GE_CHECK_NOTNULL(field); | |||
| string feild_name = field->name(); | |||
| #define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name) \ | |||
| if (feild_name == #kName) { \ | |||
| int field_size = blobs_reflection->FieldSize(*message, field); \ | |||
| for (int i = 0; i < field_size; ++i) { \ | |||
| valuetype value = blobs_reflection->GetRepeated##method(*message, field, i); \ | |||
| blobs_proto->add_##name(value); \ | |||
| } \ | |||
| continue; \ | |||
| } | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data); | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff); | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data); | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff); | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data); | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data); | |||
| #define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name, inner_message, inner_field) \ | |||
| if (feild_name == #kName) { \ | |||
| int field_size = blobs_reflection->FieldSize(inner_message, inner_field); \ | |||
| for (int i = 0; i < field_size; ++i) { \ | |||
| valuetype value = blobs_reflection->GetRepeated##method(inner_message, inner_field, i); \ | |||
| blobs_proto->add_##name(value); \ | |||
| } \ | |||
| continue; \ | |||
| } | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data, message, field); | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff, message, field); | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data, message, field); | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff, message, field); | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data, message, field); | |||
| CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data, message, field); | |||
| #undef CASE_BLOBS_FIELD_NAME_REPEATED | |||
| #define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name) \ | |||
| if (feild_name == #kName) { \ | |||
| valuetype value = blobs_reflection->Get##method(*message, field); \ | |||
| blobs_proto->set_##name(value); \ | |||
| continue; \ | |||
| } | |||
| CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data); | |||
| CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num); | |||
| CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels); | |||
| CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height); | |||
| CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width); | |||
| #define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name, inner_message, inner_field) \ | |||
| if (feild_name == #kName) { \ | |||
| valuetype value = blobs_reflection->Get##method(inner_message, inner_field); \ | |||
| blobs_proto->set_##name(value); \ | |||
| continue; \ | |||
| } | |||
| CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data, message, field); | |||
| CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num, message, field); | |||
| CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels, message, field); | |||
| CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height, message, field); | |||
| CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width, message, field); | |||
| #undef CASE_BLOBS_FIELD_NAME | |||
| if (feild_name == kFieldShape) { | |||
| const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(*message, field); | |||
| const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(message, field); | |||
| domi::caffe::BlobShape *blob_shape = blobs_proto->mutable_shape(); | |||
| ConvertBlobShapeProto(&sub_message, blob_shape); | |||
| ConvertBlobShapeProto(sub_message, blob_shape); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message *message, | |||
| Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message &message, | |||
| google::protobuf::Message *dest_message) const { | |||
| const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
| const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
| vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
| reflection->ListFields(*message, &field_desc); | |||
| reflection->ListFields(message, &field_desc); | |||
| domi::caffe::BlobShape *shape_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobShape>(dest_message); | |||
| @@ -2039,21 +2035,21 @@ Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message | |||
| if (field->name() != kFieldDim) { | |||
| continue; | |||
| } | |||
| int field_size = reflection->FieldSize(*message, field); | |||
| int field_size = reflection->FieldSize(message, field); | |||
| for (int i = 0; i < field_size; ++i) { | |||
| int64_t value = reflection->GetRepeatedInt64(*message, field, i); | |||
| int64_t value = reflection->GetRepeatedInt64(message, field, i); | |||
| shape_proto->add_dim(value); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message *message, | |||
| Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message &message, | |||
| google::protobuf::Message *dest_message) const { | |||
| const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
| const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
| vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
| reflection->ListFields(*message, &field_desc); | |||
| reflection->ListFields(message, &field_desc); | |||
| domi::caffe::ConvolutionParameter *conv_param_proto = | |||
| PtrToPtr<google::protobuf::Message, domi::caffe::ConvolutionParameter>(dest_message); | |||
| @@ -2062,18 +2058,18 @@ Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message | |||
| if (field->name() != kFieldBiasTerm) { | |||
| continue; | |||
| } | |||
| bool value = reflection->GetBool(*message, field); | |||
| bool value = reflection->GetBool(message, field); | |||
| conv_param_proto->set_bias_term(value); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message *message, | |||
| Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message &message, | |||
| google::protobuf::Message *dest_message) const { | |||
| const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
| const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
| vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
| reflection->ListFields(*message, &field_desc); | |||
| reflection->ListFields(message, &field_desc); | |||
| domi::caffe::InnerProductParameter *inner_product_proto = | |||
| PtrToPtr<google::protobuf::Message, domi::caffe::InnerProductParameter>(dest_message); | |||
| @@ -2082,17 +2078,17 @@ Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Mess | |||
| if (field->name() != kFieldBiasTerm) { | |||
| continue; | |||
| } | |||
| bool value = reflection->GetBool(*message, field); | |||
| bool value = reflection->GetBool(message, field); | |||
| inner_product_proto->set_bias_term(value); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *message) const { | |||
| const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
| Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message &message) const { | |||
| const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
| CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
| vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
| reflection->ListFields(*message, &field_desc); | |||
| reflection->ListFields(message, &field_desc); | |||
| int num_layer = 0; | |||
| int num_layers = 0; | |||
| @@ -2110,7 +2106,7 @@ Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *mess | |||
| return FAILED; | |||
| } | |||
| int field_size = reflection->FieldSize(*message, field); | |||
| int field_size = reflection->FieldSize(message, field); | |||
| if (field->name() == kLayerName) { | |||
| num_layer = field_size; | |||
| } else { | |||
| @@ -212,8 +212,8 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||
| * @return SUCCESS parse layer successfully | |||
| * @return FAILED parse layer failed | |||
| */ | |||
| Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||
| const google::protobuf::Message *message, std::vector<ge::Operator> &operators) const; | |||
| Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||
| const google::protobuf::Message &message, std::vector<ge::Operator> &operators) const; | |||
| /* | |||
| * @ingroup domi_omg | |||
| @@ -386,33 +386,33 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser { | |||
| Status ParseWeightByFusionProto(const char *weight_path, const string &fusion_proto_path, | |||
| const string &fusion_proto_name, ge::ComputeGraphPtr &graph); | |||
| Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||
| const google::protobuf::Message *message, | |||
| Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||
| const google::protobuf::Message &message, | |||
| ge::ComputeGraphPtr &graph); | |||
| Status ConvertLayerParameter(const google::protobuf::Message *layer_message, | |||
| ge::ComputeGraphPtr &graph); | |||
| Status CheckLayersSize(const google::protobuf::Message *message) const; | |||
| Status CheckLayersSize(const google::protobuf::Message &message) const; | |||
| Status ConvertLayerProto(const google::protobuf::Message *message, | |||
| Status ConvertLayerProto(const google::protobuf::Message &message, | |||
| google::protobuf::Message *layer); | |||
| Status ParseLayerField(const google::protobuf::Reflection *reflection, | |||
| const google::protobuf::Message *message, | |||
| const google::protobuf::FieldDescriptor *field, | |||
| Status ParseLayerField(const google::protobuf::Reflection &reflection, | |||
| const google::protobuf::Message &message, | |||
| const google::protobuf::FieldDescriptor &field, | |||
| google::protobuf::Message *layer) const; | |||
| Status ConvertBlobsProto(const google::protobuf::Message *message, | |||
| Status ConvertBlobsProto(const google::protobuf::Message &message, | |||
| google::protobuf::Message *blobs) const; | |||
| Status ConvertBlobShapeProto(const google::protobuf::Message *message, | |||
| Status ConvertBlobShapeProto(const google::protobuf::Message &message, | |||
| google::protobuf::Message *dest_message) const; | |||
| Status ConvertInnerProdcutProto(const google::protobuf::Message *message, | |||
| Status ConvertInnerProdcutProto(const google::protobuf::Message &message, | |||
| google::protobuf::Message *dest_message) const; | |||
| Status ConvertConvParamProto(const google::protobuf::Message *message, | |||
| Status ConvertConvParamProto(const google::protobuf::Message &message, | |||
| google::protobuf::Message *dest_message) const; | |||
| /** | |||
| * @ingroup domi_omg | |||
| @@ -431,6 +431,41 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra | |||
| return SUCCESS; | |||
| } | |||
| domi::Status AclGrphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, | |||
| const std::string &input_data_names) const { | |||
| std::vector<std::string> input_names = StringUtils::Split(input_data_names, ','); | |||
| std::unordered_map<std::string, size_t> name_to_index; | |||
| for (auto &input_name : input_names) { | |||
| if (!name_to_index.emplace(input_name, name_to_index.size()).second) { | |||
| GELOGE(PARAM_INVALID, "[Check][Param] Duplicate input name[%s].", input_name.c_str()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| for (const NodePtr &node : graph->GetDirectNode()) { | |||
| if (node->GetType() != ge::parser::DATA) { | |||
| continue; | |||
| } | |||
| auto op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| auto iter = name_to_index.find(node->GetName()); | |||
| if (iter== name_to_index.cend()) { | |||
| GELOGE(PARAM_INVALID, "[Check][Param] Input name[%s] is not in input_data_names", | |||
| node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| GELOGI("[SetSpecifyIndexAttr] set node(%s) index attr, index is %ld", | |||
| op_desc->GetName().c_str(), iter->second); | |||
| if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, iter->second)) { | |||
| REPORT_CALL_ERROR("E19999", "set attr %s failed for node:%s", | |||
| ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str()); | |||
| GELOGE(FAILED, "set attr %s failed for node:%s", ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name) const { | |||
| output_nodes_name.clear(); | |||
| @@ -670,6 +705,16 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | |||
| return PARAM_INVALID; | |||
| } | |||
| string input_data_names; | |||
| GetAclParams(parser_params, ge::ir_option::INPUT_DATA_NAMES, input_data_names); | |||
| if (!input_data_names.empty()) { | |||
| if (SetSpecifyIndexAttrByInputNames(compute_graph, input_data_names) != SUCCESS) { | |||
| GELOGE(FAILED, "[Invoke][SetIndexAttr] set index attr failed, graph:%s", | |||
| compute_graph->GetName().c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -61,6 +61,7 @@ class AclGrphParseUtil { | |||
| size_t index, OpDescPtr &op_desc); | |||
| domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | |||
| const string &is_input_adjust_hw_layout) const; | |||
| domi::Status SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, const std::string &input_data_names) const; | |||
| domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | |||
| std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const; | |||
| }; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -82,7 +82,7 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr | |||
| switch (field->type()) { | |||
| case ProtobufFieldDescriptor::TYPE_MESSAGE: { | |||
| const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); | |||
| if (0UL != tmp_message.ByteSizeLong()) { | |||
| if (tmp_message.ByteSizeLong() != 0UL) { | |||
| Message2Json(tmp_message, black_fields, json[field->name()], enum2str, depth + 1); | |||
| } | |||
| break; | |||
| @@ -122,7 +122,7 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr | |||
| case ProtobufFieldDescriptor::TYPE_FLOAT: | |||
| char str[kSignificantDigits]; | |||
| if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1){ | |||
| if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) { | |||
| json[field->name()] = str; | |||
| } else { | |||
| json[field->name()] = reflection->GetFloat(message, field); | |||
| @@ -155,10 +155,8 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { | |||
| } | |||
| string result = ""; | |||
| for (char temp_value : type_bytes) { | |||
| uint8_t *value = 0; | |||
| value = reinterpret_cast<uint8_t *>(&temp_value); | |||
| char str[kSignificantDigits]; | |||
| if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1){ | |||
| if (sprintf_s(str, kSignificantDigits, "%c", temp_value) == -1) { | |||
| GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); | |||
| continue; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "parser/common/op_def/arg_op.h" | |||
| #include "parser/common/op_def/arg_op_operator.h" | |||
| #include <string> | |||
| #include "framework/common/fmk_types.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/op_def/constant_op.h" | |||
| #include "common/op_def/constant_operator.h" | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/op_def/fill_op.h" | |||
| #include "common/op_def/fill_operator.h" | |||
| #include "framework/common/fmk_types.h" | |||
| namespace ge { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/op_def/frameworkop_op.h" | |||
| #include "common/op_def/framework_op_operator.h" | |||
| #include <string> | |||
| #include "framework/common/fmk_types.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||
| #include "common/op_def/no_op_op.h" | |||
| #include "common/op_def/no_op_operator.h" | |||
| #include <string> | |||
| namespace ge { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -18,7 +18,6 @@ | |||
| #ifndef DOMI_OP_NO_OP_OP_H_ | |||
| #define DOMI_OP_NO_OP_OP_H_ | |||
| #include "parser/common/op_def/operator.h" | |||
| #include "framework/omg/parser/parser_types.h" | |||
| namespace ge { | |||
| class NoOpOperator : public ParserOperator { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||
| #include "common/op_def/ref_switch_op.h" | |||
| #include "common/op_def/ref_switch_operator.h" | |||
| namespace ge { | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::RefSwitchOperator() : ParserOperator("RefSwitch") {} | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||
| #include "common/op_def/shape_n_op.h" | |||
| #include "common/op_def/shape_n_operator.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "framework/omg/parser/parser_types.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||
| #include "common/op_def/var_is_initialized_op_op.h" | |||
| #include "common/op_def/var_is_initialized_op_operator.h" | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "parser/common/op_def/variable_op.h" | |||
| #include "parser/common/op_def/variable_operator.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -92,18 +92,18 @@ PARSER_SCOPE_SRC_FILES := \ | |||
| tensorflow/scope/scope_pass_manager.cc \ | |||
| FMK_COMMON_SRC_FILES := \ | |||
| tensorflow/graph_functiondef.cc \ | |||
| tensorflow/graph_optimizer.cc \ | |||
| tensorflow/graph_to_function_def.cc \ | |||
| tensorflow/parser_graph_optimizer.cc \ | |||
| tensorflow/iterator_fusion_pass.cc \ | |||
| common/op_def/arg_op.cc \ | |||
| common/op_def/constant_op.cc \ | |||
| common/op_def/fill_op.cc \ | |||
| common/op_def/frameworkop_op.cc \ | |||
| common/op_def/no_op_op.cc \ | |||
| common/op_def/ref_switch_op.cc \ | |||
| common/op_def/shape_n_op.cc \ | |||
| common/op_def/var_is_initialized_op_op.cc \ | |||
| common/op_def/variable_op.cc \ | |||
| common/op_def/arg_op_operator.cc \ | |||
| common/op_def/constant_operator.cc \ | |||
| common/op_def/fill_operator.cc \ | |||
| common/op_def/framework_op_operator.cc \ | |||
| common/op_def/no_op_operator.cc \ | |||
| common/op_def/ref_switch_operator.cc \ | |||
| common/op_def/shape_n_operator.cc \ | |||
| common/op_def/var_is_initialized_op_operator.cc \ | |||
| common/op_def/variable_operator.cc \ | |||
| LOCAL_SRC_FILES := $(PARSER_TENSORFLOW_SRC_FILES) | |||
| LOCAL_SRC_FILES += $(PARSER_SCOPE_SRC_FILES) | |||
| @@ -69,16 +69,14 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { | |||
| break; | |||
| } | |||
| #define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ | |||
| case dt_type: \ | |||
| { \ | |||
| case dt_type: { \ | |||
| unique_ptr<value_type> addr_trans(new(std::nothrow) value_type[count]()); \ | |||
| GE_CHECK_NOTNULL(addr_trans); \ | |||
| for (int32_t i = 0; i < (count); i++) { \ | |||
| *(addr_trans.get() + i) = static_cast<value_type>(*((addr).get() + i)); \ | |||
| } \ | |||
| (tensor).SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), (count) * sizeof(value_type)); \ | |||
| break; \ | |||
| } \ | |||
| break; } \ | |||
| CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor) | |||
| CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor) | |||
| @@ -89,7 +87,7 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { | |||
| #undef CASE_SET_DATA | |||
| default: | |||
| { | |||
| tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(T)); | |||
| tensor.SetData(PtrToPtr<T, uint8_t>(addr.get()), count * sizeof(T)); | |||
| break; | |||
| } | |||
| } | |||
| @@ -31,11 +31,11 @@ using GeTensorDesc = ge::GeTensorDesc; | |||
| using namespace ge::parser; | |||
| namespace { | |||
| const std::string kAttrShape = "shape"; | |||
| const std::string kAttrDataType = "dtype"; | |||
| const std::string kFileConstantPath = "file_constant_path"; | |||
| const std::string kLocation = "location"; | |||
| const std::string kOffset = "offset"; | |||
| const char *const kAttrShape = "shape"; | |||
| const char *const kAttrDataType = "dtype"; | |||
| const char *const kFileConstantPath = "file_constant_path"; | |||
| const char *const kLocation = "location"; | |||
| const char *const kOffset = "offset"; | |||
| const int64_t kOffsetCoefficient = 4096; | |||
| const char *const kFileConstant = "FileConstant"; | |||
| } | |||
| @@ -46,7 +46,7 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator & | |||
| GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); | |||
| ge::onnx::TensorProto tensor_proto; | |||
| if (GetTensorProto(node, tensor_proto) != SUCCESS) { | |||
| if (GetTensorProto(*node, tensor_proto) != SUCCESS) { | |||
| REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str()); | |||
| GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str()); | |||
| return FAILED; | |||
| @@ -65,29 +65,29 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator & | |||
| return SUCCESS; | |||
| } | |||
| Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto *node_proto, | |||
| ge::onnx::TensorProto &tensor_proto) { | |||
| for (const auto &it : node_proto->attribute()) { | |||
| Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto &node_proto, | |||
| ge::onnx::TensorProto &tensor_proto) const { | |||
| for (const auto &it : node_proto.attribute()) { | |||
| if (it.name() != ge::kAttrNameValue) { | |||
| continue; | |||
| } | |||
| tensor_proto = it.t(); | |||
| return SUCCESS; | |||
| } | |||
| REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto->name().c_str()); | |||
| GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto->name().c_str()); | |||
| REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto.name().c_str()); | |||
| GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto.name().c_str()); | |||
| return FAILED; | |||
| } | |||
| void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||
| void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||
| std::vector<int64_t> tmp_shape; | |||
| for (int i = 0; i < tensor_proto.dims_size(); i++) { | |||
| tmp_shape.push_back(tensor_proto.dims(i)); | |||
| } | |||
| op_def.SetAttr(kAttrShape.c_str(), tmp_shape); | |||
| op_def.SetAttr(kAttrShape, tmp_shape); | |||
| } | |||
| Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||
| Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||
| int64_t data_type = tensor_proto.data_type(); | |||
| ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type); | |||
| if (type >= ge::DataType::DT_UNDEFINED) { | |||
| @@ -96,11 +96,11 @@ Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor | |||
| return FAILED; | |||
| } | |||
| op_def.SetAttr(kAttrDataType.c_str(), type); | |||
| op_def.SetAttr(kAttrDataType, type); | |||
| return SUCCESS; | |||
| } | |||
| Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||
| Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||
| ge::NamedAttrs attrs; | |||
| for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) { | |||
| const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i); | |||
| @@ -116,12 +116,12 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro | |||
| GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | |||
| return FAILED; | |||
| } | |||
| op_def.SetAttr(kFileConstantPath.c_str(), attrs); | |||
| op_def.SetAttr(kFileConstantPath, attrs); | |||
| return SUCCESS; | |||
| } | |||
| Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, | |||
| ge::NamedAttrs &attrs) { | |||
| ge::NamedAttrs &attrs) const { | |||
| if (string_proto.key() == kLocation) { | |||
| AttrUtils::SetStr(attrs, kLocation, string_proto.value()); | |||
| } else { | |||
| @@ -134,7 +134,7 @@ Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProt | |||
| return FAILED; | |||
| } | |||
| if (string_proto.key() == kOffset) { | |||
| if (std::numeric_limits<int64_t>::max() / kOffsetCoefficient < value) { | |||
| if (value > (std::numeric_limits<int64_t>::max() / kOffsetCoefficient)) { | |||
| REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | |||
| GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | |||
| return FAILED; | |||
| @@ -26,11 +26,11 @@ class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser { | |||
| Status ParseParams(const Message *op_src, ge::Operator &op_def) override; | |||
| private: | |||
| Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||
| Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||
| void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||
| Status GetTensorProto(const ge::onnx::NodeProto *node_proto, ge::onnx::TensorProto &tensor_proto); | |||
| Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs); | |||
| Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||
| Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||
| void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||
| Status GetTensorProto(const ge::onnx::NodeProto &node_proto, ge::onnx::TensorProto &tensor_proto) const; | |||
| Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs) const; | |||
| }; | |||
| } // namespace ge | |||
| @@ -232,7 +232,8 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra | |||
| domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type); | |||
| if (post_func == nullptr) { | |||
| GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str()); | |||
| if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) { | |||
| if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || | |||
| parse_func_v2 == nullptr) { | |||
| GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str()); | |||
| return SUCCESS; | |||
| } | |||
| @@ -522,9 +523,9 @@ Status OnnxModelParser::SetOperatorInputs() { | |||
| auto src_op = output_op_iter->second; | |||
| int dst_index = input_node_index.second; | |||
| int src_index = out_node_index.second; | |||
| GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", src_index, | |||
| ParserUtils::GetOperatorName(src_op).c_str(), dst_index, | |||
| ParserUtils::GetOperatorName(dst_op).c_str()); | |||
| GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", | |||
| src_index, ParserUtils::GetOperatorName(src_op).c_str(), | |||
| dst_index, ParserUtils::GetOperatorName(dst_op).c_str()); | |||
| auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op); | |||
| GE_CHECK_NOTNULL(dst_op_desc); | |||
| auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op); | |||
| @@ -689,7 +690,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve | |||
| return PARAM_INVALID; | |||
| } | |||
| input_ops.emplace_back(in_op->second); | |||
| GELOGI("Model assigned input node name: %s", ParserUtils::GetOperatorName(in_op->second).c_str()); | |||
| GELOGI("Model assigned input node name: %s", | |||
| ParserUtils::GetOperatorName(in_op->second).c_str()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -717,7 +719,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vec | |||
| int index = node_name_index.second; | |||
| output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)}); | |||
| out_tensor_to_nodes[output_name] = std::make_pair(node_name, index); | |||
| GELOGI("out node index %d, node:%s", index, node_name.c_str()); | |||
| GELOGI("Out node index %d, node:%s", index, node_name.c_str()); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| @@ -934,16 +936,13 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||
| cur_compute_graph->GetName().c_str()); | |||
| return ret; | |||
| } | |||
| } | |||
| UpdateDataFormat(root_graph); | |||
| return SUCCESS; | |||
| } | |||
| Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { | |||
| ClearMembers(); | |||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), | |||
| "Run ProtoType Pass Failed"); | |||
| // 1. Get all inializer. | |||
| @@ -1174,7 +1173,8 @@ Status OnnxModelParser::SetOutputsInfo(const ParserUtils::OutputMapping &final_o | |||
| default_out_nodes.emplace_back(output_node_info); | |||
| output_tensor_names.emplace_back(tensor_name); | |||
| GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].", | |||
| output_node_info.first.c_str(), output_node_info.second, tensor_name.c_str()); | |||
| output_node_info.first.c_str(), | |||
| output_node_info.second, tensor_name.c_str()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -17,6 +17,8 @@ | |||
| #ifndef PARSER_ONNX_ONNX_UTIL_PARSER_H_ | |||
| #define PARSER_ONNX_ONNX_UTIL_PARSER_H_ | |||
| #include <string> | |||
| #include <cstdint> | |||
| #include "external/graph/types.h" | |||
| namespace OnnxDataType { | |||
| @@ -59,4 +61,4 @@ class OnnxUtil { | |||
| }; | |||
| } // namespace ge | |||
| #endif //PARSER_ONNX_ONNX_UTIL_PARSER_H_ | |||
| #endif // PARSER_ONNX_ONNX_UTIL_PARSER_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -25,6 +25,7 @@ using parser::IF; | |||
| namespace { | |||
| const std::map<std::string, int> kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}}; | |||
| const int kIfNodeAttrSize = 2; | |||
| const char *kIf = "If"; | |||
| } // namespace | |||
| domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||
| ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
| @@ -33,7 +34,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||
| GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), | |||
| parent_node->op_type().c_str()); | |||
| auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); | |||
| auto ret = ParseIfNodeSubgraphs(*parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "[Parse][Node] Parse if node failed."); | |||
| REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); | |||
| @@ -44,19 +45,19 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||
| } | |||
| domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||
| ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
| ge::onnx::NodeProto &parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) const { | |||
| if (parent_node->attribute_size() != kIfNodeAttrSize) { | |||
| GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | |||
| REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | |||
| if (parent_node.attribute_size() != kIfNodeAttrSize) { | |||
| GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size()); | |||
| REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size()); | |||
| return FAILED; | |||
| } | |||
| GELOGD("node attribute size:%d.", parent_node->attribute_size()); | |||
| GELOGD("node attribute size:%d.", parent_node.attribute_size()); | |||
| std::set<std::string> all_inputs; | |||
| // for onnx graph, the first attribute may be else branch and the second attribute may be then branch | |||
| for (int i = 0; i < parent_node->attribute_size(); i++) { | |||
| ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i); | |||
| for (int i = 0; i < parent_node.attribute_size(); i++) { | |||
| ge::onnx::AttributeProto *attribute = parent_node.mutable_attribute(i); | |||
| GE_CHECK_NOTNULL(attribute); | |||
| std::string attr_name = attribute->name(); | |||
| auto itr = kAttrNameToIndex.find(attr_name); | |||
| @@ -68,7 +69,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||
| return FAILED; | |||
| } | |||
| std::string unique_subgraph_name; | |||
| std::string node_name = parent_node->name(); | |||
| std::string node_name = parent_node.name(); | |||
| if (!parent_graph_name.empty()) { | |||
| node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name); | |||
| } | |||
| @@ -90,7 +91,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||
| AddInputNodeForGraph(all_inputs, *onnx_graph); | |||
| } | |||
| AddInputForParentNode(all_inputs, *parent_node); | |||
| AddInputForParentNode(all_inputs, parent_node); | |||
| return SUCCESS; | |||
| } | |||
| @@ -135,5 +136,5 @@ void IfSubgraphAdapter::AddInputForParentNode(const std::set<std::string> &all_i | |||
| parent_node.add_input(input_name); | |||
| } | |||
| } | |||
| REGISTER_SUBGRAPH_ADAPTER_CREATOR(IF, IfSubgraphAdapter); | |||
| REGISTER_SUBGRAPH_ADAPTER_CREATOR(kIf, IfSubgraphAdapter); | |||
| } // namespace ge | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -20,6 +20,7 @@ | |||
| #include <set> | |||
| #include <string> | |||
| #include "subgraph_adapter.h" | |||
| #include "parser/onnx/onnx_util.h" | |||
| namespace ge { | |||
| class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | |||
| @@ -30,7 +31,7 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | |||
| const std::string &parent_graph_name = "") override; | |||
| private: | |||
| domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
| domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto &parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | |||
| const std::string &parent_graph_name) const; | |||
| domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -36,8 +36,6 @@ | |||
| #include "proto/onnx/ge_onnx.pb.h" | |||
| #include "external/register/register_error_codes.h" | |||
| #include "framework/omg/parser/parser_types.h" | |||
| #include "parser/onnx/onnx_util.h" | |||
| namespace ge { | |||
| class PARSER_FUNC_VISIBILITY SubgraphAdapter { | |||
| public: | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -62,7 +62,6 @@ protected: | |||
| * @brief SubgraphAdapter creation function | |||
| * @return Created SubgraphAdapter | |||
| */ | |||
| // typedef shared_ptr<SubgraphAdapter> (*CREATOR_FUN)(void); | |||
| using CREATOR_FUN = std::function<std::shared_ptr<SubgraphAdapter>(void)>; | |||
| /** | |||
| @@ -105,7 +104,7 @@ public: | |||
| * @param [in] op_type Op type | |||
| * @param [in] clazz SubgraphAdapter implementation class | |||
| */ | |||
| #define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ | |||
| #define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ | |||
| std::shared_ptr<SubgraphAdapter> Creator_##op_type##_Subgraph_Adapter() { \ | |||
| std::shared_ptr<clazz> ptr(new (std::nothrow) clazz()); \ | |||
| if (ptr == nullptr) { \ | |||
| @@ -167,6 +167,33 @@ class H2CC(object): | |||
| del self.stack_template | |||
| del self.func_list_exist | |||
| @staticmethod | |||
| def implement_function(func): | |||
| function_def = '' | |||
| function_def += '{\n' | |||
| all_items = func.split() | |||
| start = 0 | |||
| return_type = all_items[start] | |||
| if return_type == "const": | |||
| start += 1 | |||
| return_type = all_items[start] | |||
| if return_type.startswith(('std::map', 'std::set', 'std::vector')): | |||
| return_type = "std::map" | |||
| if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): | |||
| return_type = "Ptr" | |||
| if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): | |||
| return_type += "&" | |||
| if RETURN_STATEMENTS.__contains__(return_type): | |||
| function_def += RETURN_STATEMENTS[return_type] | |||
| else: | |||
| logging.warning("Unhandled return type[%s]", return_type) | |||
| function_def += '\n' | |||
| function_def += '}\n' | |||
| function_def += '\n' | |||
| return function_def | |||
| def just_skip(self): | |||
| # skip blank line or comment | |||
| if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( | |||
| @@ -263,6 +290,7 @@ class H2CC(object): | |||
| logging.info('Added %s functions', len(self.func_list_exist)) | |||
| logging.info('Successfully converted,please see ' + self.output_file) | |||
| def handle_func1(self, line): | |||
| """ | |||
| :param line: | |||
| @@ -461,12 +489,6 @@ class H2CC(object): | |||
| logging.info("func_name[%s]", func_name) | |||
| return line, func_name | |||
| def write_func_content(self, content, func_name, need_generate): | |||
| if not (func_name in self.func_list_exist) and need_generate: | |||
| self.output_fd.write(content) | |||
| self.func_list_exist.append(func_name) | |||
| logging.info('add func:[%s]', func_name) | |||
| def gen_comment(self, start_i): | |||
| comment_line = '' | |||
| # Function comments are on top of function declarations, copy them over | |||
| @@ -488,32 +510,11 @@ class H2CC(object): | |||
| break | |||
| return comment_line | |||
| @staticmethod | |||
| def implement_function(func): | |||
| function_def = '' | |||
| function_def += '{\n' | |||
| all_items = func.split() | |||
| start = 0 | |||
| return_type = all_items[start] | |||
| if return_type == "const": | |||
| start += 1 | |||
| return_type = all_items[start] | |||
| if return_type.startswith(('std::map', 'std::set', 'std::vector')): | |||
| return_type = "std::map" | |||
| if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): | |||
| return_type = "Ptr" | |||
| if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): | |||
| return_type += "&" | |||
| if RETURN_STATEMENTS.__contains__(return_type): | |||
| function_def += RETURN_STATEMENTS[return_type] | |||
| else: | |||
| logging.warning("Unhandled return type[%s]", return_type) | |||
| function_def += '\n' | |||
| function_def += '}\n' | |||
| function_def += '\n' | |||
| return function_def | |||
| def write_func_content(self, content, func_name, need_generate): | |||
| if not (func_name in self.func_list_exist) and need_generate: | |||
| self.output_fd.write(content) | |||
| self.func_list_exist.append(func_name) | |||
| logging.info('add func:[%s]', func_name) | |||
| def collect_header_files(path): | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph_functiondef.h" | |||
| #include "graph_to_function_def.h" | |||
| #include <iostream> | |||
| #include "common/fmk_error_codes.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -20,7 +20,7 @@ | |||
| #include "framework/omg/parser/parser_types.h" | |||
| #include "common/util.h" | |||
| #include "graph_optimizer.h" | |||
| #include "parser_graph_optimizer.h" | |||
| #include "framework/common/ge_inner_error_codes.h" | |||
| namespace ge { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,14 +14,14 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph_optimizer.h" | |||
| #include "parser_graph_optimizer.h" | |||
| #include "graph/op_types.h" | |||
| #include "common/types_map.h" | |||
| #include "common/util.h" | |||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/utils/type_utils.h" | |||
| #include "graph_functiondef.h" | |||
| #include "graph_to_function_def.h" | |||
| #include "parser/common/acl_graph_parser_util.h" | |||
| #include "register/op_registry.h" | |||
| @@ -188,7 +188,10 @@ Status CollectNodeFuncs(vector<ge::NodePtr> &nodes, FunctionDefLibrary *library) | |||
| Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) { | |||
| ComputeGraphPtr sub_graph = nullptr; | |||
| GE_MAKE_SHARED(sub_graph = std::make_shared<ComputeGraph>("subGraph"), sub_graph = nullptr; return PARAM_INVALID); | |||
| GE_MAKE_SHARED( | |||
| sub_graph = std::make_shared<ComputeGraph>("subGraph"), | |||
| sub_graph = nullptr; | |||
| return PARAM_INVALID); | |||
| unordered_map<string, NodePtr> node_map; | |||
| vector<InDataAnchorPtr> input_anchors; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "parser/common/op_def/arg_op.h" | |||
| #include "parser/common/op_def/arg_op_operator.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||
| #include "graph/compute_graph.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -19,7 +19,7 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "parser/common/acl_graph_parser_util.h" | |||
| #include "parser/common/op_def/constant_op.h" | |||
| #include "parser/common/op_def/constant_operator.h" | |||
| #include "parser/common/op_def/ir_pb_converter.h" | |||
| #include "parser/common/util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| @@ -17,7 +17,7 @@ | |||
| #ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_CONSTANT_PARSER_H_ | |||
| #define GE_PARSER_TENSORFLOW_TENSORFLOW_CONSTANT_PARSER_H_ | |||
| #include "common/op_def/constant_op.h" | |||
| #include "common/op_def/constant_operator.h" | |||
| #include "parser/common/data_op_parser.h" | |||
| #include "parser/tensorflow/tensorflow_op_parser.h" | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "parser/common/op_def/fill_op.h" | |||
| #include "parser/common/op_def/fill_operator.h" | |||
| #include "parser/tensorflow/tensorflow_parser_register.h" | |||
| #include "framework/omg/parser/parser_types.h" | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "parser/common/op_def/frameworkop_op.h" | |||
| #include "parser/common/op_def/framework_op_operator.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "framework/omg/parser/parser_types.h" | |||
| @@ -18,7 +18,7 @@ | |||
| #include "common/util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "parser/common/op_def/ir_pb_converter.h" | |||
| #include "parser/common/op_def/no_op_op.h" | |||
| #include "parser/common/op_def/no_op_operator.h" | |||
| #include "parser/common/op_parser_factory.h" | |||
| using domi::TENSORFLOW; | |||
| @@ -17,7 +17,7 @@ | |||
| #include "parser/tensorflow/tensorflow_ref_switch_parser.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "parser/common/op_def/ir_pb_converter.h" | |||
| #include "parser/common/op_def/ref_switch_op.h" | |||
| #include "parser/common/op_def/ref_switch_operator.h" | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "parser/common/util.h" | |||
| @@ -17,7 +17,7 @@ | |||
| #ifndef DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_REF_SWITCH_H_ | |||
| #define DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_REF_SWITCH_H_ | |||
| #include "common/op_def/ref_switch_op.h" | |||
| #include "common/op_def/ref_switch_operator.h" | |||
| #include "parser/tensorflow/tensorflow_op_parser.h" | |||
| namespace ge { | |||
| @@ -18,7 +18,7 @@ | |||
| #include "parser/common/op_def/ir_pb_converter.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "parser/common/op_def/shape_n_op.h" | |||
| #include "parser/common/op_def/shape_n_operator.h" | |||
| #include "parser/common/util.h" | |||
| using domi::TENSORFLOW; | |||
| @@ -17,7 +17,7 @@ | |||
| #ifndef DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_SHAPE_N_H_ | |||
| #define DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_SHAPE_N_H_ | |||
| #include "common/op_def/shape_n_op.h" | |||
| #include "common/op_def/shape_n_operator.h" | |||
| #include "parser/tensorflow/tensorflow_op_parser.h" | |||
| namespace ge { | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "parser/common/op_def/var_is_initialized_op_op.h" | |||
| #include "parser/common/op_def/var_is_initialized_op_operator.h" | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "parser/tensorflow/tensorflow_op_parser.h" | |||
| #include "parser/tensorflow/tensorflow_parser_register.h" | |||
| @@ -21,7 +21,7 @@ | |||
| #include "graph/op_desc.h" | |||
| #include "graph/utils/attr_utils.h" | |||
| #include "graph/utils/tensor_utils.h" | |||
| #include "parser/common/op_def/variable_op.h" | |||
| #include "parser/common/op_def/variable_operator.h" | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "parser/tensorflow/tensorflow_op_parser.h" | |||
| #include "parser/tensorflow/tensorflow_parser_register.h" | |||
| @@ -249,17 +249,17 @@ set(PARSER_SRC_FILES | |||
| "${PARSER_DIR}/parser/common/convert/message2operator.cc" | |||
| "${PARSER_DIR}/parser/common/data_op_parser.cc" | |||
| "${PARSER_DIR}/parser/common/model_saver.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/arg_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/constant_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/fill_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/frameworkop_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/arg_op_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/constant_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/fill_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/framework_op_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/ir_pb_converter.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/no_op_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/no_op_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/ref_switch_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/shape_n_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/variable_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/ref_switch_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/shape_n_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/variable_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_map.cc" | |||
| "${PARSER_DIR}/parser/common/op_parser_factory.cc" | |||
| "${PARSER_DIR}/parser/common/parser_api.cc" | |||
| @@ -284,8 +284,8 @@ set(PARSER_SRC_FILES | |||
| "${PARSER_DIR}/parser/onnx/onnx_util.cc" | |||
| "${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc" | |||
| "${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/graph_to_function_def.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/parser_graph_optimizer.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/scope/scope_pass_manager.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_arg_parser.cc" | |||
| @@ -765,7 +765,7 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_CheckLayersSize_test) | |||
| layer->set_name("Abs"); | |||
| layer->set_type("AbsVal"); | |||
| Status ret = weightParser.CheckLayersSize(layer); | |||
| Status ret = weightParser.CheckLayersSize(*layer); | |||
| EXPECT_EQ(ret, FAILED); | |||
| } | |||
| @@ -809,8 +809,54 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test) | |||
| const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | |||
| const google::protobuf::Message *message = proto->New(); | |||
| Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph); | |||
| Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph); | |||
| delete message; | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| } | |||
| TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_DimSize_0) | |||
| { | |||
| CaffeModelParser modelParser; | |||
| domi::caffe::NetParameter net; | |||
| net.add_input("111"); | |||
| net.add_input_shape(); | |||
| bool input_data_flag = true; | |||
| Status ret = modelParser.ParseInput(net, input_data_flag); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| } | |||
| TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_Err1) | |||
| { | |||
| CaffeModelParser modelParser; | |||
| domi::caffe::NetParameter net; | |||
| net.add_input("111"); | |||
| net.add_input("222"); | |||
| net.add_input_shape(); | |||
| bool input_data_flag = true; | |||
| Status ret = modelParser.ParseInput(net, input_data_flag); | |||
| EXPECT_EQ(ret, FAILED); | |||
| } | |||
| TEST_F(STestCaffeParser, CaffeModelParser_ParserLayerParameter_Succ) | |||
| { | |||
| CaffeModelParser modelParser; | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/origin_models/"; | |||
| const char *model_path = model_file.c_str(); | |||
| std::string custom_proto = model_file; | |||
| std::string caffe_proto = model_file; | |||
| std::vector<ge::Operator> operators; | |||
| ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Data", "Input"); | |||
| ge::Operator op_src = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||
| operators.emplace_back(op_src); | |||
| model_file = case_dir + "/origin_models/caffe_add.pbtxt"; | |||
| custom_proto = case_dir + "/../../../metadef/proto/caffe/caffe.proto"; | |||
| model_path = model_file.c_str(); | |||
| std::string caffe_proto_path = case_dir + "/../../../metadef/proto/caffe/caffe.proto"; | |||
| auto ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto_path, operators); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| } | |||
| } // namespace ge | |||
| @@ -34,17 +34,17 @@ | |||
| #include "external/parser/tensorflow_parser.h" | |||
| #include "parser/tensorflow/tensorflow_constant_parser.h" | |||
| #include "common/types.h" | |||
| #include "parser/common/op_def/variable_op.h" | |||
| #include "parser/common/op_def/variable_operator.h" | |||
| #include "parser/tensorflow/tensorflow_ref_switch_parser.h" | |||
| #include "parser/tensorflow/tensorflow_fusion_op_parser.h" | |||
| #include "parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h" | |||
| #include "parser/common/op_def/arg_op.h" | |||
| #include "parser/common/op_def/arg_op_operator.h" | |||
| #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" | |||
| #include "parser/tensorflow/tensorflow_reshape_parser.h" | |||
| #include "parser/tensorflow/tensorflow_custom_parser_adapter.h" | |||
| #include "parser/tensorflow/tensorflow_squeeze_parser.h" | |||
| #include "parser/tensorflow/graph_functiondef.h" | |||
| #include "parser/tensorflow/graph_optimizer.h" | |||
| #include "parser/tensorflow/graph_to_function_def.h" | |||
| #include "parser/tensorflow/parser_graph_optimizer.h" | |||
| #include "cce/dnn_base_def.hpp" | |||
| #include "parser/tensorflow/scope/scope_pass_manager.h" | |||
| #include "parser/tensorflow/tensorflow_util.h" | |||
| @@ -52,10 +52,10 @@ | |||
| #include "parser/tensorflow/tensorflow_enter_parser.h" | |||
| #include "parser/common/op_def/ir_pb_converter.h" | |||
| #include "parser/common/tuple.h" | |||
| #include "common/op_def/frameworkop_op.h" | |||
| #include "common/op_def/shape_n_op.h" | |||
| #include "common/op_def/var_is_initialized_op_op.h" | |||
| #include "common/op_def/fill_op.h" | |||
| #include "common/op_def/framework_op_operator.h" | |||
| #include "common/op_def/shape_n_operator.h" | |||
| #include "common/op_def/var_is_initialized_op_operator.h" | |||
| #include "common/op_def/fill_operator.h" | |||
| #include "common/convert/pb2json.h" | |||
| #include "common/convert/message2operator.h" | |||
| #include "parser/common/proto_file_parser.h" | |||
| @@ -70,7 +70,7 @@ | |||
| #include "parser/common/prototype_pass_manager.h" | |||
| #include "parser/common/register_tbe.h" | |||
| #include "parser/common/pass_manager.h" | |||
| #include "parser/tensorflow/graph_optimizer.h" | |||
| #include "parser/tensorflow/parser_graph_optimizer.h" | |||
| #include "metadef/inc/register/scope/scope_pass_registry_impl.h" | |||
| #include "register/scope/scope_fusion_pass_register.h" | |||
| #undef protected | |||
| @@ -678,6 +678,7 @@ namespace { | |||
| if ((_name== "S") || (_name == "K")) { | |||
| int index = 0; | |||
| ge::AttrUtils::SetInt(opDef, "T", 1); | |||
| ge::AttrUtils::SetInt(opDef, "arg_index", index); | |||
| ge::AttrUtils::SetInt(opDef, "ret_index", index); | |||
| @@ -1029,7 +1030,9 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) { | |||
| ParserOperator unused("Add"); | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/origin_models/tf_add.pb"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")}, | |||
| }; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| @@ -1043,6 +1046,21 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) { | |||
| EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); | |||
| } | |||
| TEST_F(STestTensorflowParser, tensorflow_parser_failed_for_input_data_names_error) { | |||
| RegisterCustomOp(); | |||
| std::string case_dir = __FILE__; | |||
| ParserOperator unused("Add"); | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/origin_models/tf_add.pb"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_3")}, | |||
| }; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||
| ASSERT_EQ(ret, ge::GRAPH_FAILED); | |||
| } | |||
| TEST_F(STestTensorflowParser, tensorflow_model_Failed) { | |||
| ge::Graph graph; | |||
| std::string caseDir = __FILE__; | |||
| @@ -250,17 +250,17 @@ set(PARSER_SRC_FILES | |||
| "${PARSER_DIR}/parser/common/convert/message2operator.cc" | |||
| "${PARSER_DIR}/parser/common/data_op_parser.cc" | |||
| "${PARSER_DIR}/parser/common/model_saver.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/arg_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/constant_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/fill_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/frameworkop_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/arg_op_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/constant_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/fill_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/framework_op_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/ir_pb_converter.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/no_op_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/no_op_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/ref_switch_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/shape_n_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/variable_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_op.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/ref_switch_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/shape_n_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/variable_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_operator.cc" | |||
| "${PARSER_DIR}/parser/common/op_map.cc" | |||
| "${PARSER_DIR}/parser/common/op_parser_factory.cc" | |||
| "${PARSER_DIR}/parser/common/parser_api.cc" | |||
| @@ -285,8 +285,8 @@ set(PARSER_SRC_FILES | |||
| "${PARSER_DIR}/parser/onnx/onnx_util.cc" | |||
| "${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc" | |||
| "${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/graph_to_function_def.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/parser_graph_optimizer.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/scope/scope_pass_manager.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_arg_parser.cc" | |||
| @@ -739,6 +739,9 @@ TEST_F(UtestCaffeParser, CaffeModelParser_CustomProtoParse_test) | |||
| Status ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto, operators); | |||
| EXPECT_EQ(ret, PARAM_INVALID); | |||
| ret = modelParser.CustomProtoParse("", custom_proto, caffe_proto, operators); | |||
| EXPECT_EQ(ret, FAILED); | |||
| model_file = case_dir + "/caffe_model/caffe_add.pbtxt"; | |||
| custom_proto = case_dir + "/../../../../../metadef/proto/caffe/caffe.proto"; | |||
| model_path = model_file.c_str(); | |||
| @@ -890,7 +893,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_CheckLayersSize_test) | |||
| layer->set_name("Abs"); | |||
| layer->set_type("AbsVal"); | |||
| Status ret = weightParser.CheckLayersSize(layer); | |||
| Status ret = weightParser.CheckLayersSize(*layer); | |||
| EXPECT_EQ(ret, FAILED); | |||
| } | |||
| @@ -902,7 +905,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test) | |||
| layer->set_name("Abs"); | |||
| layer->set_type("AbsVal"); | |||
| Status ret = weightParser.ConvertLayerProto(&net, &net); | |||
| Status ret = weightParser.ConvertLayerProto(net, &net); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| BlobProto* blob = layer->add_blobs(); | |||
| @@ -911,16 +914,16 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test) | |||
| BlobShape* shap = blob->mutable_shape(); | |||
| shap->add_dim(1); | |||
| shap->add_dim(2); | |||
| ret = weightParser.ConvertBlobsProto(&net, &net); | |||
| ret = weightParser.ConvertBlobsProto(net, &net); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| ret = weightParser.ConvertBlobShapeProto(&net, &net); | |||
| ret = weightParser.ConvertBlobShapeProto(net, &net); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| ret = weightParser.ConvertConvParamProto(&net, &net); | |||
| ret = weightParser.ConvertConvParamProto(net, &net); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| ret = weightParser.ConvertInnerProdcutProto(&net, &net); | |||
| ret = weightParser.ConvertInnerProdcutProto(net, &net); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| } | |||
| @@ -1133,7 +1136,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test) | |||
| const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | |||
| const google::protobuf::Message *message = proto->New(); | |||
| Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph); | |||
| Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph); | |||
| delete message; | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| } | |||
| @@ -1163,7 +1166,7 @@ TEST_F(UtestCaffeParser, CaffeModelParser_ParseLayerParameter_test) | |||
| google::protobuf::DynamicMessageFactory factory; | |||
| const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | |||
| google::protobuf::Message *message = proto->New(); | |||
| Status ret = modelParser.ParseLayerParameter(descriptor, message, operators); | |||
| Status ret = modelParser.ParseLayerParameter(*descriptor, *message, operators); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| const domi::FrameworkType fmk_type = domi::TENSORFLOW; | |||
| @@ -7,7 +7,7 @@ | |||
| #include "tensorflow/iterator_fusion_pass.h" | |||
| #include "parser/common/acl_graph_parser_util.h" | |||
| #define private public | |||
| #include "tensorflow/graph_optimizer.h" | |||
| #include "tensorflow/parser_graph_optimizer.h" | |||
| #undef private | |||
| namespace ge { | |||
| class UtestGraphOptimizer : public testing::Test { | |||
| @@ -381,7 +381,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto) | |||
| OnnxFileConstantParser parser; | |||
| ge::onnx::NodeProto input_node; | |||
| ge::onnx::TensorProto tensor_proto; | |||
| Status ret = parser.GetTensorProto(&input_node, tensor_proto); | |||
| Status ret = parser.GetTensorProto(input_node, tensor_proto); | |||
| EXPECT_EQ(ret, FAILED); | |||
| ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | |||
| @@ -391,7 +391,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto) | |||
| ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | |||
| *attribute_tensor = tensor_proto; | |||
| ret = parser.GetTensorProto(&input_node, tensor_proto); | |||
| ret = parser.GetTensorProto(input_node, tensor_proto); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| } | |||
| @@ -38,17 +38,17 @@ | |||
| #include "tests/depends/ops_stub/ops_stub.h" | |||
| #include "parser/tensorflow/tensorflow_constant_parser.h" | |||
| #include "common/types.h" | |||
| #include "parser/common/op_def/variable_op.h" | |||
| #include "parser/common/op_def/variable_operator.h" | |||
| #include "parser/tensorflow/tensorflow_ref_switch_parser.h" | |||
| #include "parser/tensorflow/tensorflow_fusion_op_parser.h" | |||
| #include "parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h" | |||
| #include "parser/common/op_def/arg_op.h" | |||
| #include "parser/common/op_def/arg_op_operator.h" | |||
| #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" | |||
| #include "parser/tensorflow/tensorflow_reshape_parser.h" | |||
| #include "parser/tensorflow/tensorflow_custom_parser_adapter.h" | |||
| #include "parser/tensorflow/tensorflow_squeeze_parser.h" | |||
| #include "parser/tensorflow/graph_functiondef.h" | |||
| #include "parser/tensorflow/graph_optimizer.h" | |||
| #include "parser/tensorflow/graph_to_function_def.h" | |||
| #include "parser/tensorflow/parser_graph_optimizer.h" | |||
| #include "cce/dnn_base_def.hpp" | |||
| #include "parser/tensorflow/scope/scope_pass_manager.h" | |||
| #include "parser/tensorflow/tensorflow_util.h" | |||
| @@ -56,10 +56,10 @@ | |||
| #include "parser/tensorflow/tensorflow_enter_parser.h" | |||
| #include "parser/common/op_def/ir_pb_converter.h" | |||
| #include "parser/common/tuple.h" | |||
| #include "common/op_def/frameworkop_op.h" | |||
| #include "common/op_def/shape_n_op.h" | |||
| #include "common/op_def/var_is_initialized_op_op.h" | |||
| #include "common/op_def/fill_op.h" | |||
| #include "common/op_def/framework_op_operator.h" | |||
| #include "common/op_def/shape_n_operator.h" | |||
| #include "common/op_def/var_is_initialized_op_operator.h" | |||
| #include "common/op_def/fill_operator.h" | |||
| #include "common/convert/pb2json.h" | |||
| #include "common/convert/message2operator.h" | |||
| #include "parser/common/proto_file_parser.h" | |||
| @@ -73,7 +73,7 @@ | |||
| #include "parser/common/prototype_pass_manager.h" | |||
| #include "parser/common/register_tbe.h" | |||
| #include "parser/common/pass_manager.h" | |||
| #include "parser/tensorflow/graph_optimizer.h" | |||
| #include "parser/tensorflow/parser_graph_optimizer.h" | |||
| #include "metadef/inc/register/scope/scope_pass_registry_impl.h" | |||
| #include "register/scope/scope_fusion_pass_register.h" | |||
| #include "common/op_map.h" | |||
| @@ -1032,7 +1032,9 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | |||
| ParserOperator unused("Add"); | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")}, | |||
| }; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| @@ -1046,6 +1048,21 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | |||
| EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_parser_input_data_names_failed) { | |||
| RegisterCustomOp(); | |||
| std::string case_dir = __FILE__; | |||
| ParserOperator unused("Add"); | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params = { | |||
| {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_2")}, | |||
| }; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||
| ASSERT_EQ(ret, ge::GRAPH_FAILED); | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_model_Failed) { | |||
| ge::Graph graph; | |||
| std::string caseDir = __FILE__; | |||