From bc207ebc1b9b8b8eeef4f841654c82f9d0b3364e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=9B=E6=A5=A0?= Date: Mon, 30 Aug 2021 16:36:36 +0800 Subject: [PATCH 01/16] use true types instead of GeAttrValue:: --- parser/tensorflow/graph_functiondef.cc | 2 +- parser/tensorflow/graph_optimizer.cc | 48 ++++---- tests/depends/graph/src/attr_util_stub.cc | 128 +++++++++++----------- 3 files changed, 89 insertions(+), 89 deletions(-) diff --git a/parser/tensorflow/graph_functiondef.cc b/parser/tensorflow/graph_functiondef.cc index f0230de..cac9e4f 100644 --- a/parser/tensorflow/graph_functiondef.cc +++ b/parser/tensorflow/graph_functiondef.cc @@ -424,7 +424,7 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph } // Analysis of nodedef of original tensorflow - ge::GeAttrValue::BYTES nodedef_bytes; + ge::Buffer nodedef_bytes; GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetBytes(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, nodedef_bytes), PARAM_INVALID, "Get type attr nodedef failed."); domi::tensorflow::NodeDef node_def_; diff --git a/parser/tensorflow/graph_optimizer.cc b/parser/tensorflow/graph_optimizer.cc index f46b950..ff211b8 100644 --- a/parser/tensorflow/graph_optimizer.cc +++ b/parser/tensorflow/graph_optimizer.cc @@ -515,10 +515,10 @@ Status SetNodedefProto(domi::tensorflow::NodeDef &proto, ge::NodePtr n, string o return SUCCESS; } -typedef Status (*PIOListHandle)(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +typedef Status (*PIOListHandle)(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc); -Status GatherV2IOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status GatherV2IOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int tparams; GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "Tparams", tparams)), @@ -546,7 +546,7 @@ Status GatherV2IOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LI return SUCCESS; } -Status ConstIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status ConstIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int dtype; GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "dtype", dtype)), return PARAM_INVALID, "Get dtype error."); @@ -556,7 +556,7 @@ Status ConstIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_ return SUCCESS; } -Status MaxMinIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status MaxMinIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int attrT; GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", attrT)), @@ -574,7 +574,7 @@ Status MaxMinIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST return SUCCESS; } -Status CastIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status CastIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int srcT; int dstT; @@ -592,7 +592,7 @@ Status CastIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_I return SUCCESS; } -Status AddIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, ge::OpDescPtr &opDesc) { +Status AddIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int type; GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", type)), REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed", @@ -607,7 +607,7 @@ Status AddIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_IN return SUCCESS; } -Status LessIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status LessIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int dtype; GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", dtype)), @@ -622,7 +622,7 @@ Status LessIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_I return SUCCESS; } -Status MulIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, ge::OpDescPtr &opDesc) { +Status MulIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int dataType; GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, ge::ATTR_NAME_T, dataType)), REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ge::ATTR_NAME_T.c_str(), @@ -638,7 +638,7 @@ Status MulIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_IN return SUCCESS; } -Status RealDivIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status RealDivIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int t; GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), @@ -654,7 +654,7 @@ Status RealDivIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIS return SUCCESS; } -Status SelectIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status SelectIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int t; GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), @@ -671,7 +671,7 @@ Status SelectIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST return SUCCESS; } -Status SqrtIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status SqrtIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int dataType; GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, ge::ATTR_NAME_T, dataType)), @@ -687,7 +687,7 @@ Status SqrtIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_I return SUCCESS; } -Status TruncatedNormalIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status TruncatedNormalIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int t; int dtype; @@ -707,7 +707,7 @@ Status TruncatedNormalIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrVa return SUCCESS; } -Status PackIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status PackIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int t; int n; @@ -729,7 +729,7 @@ Status PackIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_I return SUCCESS; } -Status DropOutGenMaskIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status DropOutGenMaskIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { input_list.push_back(domi::tensorflow::DT_INT64); input_list.push_back(domi::tensorflow::DT_FLOAT); @@ -738,7 +738,7 @@ Status DropOutGenMaskIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrVal return SUCCESS; } -Status ExpandDimsIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status ExpandDimsIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { int dataType; int dimType; @@ -759,7 +759,7 @@ Status ExpandDimsIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue:: return SUCCESS; } -Status SqueezeIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, +Status SqueezeIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { // Set - TENSORFLOW_IN_DATATYPE/TENSORFLOW_OUT_DATATYPE int dataType; @@ -785,7 +785,7 @@ Status SqueezeIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIS return SUCCESS; } -Status TopKV2IOList(ge::GeAttrValue::LIST_INT &inputList, ge::GeAttrValue::LIST_INT &outputList, +Status TopKV2IOList(std::vector &inputList, std::vector &outputList, ge::OpDescPtr &opDesc) { int t; GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), @@ -829,8 +829,8 @@ Status CreateNodeDefBytes(ge::NodePtr n, string originalType, mapGetName() = %s.\n", n->GetName().c_str()); // Set - NodeDef PROTO domi::tensorflow::NodeDef proto; - ge::GeAttrValue::LIST_INT inputList; - ge::GeAttrValue::LIST_INT outputList; + std::vector inputList; + std::vector outputList; ret = SetNodedefProto(proto, n, originalType); GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "SetNodedefProto failed."); @@ -891,7 +891,7 @@ Status CreateNodeDefBytes(ge::NodePtr n, string originalType, map(nodefStr.data()), nodefStr.length())); @@ -1279,7 +1279,7 @@ Status CreateFuncDefBytes(ge::NodePtr n, string original_type, string func_bin_p GELOGI("len =%d\n", len); - ge::GeAttrValue::BYTES funcDefBytes; + ge::Buffer funcDefBytes; funcDefBytes = ge::Buffer::CopyFrom((std::uint8_t *)buf, len); (void)ge::AttrUtils::SetBytes(opDesc, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes); GELOGI("funcDefBytes.GetSize() =%zu", funcDefBytes.GetSize()); @@ -1453,7 +1453,7 @@ Status CollectNodeFuncs(vector &nodes, FunctionDefLibrary *library) GE_CHECK_NOTNULL(node); OpDescPtr opDef = node->GetOpDesc(); string funcdefStr; - ge::GeAttrValue::BYTES funcDefBytes; + ge::Buffer funcDefBytes; GE_IF_BOOL_EXEC( AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib; @@ -1648,7 +1648,7 @@ Status ParserGraphOptimizer::LinkInnerAnchor(unordered_map // rebuild output anchor Status ParserGraphOptimizer::RebuildOutputAnchors(vector &output_anchors, ge::OpDescPtr fusion_op_desc) { - ge::GeAttrValue::LIST_INT output_list; + std::vector output_list; GE_CHECK_NOTNULL(fusion_op_desc); // create input desc @@ -1679,7 +1679,7 @@ Status ParserGraphOptimizer::RebuildOutputAnchors(vector & // rebuild input desc Status ParserGraphOptimizer::RebuildInputAnchors(vector &input_anchors, ge::OpDescPtr fusion_op_desc) { - ge::GeAttrValue::LIST_INT input_list; + std::vector input_list; GE_CHECK_NOTNULL(fusion_op_desc); // add input desc for (auto in_anchor : input_anchors) { diff --git a/tests/depends/graph/src/attr_util_stub.cc b/tests/depends/graph/src/attr_util_stub.cc index a51bd84..5bb3cc7 100644 --- a/tests/depends/graph/src/attr_util_stub.cc +++ b/tests/depends/graph/src/attr_util_stub.cc @@ -83,61 +83,61 @@ class GeAttrValueImp { static map attr_val_one_type_map_; static map attr_val_list_type_map_; - static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::INT val); - static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::FLOAT val); - static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::BOOL val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::STR &val); + static bool SetValue(proto::AttrDef &attr_def, int64_t val); + static bool SetValue(proto::AttrDef &attr_def, float val); + static bool SetValue(proto::AttrDef &attr_def, bool val); + static bool SetValue(proto::AttrDef &attr_def, const std::string &val); static bool SetValue(proto::AttrDef &attr_def, const ConstGeTensorPtr &val); static bool SetValue(proto::AttrDef &attr_def, const GeTensor &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::TENSOR_DESC &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::BYTES &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::NAMED_ATTRS &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::GRAPH &val); + static bool SetValue(proto::AttrDef &attr_def, const GeTensorDesc &val); + static bool SetValue(proto::AttrDef &attr_def, const Buffer &val); + static bool SetValue(proto::AttrDef &attr_def, const NamedAttrs &val); + static bool SetValue(proto::AttrDef &attr_def, const ComputeGraphPtr &val); static bool SetValue(proto::AttrDef &attr_def, const vector &val); static bool SetValue(proto::AttrDef &attr_def, const vector &val); static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_FLOAT &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BOOL &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_STR &val); + static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); + static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); + static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_TENSOR_DESC &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BYTES &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_NAMED_ATTRS &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_GRAPH &val); - - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::INT &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::FLOAT &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BOOL &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::STR &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::TENSOR &val); + static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); + static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); + static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); + static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); + + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, int64_t &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, float &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, bool &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, std::string &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensorPtr &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensor &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::TENSOR_DESC &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BYTES &val); + GeTensorDesc &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::NAMED_ATTRS &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::GRAPH &val); + NamedAttrs &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ComputeGraphPtr &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_INT &val); + std::vector &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_FLOAT &val); + std::vector &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_BOOL &val); + std::vector &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_STR &val); + std::vector &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_TENSOR &val); + std::vector &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, vector &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_TENSOR_DESC &val); + std::vector &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_BYTES &val); + std::vector &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_NAMED_ATTRS &val); + std::vector &val); static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_GRAPH &val); + std::vector &val); // Value will be moved static bool SetZeroCopyBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &&buffer); static bool GetZeroCopyBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &buffer); @@ -246,30 +246,30 @@ GeAttrValue GeAttrValue::Copy() const { return GRAPH_FAILED; \ } -ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524 -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR_DESC) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::GRAPH) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) -ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(std::string) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(int64_t) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(float) // lint !e524 +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(bool) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeTensorDesc) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeTensorPtr) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(ComputeGraphPtr) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(Buffer) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(NamedAttrs) +ATTR_VALUE_SET_GET_IMP(vector) /*lint -e665*/ ATTR_VALUE_SET_GET_IMP(vector>) ATTR_VALUE_SET_GET_IMP(vector>) /*lint +e665*/ ATTR_VALUE_SET_GET_IMP(vector) // lint !e665 -ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665 +ATTR_VALUE_SET_GET_IMP(DataType) // lint !e665 #undef ATTR_VALUE_SET_GET_IMP @@ -569,7 +569,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { return false; @@ -592,7 +592,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { return false; @@ -822,7 +822,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM return true; } -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeAttrValue::BYTES &value) { +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, Buffer &value) { if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { return false; } @@ -833,7 +833,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM } bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { + vector &value) { value.clear(); if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, ListValueItemCheck(bt))) { @@ -847,7 +847,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM } bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - GeAttrValue::NAMED_ATTRS &value) { + NamedAttrs &value) { if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { return false; } @@ -860,7 +860,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM } bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { + vector &value) { value.clear(); if (!AttrUtilsHelper::GetValueCheckListType( proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { @@ -868,7 +868,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM } auto &list = proto_attr_val.list(); for (const auto &item : list.na()) { - value.emplace_back(GeAttrValue::NAMED_ATTRS()); + value.emplace_back(NamedAttrs()); if (value.empty()) { return false; } @@ -1107,7 +1107,7 @@ ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc) ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) ATTR_UTILS_SET_IMP(Tensor, GeTensor) -ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) +ATTR_UTILS_SET_GET_IMP(NamedAttrs, NamedAttrs) ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) /*lint -e665*/ @@ -1124,7 +1124,7 @@ ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector) ATTR_UTILS_SET_IMP(ListTensor, vector) ATTR_UTILS_SET_IMP(ListTensor, vector) ATTR_UTILS_SET_IMP(ListTensor, vector) -ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector) +ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector) ATTR_UTILS_SET_GET_IMP(ListBytes, vector) ATTR_UTILS_SET_GET_IMP(ListGraph, vector) ATTR_UTILS_SET_GET_IMP(ListDataType, vector) // lint !e665 From e3d033ae283a120400b53d47fcbd83428fb5e6a6 Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Fri, 10 Sep 2021 09:25:25 +0800 Subject: [PATCH 02/16] use true types istead of GeAttrValue:: --- parser/tensorflow/graph_optimizer.cc | 1844 +---------------- parser/tensorflow/graph_optimizer.h | 57 +- parser/tensorflow/iterator_fusion_pass.cc | 2 - parser/tensorflow/iterator_fusion_pass.h | 5 +- parser/tensorflow/tensorflow_parser.cc | 2 +- tests/ut/parser/CMakeLists.txt | 2 + tests/ut/parser/graph_builder_utils.cc | 48 + tests/ut/parser/graph_builder_utils.h | 48 + .../graph_optimizer_unittest.cc | 71 + 9 files changed, 175 insertions(+), 1904 deletions(-) create mode 100644 tests/ut/parser/graph_builder_utils.cc create mode 100644 tests/ut/parser/graph_builder_utils.h create mode 100644 tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc diff --git a/parser/tensorflow/graph_optimizer.cc b/parser/tensorflow/graph_optimizer.cc index ff211b8..a2488b3 100644 --- a/parser/tensorflow/graph_optimizer.cc +++ b/parser/tensorflow/graph_optimizer.cc @@ -89,1251 +89,6 @@ const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; const char *const kShapeNodeName = "Shape"; } // namespace -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map g_OpSupportTranInfo = {}; - -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::MUL, - std::vector({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC, - ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}), - InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::L2LOSS, - std::vector({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NC1HWC0, ge::FORMAT_NHWC, - ge::FORMAT_HWCN}), // inputformats - ge::DT_FLOAT, ge::FORMAT_NC1HWC0, ge::DT_FLOAT) - -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONVGRADFILTER, InFmtSupportUndefined, InDtSupportUndefined, - ge::FORMAT_FRACTAL_Z, ge::DT_FLOAT) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONV2DBACKPROPINPUT, InFmtSupportUndefined, InDtSupportUndefined, - ge::FORMAT_NC1HWC0, ge::DT_FLOAT16) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::BIASADDGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, - ge::DT_FLOAT) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::BIASADD, ge::FORMAT_NCHW, ge::DT_FLOAT, ge::FORMAT_NCHW, ge::DT_FLOAT) - -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ACTIVATION, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, - ge::DT_FLOAT16) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ACTIVATIONGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, - ge::DT_FLOAT16) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::SOFTMAX, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, - ge::DT_FLOAT16) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SOFTMAX, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) - -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DBACKPROPFILTER, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, - ge::FORMAT_C1HWNCoC0, ge::DT_FLOAT) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DBACKPORPINPUT, InFmtSupportUndefined, InDtSupportUndefined, - OutFmtSupportAsInput, OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DFORWARDNATIVE, InFmtSupportUndefined, InDtSupportUndefined, - OutFmtSupportAsInput, OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::FUSEDBATCHNORM, InFmtSupportUndefined, InDtSupportUndefined, - OutFmtSupportAsInput, OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::FUSEDBATCHNORMGRAD, InFmtSupportUndefined, InDtSupportUndefined, - OutFmtSupportAsInput, OutDtSupportUndefined) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONV2D, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput, - OutDtSupportUndefined) - -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::RESHAPE, ge::FORMAT_NHWC, InDtSupportAll, ge::FORMAT_NHWC, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, InFmtSupport5D, ge::DT_FLOAT16, - OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MAXIMUM_GRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::APPLYRMSPROP, - std::vector({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC, - ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}), - ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DROPOUTDOMASK, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::LOG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQRTGRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SIGMOIDGRAD, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SIGMOID, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ARGMAX, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::AVGPOOLGRAD, InFmtSupport5D, ge::DT_FLOAT16, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::NEG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::RECIPROCAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQUARE, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SUB, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SUM, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MATMUL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GATHERV2, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GREATEREQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::REALDIV, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQRT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::STRIDEDSLICE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::TILE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TFRELU6, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::RELU6GRAD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::EQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GREATER, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SELECT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_BATCH_MATMUL, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::TRANSPOSE, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput, - OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::STREAMMERGE, - std::vector({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}), - InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) -TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::MEMCPYASYNC, - std::vector({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}), - InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) - -bool GetCceTbeTransInfo(string opType, OpSupportTranInfo &opSupportInfo) { - static bool fmtInited = false; - GE_IF_BOOL_EXEC( - !fmtInited, fmtInited = true; - if (domi::OpRegistry().Instance()->GetImplyType(ge::parser::DEPTHWISEWEIGHT4D26D) == domi::ImplyType::TVM) { - auto it = g_OpSupportTranInfo.find(string("TBE:") + ge::parser::MUL); - if (it != g_OpSupportTranInfo.end()) { - auto &fmts = it->second.inputFormats; - auto itFmt = std::find(fmts.begin(), fmts.end(), ge::FORMAT_NC1HWC0); - fmts.erase(itFmt); - } - }) - string cceTbeOpType = "TBE"; - GE_IF_BOOL_EXEC(domi::OpRegistry().Instance()->GetImplyType(opType) == domi::ImplyType::BUILDIN, - cceTbeOpType = "CCE";) - cceTbeOpType = cceTbeOpType + ":" + opType; - GE_IF_BOOL_EXEC(g_OpSupportTranInfo.find(cceTbeOpType) != g_OpSupportTranInfo.end(), - opSupportInfo = g_OpSupportTranInfo[cceTbeOpType]; - return true;) - return false; -} - -Status ParserGraphOptimizer::Optimize() { return SUCCESS; } - -Status ParserGraphOptimizer::OptimizeAfterCal() { return SUCCESS; } - -void SetStringAttr(const string &originalType, OpDescPtr &opDesc, - google::protobuf::Map *tfAttr, - const pair &attr) { - string s; - (void)AttrUtils::GetStr(opDesc, attr.first, s); - - if (originalType == "ParallelMapDataset" || originalType == "FilterDataset" || - originalType == "MapAndBatchDatasetV2") { - ::domi::tensorflow::NameAttrList *nameAttrList = (*tfAttr)[attr.first].mutable_func(); - nameAttrList->set_name(s); - } else { - (*tfAttr)[attr.first].set_s(s); - } -} - -void SetIntAttr(const string &originalType, OpDescPtr &opDesc, - google::protobuf::Map *tfAttr, - const pair &attr) { - int32_t i = 0; - (void)AttrUtils::GetInt(opDesc, attr.first, i); - - if (originalType == "Pack" && (attr.first == "axis" || attr.first == "N")) { - (*tfAttr)[attr.first].set_i(i); - } else if (originalType == "TruncatedNormal" && (attr.first == "seed" || attr.first == "seed2")) { - (*tfAttr)[attr.first].set_i(i); - } else { - (*tfAttr)[attr.first].set_type((domi::tensorflow::DataType)i); - } -} - -void SetSqueezeDims(const string &originalType, google::protobuf::Map *tfAttr, - const pair &attr, const vector &intList, - const domi::tensorflow::AttrValue &attrValue, domi::tensorflow::AttrValue_ListValue *list) { - if (originalType == "Squeeze" && (attr.first == "squeeze_dims")) { - for (auto i : intList) { - list->add_i(i); - } - (*tfAttr)[attr.first] = attrValue; - } -} - -void SetListIntAttr(const string &originalType, OpDescPtr &opDesc, - google::protobuf::Map *tfAttr, - const pair &attr) { - vector intList; - (void)AttrUtils::GetListInt(opDesc, attr.first, intList); - - domi::tensorflow::AttrValue attrValue; - domi::tensorflow::AttrValue_ListValue *list = attrValue.mutable_list(); - - vector::iterator iter = std::find(is_dataset_op_vec.begin(), is_dataset_op_vec.end(), originalType); - if (iter != is_dataset_op_vec.end()) { - if (attr.first == "Toutput_types" || attr.first == "output_types") { - for (auto i : intList) { - list->add_type((domi::tensorflow::DataType)i); - } - (*tfAttr)[attr.first] = attrValue; - } else if ((originalType == "ParallelMapDataset" || originalType == "FilterDataset" || - originalType == "MapAndBatchDatasetV2") && - attr.first == "Targuments") { - (*tfAttr)[attr.first] = attrValue; - } - } else { - SetSqueezeDims(originalType, tfAttr, attr, intList, attrValue, list); - } -} - -void SetListListIntAttr(const string &originalType, OpDescPtr &opDesc, - google::protobuf::Map *tfAttr, - const pair &attr) { - vector> intListList; - (void)AttrUtils::GetListListInt(opDesc, attr.first, intListList); - - domi::tensorflow::AttrValue attrValue; - domi::tensorflow::AttrValue_ListValue *list = attrValue.mutable_list(); - - if ((originalType == "IteratorV2" || originalType == "BatchDatasetV2" || originalType == "IteratorGetNext" || - originalType == "ParallelMapDataset" || originalType == "DeviceQueueDataset" || originalType == "QueueDataset" || - originalType == "FilterDataset" || originalType == "MapAndBatchDatasetV2") && - attr.first == "output_shapes") { - for (size_t ill = 0; ill < intListList.size(); ill++) { - TensorShapeProto *tensorShape = list->add_shape(); - auto intList_ = intListList[ill]; - for (auto i : intList_) { - TensorShapeProto_Dim *dim = tensorShape->add_dim(); - dim->set_size(i); - } - } - (*tfAttr)[attr.first] = attrValue; - } else if (originalType == "TensorDataset" && attr.first == "output_shapes") { - domi::tensorflow::TensorShapeProto *tensorShape = list->add_shape(); - domi::tensorflow::TensorShapeProto_Dim *dim = tensorShape->add_dim(); - dim->set_size(0); - (*tfAttr)[attr.first] = attrValue; - } -} - -void SetTensorValue(const ge::ConstGeTensorPtr &geTensor, domi::tensorflow::TensorProto *tfTensor, int32_t dataCount) { - if (dataCount > 1) { - tfTensor->set_tensor_content(geTensor->GetData().data(), geTensor->GetData().size()); - } else { - switch (geTensor->GetTensorDesc().GetDataType()) { - case ge::DT_FLOAT: { - float f = *(reinterpret_cast(geTensor->GetData().data())); - tfTensor->add_float_val(f); - break; - } - - case ge::DT_INT32: { - int32_t i = *(reinterpret_cast(geTensor->GetData().data())); - tfTensor->add_int_val(i); - break; - } - - case ge::DT_BOOL: { - bool b = *(reinterpret_cast(geTensor->GetData().data())); - tfTensor->add_bool_val(b); - break; - } - - case ge::DT_INT64: { - int64_t i = *(reinterpret_cast(geTensor->GetData().data())); - tfTensor->add_int64_val(i); - break; - } - - case ge::DT_FLOAT16: { - int32_t f = *(reinterpret_cast(geTensor->GetData().data())); - tfTensor->add_half_val(f); - break; - } - - default: { - GELOGW("SetTensorValue not support the data type %s.", - ge::TypeUtils::DataTypeToSerialString(geTensor->GetTensorDesc().GetDataType()).c_str()); - } - } - } -} - -Status SetTensorAttr(ge::OpDescPtr &opDesc, google::protobuf::Map *tfAttr, - const pair &attr) { - ge::ConstGeTensorPtr ge_tensor; - (void)ge::AttrUtils::GetTensor(opDesc, attr.first, ge_tensor); - - domi::tensorflow::TensorProto *tf_tensor = (*tfAttr)[attr.first].mutable_tensor(); - - // Set datatype - domi::tensorflow::DataType datatype; - auto ge_datatype = ge_tensor->GetTensorDesc().GetDataType(); - int32_t data_count = 1; - switch (ge_datatype) { - case ge::DataType::DT_FLOAT: - datatype = domi::tensorflow::DataType::DT_FLOAT; - data_count = ge_tensor->GetData().size() / sizeof(float); - break; - case ge::DataType::DT_FLOAT16: - datatype = domi::tensorflow::DataType::DT_HALF; - data_count = ge_tensor->GetData().size() / sizeof(int16_t); - break; - case ge::DataType::DT_INT32: - datatype = domi::tensorflow::DataType::DT_INT32; - data_count = ge_tensor->GetData().size() / sizeof(int32_t); - break; - case ge::DataType::DT_INT64: - datatype = domi::tensorflow::DataType::DT_INT64; - data_count = ge_tensor->GetData().size() / sizeof(int64_t); - break; - case ge::DataType::DT_UINT8: - datatype = domi::tensorflow::DataType::DT_UINT8; - data_count = ge_tensor->GetData().size() / sizeof(uint8_t); - break; - case ge::DataType::DT_BOOL: - datatype = domi::tensorflow::DataType::DT_BOOL; - data_count = ge_tensor->GetData().size() / sizeof(bool); - break; - default: - REPORT_INNER_ERROR("E19999", "datatype:%d of Attr:%s in node:%s:%s is not supported", - ge_datatype, attr.first.c_str(), opDesc->GetName().c_str(), opDesc->GetType().c_str()); - GELOGE(PARAM_INVALID, "NO SUPPORT datatype = %s", ge::TypeUtils::DataTypeToSerialString(ge_datatype).c_str()); - return PARAM_INVALID; - } - tf_tensor->set_dtype(datatype); - - domi::tensorflow::TensorShapeProto *tf_shape = tf_tensor->mutable_tensor_shape(); - for (auto dim : ge_tensor->GetTensorDesc().GetShape().GetDims()) { - domi::tensorflow::TensorShapeProto_Dim *tf_dims = tf_shape->add_dim(); - tf_dims->set_size(dim); - } - - SetTensorValue(ge_tensor, tf_tensor, data_count); - return SUCCESS; -} - -Status SetNodedefProto(domi::tensorflow::NodeDef &proto, ge::NodePtr n, string original_type) { - GELOGI("graph_optimizer.cpp && SetNodedefProto"); - // Set proto head - Status ret; - auto op_desc = n->GetOpDesc(); - GELOGI("n->GetName =%s, original_type =%s", n->GetName().c_str(), original_type.c_str()); - proto.set_name(n->GetName()); - proto.set_op(original_type); - - // Set input - auto input_names = op_desc->GetInputName(); - - for (auto anchor : n->GetAllInDataAnchors()) { - GE_IF_BOOL_EXEC(anchor == nullptr || anchor->GetPeerOutAnchor() == nullptr || - anchor->GetPeerOutAnchor()->GetOwnerNode() == nullptr || - anchor->GetPeerOutAnchor()->GetOwnerNode()->GetOpDesc() == nullptr, - continue); - OutDataAnchorPtr src_anchor = anchor->GetPeerOutAnchor(); - NodePtr src_node = anchor->GetPeerOutAnchor()->GetOwnerNode(); - OpDescPtr src_opdesc = src_node->GetOpDesc(); - GELOGI("inedge src:%s, src_out_index:%d, dst:%s, dst_in_index:%d", src_opdesc->GetName().c_str(), - src_anchor->GetIdx(), op_desc->GetName().c_str(), anchor->GetIdx()); - string inputName; - inputName = src_opdesc->GetName() + ":" + "output:" + std::to_string(src_anchor->GetIdx()); - GELOGI("inputName =%s\n", inputName.c_str()); - proto.add_input(inputName); - } - - // Set device - proto.set_device("CPU"); - - // Set proto attr - google::protobuf::Map *tf_attr = proto.mutable_attr(); - map allattrs = op_desc->GetAllAttrs(); - allattrs.erase(ge::ATTR_NAME_FRAMEWORK_FWK_TYPE); - allattrs.erase(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE); - if (original_type == "Add") { - allattrs.erase(ge::ATTR_NAME_MODE); - } else if (original_type == "IteratorGetNext") { - allattrs.erase("output_num"); - } - - for (const auto &attr : allattrs) { - ge::GeAttrValue::ValueType v_t = attr.second.GetValueType(); - switch (v_t) { - case ge::GeAttrValue::ValueType::VT_STRING: { - SetStringAttr(original_type, op_desc, tf_attr, attr); - - break; - } - - case ge::GeAttrValue::ValueType::VT_INT: { - SetIntAttr(original_type, op_desc, tf_attr, attr); - - break; - } - case ge::GeAttrValue::ValueType::VT_BOOL: { - bool i = false; - (void)ge::AttrUtils::GetBool(op_desc, attr.first, i); - (*tf_attr)[attr.first].set_b(i); - break; - } - case ge::GeAttrValue::ValueType::VT_LIST_INT: { - SetListIntAttr(original_type, op_desc, tf_attr, attr); - - break; - } - case ge::GeAttrValue::ValueType::VT_LIST_LIST_INT: { - SetListListIntAttr(original_type, op_desc, tf_attr, attr); - - break; - } - case ge::GeAttrValue::ValueType::VT_TENSOR: { - ret = SetTensorAttr(op_desc, tf_attr, attr); - GE_IF_BOOL_EXEC(ret != SUCCESS, return ret); - break; - } - default: - break; - } - } - - return SUCCESS; -} - -typedef Status (*PIOListHandle)(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc); - -Status GatherV2IOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - int tparams; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "Tparams", tparams)), - REPORT_CALL_ERROR("E19999", "Get Attr:Tparams from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get Tparams error."); - int tindices; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "Tindices", tindices)), - REPORT_CALL_ERROR("E19999", "Get Attr:Tindices from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get Tindices error."); - int taxis; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "Taxis", taxis)), - REPORT_CALL_ERROR("E19999", "Get Attr:Taxis from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get Taxis error."); - - // input_list - eg:{1, 3, 3} - input_list.push_back(tparams); - input_list.push_back(tindices); - input_list.push_back(taxis); - // output_list - eg:{3} - output_list.push_back(tparams); - - return SUCCESS; -} - -Status ConstIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - int dtype; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "dtype", dtype)), return PARAM_INVALID, "Get dtype error."); - // output_list - {3} - output_list.push_back(dtype); - - return SUCCESS; -} - -Status MaxMinIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - int attrT; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", attrT)), - REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get Tparams error."); - - // input_list - input_list.push_back(attrT); - input_list.push_back(attrT); - - // output_list - output_list.push_back(attrT); - - return SUCCESS; -} - -Status CastIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - int srcT; - int dstT; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "SrcT", srcT)), - REPORT_CALL_ERROR("E19999", "Get Attr:SrcT from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get srcT error."); - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "DstT", dstT)), - REPORT_CALL_ERROR("E19999", "Get Attr:DstT from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get dstT error."); - input_list.push_back(srcT); - output_list.push_back(dstT); - - return SUCCESS; -} - -Status AddIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { - int type; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", type)), - REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get T error."); - - input_list.push_back(type); - input_list.push_back(type); - - output_list.push_back(type); - - return SUCCESS; -} - -Status LessIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - int dtype; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", dtype)), - REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get dtype error."); - - input_list.push_back(dtype); - input_list.push_back(dtype); - output_list.push_back(domi::tensorflow::DataType::DT_BOOL); - - return SUCCESS; -} - -Status MulIOList(std::vector &input_list, std::vector &output_list, ge::OpDescPtr &opDesc) { - int dataType; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, ge::ATTR_NAME_T, dataType)), - REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ge::ATTR_NAME_T.c_str(), - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, - "Get Tparams error."); - - input_list.push_back(dataType); - input_list.push_back(dataType); - - output_list.push_back(dataType); - - return SUCCESS; -} - -Status RealDivIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - int t; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), - REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get beta error."); - - input_list.push_back(t); - input_list.push_back(t); - - output_list.push_back(t); - - return SUCCESS; -} - -Status SelectIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - int t; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), - REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get e error."); - - input_list.push_back(domi::tensorflow::DataType::DT_BOOL); - input_list.push_back(t); - input_list.push_back(t); - - output_list.push_back(t); - - return SUCCESS; -} - -Status SqrtIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - int dataType; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, ge::ATTR_NAME_T, dataType)), - REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ge::ATTR_NAME_T.c_str(), - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, - "Get Tparam error."); - - input_list.push_back(dataType); - - output_list.push_back(dataType); - - return SUCCESS; -} - -Status TruncatedNormalIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - int t; - int dtype; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), - REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get T error."); - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "dtype", dtype)), - REPORT_CALL_ERROR("E19999", "Get Attr:dtype from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get e error."); - - input_list.push_back(t); - - output_list.push_back(dtype); - - return SUCCESS; -} - -Status PackIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - int t; - int n; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), - REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get T error."); - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "N", n)), - REPORT_CALL_ERROR("E19999", "Get Attr:N from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get N error."); - - for (int i = 0; i < n; i++) { - input_list.push_back(t); - } - - output_list.push_back(t); - - return SUCCESS; -} - -Status DropOutGenMaskIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - input_list.push_back(domi::tensorflow::DT_INT64); - input_list.push_back(domi::tensorflow::DT_FLOAT); - output_list.push_back(domi::tensorflow::DT_UINT8); - - return SUCCESS; -} - -Status ExpandDimsIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - int dataType; - int dimType; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", dataType)), - REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get T error."); - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "Tdim", dimType)), - REPORT_CALL_ERROR("E19999", "Get Attr:Tdim from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get Tdim error."); - // input_list - x y data type - input_list.push_back(dataType); - input_list.push_back(dimType); - // output_list - z data type - output_list.push_back(dataType); - - return SUCCESS; -} - -Status SqueezeIOList(std::vector &input_list, std::vector &output_list, - ge::OpDescPtr &opDesc) { - // Set - TENSORFLOW_IN_DATATYPE/TENSORFLOW_OUT_DATATYPE - int dataType; - vector dimTypeList; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", dataType)), - REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get T error."); - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetListInt(opDesc, "squeeze_dims", dimTypeList)), - REPORT_CALL_ERROR("E19999", "Get Attr:squeeze_dims from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, - "Get squeeze_dims error."); - for (auto i : dimTypeList) { - GELOGI("squeeze_dims = %d.\n", i); - } - - // input_list - x y data type - input_list.push_back(dataType); - // output_list - z data type - output_list.push_back(dataType); - - return SUCCESS; -} - -Status TopKV2IOList(std::vector &inputList, std::vector &outputList, - ge::OpDescPtr &opDesc) { - int t; - GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), - REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed", - opDesc->GetName().c_str(), opDesc->GetType().c_str()); - return PARAM_INVALID, "Get T error."); - - // input_list - eg:{1, 3} - inputList.push_back(t); - inputList.push_back(domi::tensorflow::DataType::DT_INT32); - - // output_list - eg:{1, 3} - outputList.push_back(t); - outputList.push_back(domi::tensorflow::DataType::DT_INT32); - - return SUCCESS; -} - -void CreateIOListFuncMap(map &mOpIOListFuncMap) { - mOpIOListFuncMap.insert({"GatherV2", GatherV2IOList}); - mOpIOListFuncMap.insert({"Const", ConstIOList}); - mOpIOListFuncMap.insert({"Maximum", MaxMinIOList}); - mOpIOListFuncMap.insert({"Minimum", MaxMinIOList}); - mOpIOListFuncMap.insert({"Cast", CastIOList}); - mOpIOListFuncMap.insert({"Add", AddIOList}); - mOpIOListFuncMap.insert({"Less", LessIOList}); - mOpIOListFuncMap.insert({"Mul", MulIOList}); - mOpIOListFuncMap.insert({"RealDiv", RealDivIOList}); - mOpIOListFuncMap.insert({"Select", SelectIOList}); - mOpIOListFuncMap.insert({"TruncatedNormal", TruncatedNormalIOList}); - mOpIOListFuncMap.insert({"Pack", PackIOList}); - mOpIOListFuncMap.insert({"DropOutGenMask", DropOutGenMaskIOList}); - mOpIOListFuncMap.insert({"ExpandDims", ExpandDimsIOList}); - mOpIOListFuncMap.insert({"Squeeze", SqueezeIOList}); - mOpIOListFuncMap.insert({"TopKV2", TopKV2IOList}); -} - -Status CreateNodeDefBytes(ge::NodePtr n, string originalType, map &mOpIOListFuncMap) { - Status ret; - auto opDesc = n->GetOpDesc(); - GELOGI("n->GetName() = %s.\n", n->GetName().c_str()); - // Set - NodeDef PROTO - domi::tensorflow::NodeDef proto; - std::vector inputList; - std::vector outputList; - ret = SetNodedefProto(proto, n, originalType); - GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "SetNodedefProto failed."); - - // Set inputList & outputList - // Set - TENSORFLOW_IN_DATATYPE/TENSORFLOW_OUT_DATATYPE - PIOListHandle funcPtr = nullptr; - map::iterator it = mOpIOListFuncMap.find(originalType); - if (it != mOpIOListFuncMap.end()) { - funcPtr = it->second; - } - - if (funcPtr != nullptr) { - ret = ((PIOListHandle)funcPtr)(inputList, outputList, opDesc); - if (ret != SUCCESS) { - return ret; - } - } - - vector::iterator iter = std::find(is_dataset_op_vec.begin(), is_dataset_op_vec.end(), originalType); - if (iter == is_dataset_op_vec.end()) { - (void)ge::AttrUtils::SetListInt(opDesc, ge::T_IN_DATATYPE, inputList); - (void)ge::AttrUtils::SetListInt(opDesc, ge::T_OUT_DATATYPE, outputList); - } - - // Set size - for (auto ge_desc : opDesc->GetAllOutputsDescPtr()) { - int64_t real_size = 1; - int64_t tmp_dim = 0; - auto data_type = ge_desc->GetDataType(); - - uint32_t size_type = 1; - bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type); - GE_IF_BOOL_EXEC(!type_ret, - REPORT_CALL_ERROR("E19999", "Can't get DataType:%s length of op:%s(%s)", - ge::TypeUtils::DataTypeToSerialString(data_type).c_str(), - n->GetName().c_str(), n->GetType().c_str()); - GELOGE(PARAM_INVALID, "Can't GetDataTypeLength of data_type: %s", - ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); - return PARAM_INVALID); - - // calculate size - for (uint32_t j = 0; j < ge_desc->GetShape().GetDimNum(); ++j) { - tmp_dim = ge_desc->GetShape().GetDim(j); - GE_CHECK_GE(tmp_dim, 0); - PARSER_INT64_MULCHECK(real_size, tmp_dim); - real_size *= tmp_dim; - } - ge::TensorUtils::SetSize(*ge_desc, real_size * size_type); - ge::TensorUtils::SetRealDimCnt(*ge_desc, ge_desc->GetShape().GetDimNum()); - } - - // Serial - nodedef proto - string nodefStr; - GE_IF_BOOL_EXEC(!proto.SerializeToString(&nodefStr), - REPORT_CALL_ERROR("E19999", "Serialize nodedef to string failed, op:%s(%s)", - n->GetName().c_str(), n->GetType().c_str()); - GELOGE(PARAM_INVALID, "Serialize nodedef to string failed."); - return PARAM_INVALID); - - // Set - ATTR_NAME_FRAMEWORK_NODE_DEF - ge::Buffer nodeDefBytes; - (void)ge::AttrUtils::SetZeroCopyBytes( - opDesc, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, - nodeDefBytes.CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); - - // print proto - string nodefstr; - google::protobuf::TextFormat::PrintToString(proto, &nodefstr); - GELOGI("---> ! CreateNodeDefBytes() nodefstr : %s", nodefstr.c_str()); - return SUCCESS; -} - -Status CreateOpDefBytes(ge::NodePtr n, string original_type) { - auto opDesc = n->GetOpDesc(); - GELOGI("n->GetName() =%s, original_type =%s", n->GetName().c_str(), original_type.c_str()); - - // Set - OpDef PROTO - domi::tensorflow::OpDef proto; - proto.set_name(original_type); - - if (original_type == "Const") { - // Set input_arg & output_arg - domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); - outputArgdef->set_name("output"); - outputArgdef->set_type_attr("dtype"); - - // Set domi::AttrDef - domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); - attr1->set_name("value"); - attr1->set_type("tensor"); - - domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); - attr2->set_name("dtype"); - attr2->set_type("type"); - } else if (original_type == "TensorDataset") { - // Set input_arg & output_arg - domi::tensorflow::OpDef::ArgDef *inputArgdef = proto.add_input_arg(); - inputArgdef->set_name("components"); - inputArgdef->set_type_list_attr("Toutput_types"); - - domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); - outputArgdef->set_name("handle"); - outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - // Set domi::AttrDef - domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); - attr1->set_name("Toutput_types"); - attr1->set_type("list(type)"); - attr1->set_has_minimum(true); - attr1->set_minimum(1); - - domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); - attr2->set_name("output_shapes"); - attr2->set_type("list(shape)"); - attr2->set_has_minimum(true); - attr2->set_minimum(1); - - // Set stateful - proto.set_is_stateful(true); - } else if (original_type == "QueueDataset") { - // Set input_arg & output_arg - domi::tensorflow::OpDef::ArgDef *inputArgdef = proto.add_input_arg(); - inputArgdef->set_name("input_dataset"); - inputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); - outputArgdef->set_name("handle"); - outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - // Set domi::AttrDef - domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); - attr1->set_name("sourcedata"); - attr1->set_type("string"); - - domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); - attr2->set_name("output_types"); - attr2->set_type("list(type)"); - attr2->set_has_minimum(true); - attr2->set_minimum(1); - - domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); - attr3->set_name("output_shapes"); - attr3->set_type("list(shape)"); - attr3->set_has_minimum(true); - attr3->set_minimum(1); - - // Set stateful - proto.set_is_stateful(true); - } else if (original_type == "DeviceQueueDataset") { - // Set output_arg - domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); - outputArgdef->set_name("handle"); - outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - // Set domi::AttrDef - domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); - attr1->set_name("channel_name"); - attr1->set_type("string"); - - domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); - attr2->set_name("output_types"); - attr2->set_type("list(type)"); - attr2->set_has_minimum(true); - attr2->set_minimum(1); - - domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); - attr3->set_name("output_shapes"); - attr3->set_type("list(shape)"); - attr3->set_has_minimum(true); - attr3->set_minimum(1); - - // Set stateful - proto.set_is_stateful(true); - } else if (original_type == "ParallelMapDataset") { - // Set input_arg & output_arg - domi::tensorflow::OpDef::ArgDef *inputArgdef1 = proto.add_input_arg(); - inputArgdef1->set_name("input_dataset"); - inputArgdef1->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - domi::tensorflow::OpDef::ArgDef *inputArgdef2 = proto.add_input_arg(); - inputArgdef2->set_name("other_arguments"); - inputArgdef2->set_type_list_attr("Targuments"); - - domi::tensorflow::OpDef::ArgDef *inputArgdef3 = proto.add_input_arg(); - inputArgdef3->set_name("num_parallel_calls"); - inputArgdef3->set_type(::domi::tensorflow::DataType::DT_INT32); - - domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); - outputArgdef->set_name("handle"); - outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - // Set domi::AttrDef - domi::tensorflow::OpDef_AttrDef *attr0 = proto.add_attr(); - attr0->set_name("f"); - attr0->set_type("func"); - - domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); - attr1->set_name("Targuments"); - attr1->set_type("list(type)"); - attr1->set_has_minimum(true); - - domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); - attr2->set_name("output_types"); - attr2->set_type("list(type)"); - attr2->set_has_minimum(true); - attr2->set_minimum(1); - - domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); - attr3->set_name("output_shapes"); - attr3->set_type("list(shape)"); - attr3->set_has_minimum(true); - attr3->set_minimum(1); - - domi::tensorflow::OpDef_AttrDef *attr4 = proto.add_attr(); - attr4->set_name("use_iter_op_parallelism"); - attr4->set_type("bool"); - ::domi::tensorflow::AttrValue *default_value = attr4->mutable_default_value(); - default_value->set_b(true); - } else if (original_type == "BatchDatasetV2") { - // Set input_arg & output_arg - domi::tensorflow::OpDef::ArgDef *inputArgdef1 = proto.add_input_arg(); - inputArgdef1->set_name("input_dataset"); - inputArgdef1->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - domi::tensorflow::OpDef::ArgDef *inputArgdef2 = proto.add_input_arg(); - inputArgdef2->set_name("batch_size"); - inputArgdef2->set_type(::domi::tensorflow::DataType::DT_INT64); - - domi::tensorflow::OpDef::ArgDef *inputArgdef3 = proto.add_input_arg(); - inputArgdef3->set_name("drop_remainder"); - inputArgdef3->set_type(::domi::tensorflow::DataType::DT_BOOL); - - domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); - outputArgdef->set_name("handle"); - outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - // Set domi::AttrDef - domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); - attr1->set_name("output_types"); - attr1->set_type("list(type)"); - attr1->set_has_minimum(true); - attr1->set_minimum(1); - - domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); - attr2->set_name("output_shapes"); - attr2->set_type("list(shape)"); - attr2->set_has_minimum(true); - attr2->set_minimum(1); - } else if (original_type == "IteratorV2") { - // Set input_arg & output_arg - domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); - outputArgdef->set_name("handle"); - outputArgdef->set_type(::domi::tensorflow::DataType::DT_RESOURCE); - - // Set domi::AttrDef - domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); - attr1->set_name("shared_name"); - attr1->set_type("string"); - - domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); - attr2->set_name("container"); - attr2->set_type("string"); - - domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); - attr3->set_name("output_types"); - attr3->set_type("list(type)"); - attr3->set_has_minimum(true); - attr3->set_minimum(1); - - domi::tensorflow::OpDef_AttrDef *attr4 = proto.add_attr(); - attr4->set_name("output_shapes"); - attr4->set_type("list(shape)"); - attr4->set_has_minimum(true); - attr4->set_minimum(1); - - // Set stateful - proto.set_is_stateful(true); - } else if (original_type == "MakeIterator") { - // Set input_arg & output_arg - domi::tensorflow::OpDef::ArgDef *inputArgdef1 = proto.add_input_arg(); - inputArgdef1->set_name("dataset"); - inputArgdef1->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - domi::tensorflow::OpDef::ArgDef *inputArgdef2 = proto.add_input_arg(); - inputArgdef2->set_name("iterator"); - inputArgdef2->set_type(::domi::tensorflow::DataType::DT_RESOURCE); - - // Set domi::AttrDef - domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); - attr1->set_name("_kernel"); - attr1->set_type("dp"); - - // Set stateful - proto.set_is_stateful(true); - } else if (original_type == "IteratorGetNext") { - // Set input_arg & output_arg - domi::tensorflow::OpDef::ArgDef *input_argdef_1 = proto.add_input_arg(); - input_argdef_1->set_name("iterator"); - input_argdef_1->set_type(::domi::tensorflow::DataType::DT_RESOURCE); - - domi::tensorflow::OpDef::ArgDef *output_argdef = proto.add_output_arg(); - output_argdef->set_name("components"); - output_argdef->set_type_list_attr("output_types"); - - // Set domi::AttrDef - domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); - attr1->set_name("output_types"); - attr1->set_type("list(type)"); - attr1->set_has_minimum(true); - attr1->set_minimum(1); - - domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); - attr2->set_name("output_shapes"); - attr2->set_type("list(shape)"); - attr2->set_has_minimum(true); - attr2->set_minimum(1); - - domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); - attr3->set_name("_kernel"); - attr3->set_type("dp"); - - // Set stateful - proto.set_is_stateful(true); - } else if (original_type == "FilterDataset") { - // Set input_arg & output_arg - domi::tensorflow::OpDef::ArgDef *inputArgdef1 = proto.add_input_arg(); - inputArgdef1->set_name("input_dataset"); - inputArgdef1->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - domi::tensorflow::OpDef::ArgDef *inputArgdef2 = proto.add_input_arg(); - inputArgdef2->set_name("other_arguments"); - inputArgdef2->set_type_list_attr("Targuments"); - - domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); - outputArgdef->set_name("handle"); - outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - // Set domi::AttrDef - domi::tensorflow::OpDef_AttrDef *attr0 = proto.add_attr(); - attr0->set_name("predicate"); - attr0->set_type("func"); - - domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); - attr1->set_name("Targuments"); - attr1->set_type("list(type)"); - attr1->set_has_minimum(true); - - domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); - attr2->set_name("output_types"); - attr2->set_type("list(type)"); - attr2->set_has_minimum(true); - attr2->set_minimum(1); - - domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); - attr3->set_name("output_shapes"); - attr3->set_type("list(shape)"); - attr3->set_has_minimum(true); - attr3->set_minimum(1); - } else if (original_type == "MapAndBatchDatasetV2") { - // Set input_arg & output_arg - domi::tensorflow::OpDef::ArgDef *inputArgdef1 = proto.add_input_arg(); - inputArgdef1->set_name("input_dataset"); - inputArgdef1->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - domi::tensorflow::OpDef::ArgDef *inputArgdef2 = proto.add_input_arg(); - inputArgdef2->set_name("other_arguments"); - inputArgdef2->set_type_list_attr("Targuments"); - - domi::tensorflow::OpDef::ArgDef *inputArgdef3 = proto.add_input_arg(); - inputArgdef3->set_name("batch_size"); - inputArgdef3->set_type(::domi::tensorflow::DataType::DT_INT64); - - domi::tensorflow::OpDef::ArgDef *inputArgdef4 = proto.add_input_arg(); - inputArgdef4->set_name("num_parallel_calls"); - inputArgdef4->set_type(::domi::tensorflow::DataType::DT_INT64); - - domi::tensorflow::OpDef::ArgDef *inputArgdef5 = proto.add_input_arg(); - inputArgdef5->set_name("drop_remainder"); - inputArgdef5->set_type(::domi::tensorflow::DataType::DT_BOOL); - - domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); - outputArgdef->set_name("handle"); - outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); - - // Set domi::AttrDef - domi::tensorflow::OpDef_AttrDef *attr0 = proto.add_attr(); - attr0->set_name("f"); - attr0->set_type("func"); - - domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); - attr1->set_name("Targuments"); - attr1->set_type("list(type)"); - attr1->set_has_minimum(true); - - domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); - attr2->set_name("output_types"); - attr2->set_type("list(type)"); - attr2->set_has_minimum(true); - attr2->set_minimum(1); - - domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); - attr3->set_name("output_shapes"); - attr3->set_type("list(shape)"); - attr3->set_has_minimum(true); - attr3->set_minimum(1); - } - // set - opdef - string opdefString; - GE_IF_BOOL_EXEC(!proto.SerializeToString(&opdefString), - REPORT_CALL_ERROR("E19999", "Serialize opdef to string failed, op:%s(%s)", - n->GetName().c_str(), n->GetType().c_str()); - GELOGE(PARAM_INVALID, "Serialize opdef to string failed."); - return PARAM_INVALID); - - (void)ge::AttrUtils::SetStr(opDesc, ge::ATTR_NAME_FRAMEWORK_OP_DEF, opdefString); - - // print proto - string opdefstr; - google::protobuf::TextFormat::PrintToString(proto, &opdefstr); - GELOGI("---> ! CreateOpDefBytes() opdefstr :\n"); - GELOGI("%s", opdefstr.c_str()); - return SUCCESS; -} - -Status CreateFuncDefBytes(ge::NodePtr n, string original_type, string func_bin_path) { - GELOGI("func_bin_path = %s", func_bin_path.c_str()); - auto opDesc = n->GetOpDesc(); - - std::string func_string; - if (original_type == "ParallelMapDataset" || original_type == "MapAndBatchDatasetV2") { - GE_LOGI_IF(ge::AttrUtils::GetStr(n->GetOpDesc(), "f", func_string) != true, "get func string failed."); - } else if (original_type == "FilterDataset") { - GE_LOGI_IF(ge::AttrUtils::GetStr(n->GetOpDesc(), "predicate", func_string) != true, "get func string failed."); - } - GELOGI("func_string = %s", func_string.c_str()); - - std::string file = func_bin_path + "/" + func_string + ".bin"; - GELOGI("file = %s", file.c_str()); - - char *buf = nullptr; - int32_t len = 0; - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::parser::ReadBytesFromBinaryFile(file.c_str(), &buf, len), - REPORT_CALL_ERROR("E19999", "Read bytes from file:%s failed", file.c_str()); - return false, - "read bytes file error!"); - - GELOGI("len =%d\n", len); - - ge::Buffer funcDefBytes; - funcDefBytes = ge::Buffer::CopyFrom((std::uint8_t *)buf, len); - (void)ge::AttrUtils::SetBytes(opDesc, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes); - GELOGI("funcDefBytes.GetSize() =%zu", funcDefBytes.GetSize()); - - // print proto - if (funcDefBytes.GetSize() > 0 && funcDefBytes.GetData() != nullptr) { - domi::tensorflow::FunctionDefLibrary funcdeflib; - (void)funcdeflib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()); - - string funcdeflibrarystr; - google::protobuf::TextFormat::PrintToString(funcdeflib, &funcdeflibrarystr); - GELOGI("---> !CreateFuncDefBytes() funcdeflibrarystr :"); - } - - delete[] buf; - return SUCCESS; -} - -Status ParserGraphOptimizer::MakeTfProtoDef() { - GE_CHK_STATUS_RET(graph_->TopologicalSorting(), "graph sort failed."); - - map mOpIOListFuncMap; - CreateIOListFuncMap(mOpIOListFuncMap); - - for (ge::NodePtr n : graph_->GetDirectNode()) { - if (n->GetType() != ge::parser::FRAMEWORKOP) continue; - std::string original_type; - GE_LOGI_IF(ge::AttrUtils::GetStr(n->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type) != true, - "get original type failed."); - - // create frameworkop nodedefbytes & TFindatatype & TFoutdatatype - vector::iterator iter = - std::find(local_framework_op_vec.begin(), local_framework_op_vec.end(), original_type); - if (iter != local_framework_op_vec.end()) { - Status ret = CreateNodeDefBytes(n, original_type, mOpIOListFuncMap); - GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Create nodedefBytes failed!"); - - vector::iterator iter_dataset = - std::find(is_dataset_op_vec.begin(), is_dataset_op_vec.end(), original_type); - if (original_type == "Const" || iter_dataset != is_dataset_op_vec.end()) { - ret = CreateOpDefBytes(n, original_type); - GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Create opdefBytes failed!"); - if (original_type == "ParallelMapDataset" || original_type == "FilterDataset" || - original_type == "MapAndBatchDatasetV2") { - ret = CreateFuncDefBytes(n, original_type, GetFuncBinPath()); - GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Create funcdefBytes failed!"); - } - } - } - } - - return SUCCESS; -} - Status ParserGraphOptimizer::FusionFmkop() { GELOGI("graph_optimizer.cpp && FustionFmkop()"); GELOGI("GetLocalFmkopFlag() =%d", GetLocalFmkopFlag()); @@ -1640,6 +395,7 @@ Status ParserGraphOptimizer::LinkInnerAnchor(unordered_map "%s, dst node: %s.", src_ctrl->GetName().c_str(), dst->GetName().c_str()); return FAILED); + }); } return SUCCESS; @@ -1754,602 +510,4 @@ Status ParserGraphOptimizer::RebuildFusionNode(vector &inpu } return SUCCESS; } - -Status ParserGraphOptimizer::Insert4DTo5DTransOp(OutDataAnchorPtr src_anchor, InDataAnchorPtr dst_anchor, - enum ge::Format src_out_format, enum ge::DataType src_out_data_type, - enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type) { - bool isNCHWFP32To5DFP16 = (src_out_format == ge::FORMAT_NCHW && dst_in_format == ge::FORMAT_NC1HWC0); - if (isNCHWFP32To5DFP16) { - NodePtr cast_node = nullptr; - - if (src_out_data_type != dst_in_data_type) { - OpDescPtr cast_opdesc = CreateCastOp(src_out_data_type, dst_in_data_type, ge::FORMAT_NCHW); - cast_node = graph_->AddNode(cast_opdesc); - GE_CHK_BOOL_EXEC(cast_node != nullptr, - REPORT_CALL_ERROR("E19999", "Add Cast node to graph:%s failed", - graph_->GetName().c_str()); - return INTERNAL_ERROR, "graph add cast node fail."); - } - - OpDescPtr trans_data_opdesc = CreateTransDataOp(FORMAT_NCHW); - NodePtr trans_data_node = graph_->AddNode(trans_data_opdesc); - GE_CHK_BOOL_EXEC(trans_data_node != nullptr, - REPORT_CALL_ERROR("E19999", "Add Transdata node to graph:%s failed", - graph_->GetName().c_str()); - return INTERNAL_ERROR, "graph add TransData node node fail."); - GE_CHK_STATUS_RET(NewNodeAddEdges(src_anchor, dst_anchor, nullptr, cast_node, trans_data_node), - "NewNodeAddEdges ret fail."); - - return SUCCESS; - } - - OpDescPtr translateto5D = CreateTranslateOp(src_out_format, src_out_data_type, dst_in_format, dst_in_data_type); - GE_CHECK_NOTNULL(translateto5D); - NodePtr transNode = graph_->AddNode(translateto5D); - GE_CHECK_NOTNULL(transNode); - GELOGI("Create 4D To 5D fp32 node susscess!"); - - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, transNode->GetInDataAnchor(0)), - REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", - src_anchor->GetOwnerNode()->GetName().c_str(), - src_anchor->GetOwnerNode()->GetType().c_str(), src_anchor->GetIdx(), - transNode->GetName().c_str(), transNode->GetType().c_str()); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(transNode->GetOutDataAnchor(0), dst_anchor), - REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - transNode->GetName().c_str(), transNode->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR); - - GELOGI("Create 4D To 5D susscess!"); - return SUCCESS; -} - -Status ParserGraphOptimizer::InsertFZ2HWCK(OutDataAnchorPtr src_anchor, InDataAnchorPtr dst_anchor, - enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, - enum ge::Format dstInFormat, enum ge::DataType dstInDatatype) { - GELOGI("In InsertFZ2HWCK !"); - GE_IF_BOOL_EXEC( - srcOutFormat == ge::FORMAT_FRACTAL_Z, NodePtr transHalfNode = nullptr; - if (srcOutDatatype == ge::DT_FLOAT) { - // create FZ fp32->FZ fp16 node - OpDescPtr translatetoHalf = CreateTranslateOp(srcOutFormat, srcOutDatatype, srcOutFormat, ge::DT_FLOAT16); - transHalfNode = graph_->AddNode(translatetoHalf); - GE_CHECK_NOTNULL(transHalfNode); - GELOGI("Create FZ fp32 to FZ fp16 node susscess!"); - // create FZ fp16->HWCK fp32 node - } - - OpDescPtr translatetoHWCK = CreateTranslateOp(srcOutFormat, ge::DT_FLOAT16, dstInFormat, dstInDatatype); - NodePtr transHWCKNode = graph_->AddNode(translatetoHWCK); GELOGI("Create FZ 16 to HWCK fp32 node susscess!"); - GE_CHECK_NOTNULL(transHWCKNode); if (transHalfNode) { - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, transHalfNode->GetInDataAnchor(0)), - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", - src_anchor->GetOwnerNode()->GetName().c_str(), - src_anchor->GetOwnerNode()->GetType().c_str(), src_anchor->GetIdx(), - transHalfNode->GetName().c_str(), transHalfNode->GetType().c_str()); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(transHalfNode->GetOutDataAnchor(0), transHWCKNode->GetInDataAnchor(0)), - REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed", - transHalfNode->GetName().c_str(), transHalfNode->GetType().c_str(), - transHWCKNode->GetName().c_str(), transHWCKNode->GetType().c_str()); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(transHWCKNode->GetOutDataAnchor(0), dst_anchor) != SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - transHWCKNode->GetName().c_str(), transHWCKNode->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR); - } else { - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, transHWCKNode->GetInDataAnchor(0)), - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", - src_anchor->GetOwnerNode()->GetName().c_str(), - src_anchor->GetOwnerNode()->GetType().c_str(), src_anchor->GetIdx(), - transHWCKNode->GetName().c_str(), transHWCKNode->GetType().c_str()); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(transHWCKNode->GetOutDataAnchor(0), dst_anchor) != SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - transHWCKNode->GetName().c_str(), transHWCKNode->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR); - } GELOGI("Create InsertFZ2HWCK success!");) - return SUCCESS; -} - -Status ParserGraphOptimizer::InsertVar5DTo4D(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, - enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, - enum ge::Format dstInFormat, enum ge::DataType dstInDatatype) { - GELOGI("In Insert 5D To 4D !"); - GE_IF_BOOL_EXEC( - srcOutFormat == ge::FORMAT_NC1HWC0, NodePtr cast_node = nullptr; - if (srcOutDatatype == ge::DT_FLOAT && dstInDatatype == ge::DT_FLOAT) { - auto cast_opdesc = CreateCastOp(ge::DT_FLOAT, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0); - cast_node = graph_->AddNode(cast_opdesc); - - srcOutDatatype = ge::DT_FLOAT16; - } NodePtr transHalfNode = nullptr; - OpDescPtr translateto4D = CreateTranslateOp(srcOutFormat, srcOutDatatype, dstInFormat, dstInDatatype); - NodePtr trans4DNode = graph_->AddNode(translateto4D); GELOGI("Create 5D To 4D fp32 node susscess!"); - GE_CHECK_NOTNULL(trans4DNode); - - if (cast_node) { - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, cast_node->GetInDataAnchor(0)), - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", - src_anchor->GetOwnerNode()->GetName().c_str(), - src_anchor->GetOwnerNode()->GetType().c_str(), src_anchor->GetIdx(), - cast_node->GetName().c_str(), cast_node->GetType().c_str()); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), trans4DNode->GetInDataAnchor(0)), - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed", - cast_node->GetName().c_str(), cast_node->GetType().c_str(), - trans4DNode->GetName().c_str(), trans4DNode->GetType().c_str()); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(trans4DNode->GetOutDataAnchor(0), dst_anchor) != SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - trans4DNode->GetName().c_str(), trans4DNode->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR); - } else { - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, trans4DNode->GetInDataAnchor(0)), - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", - src_anchor->GetOwnerNode()->GetName().c_str(), - src_anchor->GetOwnerNode()->GetType().c_str(), src_anchor->GetIdx(), - trans4DNode->GetName().c_str(), trans4DNode->GetType().c_str()); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(trans4DNode->GetOutDataAnchor(0), dst_anchor) != SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - trans4DNode->GetName().c_str(), trans4DNode->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR); - } GELOGI("Create 5D To 4D susscess!");) - return SUCCESS; -} - -Status ParserGraphOptimizer::InsertHWCK2FZ(OutDataAnchorPtr src_anchor, InDataAnchorPtr dst_anchor, - enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, - enum ge::Format dstInFormat, enum ge::DataType dstInDatatype) { - GELOGI("In InsertHWCK2FZ !"); - GE_IF_BOOL_EXEC( - srcOutFormat == ge::FORMAT_HWCN, NodePtr transHalfNode = nullptr; - OpDescPtr translatetoFZ = CreateTranslateOp(srcOutFormat, srcOutDatatype, dstInFormat, ge::DT_FLOAT16); - NodePtr transHWCK2FZNode = graph_->AddNode(translatetoFZ); GELOGI("Create HWCK fp32 to FZ 16 node susscess!"); - GE_CHECK_NOTNULL(transHWCK2FZNode); - - ge::NodePtr translateHalftoFp32Node = nullptr; if (dstInDatatype == ge::DT_FLOAT) { - // create FZ fp16 ->FZ fp32 node - ge::OpDescPtr translateHalftoFp32 = CreateTranslateOp(dstInFormat, ge::DT_FLOAT16, dstInFormat, dstInDatatype); - translateHalftoFp32Node = graph_->AddNode(translateHalftoFp32); - GE_CHECK_NOTNULL(translateHalftoFp32Node); - GELOGI("Create FZ fp32 to FZ fp16 node susscess!"); - // create FZ fp16->HWCK fp32 node - } - - if (translateHalftoFp32Node) { - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, transHWCK2FZNode->GetInDataAnchor(0)), - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", - src_anchor->GetOwnerNode()->GetName().c_str(), - src_anchor->GetOwnerNode()->GetType().c_str(), src_anchor->GetIdx(), - transHWCK2FZNode->GetName().c_str(), transHWCK2FZNode->GetType().c_str()); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC( - GraphUtils::AddEdge(transHWCK2FZNode->GetOutDataAnchor(0), translateHalftoFp32Node->GetInDataAnchor(0)), - REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed", - transHWCK2FZNode->GetName().c_str(), transHWCK2FZNode->GetType().c_str(), - translateHalftoFp32Node->GetName().c_str(), translateHalftoFp32Node->GetType().c_str()); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(translateHalftoFp32Node->GetOutDataAnchor(0), dst_anchor) != SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - translateHalftoFp32Node->GetName().c_str(), translateHalftoFp32Node->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR); - } else { - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, transHWCK2FZNode->GetInDataAnchor(0)), - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", - src_anchor->GetOwnerNode()->GetName().c_str(), - src_anchor->GetOwnerNode()->GetType().c_str(), src_anchor->GetIdx(), - transHWCK2FZNode->GetName().c_str(), transHWCK2FZNode->GetType().c_str()); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(transHWCK2FZNode->GetOutDataAnchor(0), dst_anchor) != SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - transHWCK2FZNode->GetName().c_str(), transHWCK2FZNode->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR); - } GELOGI("Create InsertHWCK2FZ success!");) - return SUCCESS; -} - -Status ParserGraphOptimizer::Insert5DTo4DTransOp(OutDataAnchorPtr src_anchor, InDataAnchorPtr dst_anchor, - enum ge::Format src_out_format, enum ge::DataType src_out_data_type, - enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type) { - // Status ret; - NodePtr permute_node = nullptr; - NodePtr cast_node = nullptr; - - OpDescPtr trans_data_opdesc = CreateTransDataOp(FORMAT_NC1HWC0); - NodePtr trans_data_node = graph_->AddNode(trans_data_opdesc); - GE_CHK_BOOL_EXEC(trans_data_node != nullptr, - REPORT_CALL_ERROR("E19999", "Add Transdata node to graph:%s failed", - graph_->GetName().c_str()); - return INTERNAL_ERROR, "graph add TransData node node fail."); - - if (src_out_data_type != dst_in_data_type) { - OpDescPtr cast_opdesc = CreateCastOp(src_out_data_type, dst_in_data_type, ge::FORMAT_NCHW); - cast_node = graph_->AddNode(cast_opdesc); - GE_CHK_BOOL_EXEC(cast_node != nullptr, - REPORT_CALL_ERROR("E19999", "Add Cast node to graph:%s failed", - graph_->GetName().c_str()); - return INTERNAL_ERROR, "graph add cast node fail."); - } - - if (dst_in_format == FORMAT_NHWC) { - OpDescPtr permute_opdec = CreatePermuteOp(FORMAT_NCHW, dst_in_format); - permute_node = graph_->AddNode(permute_opdec); - GE_CHK_BOOL_EXEC(permute_node != nullptr, - REPORT_CALL_ERROR("E19999", "Add Permute node to graph:%s failed", - graph_->GetName().c_str()); - return INTERNAL_ERROR, "graph add permute node fail."); - } - - GE_CHK_STATUS_RET(NewNodeAddEdges(src_anchor, dst_anchor, trans_data_node, cast_node, permute_node), - "NewNodeAddEdges ret fail."); - - return SUCCESS; -} - -Status ParserGraphOptimizer::NewNodeAddEdges(OutDataAnchorPtr src_anchor, InDataAnchorPtr dst_anchor, NodePtr first, - NodePtr second, NodePtr third) { - GE_CHECK_NOTNULL(src_anchor); - GE_CHECK_NOTNULL(dst_anchor); - OutDataAnchorPtr add_in_anchor = nullptr; - InDataAnchorPtr add_out_anchor = nullptr; - NodePtr src_node = src_anchor->GetOwnerNode(); - NodePtr dst_node = dst_anchor->GetOwnerNode(); - - if (first != nullptr) { - Status status = GraphUtils::AddEdge(src_anchor, first->GetInDataAnchor(0)); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", - src_anchor->GetOwnerNode()->GetName().c_str(), - src_anchor->GetOwnerNode()->GetType().c_str(), src_anchor->GetIdx(), - first->GetName().c_str(), first->GetType().c_str()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", - src_anchor->GetIdx(), 0); - if (second != nullptr) { - status = GraphUtils::AddEdge(first->GetOutDataAnchor(0), second->GetInDataAnchor(0)); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed", - first->GetName().c_str(), first->GetType().c_str(), - second->GetName().c_str(), second->GetType().c_str()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", 0, 0); - if (third != nullptr) { - status = GraphUtils::AddEdge(second->GetOutDataAnchor(0), third->GetInDataAnchor(0)); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed", - second->GetName().c_str(), second->GetType().c_str(), - third->GetName().c_str(), third->GetType().c_str()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", 0, 0); - status = GraphUtils::AddEdge(third->GetOutDataAnchor(0), dst_anchor); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - third->GetName().c_str(), third->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", - 0, dst_anchor->GetIdx()); - } else { - status = GraphUtils::AddEdge(second->GetOutDataAnchor(0), dst_anchor); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - second->GetName().c_str(), second->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", - 0, dst_anchor->GetIdx()); - } - } else { - if (third != nullptr) { - status = GraphUtils::AddEdge(first->GetOutDataAnchor(0), third->GetInDataAnchor(0)); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed", - first->GetName().c_str(), first->GetType().c_str(), - third->GetName().c_str(), third->GetType().c_str()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", 0, 0); - status = GraphUtils::AddEdge(third->GetOutDataAnchor(0), dst_anchor); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - third->GetName().c_str(), third->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", - 0, dst_anchor->GetIdx()); - } else { - status = GraphUtils::AddEdge(first->GetOutDataAnchor(0), dst_anchor); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - first->GetName().c_str(), first->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", - 0, dst_anchor->GetIdx()); - } - } - } else { - if (second != nullptr) { - Status status = GraphUtils::AddEdge(src_anchor, second->GetInDataAnchor(0)); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", - src_anchor->GetOwnerNode()->GetName().c_str(), - src_anchor->GetOwnerNode()->GetType().c_str(), src_anchor->GetIdx(), - second->GetName().c_str(), second->GetType().c_str()); - return INTERNAL_ERROR, - "graph add src to cast edge fail, src index:%d, dst index:%d.", src_anchor->GetIdx(), 0); - GE_IF_BOOL_EXEC( - third != nullptr, status = GraphUtils::AddEdge(second->GetOutDataAnchor(0), third->GetInDataAnchor(0)); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed", - second->GetName().c_str(), second->GetType().c_str(), - third->GetName().c_str(), third->GetType().c_str()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", 0, 0); - status = GraphUtils::AddEdge(third->GetOutDataAnchor(0), dst_anchor); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - third->GetName().c_str(), third->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", - 0, dst_anchor->GetIdx());); - GE_IF_BOOL_EXEC(third == nullptr, status = GraphUtils::AddEdge(second->GetOutDataAnchor(0), dst_anchor); - GE_CHK_BOOL_EXEC( - status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - second->GetName().c_str(), second->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR, - "graph add edge fail, src index:%d, dst index:%d.", 0, 0);); - } else { - if (third != nullptr) { - Status status = GraphUtils::AddEdge(src_anchor, third->GetInDataAnchor(0)); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", - src_anchor->GetOwnerNode()->GetName().c_str(), - src_anchor->GetOwnerNode()->GetType().c_str(), src_anchor->GetIdx(), - third->GetName().c_str(), third->GetType().c_str()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", 0, 0); - status = GraphUtils::AddEdge(third->GetOutDataAnchor(0), dst_anchor); - GE_CHK_BOOL_EXEC(status == SUCCESS, - REPORT_CALL_ERROR( - "E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed", - third->GetName().c_str(), third->GetType().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetType().c_str(), dst_anchor->GetIdx()); - return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", - 0, dst_anchor->GetIdx()); - } - } - } - return SUCCESS; -} - -OpDescPtr ParserGraphOptimizer::CreateTranslateOp(enum ge::Format inFormat, enum ge::DataType inDatatype, - enum ge::Format outFormat, enum ge::DataType outDatatype) { - /** - * 0. FP32 <-> FP16 - * 1. from HWCK(FP32) to FracZ(FP16); - * 2. from FracZ(FP16) to HWCK(FP32); - * 3. from NHWC(FP32) to NC1HWC0(FP16); - * 4. from NC1HWC0(FP32) to NHWC(FP32); - * 5. from NC1HWC0(FP16) to NHWC(FP32) - */ - static uint32_t transop_count = 0; - OpDescPtr op_def = nullptr; - std::stringstream sstmp; - sstmp << "translate_" << ge::parser::TRANSDATA << "_" << transop_count++; - GE_MAKE_SHARED(op_def = std::make_shared(sstmp.str().c_str(), ge::parser::TRANSLATE), op_def = nullptr; - return op_def); - GELOGI( - "create translate op:%s, input format:%s, input datatype:%s, output " - "format:%s, output datatype:%s.", - op_def->GetName().c_str(), ge::TypeUtils::FormatToSerialString(inFormat).c_str(), - ge::TypeUtils::DataTypeToSerialString(inDatatype).c_str(), ge::TypeUtils::FormatToSerialString(outFormat).c_str(), - ge::TypeUtils::DataTypeToSerialString(outDatatype).c_str()); - - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_INPUT_FORMAT, inFormat), - REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_INPUT_FORMAT.c_str(), - op_def->GetName().c_str(), op_def->GetType().c_str()); - return nullptr, - "SetInt ATTR_NAME_INPUT_FORMAT failed."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_INPUT_DATATYPE, inDatatype), - REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_INPUT_DATATYPE.c_str(), - op_def->GetName().c_str(), op_def->GetType().c_str()); - return nullptr, - "SetInt ATTR_NAME_INPUT_DATATYPE failed."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_OUTPUT_FORMAT, outFormat), - REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_OUTPUT_FORMAT.c_str(), - op_def->GetName().c_str(), op_def->GetType().c_str()); - return nullptr, - "SetInt ATTR_NAME_INPUT_DATATYPE failed."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_OUTPUT_DATATYPE, outDatatype), - REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_OUTPUT_DATATYPE.c_str(), - op_def->GetName().c_str(), op_def->GetType().c_str()); - return nullptr, - "SetInt ATTR_NAME_INPUT_DATATYPE failed."); - if (inDatatype != ge::DT_FLOAT16) { - GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddInputDesc(GeTensorDesc(GeShape(), inFormat)), - REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", - op_def->GetName().c_str(), op_def->GetType().c_str()); - return nullptr, - "create translate op:add input desc fail."); - } else { - GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddInputDesc(GeTensorDesc(GeShape(), inFormat, ge::DT_FLOAT16)), - REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", - op_def->GetName().c_str(), op_def->GetType().c_str()); - return nullptr, - "create translate op:add input desc fail."); - } - if (outDatatype != ge::DT_FLOAT16) { - GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddOutputDesc(GeTensorDesc(GeShape(), outFormat)), - REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", - op_def->GetName().c_str(), op_def->GetType().c_str()); - return nullptr, - "create translate op:add output desc fail."); - } else { - GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddOutputDesc(GeTensorDesc(GeShape(), outFormat, ge::DT_FLOAT16)), - REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", - op_def->GetName().c_str(), op_def->GetType().c_str()); - return nullptr, "create translate op:add output desc fail."); - } - return op_def; -} - -OpDescPtr ParserGraphOptimizer::CreatePermuteOp(enum ge::Format input_format, enum ge::Format output_format) { - static uint32_t transop_count = 0; - - std::stringstream sstmp; - sstmp << "transdata_" << ge::parser::PERMUTE << "_" << transop_count++; - - OpDescPtr op_desc = nullptr; - GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::parser::PERMUTE), op_desc = nullptr; - return op_desc); - GELOGI("create permute op:%s", op_desc->GetName().c_str()); - - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), - REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_INPUT_FORMAT.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return nullptr, - "SetInt ATTR_NAME_INPUT_FORMAT failed."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), - REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_OUTPUT_FORMAT.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return nullptr, - "SetInt ATTR_NAME_OUTPUT_FORMAT failed."); - - GE_IF_BOOL_EXEC(input_format == FORMAT_NCHW, (void)AttrUtils::SetInt(op_desc, "NCHW_to_NHWC", (int64_t)1)); - GE_IF_BOOL_EXEC(input_format == FORMAT_NHWC, (void)AttrUtils::SetInt(op_desc, "NHWC_to_NCHW", (int64_t)1)); - - GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddInputDesc(GeTensorDesc(GeShape(), input_format)), - REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return nullptr, - "create permute op:add input desc fail."); - GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddOutputDesc(GeTensorDesc(GeShape(), output_format)), - REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return nullptr, - "create permute op:add output desc fail."); - - return op_desc; -} - -OpDescPtr ParserGraphOptimizer::CreateCastOp(enum ge::DataType input_data_type, enum ge::DataType output_data_type, - enum ge::Format format) { - static uint32_t transop_count = 0; - std::stringstream sstmp; - sstmp << "transdata_" << ge::parser::CAST << "_" << transop_count++; - - OpDescPtr op_desc = nullptr; - GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::parser::CAST), op_desc = nullptr; - return op_desc); - GELOGI("create cast op:%s, input datatype:%s, out datatype:%s", op_desc->GetName().c_str(), - ge::TypeUtils::DataTypeToSerialString(input_data_type).c_str(), - ge::TypeUtils::DataTypeToSerialString(output_data_type).c_str()); - - if (!(AttrUtils::SetInt(op_desc, ge::CAST_ATTR_SRCT, (int64_t)input_data_type) && - AttrUtils::SetInt(op_desc, ge::CAST_ATTR_DSTT, (int64_t)output_data_type) && - AttrUtils::SetInt(op_desc, ge::CAST_ATTR_DST_TYPE, (int64_t)output_data_type) && - AttrUtils::SetBool(op_desc, ge::CAST_ATTR_TRUNCATE, false))) { - REPORT_CALL_ERROR("E19999", "Set Attr:%s or %s or %s or %s to op:%s(%s) failed", - CAST_ATTR_SRCT.c_str(), CAST_ATTR_DSTT.c_str(), - CAST_ATTR_DST_TYPE.c_str(), CAST_ATTR_TRUNCATE.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GELOGE(FAILED, "Set CAST_ATTR_SRCT or CAST_ATTR_DSTT or CAST_ATTR_DST_TYPE or CAST_ATTR_TRUNCATE fail, node: %s.", - op_desc->GetName().c_str()); - return nullptr; - } - - GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddInputDesc(GeTensorDesc(GeShape(), format, input_data_type)), - REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return nullptr, - "create cast op:add input desc fail."); - GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddOutputDesc(GeTensorDesc(GeShape(), format, output_data_type)), - REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return nullptr, - "create cast op:add output desc fail."); - - return op_desc; -} -OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format) { - static uint32_t transop_count = 0; - std::stringstream sstmp; - sstmp << "transdata_" << ge::parser::TRANSDATA << "_" << transop_count++; - - OpDescPtr op_desc = nullptr; - GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::parser::TRANSDATA), op_desc = nullptr; - return op_desc); - - GELOGI("create transdata op:%s, input format:%s.", op_desc->GetName().c_str(), - ge::TypeUtils::FormatToSerialString(input_format).c_str()); - enum ge::Format output_format = FORMAT_NC1HWC0; - if (input_format != FORMAT_NCHW) { - input_format = FORMAT_NC1HWC0; - output_format = FORMAT_NCHW; - } - - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), - REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_INPUT_FORMAT.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return nullptr, - "SetInt of ATTR_NAME_INPUT_FORMAT failed."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), - REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_OUTPUT_FORMAT.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return nullptr, - "SetInt of ATTR_NAME_OUTPUT_FORMAT failed."); - GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddInputDesc(GeTensorDesc(GeShape(), input_format)), - REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return nullptr, - "create transdata op:add input desc fail."); - GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddOutputDesc(GeTensorDesc(GeShape(), output_format)), - REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return nullptr, - "create transdata op:add output desc fail."); - - return op_desc; -} } // namespace ge diff --git a/parser/tensorflow/graph_optimizer.h b/parser/tensorflow/graph_optimizer.h index 9f73d69..10ec65e 100644 --- a/parser/tensorflow/graph_optimizer.h +++ b/parser/tensorflow/graph_optimizer.h @@ -35,68 +35,16 @@ namespace ge { class ParserGraphOptimizer { public: explicit ParserGraphOptimizer(ge::ComputeGraphPtr graph, domi::FrameworkType type = domi::TENSORFLOW) - : graph_(graph), fmktype_(type), local_fmk_op_flag_(false) {} + : graph_(graph), fmktype_(type) {} ~ParserGraphOptimizer() {} - domi::Status Optimize(); - - domi::Status OptimizeAfterCal(); - domi::Status FusionFmkop(); - inline bool IsHCOMOp(const string &op_type) { - return (op_type == ge::parser::HCOMALLREDUCE) || (op_type == ge::parser::HCOMALLGATHER) || - (op_type == ge::parser::HCOMBROADCAST) || (op_type == ge::parser::HCOMSEND) || - (op_type == ge::parser::HCOMRECEIVE) || (op_type == "HcomReduceScatter"); - } - - void SetLocalFmkopFlag(bool isLocalFmkopFlag) { local_fmk_op_flag_ = isLocalFmkopFlag; } - - const bool GetLocalFmkopFlag() const { return local_fmk_op_flag_; } - - void SetFuncBinPath(std::string isFuncBinPath) { func_bin_path_ = isFuncBinPath; } - const std::string GetFuncBinPath() const { return func_bin_path_; } - - domi::Status InsertHWCK2FZ(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, - enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, - enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); - - domi::Status Insert4DTo5DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, - enum ge::Format src_out_format, enum ge::DataType src_out_data_type, - enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type); - - domi::Status InsertFZ2HWCK(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, - enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, - enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); - - domi::Status Insert5DTo4DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, - enum ge::Format src_out_format, enum ge::DataType src_out_data_type, - enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type); - - ge::OpDescPtr CreateCastOp(enum ge::DataType input_datatype, enum ge::DataType output_datatype, ge::Format format); - - ge::OpDescPtr CreatePermuteOp(enum ge::Format input_format, enum ge::Format output_format); - - ge::OpDescPtr CreateTransDataOp(enum ge::Format input_format); - - domi::Status NewNodeAddEdges(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, ge::NodePtr first, - ge::NodePtr second, ge::NodePtr third); - - domi::Status InsertVar5DTo4D(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, - enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, - enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); - - ge::OpDescPtr CreateTranslateOp(enum ge::Format inFormat, ge::DataType inDatatype, enum ge::Format outFormat, - ge::DataType outDatatype); - private: ge::ComputeGraphPtr graph_; domi::FrameworkType fmktype_; - // local fmkop flag - bool local_fmk_op_flag_; - std::string func_bin_path_; - + domi::Status FindFmkNodeCluser(unordered_map> &node_cluser_Map); domi::Status MarkForFusion(unordered_map> &node_cluser_Map); @@ -122,7 +70,6 @@ class ParserGraphOptimizer { vector &input_control_anchors, vector &output_control_anchors, ge::NodePtr fusion_node); - domi::Status MakeTfProtoDef(); }; } // namespace ge #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ diff --git a/parser/tensorflow/iterator_fusion_pass.cc b/parser/tensorflow/iterator_fusion_pass.cc index ae49130..14fcf9a 100644 --- a/parser/tensorflow/iterator_fusion_pass.cc +++ b/parser/tensorflow/iterator_fusion_pass.cc @@ -32,8 +32,6 @@ Status IteratorFusionPass::Run(ge::ComputeGraphPtr graph) { REPORT_CALL_ERROR("E19999", "New ParserGraphOptimizer failed"); return FAILED; } - - graph_optimizer->SetLocalFmkopFlag(local_fmk_op_flag_); return graph_optimizer->FusionFmkop(); } } // namespace ge diff --git a/parser/tensorflow/iterator_fusion_pass.h b/parser/tensorflow/iterator_fusion_pass.h index aadde8b..af590d2 100644 --- a/parser/tensorflow/iterator_fusion_pass.h +++ b/parser/tensorflow/iterator_fusion_pass.h @@ -23,8 +23,8 @@ namespace ge { class IteratorFusionPass : public GraphPass { public: - IteratorFusionPass(domi::FrameworkType type, bool local_fmk_op_flag) - : fmk_type_(type), local_fmk_op_flag_(local_fmk_op_flag) {} + IteratorFusionPass(domi::FrameworkType type) + : fmk_type_(type) {} virtual ~IteratorFusionPass() {} @@ -32,7 +32,6 @@ class IteratorFusionPass : public GraphPass { private: domi::FrameworkType fmk_type_; - bool local_fmk_op_flag_; }; } // namespace ge diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 640ab0a..5d911f5 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -2375,7 +2375,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, ge::parser::PassManager iterator_fusion_pass; try { (void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass", - new ge::IteratorFusionPass(domi::TENSORFLOW, false)); + new ge::IteratorFusionPass(domi::TENSORFLOW)); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR; diff --git a/tests/ut/parser/CMakeLists.txt b/tests/ut/parser/CMakeLists.txt index 65804fc..7ef01d2 100644 --- a/tests/ut/parser/CMakeLists.txt +++ b/tests/ut/parser/CMakeLists.txt @@ -307,6 +307,7 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) set(PARSER_UT_FILES + "graph_builder_utils.cc" "parser_ut_utils.cc" "testcase/common/acl_graph_parser_unittest.cc" "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" @@ -314,6 +315,7 @@ set(PARSER_UT_FILES "testcase/onnx_parser_testcase/message2operator_unittest.cc" "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" "testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc" + "testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc" ) ############ libut_parser_common.a ############ diff --git a/tests/ut/parser/graph_builder_utils.cc b/tests/ut/parser/graph_builder_utils.cc new file mode 100644 index 0000000..17e10dd --- /dev/null +++ b/tests/ut/parser/graph_builder_utils.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph_builder_utils.h" + +#include "graph/utils/graph_utils.h" + +namespace ge { +namespace ut { +NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, + DataType data_type, std::vector shape) { + auto tensor_desc = std::make_shared(); + tensor_desc->SetShape(GeShape(std::move(shape))); + tensor_desc->SetFormat(format); + tensor_desc->SetDataType(data_type); + + auto op_desc = std::make_shared(name, type); + for (int i = 0; i < in_cnt; ++i) { + op_desc->AddInputDesc(tensor_desc->Clone()); + } + for (int i = 0; i < out_cnt; ++i) { + op_desc->AddOutputDesc(tensor_desc->Clone()); + } + + return graph_->AddNode(op_desc); +} +void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) { + GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); +} +void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) { + GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); +} + +} // namespace ut +} // namespace ge diff --git a/tests/ut/parser/graph_builder_utils.h b/tests/ut/parser/graph_builder_utils.h new file mode 100644 index 0000000..3ba1d7c --- /dev/null +++ b/tests/ut/parser/graph_builder_utils.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ +#define MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ + +#include +#include + +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/node.h" + +namespace ge { +namespace ut { +class GraphBuilder { + public: + explicit GraphBuilder(const std::string &name) { graph_ = std::make_shared(name); } + NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, + Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, + std::vector shape = {1, 1, 224, 224}); + void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx); + void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node); + ComputeGraphPtr GetGraph() { + graph_->TopologicalSorting(); + return graph_; + } + + private: + ComputeGraphPtr graph_; +}; +} // namespace ut +} // namespace ge + +#endif // MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ diff --git a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc new file mode 100644 index 0000000..ba5af43 --- /dev/null +++ b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc @@ -0,0 +1,71 @@ +#include +#include +#include "graph/utils/attr_utils.h" +#include "graph/debug/ge_attr_define.h" +#include "graph_builder_utils.h" +#include "common/util.h" + +#include "tensorflow/iterator_fusion_pass.h" +#include "parser/common/acl_graph_parser_util.h" +#define private public +#include "tensorflow/graph_optimizer.h" +#undef private + +namespace ge { +class UtestGraphOptimizer : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +namespace { + ComputeGraphPtr MakeGraph() { +  ge::ut::GraphBuilder builder("graph"); +  std::string name = "graph"; +  std::string original_type ; + +  original_type = "IteratorV2"; +  auto data1 = builder.AddNode(name + "_"+original_type, ge::parser::FRAMEWORKOP, 1, 1); +  ge::AttrUtils::SetStr(data1->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); + + original_type = "IteratorGetNext"; + auto data2 = builder.AddNode(name + "_"+original_type+"2", ge::parser::FRAMEWORKOP, 1, 2); + ge::AttrUtils::SetStr(data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); + + string nodefStr; + AttrUtils::SetZeroCopyBytes( + data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, + Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); + + original_type = "IteratorGetNext"; + auto data3 = builder.AddNode(name + "_"+original_type+"3", ge::parser::FRAMEWORKOP, 2, 1); + ge::AttrUtils::SetStr(data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); + + AttrUtils::SetZeroCopyBytes( + data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, + Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); + + builder.AddDataEdge(data1, 0, data2, 0); + builder.AddDataEdge(data2, 0, data3, 0); + builder.AddDataEdge(data2, 1, data3, 1); + return builder.GetGraph(); + } +} + +TEST_F(UtestGraphOptimizer, graph_optimizer) { + ge::ComputeGraphPtr graph = MakeGraph(); + ge::IteratorFusionPass iteratorFusionPass(domi::TENSORFLOW); + EXPECT_NE(iteratorFusionPass.Run(graph),ge::SUCCESS); +} + +TEST_F(UtestGraphOptimizer, graph_optimizer_output) { + ge::ComputerGraph graph = MakeGraph(); + domi::FrameworkType type = domi::TENSORFLOW; +  ge::ParserGraphOptimizer parserGraphOptimizer(graph,type); +  vector input_anchors; +  vector output_anchors; +  ge::OpDescPtr fusion_op_desc; +  EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors,fusion_op_desc),ge::SUCCESS); + EXPECT_NE(parserGraphOptimizer.RebuildOutputAnchors(output_anchors,fusion_op_desc),ge::SUCCESS); +} +} From c256bade9a39b9c637ede2cf9547e055b4b92946 Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Fri, 10 Sep 2021 09:32:16 +0800 Subject: [PATCH 03/16] use true types istead of GeAttrValue:: --- parser/tensorflow/graph_optimizer.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/parser/tensorflow/graph_optimizer.cc b/parser/tensorflow/graph_optimizer.cc index a2488b3..fa7c623 100644 --- a/parser/tensorflow/graph_optimizer.cc +++ b/parser/tensorflow/graph_optimizer.cc @@ -91,9 +91,6 @@ const char *const kShapeNodeName = "Shape"; Status ParserGraphOptimizer::FusionFmkop() { GELOGI("graph_optimizer.cpp && FustionFmkop()"); - GELOGI("GetLocalFmkopFlag() =%d", GetLocalFmkopFlag()); - GE_IF_BOOL_EXEC(GetLocalFmkopFlag() == 1, MakeTfProtoDef()); - GE_CHECK_NOTNULL(graph_); std::unordered_map> node_cluser_Map; GE_CHK_STATUS_RET(MarkForFusion(node_cluser_Map), "find framework node to be fused fail."); From 91f68547079ed55804e6fcebdc210076902b52a6 Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Fri, 10 Sep 2021 09:55:37 +0800 Subject: [PATCH 04/16] use true types istead of GeAttrValue:: --- .../graph_optimizer_testcase/graph_optimizer_unittest.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc index ba5af43..6a3704a 100644 --- a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc +++ b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc @@ -59,7 +59,7 @@ TEST_F(UtestGraphOptimizer, graph_optimizer) { } TEST_F(UtestGraphOptimizer, graph_optimizer_output) { - ge::ComputerGraph graph = MakeGraph(); + ge::ComputeGraph graph = MakeGraph(); domi::FrameworkType type = domi::TENSORFLOW;   ge::ParserGraphOptimizer parserGraphOptimizer(graph,type);   vector input_anchors; From 45ad6c58201c9d69f91bc187baa8347c7fc52aaa Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Fri, 10 Sep 2021 10:04:24 +0800 Subject: [PATCH 05/16] use true types istead of GeAttrValue:: --- .../graph_optimizer_testcase/graph_optimizer_unittest.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc index 6a3704a..e252c9c 100644 --- a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc +++ b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc @@ -59,7 +59,7 @@ TEST_F(UtestGraphOptimizer, graph_optimizer) { } TEST_F(UtestGraphOptimizer, graph_optimizer_output) { - ge::ComputeGraph graph = MakeGraph(); + ge::ComputeGraphPtr graph = MakeGraph(); domi::FrameworkType type = domi::TENSORFLOW;   ge::ParserGraphOptimizer parserGraphOptimizer(graph,type);   vector input_anchors; From d57c76994d9a8ae6b934476acd24862363eb84ca Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Fri, 10 Sep 2021 10:24:39 +0800 Subject: [PATCH 06/16] use true types istead of GeAttrValue:: --- .../graph_optimizer_unittest.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc index e252c9c..dcee686 100644 --- a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc +++ b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc @@ -52,7 +52,7 @@ namespace { } } -TEST_F(UtestGraphOptimizer, graph_optimizer) { +TEST_F(UtestGraphOptimizer, graph_optimizer) { ge::ComputeGraphPtr graph = MakeGraph(); ge::IteratorFusionPass iteratorFusionPass(domi::TENSORFLOW); EXPECT_NE(iteratorFusionPass.Run(graph),ge::SUCCESS); @@ -61,11 +61,11 @@ TEST_F(UtestGraphOptimizer, graph_optimizer) { TEST_F(UtestGraphOptimizer, graph_optimizer_output) { ge::ComputeGraphPtr graph = MakeGraph(); domi::FrameworkType type = domi::TENSORFLOW; -  ge::ParserGraphOptimizer parserGraphOptimizer(graph,type); -  vector input_anchors; -  vector output_anchors; -  ge::OpDescPtr fusion_op_desc; -  EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors,fusion_op_desc),ge::SUCCESS); + ge::ParserGraphOptimizer parserGraphOptimizer(graph,type); + vector input_anchors; + vector output_anchors; + ge::OpDescPtr fusion_op_desc; + EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors,fusion_op_desc),ge::SUCCESS); EXPECT_NE(parserGraphOptimizer.RebuildOutputAnchors(output_anchors,fusion_op_desc),ge::SUCCESS); } } From b30930e059fbffa80f163b30bd3035ab34fc9c6e Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Fri, 10 Sep 2021 10:42:02 +0800 Subject: [PATCH 07/16] use true types istead of GeAttrValue:: --- .../graph_optimizer_unittest.cc | 107 +++++++++--------- 1 file changed, 51 insertions(+), 56 deletions(-) diff --git a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc index dcee686..9d59dfd 100644 --- a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc +++ b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc @@ -4,68 +4,63 @@ #include "graph/debug/ge_attr_define.h" #include "graph_builder_utils.h" #include "common/util.h" - #include "tensorflow/iterator_fusion_pass.h" #include "parser/common/acl_graph_parser_util.h" #define private public #include "tensorflow/graph_optimizer.h" #undef private +namespace ge +{ + class UtestGraphOptimizer : public testing::Test + { + protected: + void SetUp() {} + void TearDown() {} + }; + namespace + { + ComputeGraphPtr MakeGraph() { + ge::ut::GraphBuilder builder("graph"); + std::string name = "graph"; + std::string original_type; + original_type = "IteratorV2"; // + auto data1 = builder.AddNode(name + "_" + original_type, ge::parser::FRAMEWORKOP, 1, 1); + ge::AttrUtils::SetStr(data1->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); + original_type = "IteratorGetNext"; + auto data2 = builder.AddNode(name + "_" + original_type + "2", ge::parser::FRAMEWORKOP, 1, 2); + ge::AttrUtils::SetStr(data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); + string nodefStr; + AttrUtils::SetZeroCopyBytes( + data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, + Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); + original_type = "IteratorGetNext"; + auto data3 = builder.AddNode(name + "_" + original_type + "3", ge::parser::FRAMEWORKOP, 2, 1); + ge::AttrUtils::SetStr(data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); + AttrUtils::SetZeroCopyBytes( + data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, + Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); -namespace ge { -class UtestGraphOptimizer : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -namespace { - ComputeGraphPtr MakeGraph() { -  ge::ut::GraphBuilder builder("graph"); -  std::string name = "graph"; -  std::string original_type ; - -  original_type = "IteratorV2"; -  auto data1 = builder.AddNode(name + "_"+original_type, ge::parser::FRAMEWORKOP, 1, 1); -  ge::AttrUtils::SetStr(data1->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); - - original_type = "IteratorGetNext"; - auto data2 = builder.AddNode(name + "_"+original_type+"2", ge::parser::FRAMEWORKOP, 1, 2); - ge::AttrUtils::SetStr(data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); - - string nodefStr; - AttrUtils::SetZeroCopyBytes( - data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, - Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); - - original_type = "IteratorGetNext"; - auto data3 = builder.AddNode(name + "_"+original_type+"3", ge::parser::FRAMEWORKOP, 2, 1); - ge::AttrUtils::SetStr(data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); - - AttrUtils::SetZeroCopyBytes( - data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, - Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); - - builder.AddDataEdge(data1, 0, data2, 0); - builder.AddDataEdge(data2, 0, data3, 0); - builder.AddDataEdge(data2, 1, data3, 1); - return builder.GetGraph(); + builder.AddDataEdge(data1, 0, data2, 0); + builder.AddDataEdge(data2, 0, data3, 0); + builder.AddDataEdge(data2, 1, data3, 1); + return builder.GetGraph(); + } } -} -TEST_F(UtestGraphOptimizer, graph_optimizer) { - ge::ComputeGraphPtr graph = MakeGraph(); - ge::IteratorFusionPass iteratorFusionPass(domi::TENSORFLOW); - EXPECT_NE(iteratorFusionPass.Run(graph),ge::SUCCESS); -} + TEST_F(UtestGraphOptimizer, graph_optimizer) { + ge::ComputeGraphPtr graph = MakeGraph(); + ge::IteratorFusionPass iteratorFusionPass(domi::TENSORFLOW); + EXPECT_NE(iteratorFusionPass.Run(graph), ge::SUCCESS); + } + TEST_F(UtestGraphOptimizer, graph_optimizer_output) { + ge::ComputeGraphPtr graph = MakeGraph(); + domi::FrameworkType type = domi::TENSORFLOW; + ge::ParserGraphOptimizer parserGraphOptimizer(graph, type); -TEST_F(UtestGraphOptimizer, graph_optimizer_output) { - ge::ComputeGraphPtr graph = MakeGraph(); - domi::FrameworkType type = domi::TENSORFLOW; - ge::ParserGraphOptimizer parserGraphOptimizer(graph,type); - vector input_anchors; - vector output_anchors; - ge::OpDescPtr fusion_op_desc; - EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors,fusion_op_desc),ge::SUCCESS); - EXPECT_NE(parserGraphOptimizer.RebuildOutputAnchors(output_anchors,fusion_op_desc),ge::SUCCESS); -} -} + vector input_anchors; + vector output_anchors; + ge::OpDescPtr fusion_op_desc; + EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors, fusion_op_desc), ge::SUCCESS); + EXPECT_NE(parserGraphOptimizer.RebuildOutputAnchors(output_anchors, fusion_op_desc), ge::SUCCESS); + } +} \ No newline at end of file From 5acefb652265f144ce937e09b516a5341e7a5aaf Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Fri, 10 Sep 2021 10:53:44 +0800 Subject: [PATCH 08/16] use true types istead of GeAttrValue:: --- tests/ut/parser/graph_builder_utils.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ut/parser/graph_builder_utils.cc b/tests/ut/parser/graph_builder_utils.cc index 17e10dd..798a5a2 100644 --- a/tests/ut/parser/graph_builder_utils.cc +++ b/tests/ut/parser/graph_builder_utils.cc @@ -40,9 +40,9 @@ NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) { GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); } -void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) { - GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); -} +// void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) { +// GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); +// } } // namespace ut } // namespace ge From 3e3b2a4aaae736c73db41bb091dae0f98a6bcb53 Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Fri, 10 Sep 2021 11:05:13 +0800 Subject: [PATCH 09/16] use true types istead of GeAttrValue:: --- tests/ut/parser/CMakeLists.txt | 1 - .../graph_optimizer_testcase/graph_optimizer_unittest.cc | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/ut/parser/CMakeLists.txt b/tests/ut/parser/CMakeLists.txt index 7ef01d2..8deb30c 100644 --- a/tests/ut/parser/CMakeLists.txt +++ b/tests/ut/parser/CMakeLists.txt @@ -307,7 +307,6 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) set(PARSER_UT_FILES - "graph_builder_utils.cc" "parser_ut_utils.cc" "testcase/common/acl_graph_parser_unittest.cc" "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" diff --git a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc index 9d59dfd..083c2ae 100644 --- a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc +++ b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc @@ -2,7 +2,7 @@ #include #include "graph/utils/attr_utils.h" #include "graph/debug/ge_attr_define.h" -#include "graph_builder_utils.h" +#include "parser_ut_utils.cc" #include "common/util.h" #include "tensorflow/iterator_fusion_pass.h" #include "parser/common/acl_graph_parser_util.h" From d35dbc6a449ec1afa1951ef500fc7dcfbae5a4a9 Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Fri, 10 Sep 2021 11:15:07 +0800 Subject: [PATCH 10/16] use true types istead of GeAttrValue:: --- .../graph_optimizer_testcase/graph_optimizer_unittest.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc index 083c2ae..bf44fd3 100644 --- a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc +++ b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc @@ -20,7 +20,8 @@ namespace ge namespace { ComputeGraphPtr MakeGraph() { - ge::ut::GraphBuilder builder("graph"); + auto builder = ut::GraphBuilder("graph"); + //ge::ut::GraphBuilder builder("graph"); std::string name = "graph"; std::string original_type; original_type = "IteratorV2"; // From 03f2f43926392c903337c3605c96f355ed9bc522 Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Mon, 13 Sep 2021 12:33:03 +0800 Subject: [PATCH 11/16] use true type --- tests/ut/parser/graph_builder_utils.cc | 48 -------------------------- tests/ut/parser/graph_builder_utils.h | 48 -------------------------- 2 files changed, 96 deletions(-) delete mode 100644 tests/ut/parser/graph_builder_utils.cc delete mode 100644 tests/ut/parser/graph_builder_utils.h diff --git a/tests/ut/parser/graph_builder_utils.cc b/tests/ut/parser/graph_builder_utils.cc deleted file mode 100644 index 798a5a2..0000000 --- a/tests/ut/parser/graph_builder_utils.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph_builder_utils.h" - -#include "graph/utils/graph_utils.h" - -namespace ge { -namespace ut { -NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, - DataType data_type, std::vector shape) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape(std::move(shape))); - tensor_desc->SetFormat(format); - tensor_desc->SetDataType(data_type); - - auto op_desc = std::make_shared(name, type); - for (int i = 0; i < in_cnt; ++i) { - op_desc->AddInputDesc(tensor_desc->Clone()); - } - for (int i = 0; i < out_cnt; ++i) { - op_desc->AddOutputDesc(tensor_desc->Clone()); - } - - return graph_->AddNode(op_desc); -} -void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) { - GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); -} -// void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) { -// GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); -// } - -} // namespace ut -} // namespace ge diff --git a/tests/ut/parser/graph_builder_utils.h b/tests/ut/parser/graph_builder_utils.h deleted file mode 100644 index 3ba1d7c..0000000 --- a/tests/ut/parser/graph_builder_utils.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ -#define MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ - -#include -#include - -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/node.h" - -namespace ge { -namespace ut { -class GraphBuilder { - public: - explicit GraphBuilder(const std::string &name) { graph_ = std::make_shared(name); } - NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, - Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, - std::vector shape = {1, 1, 224, 224}); - void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx); - void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node); - ComputeGraphPtr GetGraph() { - graph_->TopologicalSorting(); - return graph_; - } - - private: - ComputeGraphPtr graph_; -}; -} // namespace ut -} // namespace ge - -#endif // MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ From beb9a3716d110e636cd5d6501a7a15f8e1f4fbf0 Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Mon, 13 Sep 2021 12:47:41 +0800 Subject: [PATCH 12/16] use true type --- .../graph_optimizer_testcase/graph_optimizer_unittest.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc index bf44fd3..0627af8 100644 --- a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc +++ b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc @@ -2,7 +2,7 @@ #include #include "graph/utils/attr_utils.h" #include "graph/debug/ge_attr_define.h" -#include "parser_ut_utils.cc" +#include "ut/parser/parser_ut_utils.h" #include "common/util.h" #include "tensorflow/iterator_fusion_pass.h" #include "parser/common/acl_graph_parser_util.h" @@ -20,8 +20,7 @@ namespace ge namespace { ComputeGraphPtr MakeGraph() { - auto builder = ut::GraphBuilder("graph"); - //ge::ut::GraphBuilder builder("graph"); + ge::ut::GraphBuilder builder("graph"); std::string name = "graph"; std::string original_type; original_type = "IteratorV2"; // From 0b39deeba3eb3fc108ef185efbb6a76dc76f5620 Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Mon, 13 Sep 2021 13:47:19 +0800 Subject: [PATCH 13/16] use true type --- parser/tensorflow/graph_optimizer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parser/tensorflow/graph_optimizer.h b/parser/tensorflow/graph_optimizer.h index 10ec65e..4ec7533 100644 --- a/parser/tensorflow/graph_optimizer.h +++ b/parser/tensorflow/graph_optimizer.h @@ -44,7 +44,7 @@ class ParserGraphOptimizer { private: ge::ComputeGraphPtr graph_; domi::FrameworkType fmktype_; - + domi::Status FindFmkNodeCluser(unordered_map> &node_cluser_Map); domi::Status MarkForFusion(unordered_map> &node_cluser_Map); From cfe54f97cabf07820706daa2b8e2a7a0efd7c6b0 Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Mon, 13 Sep 2021 20:21:49 +0800 Subject: [PATCH 14/16] use true type --- .../graph_optimizer_unittest.cc | 90 +++++++++---------- 1 file changed, 43 insertions(+), 47 deletions(-) diff --git a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc index 0627af8..94d201d 100644 --- a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc +++ b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc @@ -9,58 +9,54 @@ #define private public #include "tensorflow/graph_optimizer.h" #undef private -namespace ge -{ - class UtestGraphOptimizer : public testing::Test - { +namespace ge { + class UtestGraphOptimizer : public testing::Test { protected: void SetUp() {} void TearDown() {} }; - namespace - { - ComputeGraphPtr MakeGraph() { - ge::ut::GraphBuilder builder("graph"); - std::string name = "graph"; - std::string original_type; - original_type = "IteratorV2"; // - auto data1 = builder.AddNode(name + "_" + original_type, ge::parser::FRAMEWORKOP, 1, 1); - ge::AttrUtils::SetStr(data1->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); - original_type = "IteratorGetNext"; - auto data2 = builder.AddNode(name + "_" + original_type + "2", ge::parser::FRAMEWORKOP, 1, 2); - ge::AttrUtils::SetStr(data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); - string nodefStr; - AttrUtils::SetZeroCopyBytes( - data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, - Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); - original_type = "IteratorGetNext"; - auto data3 = builder.AddNode(name + "_" + original_type + "3", ge::parser::FRAMEWORKOP, 2, 1); - ge::AttrUtils::SetStr(data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); - AttrUtils::SetZeroCopyBytes( - data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, - Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); +namespace { + ComputeGraphPtr MakeGraph() { + ge::ut::GraphBuilder builder("graph"); + std::string name = "graph"; + std::string original_type; + original_type = "IteratorV2"; // + auto data1 = builder.AddNode(name + "_" + original_type, ge::parser::FRAMEWORKOP, 1, 1); + ge::AttrUtils::SetStr(data1->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); + original_type = "IteratorGetNext"; + auto data2 = builder.AddNode(name + "_" + original_type + "2", ge::parser::FRAMEWORKOP, 1, 2); + ge::AttrUtils::SetStr(data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); + string nodefStr; + AttrUtils::SetZeroCopyBytes( + data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, + Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); + original_type = "IteratorGetNext"; + auto data3 = builder.AddNode(name + "_" + original_type + "3", ge::parser::FRAMEWORKOP, 2, 1); + ge::AttrUtils::SetStr(data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); + AttrUtils::SetZeroCopyBytes( + data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, + Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); - builder.AddDataEdge(data1, 0, data2, 0); - builder.AddDataEdge(data2, 0, data3, 0); - builder.AddDataEdge(data2, 1, data3, 1); - return builder.GetGraph(); - } + builder.AddDataEdge(data1, 0, data2, 0); + builder.AddDataEdge(data2, 0, data3, 0); + builder.AddDataEdge(data2, 1, data3, 1); + return builder.GetGraph(); } +} +TEST_F(UtestGraphOptimizer, graph_optimizer) { + ge::ComputeGraphPtr graph = MakeGraph(); + ge::IteratorFusionPass iteratorFusionPass(domi::TENSORFLOW); + EXPECT_NE(iteratorFusionPass.Run(graph), ge::SUCCESS); +} +TEST_F(UtestGraphOptimizer, graph_optimizer_output) { + ge::ComputeGraphPtr graph = MakeGraph(); + domi::FrameworkType type = domi::TENSORFLOW; + ge::ParserGraphOptimizer parserGraphOptimizer(graph, type); - TEST_F(UtestGraphOptimizer, graph_optimizer) { - ge::ComputeGraphPtr graph = MakeGraph(); - ge::IteratorFusionPass iteratorFusionPass(domi::TENSORFLOW); - EXPECT_NE(iteratorFusionPass.Run(graph), ge::SUCCESS); - } - TEST_F(UtestGraphOptimizer, graph_optimizer_output) { - ge::ComputeGraphPtr graph = MakeGraph(); - domi::FrameworkType type = domi::TENSORFLOW; - ge::ParserGraphOptimizer parserGraphOptimizer(graph, type); - - vector input_anchors; - vector output_anchors; - ge::OpDescPtr fusion_op_desc; - EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors, fusion_op_desc), ge::SUCCESS); - EXPECT_NE(parserGraphOptimizer.RebuildOutputAnchors(output_anchors, fusion_op_desc), ge::SUCCESS); - } + vector input_anchors; + vector output_anchors; + ge::OpDescPtr fusion_op_desc; + EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors, fusion_op_desc), ge::SUCCESS); + EXPECT_NE(parserGraphOptimizer.RebuildOutputAnchors(output_anchors, fusion_op_desc), ge::SUCCESS); +} } \ No newline at end of file From 85c6f3b3987e552055992ab699a448a77d42afdc Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Mon, 13 Sep 2021 21:05:04 +0800 Subject: [PATCH 15/16] use true type --- parser/tensorflow/graph_insert_trans_op.h | 183 ------------------ .../graph_optimizer_unittest.cc | 6 +- 2 files changed, 3 insertions(+), 186 deletions(-) delete mode 100644 parser/tensorflow/graph_insert_trans_op.h diff --git a/parser/tensorflow/graph_insert_trans_op.h b/parser/tensorflow/graph_insert_trans_op.h deleted file mode 100644 index 37bb010..0000000 --- a/parser/tensorflow/graph_insert_trans_op.h +++ /dev/null @@ -1,183 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_ -#define GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_ -#include -#include -#include -#include "common/fmk_types.h" -#include "framework/omg/parser/parser_types.h" -#include "graph/compute_graph.h" -#include "graph/node.h" -#include "graph/types.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/tensor_utils.h" -#include "register/op_registry.h" - -namespace ge { -enum InFmtSupportEnum { - InFmtSupportUndefined, - InFmtSupportElewise, - InFmtSupport4D, - InFmtSupport5D, - InFmtSupport4D_5D, - InFmtSupportNCHW_NC1HWC0 -}; - -enum InDtSupportEnum { - InDtSupportUndefined = 0, - InDtSupportAll = 1, -}; - -enum OutFmtSupportEnum { - OutFmtSupportUndefined = 0, - OutFmtSupportAsInput = 1, -}; - -enum OutDtSupportEnum { - OutDtSupportUndefined = 0, - OutDtSupportAsInput = 1, -}; - -struct OpSupportTranInfo { - InFmtSupportEnum inputFormatSupportEnum = InFmtSupportUndefined; - InDtSupportEnum inputDataTypeSupportEnum = InDtSupportUndefined; - OutFmtSupportEnum outputFormatSupportEnum = OutFmtSupportUndefined; - OutDtSupportEnum outputDataTypeSupportEnum = OutDtSupportUndefined; - - std::vector inputFormats; - std::vector inputDataTypes; - ge::Format limitOutputFormat = ge::FORMAT_RESERVED; - ge::DataType limitOutputDataType = ge::DT_UNDEFINED; -}; - -extern std::map g_OpSupportTranInfo; - -class OpTransAddSupportReg { - public: - template - OpTransAddSupportReg(const std::string &cceTbeTg, const std::string &opType, - InFmts inputFormats, InDts inputDataTypes, - OutFmts outputormat, OutDts outputDataType) { - auto cceTbeOpType = cceTbeTg + ":" + opType; - g_OpSupportTranInfo.erase(cceTbeOpType); - SetInputFormat(cceTbeOpType, inputFormats); - SetInputDataType(cceTbeOpType, inputDataTypes); - SetOutputFormat(cceTbeOpType, outputormat); - SetOutputDataType(cceTbeOpType, outputDataType); - } - ~OpTransAddSupportReg() = default; - - private: - void SetInputFormat(std::string opType, - const std::vector& supportFormat) { - auto& opInfo = g_OpSupportTranInfo[opType]; - for (auto& format : supportFormat) { - opInfo.inputFormats.push_back(format); - } - } - - void SetInputFormat(std::string opType, ge::Format supportFormat) { - auto& opInfo = g_OpSupportTranInfo[opType]; - opInfo.inputFormats.push_back(supportFormat); - } - - void SetInputFormat(std::string opType, InFmtSupportEnum enumFormat) { - auto& opInfo = g_OpSupportTranInfo[opType]; - opInfo.inputFormatSupportEnum = enumFormat; - switch (enumFormat) { - case InFmtSupportElewise: - opInfo.inputFormats = {ge::FORMAT_FRACTAL_Z, ge::FORMAT_HWCN, - ge::FORMAT_NC1HWC0, ge::FORMAT_NHWC, - ge::FORMAT_NCHW}; - break; - case InFmtSupport4D: - opInfo.inputFormats = {ge::FORMAT_HWCN, ge::FORMAT_NHWC, - ge::FORMAT_NCHW}; - break; - case InFmtSupport5D: - opInfo.inputFormats = {ge::FORMAT_NC1HWC0}; - break; - case InFmtSupport4D_5D: - opInfo.inputFormats = {ge::FORMAT_HWCN, ge::FORMAT_NHWC, - ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0}; - break; - case InFmtSupportNCHW_NC1HWC0: - opInfo.inputFormats = {ge::FORMAT_NC1HWC0, ge::FORMAT_NCHW}; - break; - default: - break; - } - } - - void SetInputDataType(std::string opType, - const std::vector& supportDataType) { - auto& opInfo = g_OpSupportTranInfo[opType]; - for (auto& dataType : supportDataType) { - opInfo.inputDataTypes.push_back(dataType); - } - } - - void SetInputDataType(std::string opType, ge::DataType supportDataType) { - auto& opInfo = g_OpSupportTranInfo[opType]; - opInfo.inputDataTypes.push_back(supportDataType); - } - - void SetInputDataType(std::string opType, InDtSupportEnum enumDataType) { - auto& opInfo = g_OpSupportTranInfo[opType]; - opInfo.inputDataTypeSupportEnum = enumDataType; - } - - void SetOutputFormat(std::string opType, ge::Format limitOutputormat) { - auto& opInfo = g_OpSupportTranInfo[opType]; - opInfo.limitOutputFormat = limitOutputormat; - } - - void SetOutputFormat(std::string opType, OutFmtSupportEnum enumFormat) { - auto& opInfo = g_OpSupportTranInfo[opType]; - opInfo.outputFormatSupportEnum = enumFormat; - } - - void SetOutputDataType(std::string opType, ge::DataType limitOutputDataType) { - auto& opInfo = g_OpSupportTranInfo[opType]; - opInfo.limitOutputDataType = limitOutputDataType; - } - - void SetOutputDataType(std::string opType, OutDtSupportEnum enumDataType) { - auto& opInfo = g_OpSupportTranInfo[opType]; - opInfo.outputDataTypeSupportEnum = enumDataType; - } -}; - -#define TBE_SET_FORMAT_DATAYPE_INFO(cce_tbe, op, inputFormats, inputDataType, \ - outFormats, outputDataTypes) \ - TBE_SET_FORMAT_DATAYPE_INFO_UNIQ_HELPER(__COUNTER__, #cce_tbe, op, \ - inputFormats, inputDataType, \ - outFormats, outputDataTypes) -#define TBE_SET_FORMAT_DATAYPE_INFO_UNIQ_HELPER(ctr, cce_tbe, op, \ - inputFormats, inputDataType, \ - outFormats, outputDataTypes) \ - TBE_SET_FORMAT_DATAYPE_INFO_UNIQ(ctr, cce_tbe, op, inputFormats, \ - inputDataType, outFormats, outputDataTypes) -#define TBE_SET_FORMAT_DATAYPE_INFO_UNIQ(ctr, cce_tbe, op, inputFormats, \ - inputDataType, outFormats, \ - outputDataTypes) \ - OpTransAddSupportReg __gOpTransAddSupportReg##ctr( \ - cce_tbe, op, inputFormats, inputDataType, outFormats, outputDataTypes); -} // namespace domi -#endif // GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_ diff --git a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc index 94d201d..5139002 100644 --- a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc +++ b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc @@ -11,9 +11,9 @@ #undef private namespace ge { class UtestGraphOptimizer : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} + protected: + void SetUp() {} + void TearDown() {} }; namespace { ComputeGraphPtr MakeGraph() { From 630701dac61698ae56d07f6f1a5ab1ee74954342 Mon Sep 17 00:00:00 2001 From: CLAY-panjw <1330286576@qq.com> Date: Mon, 13 Sep 2021 21:06:30 +0800 Subject: [PATCH 16/16] use true type --- parser/tensorflow/graph_optimizer.cc | 53 +--------------------------- 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/parser/tensorflow/graph_optimizer.cc b/parser/tensorflow/graph_optimizer.cc index fa7c623..865402b 100644 --- a/parser/tensorflow/graph_optimizer.cc +++ b/parser/tensorflow/graph_optimizer.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,67 +15,16 @@ */ #include "graph_optimizer.h" -#include -#include -#include -#include -#include "./graph_insert_trans_op.h" -#include "cce/cce.h" -#include "cce/dnn.h" -#include "parser/common/acl_graph_parser_util.h" -#include "common/op_map.h" #include "common/op_types.h" #include "common/types_map.h" #include "common/util.h" -#include "framework/common/debug/ge_log.h" #include "framework/omg/parser/parser_inner_ctx.h" -#include "framework/omg/parser/parser_types.h" #include "graph/debug/ge_attr_define.h" -#include "graph/ge_tensor.h" -#include "graph/types.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "graph_functiondef.h" #include "parser/common/acl_graph_parser_util.h" -#include "proto/tensorflow/attr_value.pb.h" #include "register/op_registry.h" -using domi::tensorflow::NodeDef; -using domi::tensorflow::TensorProto; -using domi::tensorflow::TensorShapeProto; -using domi::tensorflow::TensorShapeProto_Dim; - -using ge::FORMAT_NC1HWC0; -using ge::FORMAT_NCHW; -using ge::FORMAT_NHWC; - -using ge::AttrUtils; -using ge::Buffer; -using ge::ComputeGraph; -using ge::ComputeGraphPtr; -using ge::GE_TENSORFLOW_DATA_TYPE_MAP; -using ge::GeShape; -using ge::GeTensorDesc; -using ge::GeTensorPtr; -using ge::GRAPH_SUCCESS; -using ge::GraphToFunctionDef; -using ge::GraphUtils; -using ge::InControlAnchorPtr; -using ge::InDataAnchorPtr; -using ge::is_dataset_op_vec; -using ge::local_framework_op_vec; -using ge::NodePtr; -using ge::OpDesc; -using ge::OpDescPtr; -using ge::OutControlAnchorPtr; -using ge::OutDataAnchorPtr; -using ge::TensorUtils; - -using ge::ATTR_NAME_INPUT_DATATYPE; -using ge::ATTR_NAME_OUTPUT_DATATYPE; - namespace ge { REGISTER_OPTYPE_DEFINE(TF_MAXIMUM_GRAD, "MaximumGrad"); REGISTER_OPTYPE_DEFINE(TF_MATMUL, "Matmul");