| @@ -6,7 +6,6 @@ set(SRC_LIST | |||||
| "tensorflow/tensorflow_enter_parser.cc" | "tensorflow/tensorflow_enter_parser.cc" | ||||
| "tensorflow/tensorflow_fill_parser.cc" | "tensorflow/tensorflow_fill_parser.cc" | ||||
| "tensorflow/tensorflow_frameworkop_parser.cc" | "tensorflow/tensorflow_frameworkop_parser.cc" | ||||
| "tensorflow/tensorflow_fusionop_util.cc" | |||||
| "tensorflow/tensorflow_identity_parser.cc" | "tensorflow/tensorflow_identity_parser.cc" | ||||
| "tensorflow/tensorflow_merge_parser.cc" | "tensorflow/tensorflow_merge_parser.cc" | ||||
| "tensorflow/tensorflow_no_op_parser.cc" | "tensorflow/tensorflow_no_op_parser.cc" | ||||
| @@ -71,7 +71,6 @@ PARSER_TENSORFLOW_SRC_FILES := \ | |||||
| tensorflow/tensorflow_enter_parser.cc \ | tensorflow/tensorflow_enter_parser.cc \ | ||||
| tensorflow/tensorflow_fill_parser.cc \ | tensorflow/tensorflow_fill_parser.cc \ | ||||
| tensorflow/tensorflow_frameworkop_parser.cc \ | tensorflow/tensorflow_frameworkop_parser.cc \ | ||||
| tensorflow/tensorflow_fusionop_util.cc \ | |||||
| tensorflow/tensorflow_identity_parser.cc \ | tensorflow/tensorflow_identity_parser.cc \ | ||||
| tensorflow/tensorflow_merge_parser.cc \ | tensorflow/tensorflow_merge_parser.cc \ | ||||
| tensorflow/tensorflow_no_op_parser.cc \ | tensorflow/tensorflow_no_op_parser.cc \ | ||||
| @@ -20,7 +20,6 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| #include "omg/parser/op_parser.h" | #include "omg/parser/op_parser.h" | ||||
| #include "parser/tensorflow/tensorflow_fusionop_util.h" | |||||
| #include "parser/tensorflow/tensorflow_op_parser.h" | #include "parser/tensorflow/tensorflow_op_parser.h" | ||||
| #include "parser/tensorflow/tensorflow_util.h" | #include "parser/tensorflow/tensorflow_util.h" | ||||
| #include "proto/tensorflow/graph.pb.h" | #include "proto/tensorflow/graph.pb.h" | ||||
| @@ -1,378 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/tensorflow/tensorflow_fusionop_util.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "parser/common/util.h" | |||||
| #include "parser/tensorflow/tensorflow_parser.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include <iostream> | |||||
| #include <cstdlib> | |||||
| #include <memory> | |||||
| using domi::tensorflow::NodeDef; | |||||
| namespace ge { | |||||
| // constraint: At present, only a few fixed fusion operators are supported, | |||||
| // and forward matching method is used for recognition | |||||
| // eg: in the MaskRCNN network, | |||||
| // clip_boxes are treated as fusion operators but generate_rpn_proposals/clip_boxes is also fused | |||||
| // considered to be a child operator of generate_rpn_proposals. | |||||
| // clip_boxes | |||||
| // fastrcnn_predictions | |||||
| // decode_bbox_target | |||||
| // generate_rpn_proposals | |||||
| // roi_align | |||||
| // cond_1/roi_align | |||||
| namespace { | |||||
| const char *const kLstmCellKernelFw = "fw/basic_lstm_cell/kernel"; | |||||
| const char *const kLstmCellKernelBw = "bw/basic_lstm_cell/kernel"; | |||||
| const char *const kLstmCellBiasFw = "fw/basic_lstm_cell/bias"; | |||||
| const char *const kLstmCellBiasBw = "bw/basic_lstm_cell/bias"; | |||||
| const char *const kAttentionDecoderEmbeeding = "embedding_attention_decoder/embedding"; | |||||
| const char *const kAttentionDecoderAttenW0 = "embedding_attention_decoder/attention_decoder/AttnW_0"; | |||||
| const char *const kAttentionDecoderAttenVa = "embedding_attention_decoder/attention_decoder/AttnV_0"; | |||||
| const char *const kAttentionDecoderAttentionDecoderKernel = "embedding_attention_decoder/attention_decoder/kernel"; | |||||
| const char *const kAttentionDecoderAtteBias = "embedding_attention_decoder/attention_decoder/bias"; | |||||
| const char *const kAttentionDecoderCell0GatesKernel = | |||||
| "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_0/gru_cell/gates/kernel"; | |||||
| const char *const kAttentionDecoderCell0GatesBias = | |||||
| "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_0/gru_cell/gates/bias"; | |||||
| const char *const kAttentionDecoderCell0CandidateKernel = | |||||
| "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_0/gru_cell/candidate/kernel"; | |||||
| const char *const kAttentionDecoderCell0CandidateBias = | |||||
| "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_0/gru_cell/candidate/bias"; | |||||
| const char *const kAttentionDecoderCell1GatesKernel = | |||||
| "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_1/gru_cell/gates/kernel"; | |||||
| const char *const kAttentionDecoderCell1GatesBias = | |||||
| "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_1/gru_cell/gates/bias"; | |||||
| const char *const kAttentionDecoderCell1CandidateKernel = | |||||
| "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_1/gru_cell/candidate/kernel"; | |||||
| const char *const kAttentionDecoderCell1CandidateBias = | |||||
| "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_1/gru_cell/candidate/bias"; | |||||
| const char *const kAttentionDecoderAttention0Kernel = | |||||
| "embedding_attention_decoder/attention_decoder/Attention_0/kernel"; | |||||
| const char *const kAttentionDecoderAttention0Bias = "embedding_attention_decoder/attention_decoder/Attention_0/bias"; | |||||
| const char *const kAttentionDecoderAttnOutputProjectionKernel = | |||||
| "embedding_attention_decoder/attention_decoder/AttnOutputProjection/kernel"; | |||||
| const char *const kAttentionDecoderAttnOutputProjectionBias = | |||||
| "embedding_attention_decoder/attention_decoder/AttnOutputProjection/bias"; | |||||
| const char *const kHuberLossFill = "gradients/Fill"; | |||||
| const char *const kHuberLossConst = "huber_loss/Const"; | |||||
| const char *const kHuberLossMul2X = "huber_loss/Mul_2/x"; | |||||
| const char *const kSparseSoftmaxConst = "sparse_softmax_cross_entropy_loss/Const"; | |||||
| const char *const kDeeplabV3ConfusionMatrix = "Select"; | |||||
| const char *const kDeeplabV3ConfusionMatrix1 = "ToFloat_1"; | |||||
| const char *const kConstantFoldingSuffix = "ConstantFolding/"; | |||||
| } // namespace | |||||
| vector<string> const_op_update_vec = {kLstmCellKernelFw, | |||||
| kLstmCellKernelBw, | |||||
| kLstmCellBiasFw, | |||||
| kLstmCellBiasBw, | |||||
| kAttentionDecoderAttenW0, | |||||
| kAttentionDecoderAttention0Kernel, | |||||
| kAttentionDecoderAttnOutputProjectionKernel, | |||||
| kAttentionDecoderAttentionDecoderKernel, | |||||
| kAttentionDecoderCell0GatesKernel, | |||||
| kAttentionDecoderCell0CandidateKernel, | |||||
| kAttentionDecoderCell1GatesKernel, | |||||
| kAttentionDecoderCell1CandidateKernel, | |||||
| kAttentionDecoderAttention0Bias, | |||||
| kAttentionDecoderAttnOutputProjectionBias, | |||||
| kAttentionDecoderAtteBias, | |||||
| kAttentionDecoderCell0GatesBias, | |||||
| kAttentionDecoderCell0CandidateBias, | |||||
| kAttentionDecoderCell1GatesBias, | |||||
| kAttentionDecoderCell1CandidateBias, | |||||
| kAttentionDecoderEmbeeding, | |||||
| kAttentionDecoderAttenVa, | |||||
| kHuberLossFill, | |||||
| kHuberLossConst, | |||||
| kHuberLossMul2X, | |||||
| kSparseSoftmaxConst, | |||||
| kDeeplabV3ConfusionMatrix, | |||||
| kDeeplabV3ConfusionMatrix1}; | |||||
| static map<string, string> tensorflow_fusionop_map = { | |||||
| }; | |||||
| // <Types of fusion operators, Number of children operators> | |||||
| static map<string, vector<int>> tensorflow_fusionop_children_nums_map = { | |||||
| {ge::parser::CLIPBOXES, {8}}, | |||||
| {ge::parser::FASTRCNNPREDICTIONS, {118, 119, 120, 123, 125}}, | |||||
| {ge::parser::RPNPROPOSALS, {75, 85, 97}}, | |||||
| {ge::parser::DECODEBBOX, {24, 28}}, | |||||
| {ge::parser::ROIALIGN, {82, 83, 84}}, | |||||
| {ge::parser::FUSIONBATCHNORM, {8}}, | |||||
| {ge::parser::GETSPAN, {81, 71, 91}}, // The pbtxt only has 62 nodes when test GetSpan sub net. However the | |||||
| {ge::parser::HUBERLOSSGRAD, {8, 9, 10, 20, 21}}, | |||||
| }; | |||||
| // <Types of fusion operators, Name of children operators(Remove the prefixes and/)> | |||||
| static map<string, vector<string>> tensorflow_fusionop_children_names_map = { | |||||
| {ge::parser::FUSIONBATCHNORM, {"add/y", "add", "Rsqrt", "mul", "mul_1", "mul_2", "sub", "add_1"}}, | |||||
| {ge::parser::GETSPAN, {}}, | |||||
| {ge::parser::HUBERLOSSGRAD, {}}, | |||||
| }; | |||||
| // ----------------------------Index table of input and output of fusion operator-------------- | |||||
| // The specific operator is the input and output of the whole fusion operator, and the index number is specified | |||||
| // Considering that an operator may have multiple inputs / outputs, vector is used to save | |||||
| // search method: new_index=vector(old_index), | |||||
| // Generally, the old index is 0. If the new index value is kFusionDisableIndex, the edge can be ignored. | |||||
| // If it is control edge input, the index is graph::kControlSlot(-1). | |||||
| static map<string, vector<std::pair<string, vector<int32_t>>>> tensorflow_fusionop_inputs_map = { | |||||
| {ge::parser::FUSIONBATCHNORM, | |||||
| {{"mul_1", {0, kFusionDisableIndex}}, | |||||
| {"mul", {1, 1}}, | |||||
| {"sub", {2, kFusionDisableIndex}}, | |||||
| {"mul_2", {3, kFusionDisableIndex}}, | |||||
| {"add", {4, kFusionDisableIndex}}}}, | |||||
| {ge::parser::GETSPAN, {{"transpose", {0}}, {"TensorArray", {1}}, {"transpose_1", {2}}}}, | |||||
| {ge::parser::HUBERLOSSGRAD, {{"Sub_1_grad/Neg", {1}}, {"Abs_grad/Sign", {0}}}}, | |||||
| }; | |||||
| static map<string, vector<std::pair<string, vector<int32_t>>>> tensorflow_fusionop_outputs_map = { | |||||
| {ge::parser::FUSIONBATCHNORM, {{"add_1", {0}}}}, | |||||
| {ge::parser::GETSPAN, {{"while/Exit_1", {0}}, {"while/Exit_2", {1}}}}, | |||||
| {ge::parser::HUBERLOSSGRAD, {{"Abs_grad/mul", {0}}}}, | |||||
| }; | |||||
| map<string, vector<std::pair<string, uint32_t>>> tensorflow_fusionop_input_const_weight_index_map = { | |||||
| {ge::parser::FUSIONBATCHNORM, {{"mul", 0}, {"sub", 1}, {"mul_2", 2}, {"add", 3}}}, | |||||
| }; | |||||
| // Can a string be converted to an integer | |||||
| bool TensorFlowFunsionOPUtil::IsIntegerStr(const string &index_str) { | |||||
| try { | |||||
| if (std::stoi(index_str) > 0) { | |||||
| return true; | |||||
| } | |||||
| } catch (std::invalid_argument &) { | |||||
| GELOGE(FAILED, "index_str:%s is invalid", index_str.c_str()); | |||||
| } catch (std::out_of_range &) { | |||||
| GELOGE(FAILED, "index_str:%s is out of range", index_str.c_str()); | |||||
| } catch (...) { | |||||
| GELOGE(FAILED, "index_str:%s cannot change to int s", index_str.c_str()); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| // Get child node name of fusion operator. | |||||
| // eg: input: fastrcnn_predictions/map/TensorArray_2 output: map/TensorArray_2 | |||||
| string TensorFlowFunsionOPUtil::GetChildName(const string &node_name, const string &fusion_node_name) { | |||||
| GE_CHK_BOOL_EXEC_NOLOG( | |||||
| (node_name.length() - fusion_node_name.length()) > 0, GELOGW("fusion_node_name length not valid."); return "";); | |||||
| string child_name; | |||||
| string sub_name; | |||||
| // node_name begin with "ConstantFolding/" | |||||
| if (node_name.find(kConstantFoldingSuffix) == 0) { | |||||
| auto length = strlen(kConstantFoldingSuffix); | |||||
| sub_name = | |||||
| node_name.substr(fusion_node_name.length() + length, node_name.length() - fusion_node_name.length() - length); | |||||
| } else { | |||||
| sub_name = node_name.substr(fusion_node_name.length(), node_name.length() - fusion_node_name.length()); | |||||
| } | |||||
| auto index = sub_name.find('/'); | |||||
| if (index != string::npos) { | |||||
| child_name = sub_name.substr(index + 1, sub_name.length() - index - 1); | |||||
| } | |||||
| return child_name; | |||||
| } | |||||
| // Check whether the operator node name can be a fusion operator | |||||
| bool TensorFlowFunsionOPUtil::MaybeFusionOp(const string &node_name, ScopeFusionOpInfo *info) { | |||||
| GE_CHK_BOOL_EXEC(info != nullptr, return false, "info is null."); | |||||
| info->node_name = node_name; | |||||
| // Direct forward matching | |||||
| for (auto iter = tensorflow_fusionop_map.begin(); iter != tensorflow_fusionop_map.end(); ++iter) { | |||||
| const string fop_name = iter->first; | |||||
| string node_name_tmp = node_name; | |||||
| // begin with "ConstantFolding/" | |||||
| if (node_name_tmp.find(kConstantFoldingSuffix) == 0) { | |||||
| auto length = strlen(kConstantFoldingSuffix); | |||||
| node_name_tmp = node_name.substr(length, node_name.length() - length); | |||||
| } | |||||
| // not match | |||||
| if (node_name_tmp.find(fop_name) != 0) { | |||||
| continue; | |||||
| } | |||||
| // match,"FusionName/" scene: | |||||
| if (node_name_tmp.substr(fop_name.length(), 1) == string("/")) { | |||||
| info->fusion_node_name = fop_name; | |||||
| info->fusion_op_type = tensorflow_fusionop_map[fop_name]; | |||||
| info->description = ""; | |||||
| info->scope_pass = false; | |||||
| return true; | |||||
| } | |||||
| // match "FusionName_Index/" scene: | |||||
| // special characters need unified definition | |||||
| string sub_name = node_name_tmp.substr(fop_name.length(), node_name_tmp.length() - fop_name.length()); | |||||
| auto index = sub_name.find('/'); | |||||
| if ((sub_name.substr(0, 1) == string("_")) && (index > 1) && IsIntegerStr(sub_name.substr(1, index - 1))) { | |||||
| info->fusion_node_name = fop_name + sub_name.substr(0, index); | |||||
| info->fusion_op_type = tensorflow_fusionop_map[fop_name]; | |||||
| info->description = ""; | |||||
| info->scope_pass = false; | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| // Confirm whether it is a fusion operator | |||||
| bool TensorFlowFunsionOPUtil::IsFusionOp(const domi::tensorflow::NodeDef *node_def) { | |||||
| GE_CHK_BOOL_EXEC(node_def != nullptr, return false, "node_def is null."); | |||||
| string type = node_def->op(); | |||||
| auto iter = tensorflow_fusionop_children_nums_map.find(type); | |||||
| return iter != tensorflow_fusionop_children_nums_map.end(); | |||||
| } | |||||
| // Check the validity of fusion operator (all child nodes) | |||||
| Status TensorFlowFunsionOPUtil::CheckFusionOpChildren(const string &fusion_node_name, | |||||
| const vector<const domi::tensorflow::NodeDef *> &nodedef_list, | |||||
| const string &funsion_op_type) { | |||||
| // Number matching of fusion operators | |||||
| auto iter_children_nums = tensorflow_fusionop_children_nums_map.find(funsion_op_type); | |||||
| if (iter_children_nums == tensorflow_fusionop_children_nums_map.end()) { | |||||
| REPORT_INNER_ERROR("E19999", "Op[%s]'s optype[%s] not a Fusion OP, check invalid", | |||||
| fusion_node_name.c_str(), funsion_op_type.c_str()); | |||||
| GELOGE(domi::INTERNAL_ERROR, | |||||
| "Op[%s]'s optype[%s] not a Fusion OP!", fusion_node_name.c_str(), funsion_op_type.c_str()); | |||||
| return domi::INTERNAL_ERROR; | |||||
| } | |||||
| vector<int> children_nums = iter_children_nums->second; | |||||
| bool find = false; | |||||
| int children_num = nodedef_list.size(); | |||||
| for (uint32_t i = 0; i < children_nums.size(); i++) { | |||||
| if (children_nums[i] == children_num) { | |||||
| find = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!find) { | |||||
| REPORT_INNER_ERROR("E19999", "CheckFusionOp op[%s]'s optype[%s] children_nums[%d] is not the same for define", | |||||
| fusion_node_name.c_str(), funsion_op_type.c_str(), children_num); | |||||
| GELOGE(domi::INTERNAL_ERROR, | |||||
| "Op[%s]'s optype[%s] children_nums:%d is not the same for define.", | |||||
| fusion_node_name.c_str(), | |||||
| funsion_op_type.c_str(), | |||||
| children_num); | |||||
| return domi::INTERNAL_ERROR; | |||||
| } | |||||
| // Key children operators matching | |||||
| auto iter_children_names = tensorflow_fusionop_children_names_map.find(funsion_op_type); | |||||
| if (iter_children_names != tensorflow_fusionop_children_names_map.end()) { | |||||
| vector<string> children_names = iter_children_names->second; | |||||
| if (!children_names.empty()) { | |||||
| uint32_t count = 0; | |||||
| for (uint32_t i = 0; i < children_names.size(); i++) { | |||||
| for (uint32_t j = 0; j < nodedef_list.size(); j++) { | |||||
| const domi::tensorflow::NodeDef *node_def = nodedef_list[j]; | |||||
| GE_CHECK_NOTNULL(node_def); | |||||
| string node_name = node_def->name(); | |||||
| string child_name = GetChildName(node_name, fusion_node_name); | |||||
| if (children_names[i] == child_name) { | |||||
| count++; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| GE_IF_BOOL_EXEC(count != children_names.size(), | |||||
| REPORT_INNER_ERROR("E19999", "Op[%s]'s optype[%s] has no enough importance child.", fusion_node_name.c_str(), | |||||
| funsion_op_type.c_str()); | |||||
| GELOGE(domi::INTERNAL_ERROR, "Op[%s]'s optype[%s] has no enough importance child.", fusion_node_name.c_str(), | |||||
| funsion_op_type.c_str()); | |||||
| return domi::INTERNAL_ERROR;); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // Get the child node of the fusion operator as the input / output index number of the whole fusion operator | |||||
| Status TensorFlowFunsionOPUtil::GetNodeindex( | |||||
| const ScopeFusionOpInfo &info, const int32_t old_index, int32_t &new_index, | |||||
| const map<string, vector<std::pair<string, vector<int32_t>>>> &fusionop_context_map) { | |||||
| auto iter = fusionop_context_map.find(info.fusion_op_type); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iter == fusionop_context_map.end(), | |||||
| return domi::INTERNAL_ERROR, | |||||
| "Op[%s] could not find item of optype[%s] in fusionop_context_map", | |||||
| info.node_name.c_str(), info.fusion_op_type.c_str()); | |||||
| vector<std::pair<string, vector<int32_t>>> pairs = iter->second; | |||||
| string child_name = GetChildName(info.node_name, info.fusion_node_name); | |||||
| GELOGI("GetNodeindex: info.node_name:%s, old_index:%d", info.node_name.c_str(), old_index); | |||||
| for (const auto &pair : pairs) { | |||||
| if (pair.first == child_name) { | |||||
| vector<int32_t> indexs = pair.second; | |||||
| if (static_cast<int32_t>(indexs.size()) < (old_index + 1)) { | |||||
| new_index = kFusionDisableIndex; | |||||
| return SUCCESS; | |||||
| } | |||||
| if (old_index != -1) { | |||||
| new_index = indexs[old_index]; | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| } | |||||
| new_index = kFusionDisableIndex; | |||||
| return SUCCESS; | |||||
| } | |||||
| // Get the input index of the fusion operator | |||||
| Status TensorFlowFunsionOPUtil::GetInPutIndex(const ScopeFusionOpInfo &info, const int32_t old_index, | |||||
| int32_t &new_index) { | |||||
| return GetNodeindex(info, old_index, new_index, tensorflow_fusionop_inputs_map); | |||||
| } | |||||
| // Get the output index of the fusion operator | |||||
| Status TensorFlowFunsionOPUtil::GetOutPutIndex(const ScopeFusionOpInfo &info, const int32_t old_index, | |||||
| int32_t &new_index) { | |||||
| return GetNodeindex(info, old_index, new_index, tensorflow_fusionop_outputs_map); | |||||
| } | |||||
| bool TensorFlowFunsionOPUtil::FusionOpChildIgnore(const ScopeFusionOpInfo &info) { | |||||
| // If the small operator is not in the input and output index table of the fusion operator, | |||||
| // it is unnecessary to establish the edge relationship and can be ignored | |||||
| int32_t old_index = 0; | |||||
| int32_t in_new_index = 0; | |||||
| int32_t out_new_index = 0; | |||||
| GE_CHK_STATUS(GetInPutIndex(info, old_index, in_new_index), "GetInPutIndex failed"); | |||||
| GE_CHK_STATUS(GetOutPutIndex(info, old_index, out_new_index), "GetOutPutIndex failed"); | |||||
| return (in_new_index == kFusionDisableIndex) && (out_new_index == kFusionDisableIndex); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,130 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_FUSIONOP_UTIL_H_ | |||||
| #define GE_PARSER_TENSORFLOW_TENSORFLOW_FUSIONOP_UTIL_H_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "common/string_util.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "omg/omg_inner_types.h" | |||||
| #include "proto/tensorflow/graph.pb.h" | |||||
| #include "external/register/scope/scope_fusion_pass_register.h" | |||||
| #include "register/scope/scope_graph_impl.h" | |||||
| namespace ge { | |||||
| using std::string; | |||||
| using std::vector; | |||||
| extern map<string, vector<std::pair<string, uint32_t>>> tensorflow_fusionop_input_const_weight_index_map; | |||||
| extern vector<string> const_op_update_vec; | |||||
| class TensorFlowFunsionOPUtil { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Check whether the operator can be a fusion operator | |||||
| * @param [in] node_name operation name | |||||
| * @return info fusion operator description | |||||
| * @return true maybe | |||||
| * @return false maybe not | |||||
| * @author | |||||
| */ | |||||
| static bool MaybeFusionOp(const string &node_name, ScopeFusionOpInfo *info); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Confirm whether it is a fusion operator | |||||
| * @param [in] nodeDef | |||||
| * @return true | |||||
| * @return false | |||||
| * @author | |||||
| */ | |||||
| static bool IsFusionOp(const domi::tensorflow::NodeDef *node_def); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Check the validity of fusion operator(All child nodes) | |||||
| * @param [in] fusion_node_name fusion operator name | |||||
| * @param [in] nodedef_list child nodes list | |||||
| * @param [in] funsion_op_type fusion operator type | |||||
| * @return legal/illegal | |||||
| * @author | |||||
| */ | |||||
| static Status CheckFusionOpChildren(const string &fusion_node_name, | |||||
| const vector<const domi::tensorflow::NodeDef *> &nodedef_list, | |||||
| const string &funsion_op_type); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief get inPut index of the fusion operator | |||||
| * @param [in] info Child node description of fusion operator | |||||
| * @param [in] old_index Child node original index | |||||
| * @return old_index As input index of the fusion operator | |||||
| * @return return code | |||||
| * @author | |||||
| */ | |||||
| static Status GetInPutIndex(const ScopeFusionOpInfo &info, const int32_t old_index, int32_t &new_index); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief get outPut index of the fusion operator | |||||
| * @param [in] info Child node description of fusion operator | |||||
| * @param [in] old_index Child node original index | |||||
| * @return old_index As output index of the fusion operator | |||||
| * @return 返回码 | |||||
| * @author | |||||
| */ | |||||
| static Status GetOutPutIndex(const ScopeFusionOpInfo &info, const int32_t old_index, int32_t &new_index); | |||||
| static bool FusionOpChildIgnore(const ScopeFusionOpInfo &info); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Get child node name of fusion operator eg: input: fastrcnn_predictions/map/TensorArray_2 output | |||||
| * :map/TensorArray_2 | |||||
| * @param [in] node_name node name | |||||
| * @param [in] fusion_node_name fusion node name | |||||
| * @return Child node name of the fusion node | |||||
| * @author | |||||
| */ | |||||
| static string GetChildName(const string &node_name, const string &fusion_node_name); | |||||
| private: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief whether a string can be converted to an integer | |||||
| * @param [in] indexstr Operator suffix index | |||||
| * @return true can | |||||
| * @return false can not | |||||
| * @author | |||||
| */ | |||||
| static bool IsIntegerStr(const string &index_str); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Get child node of fusion operator | |||||
| * @param [in] info Description of fusion operator | |||||
| * @param [in] old_index original index | |||||
| * @return new_index Fusion operator index | |||||
| * @author | |||||
| */ | |||||
| static Status GetNodeindex(const ScopeFusionOpInfo &info, const int32_t old_index, int32_t &new_index, | |||||
| const std::map<string, vector<std::pair<string, vector<int32_t>>>> &fusionop_context_map); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_PARSER_TENSORFLOW_TENSORFLOW_FUSIONOP_UTIL_H_ | |||||
| @@ -48,15 +48,12 @@ | |||||
| #include "parser/tensorflow/tensorflow_custom_parser_adapter.h" | #include "parser/tensorflow/tensorflow_custom_parser_adapter.h" | ||||
| #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" | #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" | ||||
| #include "parser/tensorflow/tensorflow_fusion_op_parser.h" | #include "parser/tensorflow/tensorflow_fusion_op_parser.h" | ||||
| #include "parser/tensorflow/tensorflow_fusionop_util.h" | |||||
| #include "parser/tensorflow/tensorflow_op_parser.h" | #include "parser/tensorflow/tensorflow_op_parser.h" | ||||
| #include "parser/tensorflow/tensorflow_util.h" | #include "parser/tensorflow/tensorflow_util.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| #include "register/scope/scope_graph_impl.h" | |||||
| #include "register/scope/scope_pass_registry_impl.h" | #include "register/scope/scope_pass_registry_impl.h" | ||||
| #include "parser/common/auto_mapping_subgraph_io_index_func.h" | #include "parser/common/auto_mapping_subgraph_io_index_func.h" | ||||
| using ge::const_op_update_vec; | |||||
| using ge::OpParserFactory; | using ge::OpParserFactory; | ||||
| using ge::Pb2Json; | using ge::Pb2Json; | ||||
| using ge::PreChecker; | using ge::PreChecker; | ||||
| @@ -80,7 +77,6 @@ using ge::TENSORFLOWF_NODE_OP_SWITCH; | |||||
| using ge::TENSORFLOWF_NODE_OP_TRANSPOSE; | using ge::TENSORFLOWF_NODE_OP_TRANSPOSE; | ||||
| using ge::TENSORFLOWF_TENSOR_NCHW; | using ge::TENSORFLOWF_TENSOR_NCHW; | ||||
| using ge::TENSORFLOWF_TENSOR_NHWC; | using ge::TENSORFLOWF_TENSOR_NHWC; | ||||
| using ge::TensorFlowFunsionOPUtil; | |||||
| using ge::TensorFlowFusionCustomParserAdapter; | using ge::TensorFlowFusionCustomParserAdapter; | ||||
| using ge::TensorFlowFusionOpParser; | using ge::TensorFlowFusionOpParser; | ||||
| using ge::TensorFlowOpParser; | using ge::TensorFlowOpParser; | ||||
| @@ -1239,9 +1235,6 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g | |||||
| "add node failed."); | "add node failed."); | ||||
| } | } | ||||
| // Verify the validity of fusionop | |||||
| GE_RETURN_IF_ERROR(CheckFusionOpValid()); | |||||
| // The fusion operator has passed the verification. | // The fusion operator has passed the verification. | ||||
| // The errors of internal non key operators (which will be ignored later) | // The errors of internal non key operators (which will be ignored later) | ||||
| // do not affect the transformation of the whole model, | // do not affect the transformation of the whole model, | ||||
| @@ -1476,9 +1469,6 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||||
| GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(graph_def, node_def, op_node_name_list), has_error = true); | GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(graph_def, node_def, op_node_name_list), has_error = true); | ||||
| } | } | ||||
| // Verify the validity of fusionop | |||||
| GE_RETURN_IF_ERROR(CheckFusionOpValid()); | |||||
| // The fusion operator has passed the verification. | // The fusion operator has passed the verification. | ||||
| // The errors of internal non key operators (which will be ignored later) | // The errors of internal non key operators (which will be ignored later) | ||||
| // do not affect the transformation of the whole model, | // do not affect the transformation of the whole model, | ||||
| @@ -1764,8 +1754,7 @@ bool TensorFlowModelParser::MaybeFusionOp(shared_ptr<ge::ScopeGraph> &scope_grap | |||||
| ge::ScopeFusionOpInfo info; | ge::ScopeFusionOpInfo info; | ||||
| std::vector<ge::ScopeFusionOpInfo> info_list; | std::vector<ge::ScopeFusionOpInfo> info_list; | ||||
| auto &impl = scope_graph->impl_; | auto &impl = scope_graph->impl_; | ||||
| if (TensorFlowFunsionOPUtil::MaybeFusionOp(node_def->name(), &info) || | |||||
| impl->IsFusionOpChild(node_def->name(), info_list)) { | |||||
| if (impl->IsFusionOpChild(node_def->name(), info_list)) { | |||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| info_list.size() > 0, for (size_t i = 0; i < info_list.size(); ++i) { | info_list.size() > 0, for (size_t i = 0; i < info_list.size(); ++i) { | ||||
| fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].fusion_op_type); | fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].fusion_op_type); | ||||
| @@ -1821,9 +1810,6 @@ bool TensorFlowModelParser::FusionOpChildIgnore(shared_ptr<ge::ScopeGraph> &scop | |||||
| // Scope fusion strategy | // Scope fusion strategy | ||||
| auto &impl = scope_graph->impl_; | auto &impl = scope_graph->impl_; | ||||
| ignore = impl->FusionOpChildIgnore(info); | ignore = impl->FusionOpChildIgnore(info); | ||||
| } else { | |||||
| // Full match fusion strategy | |||||
| ignore = TensorFlowFunsionOPUtil::FusionOpChildIgnore(info); | |||||
| } | } | ||||
| return ignore; | return ignore; | ||||
| } | } | ||||
| @@ -1832,11 +1818,7 @@ bool TensorFlowModelParser::IsFusionOp(shared_ptr<ge::ScopeGraph> &scope_graph, | |||||
| const domi::tensorflow::NodeDef *node_def) { | const domi::tensorflow::NodeDef *node_def) { | ||||
| // The caller guarantees that the pointer is not null | // The caller guarantees that the pointer is not null | ||||
| auto &impl = scope_graph->impl_; | auto &impl = scope_graph->impl_; | ||||
| if (TensorFlowFunsionOPUtil::IsFusionOp(node_def) || impl->IsFusionOp(node_def)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| return (impl->IsFusionOp(node_def)); | |||||
| } | } | ||||
| Status TensorFlowModelParser::GetInPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info, | Status TensorFlowModelParser::GetInPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info, | ||||
| const int32_t old_index, int32_t &new_index) { | const int32_t old_index, int32_t &new_index) { | ||||
| @@ -1845,10 +1827,7 @@ Status TensorFlowModelParser::GetInPutIndex(shared_ptr<ge::ScopeGraph> &scope_gr | |||||
| if (info.scope_pass) { | if (info.scope_pass) { | ||||
| auto &impl = scope_graph->impl_; | auto &impl = scope_graph->impl_; | ||||
| ret = impl->GetInputOrOutputIndex(info, old_index, true, new_index); | ret = impl->GetInputOrOutputIndex(info, old_index, true, new_index); | ||||
| } else { | |||||
| ret = TensorFlowFunsionOPUtil::GetInPutIndex(info, old_index, new_index); | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| Status TensorFlowModelParser::GetOutPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info, | Status TensorFlowModelParser::GetOutPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info, | ||||
| @@ -1858,33 +1837,10 @@ Status TensorFlowModelParser::GetOutPutIndex(shared_ptr<ge::ScopeGraph> &scope_g | |||||
| if (info.scope_pass) { | if (info.scope_pass) { | ||||
| auto &impl = scope_graph->impl_; | auto &impl = scope_graph->impl_; | ||||
| ret = impl->GetInputOrOutputIndex(info, old_index, false, new_index); | ret = impl->GetInputOrOutputIndex(info, old_index, false, new_index); | ||||
| } else { | |||||
| ret = TensorFlowFunsionOPUtil::GetOutPutIndex(info, old_index, new_index); | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| Status TensorFlowModelParser::CheckFusionOpValid() { | |||||
| for (auto &iter : fusion_op_nodedef_map_) { | |||||
| const string fusion_node_name = iter.first; | |||||
| vector<const NodeDef *> nodedef_list = iter.second; | |||||
| vector<string> funsion_op_info = fusion_op_type_map_[fusion_node_name]; | |||||
| // vecotr index 0 is fusion_op_type | |||||
| const string funsion_op_type = funsion_op_info[0]; | |||||
| if (!fusion_op_policy_[fusion_node_name]) { | |||||
| // Check the validity of the fusion_op_nodedef_map children operator | |||||
| GE_RETURN_IF_ERROR( | |||||
| TensorFlowFunsionOPUtil::CheckFusionOpChildren(fusion_node_name, nodedef_list, funsion_op_type)); | |||||
| // Because there are many scenes in tensorflow graph, | |||||
| // in order to avoid the problem of omission, the error is returned directly. | |||||
| // In the future, functions like rollback can be implemented according to the definition of fusion operator | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool TensorFlowModelParser::ConstOpNeedUpdate(const string &op_name) { | bool TensorFlowModelParser::ConstOpNeedUpdate(const string &op_name) { | ||||
| if (nodedef_map_[op_name]->op() != TENSORFLOWF_NODE_OP_CONST) { | if (nodedef_map_[op_name]->op() != TENSORFLOWF_NODE_OP_CONST) { | ||||
| // Normal op need to update | // Normal op need to update | ||||
| @@ -1900,9 +1856,6 @@ bool TensorFlowModelParser::ConstOpNeedUpdate(const string &op_name) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| if (std::find(const_op_update_vec.begin(), const_op_update_vec.end(), op_name) == const_op_update_vec.end()) { | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -2336,9 +2289,6 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||||
| PARSER_TIMESTAMP_END(AddFmkNodeDefToMap, "TensorFlowModelParser::AddFmkNodeDefToMap"); | PARSER_TIMESTAMP_END(AddFmkNodeDefToMap, "TensorFlowModelParser::AddFmkNodeDefToMap"); | ||||
| GELOGI("[TF Parser] TF subgraph isDatasetInit: %d.", isDatasetInit); | GELOGI("[TF Parser] TF subgraph isDatasetInit: %d.", isDatasetInit); | ||||
| // Verify the validity of fusionop | |||||
| GE_RETURN_IF_ERROR(CheckFusionOpValid()); | |||||
| // Build input and output relationships for all OP nodes | // Build input and output relationships for all OP nodes | ||||
| PARSER_TIMESTAMP_START(GetOpNodesContextFromGraph); | PARSER_TIMESTAMP_START(GetOpNodesContextFromGraph); | ||||
| GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def)); | GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def)); | ||||
| @@ -36,12 +36,12 @@ | |||||
| #include "omg/parser/op_parser.h" | #include "omg/parser/op_parser.h" | ||||
| #include "omg/parser/weights_parser.h" | #include "omg/parser/weights_parser.h" | ||||
| #include "parser/tensorflow/tensorflow_fusion_op_parser.h" | #include "parser/tensorflow/tensorflow_fusion_op_parser.h" | ||||
| #include "parser/tensorflow/tensorflow_fusionop_util.h" | |||||
| #include "parser/tensorflow/tensorflow_util.h" | #include "parser/tensorflow/tensorflow_util.h" | ||||
| #include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
| #include "proto/tensorflow/graph.pb.h" | #include "proto/tensorflow/graph.pb.h" | ||||
| #include "proto/tensorflow/node_def.pb.h" | #include "proto/tensorflow/node_def.pb.h" | ||||
| #include "proto/tensorflow/graph_library.pb.h" | #include "proto/tensorflow/graph_library.pb.h" | ||||
| #include "register/scope/scope_graph_impl.h" | |||||
| #include "external/register/scope/scope_fusion_pass_register.h" | #include "external/register/scope/scope_fusion_pass_register.h" | ||||
| #include "scope/scope_pass_manager.h" | #include "scope/scope_pass_manager.h" | ||||
| #include "common/parser_utils.h" | #include "common/parser_utils.h" | ||||
| @@ -312,15 +312,6 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| const ge::ScopeFusionOpInfo &info, | const ge::ScopeFusionOpInfo &info, | ||||
| const int32_t old_index, | const int32_t old_index, | ||||
| int32_t &new_index); | int32_t &new_index); | ||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Check the validity of fusionop,put it into op_node_name_list if Misjudgement | |||||
| * @param op_node_name_list | |||||
| * @return SUCCESS check successfully | |||||
| * @return FAILED check failed | |||||
| */ | |||||
| Status CheckFusionOpValid(); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -0,0 +1,174 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MAIN_OPS_STUB_H | |||||
| #define MAIN_OPS_STUB_H | |||||
| #include "external/graph/operator_reg.h" | |||||
| #include "register/op_registry.h" | |||||
| namespace ge { | |||||
| // for ir | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Variable) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .ATTR(value, Tensor, Tensor()) | |||||
| .OP_END_FACTORY_REG(Variable) | |||||
| REG_OP(Const) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(value, Tensor, Tensor()) | |||||
| .ATTR(dtype, Int, 0) | |||||
| .OP_END_FACTORY_REG(Const) | |||||
| REG_OP(Assign) | |||||
| .INPUT(resource, TensorType::ALL()) | |||||
| .INPUT(value, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .OP_END_FACTORY_REG(Assign) REG_OP(Sqrt) | |||||
| .INPUT(x, TensorType{(DT_FLOAT.DT_FLOAT16)}) | |||||
| .OUTPUT(y, TensorType{(DT_FLOAT, DT_FLOAT16)}) | |||||
| .ATTR(T, Int, 1) | |||||
| .ATTR(alpha, Float, 1.0) | |||||
| .ATTR(beta, Float, 0.0) | |||||
| .OP_END_FACTORY_REG(Sqrt) | |||||
| REG_OP(Conv2D) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||||
| .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||||
| .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||||
| .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||||
| .REQUIRED_ATTR(strides, ListInt) | |||||
| .REQUIRED_ATTR(pads, ListInt) | |||||
| .ATTR(dilations, ListInt, {1, 1, 1, 1}) | |||||
| .ATTR(groups, Int, 1) | |||||
| .ATTR(data_format, String, "NHWC") | |||||
| .ATTR(offset_x, Int, 0) | |||||
| .OP_END_FACTORY_REG(Conv2D) | |||||
| REG_OP(If) | |||||
| .INPUT(cond, TensorType::ALL()) | |||||
| .DYNAMIC_INPUT(input, TensorType::ALL()) | |||||
| .DYNAMIC_OUTPUT(output, TensorType::ALL()) | |||||
| .GRAPH(then_branch) | |||||
| .GRAPH(else_branch) | |||||
| .OP_END_FACTORY_REG(If) | |||||
| REG_OP(Add) | |||||
| .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OP_END_FACTORY_REG(Add) | |||||
| REG_OP(Identity) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OP_END_FACTORY_REG(Identity) | |||||
| REG_OP(Abs) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | |||||
| .OP_END_FACTORY_REG(Abs) | |||||
| // for plugin | |||||
| static Status ParseParamsStub(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| static Status ParseParamByOpFuncStub(const ge::Operator &op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| static Status ParseSubgraphPostFnIfStub(const std::string& subgraph_name, const ge::Graph& graph) { | |||||
| domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func = | |||||
| domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX); | |||||
| if (auto_mapping_subgraph_index_func == nullptr) { | |||||
| std::cout<<"auto mapping if subgraph func is nullptr!"<<std::endl; | |||||
| return FAILED; | |||||
| } | |||||
| return auto_mapping_subgraph_index_func(graph, | |||||
| [&](int data_index, int &parent_index) -> Status { | |||||
| parent_index = data_index + 1; | |||||
| return SUCCESS; | |||||
| }, | |||||
| [&](int output_index, int &parent_index) -> Status { | |||||
| parent_index = output_index; | |||||
| return SUCCESS; | |||||
| }); | |||||
| } | |||||
| // caffe plugin | |||||
| REGISTER_CUSTOM_OP("Data") | |||||
| .FrameworkType(domi::CAFFE) | |||||
| .OriginOpType("Input") | |||||
| .ParseParamsFn(ParseParamsStub); | |||||
| REGISTER_CUSTOM_OP("Abs") | |||||
| .FrameworkType(domi::CAFFE) | |||||
| .OriginOpType("AbsVal") | |||||
| .ParseParamsFn(ParseParamsStub); | |||||
| // onnx plugin | |||||
| REGISTER_CUSTOM_OP("Conv2D") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Conv") | |||||
| .ParseParamsFn(ParseParamsStub); | |||||
| REGISTER_CUSTOM_OP("If") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType({"ai.onnx::9::If", | |||||
| "ai.onnx::10::If", | |||||
| "ai.onnx::11::If", | |||||
| "ai.onnx::12::If", | |||||
| "ai.onnx::13::If"}) | |||||
| .ParseParamsFn(ParseParamsStub) | |||||
| .ParseParamsByOperatorFn(ParseParamByOpFuncStub) | |||||
| .ParseSubgraphPostFn(ParseSubgraphPostFnIfStub); | |||||
| REGISTER_CUSTOM_OP("Add") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Add") | |||||
| .ParseParamsFn(ParseParamsStub); | |||||
| REGISTER_CUSTOM_OP("Identity") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Identity") | |||||
| .ParseParamsFn(ParseParamsStub); | |||||
| // tf plugin | |||||
| REGISTER_CUSTOM_OP("Add") | |||||
| .FrameworkType(domi::TENSORFLOW) | |||||
| .OriginOpType("Add") | |||||
| .ParseParamsFn(ParseParamsStub); | |||||
| } // namespace ge | |||||
| #endif // MAIN_OPS_STUB_H | |||||
| @@ -290,7 +290,6 @@ set(PARSER_SRC_FILES | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_enter_parser.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_enter_parser.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fill_parser.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_fill_parser.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_frameworkop_parser.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_frameworkop_parser.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fusionop_util.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_op_parser.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_op_parser.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_identity_parser.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_identity_parser.cc" | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "external/parser/caffe_parser.h" | #include "external/parser/caffe_parser.h" | ||||
| #include "st/parser_st_utils.h" | #include "st/parser_st_utils.h" | ||||
| #include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
| #include "tests/depends/ops_stub/ops_stub.h" | |||||
| namespace ge { | namespace ge { | ||||
| class STestCaffeParser : public testing::Test { | class STestCaffeParser : public testing::Test { | ||||
| @@ -61,19 +62,6 @@ void STestCaffeParser::RegisterCustomOp() { | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
| } | } | ||||
| namespace { | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Abs) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | |||||
| .OP_END_FACTORY_REG(Abs) | |||||
| } | |||||
| TEST_F(STestCaffeParser, caffe_parser_user_output_with_default) { | TEST_F(STestCaffeParser, caffe_parser_user_output_with_default) { | ||||
| std::string case_dir = __FILE__; | std::string case_dir = __FILE__; | ||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | ||||
| @@ -99,4 +87,4 @@ TEST_F(STestCaffeParser, caffe_parser_user_output_with_default) { | |||||
| EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); | EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); | ||||
| } | } | ||||
| } // namespace ge | |||||
| } // namespace ge | |||||
| @@ -23,6 +23,7 @@ | |||||
| #include "external/parser/onnx_parser.h" | #include "external/parser/onnx_parser.h" | ||||
| #include "st/parser_st_utils.h" | #include "st/parser_st_utils.h" | ||||
| #include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
| #include "tests/depends/ops_stub/ops_stub.h" | |||||
| namespace ge { | namespace ge { | ||||
| class STestOnnxParser : public testing::Test { | class STestOnnxParser : public testing::Test { | ||||
| @@ -100,61 +101,6 @@ void STestOnnxParser::RegisterCustomOp() { | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
| } | } | ||||
| namespace { | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Const) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
| DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .ATTR(value, Tensor, Tensor()) | |||||
| .OP_END_FACTORY_REG(Const) | |||||
| REG_OP(Conv2D) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||||
| .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||||
| .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||||
| .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||||
| .REQUIRED_ATTR(strides, ListInt) | |||||
| .REQUIRED_ATTR(pads, ListInt) | |||||
| .ATTR(dilations, ListInt, {1, 1, 1, 1}) | |||||
| .ATTR(groups, Int, 1) | |||||
| .ATTR(data_format, String, "NHWC") | |||||
| .ATTR(offset_x, Int, 0) | |||||
| .OP_END_FACTORY_REG(Conv2D) | |||||
| REG_OP(If) | |||||
| .INPUT(cond, TensorType::ALL()) | |||||
| .DYNAMIC_INPUT(input, TensorType::ALL()) | |||||
| .DYNAMIC_OUTPUT(output, TensorType::ALL()) | |||||
| .GRAPH(then_branch) | |||||
| .GRAPH(else_branch) | |||||
| .OP_END_FACTORY_REG(If) | |||||
| REG_OP(Add) | |||||
| .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OP_END_FACTORY_REG(Add) | |||||
| REG_OP(Identity) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OP_END_FACTORY_REG(Identity) | |||||
| } | |||||
| TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) { | TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) { | ||||
| std::string case_dir = __FILE__; | std::string case_dir = __FILE__; | ||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | ||||
| @@ -182,4 +128,4 @@ TEST_F(STestOnnxParser, onnx_parser_if_node) { | |||||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | ||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | EXPECT_EQ(ret, GRAPH_SUCCESS); | ||||
| } | } | ||||
| } // namespace ge | |||||
| } // namespace ge | |||||
| @@ -22,6 +22,7 @@ | |||||
| #include "parser/common/register_tbe.h" | #include "parser/common/register_tbe.h" | ||||
| #include "external/parser/tensorflow_parser.h" | #include "external/parser/tensorflow_parser.h" | ||||
| #include "st/parser_st_utils.h" | #include "st/parser_st_utils.h" | ||||
| #include "tests/depends/ops_stub/ops_stub.h" | |||||
| namespace ge { | namespace ge { | ||||
| class STestTensorflowParser : public testing::Test { | class STestTensorflowParser : public testing::Test { | ||||
| @@ -54,26 +55,6 @@ void STestTensorflowParser::RegisterCustomOp() { | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
| } | } | ||||
| namespace { | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Add) | |||||
| .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OP_END_FACTORY_REG(Add) | |||||
| } | |||||
| TEST_F(STestTensorflowParser, tensorflow_parser_success) { | TEST_F(STestTensorflowParser, tensorflow_parser_success) { | ||||
| RegisterCustomOp(); | RegisterCustomOp(); | ||||
| @@ -94,4 +75,4 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) { | |||||
| ASSERT_EQ(net_out_name.size(), 1); | ASSERT_EQ(net_out_name.size(), 1); | ||||
| EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); | EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); | ||||
| } | } | ||||
| } // namespace ge | |||||
| } // namespace ge | |||||
| @@ -291,7 +291,6 @@ set(PARSER_SRC_FILES | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_enter_parser.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_enter_parser.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fill_parser.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_fill_parser.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_frameworkop_parser.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_frameworkop_parser.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fusionop_util.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_op_parser.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_op_parser.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_identity_parser.cc" | "${PARSER_DIR}/parser/tensorflow/tensorflow_identity_parser.cc" | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "external/parser/caffe_parser.h" | #include "external/parser/caffe_parser.h" | ||||
| #include "ut/parser/parser_ut_utils.h" | #include "ut/parser/parser_ut_utils.h" | ||||
| #include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
| #include "tests/depends/ops_stub/ops_stub.h" | |||||
| namespace ge { | namespace ge { | ||||
| class UtestCaffeParser : public testing::Test { | class UtestCaffeParser : public testing::Test { | ||||
| @@ -41,20 +42,7 @@ class UtestCaffeParser : public testing::Test { | |||||
| void RegisterCustomOp(); | void RegisterCustomOp(); | ||||
| }; | }; | ||||
| static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| void UtestCaffeParser::RegisterCustomOp() { | void UtestCaffeParser::RegisterCustomOp() { | ||||
| REGISTER_CUSTOM_OP("Data") | |||||
| .FrameworkType(domi::CAFFE) | |||||
| .OriginOpType("Input") | |||||
| .ParseParamsFn(ParseParams); | |||||
| REGISTER_CUSTOM_OP("Abs") | |||||
| .FrameworkType(domi::CAFFE) | |||||
| .OriginOpType("AbsVal") | |||||
| .ParseParamsFn(ParseParams); | |||||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | ||||
| for (auto reg_data : reg_datas) { | for (auto reg_data : reg_datas) { | ||||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | OpRegistrationTbe::Instance()->Finalize(reg_data); | ||||
| @@ -63,19 +51,6 @@ void UtestCaffeParser::RegisterCustomOp() { | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
| } | } | ||||
| namespace { | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Abs) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | |||||
| .OP_END_FACTORY_REG(Abs) | |||||
| } | |||||
| TEST_F(UtestCaffeParser, caffe_parser_user_output_with_name_and_index) { | TEST_F(UtestCaffeParser, caffe_parser_user_output_with_name_and_index) { | ||||
| std::string case_dir = __FILE__; | std::string case_dir = __FILE__; | ||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | ||||
| @@ -155,4 +130,4 @@ TEST_F(UtestCaffeParser, caffe_parser_user_output_with_default) { | |||||
| EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); | EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); | ||||
| } | } | ||||
| } // namespace ge | |||||
| } // namespace ge | |||||
| @@ -24,6 +24,7 @@ | |||||
| #include "external/parser/onnx_parser.h" | #include "external/parser/onnx_parser.h" | ||||
| #include "ut/parser/parser_ut_utils.h" | #include "ut/parser/parser_ut_utils.h" | ||||
| #include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
| #include "tests/depends/ops_stub/ops_stub.h" | |||||
| namespace ge { | namespace ge { | ||||
| class UtestOnnxParser : public testing::Test { | class UtestOnnxParser : public testing::Test { | ||||
| @@ -39,60 +40,7 @@ class UtestOnnxParser : public testing::Test { | |||||
| void RegisterCustomOp(); | void RegisterCustomOp(); | ||||
| }; | }; | ||||
| static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) { | |||||
| domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func = | |||||
| domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX); | |||||
| if (auto_mapping_subgraph_index_func == nullptr) { | |||||
| std::cout<<"auto mapping if subgraph func is nullptr!"<<std::endl; | |||||
| return FAILED; | |||||
| } | |||||
| return auto_mapping_subgraph_index_func(graph, | |||||
| [&](int data_index, int &parent_index) -> Status { | |||||
| parent_index = data_index + 1; | |||||
| return SUCCESS; | |||||
| }, | |||||
| [&](int output_index, int &parent_index) -> Status { | |||||
| parent_index = output_index; | |||||
| return SUCCESS; | |||||
| }); | |||||
| } | |||||
| void UtestOnnxParser::RegisterCustomOp() { | void UtestOnnxParser::RegisterCustomOp() { | ||||
| REGISTER_CUSTOM_OP("Conv2D") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Conv") | |||||
| .ParseParamsFn(ParseParams); | |||||
| // register if op info to GE | |||||
| REGISTER_CUSTOM_OP("If") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType({"ai.onnx::9::If", | |||||
| "ai.onnx::10::If", | |||||
| "ai.onnx::11::If", | |||||
| "ai.onnx::12::If", | |||||
| "ai.onnx::13::If"}) | |||||
| .ParseParamsFn(ParseParams) | |||||
| .ParseParamsByOperatorFn(ParseParamByOpFunc) | |||||
| .ParseSubgraphPostFn(ParseSubgraphPostFnIf); | |||||
| REGISTER_CUSTOM_OP("Add") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Add") | |||||
| .ParseParamsFn(ParseParams); | |||||
| REGISTER_CUSTOM_OP("Identity") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Identity") | |||||
| .ParseParamsFn(ParseParams); | |||||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | ||||
| for (auto reg_data : reg_datas) { | for (auto reg_data : reg_datas) { | ||||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | OpRegistrationTbe::Instance()->Finalize(reg_data); | ||||
| @@ -101,61 +49,6 @@ void UtestOnnxParser::RegisterCustomOp() { | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
| } | } | ||||
| namespace { | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Const) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
| DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .ATTR(value, Tensor, Tensor()) | |||||
| .OP_END_FACTORY_REG(Const) | |||||
| REG_OP(Conv2D) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||||
| .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||||
| .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||||
| .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||||
| .REQUIRED_ATTR(strides, ListInt) | |||||
| .REQUIRED_ATTR(pads, ListInt) | |||||
| .ATTR(dilations, ListInt, {1, 1, 1, 1}) | |||||
| .ATTR(groups, Int, 1) | |||||
| .ATTR(data_format, String, "NHWC") | |||||
| .ATTR(offset_x, Int, 0) | |||||
| .OP_END_FACTORY_REG(Conv2D) | |||||
| REG_OP(If) | |||||
| .INPUT(cond, TensorType::ALL()) | |||||
| .DYNAMIC_INPUT(input, TensorType::ALL()) | |||||
| .DYNAMIC_OUTPUT(output, TensorType::ALL()) | |||||
| .GRAPH(then_branch) | |||||
| .GRAPH(else_branch) | |||||
| .OP_END_FACTORY_REG(If) | |||||
| REG_OP(Add) | |||||
| .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OP_END_FACTORY_REG(Add) | |||||
| REG_OP(Identity) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OP_END_FACTORY_REG(Identity) | |||||
| } | |||||
| TEST_F(UtestOnnxParser, onnx_parser_if_node) { | TEST_F(UtestOnnxParser, onnx_parser_if_node) { | ||||
| std::string case_dir = __FILE__; | std::string case_dir = __FILE__; | ||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | ||||
| @@ -233,4 +126,4 @@ TEST_F(UtestOnnxParser, onnx_parser_user_output_with_tensor_failed) { | |||||
| EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
| } | } | ||||
| } // namespace ge | |||||
| } // namespace ge | |||||
| @@ -30,6 +30,7 @@ | |||||
| #include "external/parser/tensorflow_parser.h" | #include "external/parser/tensorflow_parser.h" | ||||
| #include "ut/parser/parser_ut_utils.h" | #include "ut/parser/parser_ut_utils.h" | ||||
| #include "graph/model.h" | #include "graph/model.h" | ||||
| #include "tests/depends/ops_stub/ops_stub.h" | |||||
| namespace ge { | namespace ge { | ||||
| class UtestTensorflowParser : public testing::Test { | class UtestTensorflowParser : public testing::Test { | ||||
| @@ -44,16 +45,7 @@ class UtestTensorflowParser : public testing::Test { | |||||
| void RegisterCustomOp(); | void RegisterCustomOp(); | ||||
| }; | }; | ||||
| static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| void UtestTensorflowParser::RegisterCustomOp() { | void UtestTensorflowParser::RegisterCustomOp() { | ||||
| REGISTER_CUSTOM_OP("Add") | |||||
| .FrameworkType(domi::TENSORFLOW) | |||||
| .OriginOpType("Add") | |||||
| .ParseParamsFn(ParseParams); | |||||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | ||||
| for (auto reg_data : reg_datas) { | for (auto reg_data : reg_datas) { | ||||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | OpRegistrationTbe::Instance()->Finalize(reg_data); | ||||
| @@ -62,26 +54,6 @@ void UtestTensorflowParser::RegisterCustomOp() { | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
| } | } | ||||
| namespace { | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Add) | |||||
| .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OP_END_FACTORY_REG(Add) | |||||
| } | |||||
| TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | ||||
| RegisterCustomOp(); | RegisterCustomOp(); | ||||
| @@ -216,4 +188,4 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_with_external_graph) { | |||||
| ret = TensorFlowModelParser::AddExternalGraph(root_graph); | ret = TensorFlowModelParser::AddExternalGraph(root_graph); | ||||
| EXPECT_EQ(ret, INTERNAL_ERROR); | EXPECT_EQ(ret, INTERNAL_ERROR); | ||||
| } | } | ||||
| } // namespace ge | |||||
| } // namespace ge | |||||