| @@ -915,8 +915,9 @@ Status TensorFlowModelParser::CheckOpType(const domi::tensorflow::NodeDef *node_ | |||||
| GE_CHK_STATUS_RET(CheckOpShapeDim(node_def, check_dims[op_type], valid), "failed to check op shape"); | GE_CHK_STATUS_RET(CheckOpShapeDim(node_def, check_dims[op_type], valid), "failed to check op shape"); | ||||
| GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP; | GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP; | ||||
| GELOGI("Set op %s to frameworkop", node_name.c_str()); | GELOGI("Set op %s to frameworkop", node_name.c_str()); | ||||
| framework_ops_[node_name] = node_def;); | |||||
| ); | |||||
| framework_ops_[node_name] = node_def; | |||||
| ); | |||||
| ); | |||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| op_type == ge::parser::ADD || op_type == ge::parser::MULTIPLY || op_type == ge::parser::MEAN, | op_type == ge::parser::ADD || op_type == ge::parser::MULTIPLY || op_type == ge::parser::MEAN, | ||||
| @@ -1777,8 +1778,8 @@ bool TensorFlowModelParser::MaybeFusionOp(shared_ptr<ge::ScopeGraph> &scope_grap | |||||
| std::vector<ge::ScopeFusionOpInfo> info_list; | std::vector<ge::ScopeFusionOpInfo> info_list; | ||||
| auto &impl = scope_graph->impl_; | auto &impl = scope_graph->impl_; | ||||
| if (impl->IsFusionOpChild(node_def->name(), info_list)) { | if (impl->IsFusionOpChild(node_def->name(), info_list)) { | ||||
| GE_IF_BOOL_EXEC( | |||||
| info_list.size() > 0, for (size_t i = 0; i < info_list.size(); ++i) { | |||||
| GE_IF_BOOL_EXEC(info_list.size() > 0, | |||||
| for (size_t i = 0; i < info_list.size(); ++i) { | |||||
| fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].fusion_op_type); | fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].fusion_op_type); | ||||
| fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].description); | fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].description); | ||||
| fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(node_def); | fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(node_def); | ||||
| @@ -3480,35 +3481,35 @@ void TensorFlowModelParser::RemoveInputAttr(domi::tensorflow::NodeDef *node_def, | |||||
| attr_map->find(ge::ATTR_NAME_INPUT_TENSOR_DESC); | attr_map->find(ge::ATTR_NAME_INPUT_TENSOR_DESC); | ||||
| if (it == attr_map->end()) { | if (it == attr_map->end()) { | ||||
| GELOGW("Failed to find input desc from tf node_def[%s]", node_def->name().c_str()); | GELOGW("Failed to find input desc from tf node_def[%s]", node_def->name().c_str()); | ||||
| } else { | |||||
| domi::tensorflow::AttrValue *input_attr_value = &(it->second); | |||||
| auto tmp_attr = input_attr_value->mutable_list()->mutable_func(); | |||||
| auto attr_it = tmp_attr->begin(); | |||||
| int index = 0; | |||||
| for (auto input_it = inputs->begin(); input_it != inputs->end(); ++input_it, ++index) { | |||||
| // 1.decide whether to remove the input | |||||
| bool flag = false; | |||||
| for (auto &remove_input : remove_inputs_map) { | |||||
| string remove_input_name = remove_input.first; | |||||
| vector<int> remove_input_indexs = remove_input.second; | |||||
| if ((*input_it) == remove_input_name && | |||||
| std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) { | |||||
| GELOGD("Remove input attr:%s, index:%d", remove_input_name.c_str(), index); | |||||
| flag = true; | |||||
| break; | |||||
| } | |||||
| return; | |||||
| } | |||||
| domi::tensorflow::AttrValue *input_attr_value = &(it->second); | |||||
| auto tmp_attr = input_attr_value->mutable_list()->mutable_func(); | |||||
| auto attr_it = tmp_attr->begin(); | |||||
| int index = 0; | |||||
| for (auto input_it = inputs->begin(); input_it != inputs->end(); ++input_it, ++index) { | |||||
| // 1.decide whether to remove the input | |||||
| bool flag = false; | |||||
| for (auto &remove_input : remove_inputs_map) { | |||||
| string remove_input_name = remove_input.first; | |||||
| vector<int> remove_input_indexs = remove_input.second; | |||||
| if ((*input_it) == remove_input_name && | |||||
| std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) { | |||||
| GELOGD("Remove input attr:%s, index:%d", remove_input_name.c_str(), index); | |||||
| flag = true; | |||||
| break; | |||||
| } | } | ||||
| } | |||||
| if (flag) { | |||||
| // 2.1 remove the input attr | |||||
| if (!tmp_attr->empty() && (attr_it != tmp_attr->end())) { | |||||
| attr_it = tmp_attr->erase(attr_it); | |||||
| } else { | |||||
| ++attr_it; | |||||
| } | |||||
| if (flag) { | |||||
| // 2.1 remove the input attr | |||||
| if (!tmp_attr->empty() && (attr_it != tmp_attr->end())) { | |||||
| attr_it = tmp_attr->erase(attr_it); | |||||
| } else { | } else { | ||||
| ++attr_it; | ++attr_it; | ||||
| } | } | ||||
| } else { | |||||
| ++attr_it; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -64,29 +64,13 @@ Status TensorFlowRefSwitchParser::ParseParams(const Message *op_src, ge::OpDescP | |||||
| op.Name(node->name()); | op.Name(node->name()); | ||||
| GELOGI("RefSwitch Op %s ParseParams Begin.", node->name().c_str()); | GELOGI("RefSwitch Op %s ParseParams Begin.", node->name().c_str()); | ||||
| GE_RETURN_IF_ERROR(PreParseParams(node, &op)); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(ParseT(node, &op), "Parse T for node %s failed.", node->name().c_str()); | GE_RETURN_WITH_LOG_IF_ERROR(ParseT(node, &op), "Parse T for node %s failed.", node->name().c_str()); | ||||
| GE_RETURN_IF_ERROR(PostParseParams(node, &op)); | |||||
| Status status = ConvertToOpDesc(op, op_dest); | Status status = ConvertToOpDesc(op, op_dest); | ||||
| return status; | return status; | ||||
| } | } | ||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| Status TensorFlowRefSwitchParser::PreParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op) { | |||||
| (void)node; | |||||
| (void)op; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TensorFlowRefSwitchParser::PostParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op) { | |||||
| (void)node; | |||||
| (void)op; | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_OP_PARSER_CREATOR(TENSORFLOW, REFSWITCH, TensorFlowRefSwitchParser); | REGISTER_OP_PARSER_CREATOR(TENSORFLOW, REFSWITCH, TensorFlowRefSwitchParser); | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -35,26 +35,6 @@ class PARSER_FUNC_VISIBILITY TensorFlowRefSwitchParser : public TensorFlowOpPars | |||||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | ||||
| protected: | protected: | ||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief 解析模型文件信息 | |||||
| * @param [in] v_input_const 待解析的模型数据 | |||||
| * @param [out] node 解析后的模型数据 | |||||
| * @return SUCCESS 解析成功 | |||||
| * @return FAILED 解析失败 | |||||
| */ | |||||
| Status PreParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief 解析模型文件信息 | |||||
| * @param [in] v_input_const 待解析的模型数据 | |||||
| * @param [out] node 解析后的模型数据 | |||||
| * @return SUCCESS 解析成功 | |||||
| * @return FAILED 解析失败 | |||||
| */ | |||||
| Status PostParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| * @brief 解析模型文件信息 | * @brief 解析模型文件信息 | ||||
| @@ -100,16 +100,12 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
| ShapeNOperator op; | ShapeNOperator op; | ||||
| op.Name(node->name()); | op.Name(node->name()); | ||||
| GE_RETURN_IF_ERROR(PreParseParams(node, &op)); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(ParseInType(node, &op), "Parse in type for node %s failed.", node->name().c_str()); | GE_RETURN_WITH_LOG_IF_ERROR(ParseInType(node, &op), "Parse in type for node %s failed.", node->name().c_str()); | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(ParseN(node, &op), "Parse N for node %s failed.", node->name().c_str()); | GE_RETURN_WITH_LOG_IF_ERROR(ParseN(node, &op), "Parse N for node %s failed.", node->name().c_str()); | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(ParseOutType(node, &op), "Parse out type for node %s failed.", node->name().c_str()); | GE_RETURN_WITH_LOG_IF_ERROR(ParseOutType(node, &op), "Parse out type for node %s failed.", node->name().c_str()); | ||||
| GE_RETURN_IF_ERROR(PostParseParams(node, &op)); | |||||
| // add dynamic input/output | // add dynamic input/output | ||||
| domi::tensorflow::AttrValue attr_num; | domi::tensorflow::AttrValue attr_num; | ||||
| CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr_num), | CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr_num), | ||||
| @@ -154,18 +150,5 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| Status TensorFlowShapeNParser::PreParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op) { | |||||
| (void)node; | |||||
| (void)op; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TensorFlowShapeNParser::PostParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op) { | |||||
| (void)node; | |||||
| (void)op; | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SHAPEN, TensorFlowShapeNParser); | REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SHAPEN, TensorFlowShapeNParser); | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -27,9 +27,6 @@ class PARSER_FUNC_VISIBILITY TensorFlowShapeNParser : public TensorFlowOpParser | |||||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | ||||
| protected: | protected: | ||||
| Status PreParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op); | |||||
| Status PostParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op); | |||||
| static Status ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | static Status ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | ||||
| static Status ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | static Status ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | ||||
| static Status ParseOutType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | static Status ParseOutType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | ||||