| @@ -424,7 +424,7 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph | |||
| } | |||
| // Analysis of nodedef of original tensorflow | |||
| ge::GeAttrValue::BYTES nodedef_bytes; | |||
| ge::Buffer nodedef_bytes; | |||
| GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetBytes(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, nodedef_bytes), | |||
| PARAM_INVALID, "Get type attr nodedef failed."); | |||
| domi::tensorflow::NodeDef node_def_; | |||
| @@ -1,183 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_ | |||
| #define GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "common/fmk_types.h" | |||
| #include "framework/omg/parser/parser_types.h" | |||
| #include "graph/compute_graph.h" | |||
| #include "graph/node.h" | |||
| #include "graph/types.h" | |||
| #include "graph/utils/attr_utils.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/utils/tensor_utils.h" | |||
| #include "register/op_registry.h" | |||
| namespace ge { | |||
| enum InFmtSupportEnum { | |||
| InFmtSupportUndefined, | |||
| InFmtSupportElewise, | |||
| InFmtSupport4D, | |||
| InFmtSupport5D, | |||
| InFmtSupport4D_5D, | |||
| InFmtSupportNCHW_NC1HWC0 | |||
| }; | |||
| enum InDtSupportEnum { | |||
| InDtSupportUndefined = 0, | |||
| InDtSupportAll = 1, | |||
| }; | |||
| enum OutFmtSupportEnum { | |||
| OutFmtSupportUndefined = 0, | |||
| OutFmtSupportAsInput = 1, | |||
| }; | |||
| enum OutDtSupportEnum { | |||
| OutDtSupportUndefined = 0, | |||
| OutDtSupportAsInput = 1, | |||
| }; | |||
| struct OpSupportTranInfo { | |||
| InFmtSupportEnum inputFormatSupportEnum = InFmtSupportUndefined; | |||
| InDtSupportEnum inputDataTypeSupportEnum = InDtSupportUndefined; | |||
| OutFmtSupportEnum outputFormatSupportEnum = OutFmtSupportUndefined; | |||
| OutDtSupportEnum outputDataTypeSupportEnum = OutDtSupportUndefined; | |||
| std::vector<ge::Format> inputFormats; | |||
| std::vector<ge::DataType> inputDataTypes; | |||
| ge::Format limitOutputFormat = ge::FORMAT_RESERVED; | |||
| ge::DataType limitOutputDataType = ge::DT_UNDEFINED; | |||
| }; | |||
| extern std::map<std::string, OpSupportTranInfo> g_OpSupportTranInfo; | |||
| class OpTransAddSupportReg { | |||
| public: | |||
| template <class InFmts, class InDts, class OutFmts, class OutDts> | |||
| OpTransAddSupportReg(const std::string &cceTbeTg, const std::string &opType, | |||
| InFmts inputFormats, InDts inputDataTypes, | |||
| OutFmts outputormat, OutDts outputDataType) { | |||
| auto cceTbeOpType = cceTbeTg + ":" + opType; | |||
| g_OpSupportTranInfo.erase(cceTbeOpType); | |||
| SetInputFormat(cceTbeOpType, inputFormats); | |||
| SetInputDataType(cceTbeOpType, inputDataTypes); | |||
| SetOutputFormat(cceTbeOpType, outputormat); | |||
| SetOutputDataType(cceTbeOpType, outputDataType); | |||
| } | |||
| ~OpTransAddSupportReg() = default; | |||
| private: | |||
| void SetInputFormat(std::string opType, | |||
| const std::vector<ge::Format>& supportFormat) { | |||
| auto& opInfo = g_OpSupportTranInfo[opType]; | |||
| for (auto& format : supportFormat) { | |||
| opInfo.inputFormats.push_back(format); | |||
| } | |||
| } | |||
| void SetInputFormat(std::string opType, ge::Format supportFormat) { | |||
| auto& opInfo = g_OpSupportTranInfo[opType]; | |||
| opInfo.inputFormats.push_back(supportFormat); | |||
| } | |||
| void SetInputFormat(std::string opType, InFmtSupportEnum enumFormat) { | |||
| auto& opInfo = g_OpSupportTranInfo[opType]; | |||
| opInfo.inputFormatSupportEnum = enumFormat; | |||
| switch (enumFormat) { | |||
| case InFmtSupportElewise: | |||
| opInfo.inputFormats = {ge::FORMAT_FRACTAL_Z, ge::FORMAT_HWCN, | |||
| ge::FORMAT_NC1HWC0, ge::FORMAT_NHWC, | |||
| ge::FORMAT_NCHW}; | |||
| break; | |||
| case InFmtSupport4D: | |||
| opInfo.inputFormats = {ge::FORMAT_HWCN, ge::FORMAT_NHWC, | |||
| ge::FORMAT_NCHW}; | |||
| break; | |||
| case InFmtSupport5D: | |||
| opInfo.inputFormats = {ge::FORMAT_NC1HWC0}; | |||
| break; | |||
| case InFmtSupport4D_5D: | |||
| opInfo.inputFormats = {ge::FORMAT_HWCN, ge::FORMAT_NHWC, | |||
| ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0}; | |||
| break; | |||
| case InFmtSupportNCHW_NC1HWC0: | |||
| opInfo.inputFormats = {ge::FORMAT_NC1HWC0, ge::FORMAT_NCHW}; | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| } | |||
| void SetInputDataType(std::string opType, | |||
| const std::vector<ge::DataType>& supportDataType) { | |||
| auto& opInfo = g_OpSupportTranInfo[opType]; | |||
| for (auto& dataType : supportDataType) { | |||
| opInfo.inputDataTypes.push_back(dataType); | |||
| } | |||
| } | |||
| void SetInputDataType(std::string opType, ge::DataType supportDataType) { | |||
| auto& opInfo = g_OpSupportTranInfo[opType]; | |||
| opInfo.inputDataTypes.push_back(supportDataType); | |||
| } | |||
| void SetInputDataType(std::string opType, InDtSupportEnum enumDataType) { | |||
| auto& opInfo = g_OpSupportTranInfo[opType]; | |||
| opInfo.inputDataTypeSupportEnum = enumDataType; | |||
| } | |||
| void SetOutputFormat(std::string opType, ge::Format limitOutputormat) { | |||
| auto& opInfo = g_OpSupportTranInfo[opType]; | |||
| opInfo.limitOutputFormat = limitOutputormat; | |||
| } | |||
| void SetOutputFormat(std::string opType, OutFmtSupportEnum enumFormat) { | |||
| auto& opInfo = g_OpSupportTranInfo[opType]; | |||
| opInfo.outputFormatSupportEnum = enumFormat; | |||
| } | |||
| void SetOutputDataType(std::string opType, ge::DataType limitOutputDataType) { | |||
| auto& opInfo = g_OpSupportTranInfo[opType]; | |||
| opInfo.limitOutputDataType = limitOutputDataType; | |||
| } | |||
| void SetOutputDataType(std::string opType, OutDtSupportEnum enumDataType) { | |||
| auto& opInfo = g_OpSupportTranInfo[opType]; | |||
| opInfo.outputDataTypeSupportEnum = enumDataType; | |||
| } | |||
| }; | |||
| #define TBE_SET_FORMAT_DATAYPE_INFO(cce_tbe, op, inputFormats, inputDataType, \ | |||
| outFormats, outputDataTypes) \ | |||
| TBE_SET_FORMAT_DATAYPE_INFO_UNIQ_HELPER(__COUNTER__, #cce_tbe, op, \ | |||
| inputFormats, inputDataType, \ | |||
| outFormats, outputDataTypes) | |||
| #define TBE_SET_FORMAT_DATAYPE_INFO_UNIQ_HELPER(ctr, cce_tbe, op, \ | |||
| inputFormats, inputDataType, \ | |||
| outFormats, outputDataTypes) \ | |||
| TBE_SET_FORMAT_DATAYPE_INFO_UNIQ(ctr, cce_tbe, op, inputFormats, \ | |||
| inputDataType, outFormats, outputDataTypes) | |||
| #define TBE_SET_FORMAT_DATAYPE_INFO_UNIQ(ctr, cce_tbe, op, inputFormats, \ | |||
| inputDataType, outFormats, \ | |||
| outputDataTypes) \ | |||
| OpTransAddSupportReg __gOpTransAddSupportReg##ctr( \ | |||
| cce_tbe, op, inputFormats, inputDataType, outFormats, outputDataTypes); | |||
| } // namespace domi | |||
| #endif // GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_ | |||
| @@ -35,67 +35,15 @@ namespace ge { | |||
| class ParserGraphOptimizer { | |||
| public: | |||
| explicit ParserGraphOptimizer(ge::ComputeGraphPtr graph, domi::FrameworkType type = domi::TENSORFLOW) | |||
| : graph_(graph), fmktype_(type), local_fmk_op_flag_(false) {} | |||
| : graph_(graph), fmktype_(type) {} | |||
| ~ParserGraphOptimizer() {} | |||
| domi::Status Optimize(); | |||
| domi::Status OptimizeAfterCal(); | |||
| domi::Status FusionFmkop(); | |||
| inline bool IsHCOMOp(const string &op_type) { | |||
| return (op_type == ge::parser::HCOMALLREDUCE) || (op_type == ge::parser::HCOMALLGATHER) || | |||
| (op_type == ge::parser::HCOMBROADCAST) || (op_type == ge::parser::HCOMSEND) || | |||
| (op_type == ge::parser::HCOMRECEIVE) || (op_type == "HcomReduceScatter"); | |||
| } | |||
| void SetLocalFmkopFlag(bool isLocalFmkopFlag) { local_fmk_op_flag_ = isLocalFmkopFlag; } | |||
| const bool GetLocalFmkopFlag() const { return local_fmk_op_flag_; } | |||
| void SetFuncBinPath(std::string isFuncBinPath) { func_bin_path_ = isFuncBinPath; } | |||
| const std::string GetFuncBinPath() const { return func_bin_path_; } | |||
| domi::Status InsertHWCK2FZ(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||
| enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, | |||
| enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); | |||
| domi::Status Insert4DTo5DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||
| enum ge::Format src_out_format, enum ge::DataType src_out_data_type, | |||
| enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type); | |||
| domi::Status InsertFZ2HWCK(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||
| enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, | |||
| enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); | |||
| domi::Status Insert5DTo4DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||
| enum ge::Format src_out_format, enum ge::DataType src_out_data_type, | |||
| enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type); | |||
| ge::OpDescPtr CreateCastOp(enum ge::DataType input_datatype, enum ge::DataType output_datatype, ge::Format format); | |||
| ge::OpDescPtr CreatePermuteOp(enum ge::Format input_format, enum ge::Format output_format); | |||
| ge::OpDescPtr CreateTransDataOp(enum ge::Format input_format); | |||
| domi::Status NewNodeAddEdges(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, ge::NodePtr first, | |||
| ge::NodePtr second, ge::NodePtr third); | |||
| domi::Status InsertVar5DTo4D(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, | |||
| enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, | |||
| enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); | |||
| ge::OpDescPtr CreateTranslateOp(enum ge::Format inFormat, ge::DataType inDatatype, enum ge::Format outFormat, | |||
| ge::DataType outDatatype); | |||
| private: | |||
| ge::ComputeGraphPtr graph_; | |||
| domi::FrameworkType fmktype_; | |||
| // local fmkop flag | |||
| bool local_fmk_op_flag_; | |||
| std::string func_bin_path_; | |||
| domi::Status FindFmkNodeCluser(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map); | |||
| @@ -122,7 +70,6 @@ class ParserGraphOptimizer { | |||
| vector<ge::InControlAnchorPtr> &input_control_anchors, | |||
| vector<ge::OutControlAnchorPtr> &output_control_anchors, ge::NodePtr fusion_node); | |||
| domi::Status MakeTfProtoDef(); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ | |||
| @@ -32,8 +32,6 @@ Status IteratorFusionPass::Run(ge::ComputeGraphPtr graph) { | |||
| REPORT_CALL_ERROR("E19999", "New ParserGraphOptimizer failed"); | |||
| return FAILED; | |||
| } | |||
| graph_optimizer->SetLocalFmkopFlag(local_fmk_op_flag_); | |||
| return graph_optimizer->FusionFmkop(); | |||
| } | |||
| } // namespace ge | |||
| @@ -23,8 +23,8 @@ | |||
| namespace ge { | |||
| class IteratorFusionPass : public GraphPass { | |||
| public: | |||
| IteratorFusionPass(domi::FrameworkType type, bool local_fmk_op_flag) | |||
| : fmk_type_(type), local_fmk_op_flag_(local_fmk_op_flag) {} | |||
| IteratorFusionPass(domi::FrameworkType type) | |||
| : fmk_type_(type) {} | |||
| virtual ~IteratorFusionPass() {} | |||
| @@ -32,7 +32,6 @@ class IteratorFusionPass : public GraphPass { | |||
| private: | |||
| domi::FrameworkType fmk_type_; | |||
| bool local_fmk_op_flag_; | |||
| }; | |||
| } // namespace ge | |||
| @@ -2375,7 +2375,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||
| ge::parser::PassManager iterator_fusion_pass; | |||
| try { | |||
| (void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass", | |||
| new ge::IteratorFusionPass(domi::TENSORFLOW, false)); | |||
| new ge::IteratorFusionPass(domi::TENSORFLOW)); | |||
| } catch (std::bad_alloc &e) { | |||
| GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | |||
| return INTERNAL_ERROR; | |||
| @@ -83,61 +83,61 @@ class GeAttrValueImp { | |||
| static map<proto::AttrDef::ValueCase, GeAttrValue::ValueType> attr_val_one_type_map_; | |||
| static map<proto::AttrDef_ListValue_ListValueType, GeAttrValue::ValueType> attr_val_list_type_map_; | |||
| static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::INT val); | |||
| static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::FLOAT val); | |||
| static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::BOOL val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::STR &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, int64_t val); | |||
| static bool SetValue(proto::AttrDef &attr_def, float val); | |||
| static bool SetValue(proto::AttrDef &attr_def, bool val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const std::string &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const ConstGeTensorPtr &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeTensor &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::TENSOR_DESC &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::BYTES &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::NAMED_ATTRS &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::GRAPH &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeTensorDesc &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const Buffer &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const NamedAttrs &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const ComputeGraphPtr &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const vector<int64_t> &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const vector<int32_t> &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const vector<uint32_t> &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_FLOAT &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BOOL &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_STR &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const std::vector<float> &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const std::vector<bool> &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const std::vector<std::string> &val); | |||
| static bool SetValue(proto::AttrDef &proto_attr_val, const vector<GeTensorPtr> &value); | |||
| static bool SetValue(proto::AttrDef &proto_attr_val, const vector<ConstGeTensorPtr> &value); | |||
| static bool SetValue(proto::AttrDef &attr_def, const vector<GeTensor> &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_TENSOR_DESC &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BYTES &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_NAMED_ATTRS &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_GRAPH &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::INT &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::FLOAT &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BOOL &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::STR &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::TENSOR &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const std::vector<GeTensorDesc> &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const std::vector<Buffer> &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const std::vector<NamedAttrs> &val); | |||
| static bool SetValue(proto::AttrDef &attr_def, const std::vector<ComputeGraphPtr> &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, int64_t &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, float &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, bool &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, std::string &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensorPtr &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensor &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, | |||
| GeAttrValue::TENSOR_DESC &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BYTES &val); | |||
| GeTensorDesc &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, | |||
| GeAttrValue::NAMED_ATTRS &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::GRAPH &val); | |||
| NamedAttrs &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ComputeGraphPtr &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, | |||
| GeAttrValue::LIST_INT &val); | |||
| std::vector<int64_t> &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, | |||
| GeAttrValue::LIST_FLOAT &val); | |||
| std::vector<float> &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, | |||
| GeAttrValue::LIST_BOOL &val); | |||
| std::vector<bool> &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, | |||
| GeAttrValue::LIST_STR &val); | |||
| std::vector<std::string> &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, | |||
| GeAttrValue::LIST_TENSOR &val); | |||
| std::vector<GeTensorPtr> &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, vector<GeTensor> &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, | |||
| GeAttrValue::LIST_TENSOR_DESC &val); | |||
| std::vector<GeTensorDesc> &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, | |||
| GeAttrValue::LIST_BYTES &val); | |||
| std::vector<Buffer> &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, | |||
| GeAttrValue::LIST_NAMED_ATTRS &val); | |||
| std::vector<NamedAttrs> &val); | |||
| static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, | |||
| GeAttrValue::LIST_GRAPH &val); | |||
| std::vector<ComputeGraphPtr> &val); | |||
| // Value will be moved | |||
| static bool SetZeroCopyBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &&buffer); | |||
| static bool GetZeroCopyBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &buffer); | |||
| @@ -246,30 +246,30 @@ GeAttrValue GeAttrValue::Copy() const { | |||
| return GRAPH_FAILED; \ | |||
| } | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524 | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR_DESC) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::TENSOR_DESC>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::TENSOR>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::GRAPH) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::GRAPH>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>) | |||
| ATTR_VALUE_SET_GET_IMP(std::string) | |||
| ATTR_VALUE_SET_GET_IMP(vector<std::string>) | |||
| ATTR_VALUE_SET_GET_IMP(int64_t) | |||
| ATTR_VALUE_SET_GET_IMP(vector<int64_t>) | |||
| ATTR_VALUE_SET_GET_IMP(float) // lint !e524 | |||
| ATTR_VALUE_SET_GET_IMP(vector<float>) | |||
| ATTR_VALUE_SET_GET_IMP(bool) | |||
| ATTR_VALUE_SET_GET_IMP(vector<bool>) | |||
| ATTR_VALUE_SET_GET_IMP(GeTensorDesc) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeTensorDesc>) | |||
| ATTR_VALUE_SET_GET_IMP(GeTensorPtr) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeTensorPtr>) | |||
| ATTR_VALUE_SET_GET_IMP(ComputeGraphPtr) | |||
| ATTR_VALUE_SET_GET_IMP(vector<ComputeGraphPtr>) | |||
| ATTR_VALUE_SET_GET_IMP(Buffer) | |||
| ATTR_VALUE_SET_GET_IMP(vector<Buffer>) | |||
| ATTR_VALUE_SET_GET_IMP(NamedAttrs) | |||
| ATTR_VALUE_SET_GET_IMP(vector<NamedAttrs>) | |||
| /*lint -e665*/ | |||
| ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>) | |||
| ATTR_VALUE_SET_GET_IMP(vector<vector<float>>) | |||
| /*lint +e665*/ | |||
| ATTR_VALUE_SET_GET_IMP(vector<DataType>) // lint !e665 | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665 | |||
| ATTR_VALUE_SET_GET_IMP(DataType) // lint !e665 | |||
| #undef ATTR_VALUE_SET_GET_IMP | |||
| @@ -569,7 +569,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeTen | |||
| return true; | |||
| } | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::BYTES &value) { | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const Buffer &value) { | |||
| if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { | |||
| return false; | |||
| } | |||
| @@ -578,7 +578,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue: | |||
| return true; | |||
| } | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::BYTES> &value) { | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<Buffer> &value) { | |||
| if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, | |||
| proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { | |||
| return false; | |||
| @@ -592,7 +592,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAtt | |||
| return true; | |||
| } | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) { | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const NamedAttrs &value) { | |||
| if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { | |||
| return false; | |||
| } | |||
| @@ -606,7 +606,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue: | |||
| return true; | |||
| } | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::NAMED_ATTRS> &value) { | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<NamedAttrs> &value) { | |||
| if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, | |||
| proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { | |||
| return false; | |||
| @@ -822,7 +822,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
| return true; | |||
| } | |||
| bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeAttrValue::BYTES &value) { | |||
| bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, Buffer &value) { | |||
| if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { | |||
| return false; | |||
| } | |||
| @@ -833,7 +833,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
| } | |||
| bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | |||
| vector<GeAttrValue::BYTES> &value) { | |||
| vector<Buffer> &value) { | |||
| value.clear(); | |||
| if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, | |||
| ListValueItemCheck(bt))) { | |||
| @@ -847,7 +847,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
| } | |||
| bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | |||
| GeAttrValue::NAMED_ATTRS &value) { | |||
| NamedAttrs &value) { | |||
| if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { | |||
| return false; | |||
| } | |||
| @@ -860,7 +860,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
| } | |||
| bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | |||
| vector<GeAttrValue::NAMED_ATTRS> &value) { | |||
| vector<NamedAttrs> &value) { | |||
| value.clear(); | |||
| if (!AttrUtilsHelper::GetValueCheckListType( | |||
| proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { | |||
| @@ -868,7 +868,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
| } | |||
| auto &list = proto_attr_val.list(); | |||
| for (const auto &item : list.na()) { | |||
| value.emplace_back(GeAttrValue::NAMED_ATTRS()); | |||
| value.emplace_back(NamedAttrs()); | |||
| if (value.empty()) { | |||
| return false; | |||
| } | |||
| @@ -1107,7 +1107,7 @@ ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc) | |||
| ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) | |||
| ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) | |||
| ATTR_UTILS_SET_IMP(Tensor, GeTensor) | |||
| ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) | |||
| ATTR_UTILS_SET_GET_IMP(NamedAttrs, NamedAttrs) | |||
| ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) | |||
| ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) | |||
| /*lint -e665*/ | |||
| @@ -1124,7 +1124,7 @@ ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector<GeTensorDesc>) | |||
| ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensorPtr>) | |||
| ATTR_UTILS_SET_IMP(ListTensor, vector<ConstGeTensorPtr>) | |||
| ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>) | |||
| ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>) | |||
| ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<NamedAttrs>) | |||
| ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>) | |||
| ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>) | |||
| ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) // lint !e665 | |||
| @@ -314,6 +314,7 @@ set(PARSER_UT_FILES | |||
| "testcase/onnx_parser_testcase/message2operator_unittest.cc" | |||
| "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" | |||
| "testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc" | |||
| "testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc" | |||
| ) | |||
| ############ libut_parser_common.a ############ | |||
| @@ -0,0 +1,62 @@ | |||
| #include <gtest/gtest.h> | |||
| #include <iostream> | |||
| #include "graph/utils/attr_utils.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "ut/parser/parser_ut_utils.h" | |||
| #include "common/util.h" | |||
| #include "tensorflow/iterator_fusion_pass.h" | |||
| #include "parser/common/acl_graph_parser_util.h" | |||
| #define private public | |||
| #include "tensorflow/graph_optimizer.h" | |||
| #undef private | |||
| namespace ge { | |||
| class UtestGraphOptimizer : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| }; | |||
| namespace { | |||
| ComputeGraphPtr MakeGraph() { | |||
| ge::ut::GraphBuilder builder("graph"); | |||
| std::string name = "graph"; | |||
| std::string original_type; | |||
| original_type = "IteratorV2"; // | |||
| auto data1 = builder.AddNode(name + "_" + original_type, ge::parser::FRAMEWORKOP, 1, 1); | |||
| ge::AttrUtils::SetStr(data1->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); | |||
| original_type = "IteratorGetNext"; | |||
| auto data2 = builder.AddNode(name + "_" + original_type + "2", ge::parser::FRAMEWORKOP, 1, 2); | |||
| ge::AttrUtils::SetStr(data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); | |||
| string nodefStr; | |||
| AttrUtils::SetZeroCopyBytes( | |||
| data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | |||
| Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length())); | |||
| original_type = "IteratorGetNext"; | |||
| auto data3 = builder.AddNode(name + "_" + original_type + "3", ge::parser::FRAMEWORKOP, 2, 1); | |||
| ge::AttrUtils::SetStr(data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); | |||
| AttrUtils::SetZeroCopyBytes( | |||
| data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | |||
| Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length())); | |||
| builder.AddDataEdge(data1, 0, data2, 0); | |||
| builder.AddDataEdge(data2, 0, data3, 0); | |||
| builder.AddDataEdge(data2, 1, data3, 1); | |||
| return builder.GetGraph(); | |||
| } | |||
| } | |||
| TEST_F(UtestGraphOptimizer, graph_optimizer) { | |||
| ge::ComputeGraphPtr graph = MakeGraph(); | |||
| ge::IteratorFusionPass iteratorFusionPass(domi::TENSORFLOW); | |||
| EXPECT_NE(iteratorFusionPass.Run(graph), ge::SUCCESS); | |||
| } | |||
| TEST_F(UtestGraphOptimizer, graph_optimizer_output) { | |||
| ge::ComputeGraphPtr graph = MakeGraph(); | |||
| domi::FrameworkType type = domi::TENSORFLOW; | |||
| ge::ParserGraphOptimizer parserGraphOptimizer(graph, type); | |||
| vector<ge::InDataAnchorPtr> input_anchors; | |||
| vector<ge::OutDataAnchorPtr> output_anchors; | |||
| ge::OpDescPtr fusion_op_desc; | |||
| EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors, fusion_op_desc), ge::SUCCESS); | |||
| EXPECT_NE(parserGraphOptimizer.RebuildOutputAnchors(output_anchors, fusion_op_desc), ge::SUCCESS); | |||
| } | |||
| } | |||