Browse Source

!372 use true types

Merge pull request !372 from 潘嘉旺/master
pull/375/head
i-robot Gitee 4 years ago
parent
commit
d53425bcbb
10 changed files with 137 additions and 2209 deletions
  1. +1
    -1
      parser/tensorflow/graph_functiondef.cc
  2. +0
    -183
      parser/tensorflow/graph_insert_trans_op.h
  3. +5
    -1901
      parser/tensorflow/graph_optimizer.cc
  4. +1
    -54
      parser/tensorflow/graph_optimizer.h
  5. +0
    -2
      parser/tensorflow/iterator_fusion_pass.cc
  6. +2
    -3
      parser/tensorflow/iterator_fusion_pass.h
  7. +1
    -1
      parser/tensorflow/tensorflow_parser.cc
  8. +64
    -64
      tests/depends/graph/src/attr_util_stub.cc
  9. +1
    -0
      tests/ut/parser/CMakeLists.txt
  10. +62
    -0
      tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc

+ 1
- 1
parser/tensorflow/graph_functiondef.cc View File

@@ -424,7 +424,7 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph
}

// Analysis of nodedef of original tensorflow
ge::GeAttrValue::BYTES nodedef_bytes;
ge::Buffer nodedef_bytes;
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetBytes(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, nodedef_bytes),
PARAM_INVALID, "Get type attr nodedef failed.");
domi::tensorflow::NodeDef node_def_;


+ 0
- 183
parser/tensorflow/graph_insert_trans_op.h View File

@@ -1,183 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_
#define GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_
#include <map>
#include <string>
#include <vector>
#include "common/fmk_types.h"
#include "framework/omg/parser/parser_types.h"
#include "graph/compute_graph.h"
#include "graph/node.h"
#include "graph/types.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/tensor_utils.h"
#include "register/op_registry.h"

namespace ge {
enum InFmtSupportEnum {
InFmtSupportUndefined,
InFmtSupportElewise,
InFmtSupport4D,
InFmtSupport5D,
InFmtSupport4D_5D,
InFmtSupportNCHW_NC1HWC0
};

enum InDtSupportEnum {
InDtSupportUndefined = 0,
InDtSupportAll = 1,
};

enum OutFmtSupportEnum {
OutFmtSupportUndefined = 0,
OutFmtSupportAsInput = 1,
};

enum OutDtSupportEnum {
OutDtSupportUndefined = 0,
OutDtSupportAsInput = 1,
};

struct OpSupportTranInfo {
InFmtSupportEnum inputFormatSupportEnum = InFmtSupportUndefined;
InDtSupportEnum inputDataTypeSupportEnum = InDtSupportUndefined;
OutFmtSupportEnum outputFormatSupportEnum = OutFmtSupportUndefined;
OutDtSupportEnum outputDataTypeSupportEnum = OutDtSupportUndefined;

std::vector<ge::Format> inputFormats;
std::vector<ge::DataType> inputDataTypes;
ge::Format limitOutputFormat = ge::FORMAT_RESERVED;
ge::DataType limitOutputDataType = ge::DT_UNDEFINED;
};

extern std::map<std::string, OpSupportTranInfo> g_OpSupportTranInfo;

class OpTransAddSupportReg {
public:
template <class InFmts, class InDts, class OutFmts, class OutDts>
OpTransAddSupportReg(const std::string &cceTbeTg, const std::string &opType,
InFmts inputFormats, InDts inputDataTypes,
OutFmts outputormat, OutDts outputDataType) {
auto cceTbeOpType = cceTbeTg + ":" + opType;
g_OpSupportTranInfo.erase(cceTbeOpType);
SetInputFormat(cceTbeOpType, inputFormats);
SetInputDataType(cceTbeOpType, inputDataTypes);
SetOutputFormat(cceTbeOpType, outputormat);
SetOutputDataType(cceTbeOpType, outputDataType);
}
~OpTransAddSupportReg() = default;

private:
void SetInputFormat(std::string opType,
const std::vector<ge::Format>& supportFormat) {
auto& opInfo = g_OpSupportTranInfo[opType];
for (auto& format : supportFormat) {
opInfo.inputFormats.push_back(format);
}
}

void SetInputFormat(std::string opType, ge::Format supportFormat) {
auto& opInfo = g_OpSupportTranInfo[opType];
opInfo.inputFormats.push_back(supportFormat);
}

void SetInputFormat(std::string opType, InFmtSupportEnum enumFormat) {
auto& opInfo = g_OpSupportTranInfo[opType];
opInfo.inputFormatSupportEnum = enumFormat;
switch (enumFormat) {
case InFmtSupportElewise:
opInfo.inputFormats = {ge::FORMAT_FRACTAL_Z, ge::FORMAT_HWCN,
ge::FORMAT_NC1HWC0, ge::FORMAT_NHWC,
ge::FORMAT_NCHW};
break;
case InFmtSupport4D:
opInfo.inputFormats = {ge::FORMAT_HWCN, ge::FORMAT_NHWC,
ge::FORMAT_NCHW};
break;
case InFmtSupport5D:
opInfo.inputFormats = {ge::FORMAT_NC1HWC0};
break;
case InFmtSupport4D_5D:
opInfo.inputFormats = {ge::FORMAT_HWCN, ge::FORMAT_NHWC,
ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0};
break;
case InFmtSupportNCHW_NC1HWC0:
opInfo.inputFormats = {ge::FORMAT_NC1HWC0, ge::FORMAT_NCHW};
break;
default:
break;
}
}

void SetInputDataType(std::string opType,
const std::vector<ge::DataType>& supportDataType) {
auto& opInfo = g_OpSupportTranInfo[opType];
for (auto& dataType : supportDataType) {
opInfo.inputDataTypes.push_back(dataType);
}
}

void SetInputDataType(std::string opType, ge::DataType supportDataType) {
auto& opInfo = g_OpSupportTranInfo[opType];
opInfo.inputDataTypes.push_back(supportDataType);
}

void SetInputDataType(std::string opType, InDtSupportEnum enumDataType) {
auto& opInfo = g_OpSupportTranInfo[opType];
opInfo.inputDataTypeSupportEnum = enumDataType;
}

void SetOutputFormat(std::string opType, ge::Format limitOutputormat) {
auto& opInfo = g_OpSupportTranInfo[opType];
opInfo.limitOutputFormat = limitOutputormat;
}

void SetOutputFormat(std::string opType, OutFmtSupportEnum enumFormat) {
auto& opInfo = g_OpSupportTranInfo[opType];
opInfo.outputFormatSupportEnum = enumFormat;
}

void SetOutputDataType(std::string opType, ge::DataType limitOutputDataType) {
auto& opInfo = g_OpSupportTranInfo[opType];
opInfo.limitOutputDataType = limitOutputDataType;
}

void SetOutputDataType(std::string opType, OutDtSupportEnum enumDataType) {
auto& opInfo = g_OpSupportTranInfo[opType];
opInfo.outputDataTypeSupportEnum = enumDataType;
}
};

#define TBE_SET_FORMAT_DATAYPE_INFO(cce_tbe, op, inputFormats, inputDataType, \
outFormats, outputDataTypes) \
TBE_SET_FORMAT_DATAYPE_INFO_UNIQ_HELPER(__COUNTER__, #cce_tbe, op, \
inputFormats, inputDataType, \
outFormats, outputDataTypes)
#define TBE_SET_FORMAT_DATAYPE_INFO_UNIQ_HELPER(ctr, cce_tbe, op, \
inputFormats, inputDataType, \
outFormats, outputDataTypes) \
TBE_SET_FORMAT_DATAYPE_INFO_UNIQ(ctr, cce_tbe, op, inputFormats, \
inputDataType, outFormats, outputDataTypes)
#define TBE_SET_FORMAT_DATAYPE_INFO_UNIQ(ctr, cce_tbe, op, inputFormats, \
inputDataType, outFormats, \
outputDataTypes) \
OpTransAddSupportReg __gOpTransAddSupportReg##ctr( \
cce_tbe, op, inputFormats, inputDataType, outFormats, outputDataTypes);
} // namespace domi
#endif // GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_

+ 5
- 1901
parser/tensorflow/graph_optimizer.cc
File diff suppressed because it is too large
View File


+ 1
- 54
parser/tensorflow/graph_optimizer.h View File

@@ -35,67 +35,15 @@ namespace ge {
class ParserGraphOptimizer {
public:
explicit ParserGraphOptimizer(ge::ComputeGraphPtr graph, domi::FrameworkType type = domi::TENSORFLOW)
: graph_(graph), fmktype_(type), local_fmk_op_flag_(false) {}
: graph_(graph), fmktype_(type) {}

~ParserGraphOptimizer() {}

domi::Status Optimize();

domi::Status OptimizeAfterCal();

domi::Status FusionFmkop();

inline bool IsHCOMOp(const string &op_type) {
return (op_type == ge::parser::HCOMALLREDUCE) || (op_type == ge::parser::HCOMALLGATHER) ||
(op_type == ge::parser::HCOMBROADCAST) || (op_type == ge::parser::HCOMSEND) ||
(op_type == ge::parser::HCOMRECEIVE) || (op_type == "HcomReduceScatter");
}

void SetLocalFmkopFlag(bool isLocalFmkopFlag) { local_fmk_op_flag_ = isLocalFmkopFlag; }

const bool GetLocalFmkopFlag() const { return local_fmk_op_flag_; }

void SetFuncBinPath(std::string isFuncBinPath) { func_bin_path_ = isFuncBinPath; }
const std::string GetFuncBinPath() const { return func_bin_path_; }

domi::Status InsertHWCK2FZ(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor,
enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype,
enum ge::Format dstInFormat, enum ge::DataType dstInDatatype);

domi::Status Insert4DTo5DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor,
enum ge::Format src_out_format, enum ge::DataType src_out_data_type,
enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type);

domi::Status InsertFZ2HWCK(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor,
enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype,
enum ge::Format dstInFormat, enum ge::DataType dstInDatatype);

domi::Status Insert5DTo4DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor,
enum ge::Format src_out_format, enum ge::DataType src_out_data_type,
enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type);

ge::OpDescPtr CreateCastOp(enum ge::DataType input_datatype, enum ge::DataType output_datatype, ge::Format format);

ge::OpDescPtr CreatePermuteOp(enum ge::Format input_format, enum ge::Format output_format);

ge::OpDescPtr CreateTransDataOp(enum ge::Format input_format);

domi::Status NewNodeAddEdges(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, ge::NodePtr first,
ge::NodePtr second, ge::NodePtr third);

domi::Status InsertVar5DTo4D(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor,
enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype,
enum ge::Format dstInFormat, enum ge::DataType dstInDatatype);

ge::OpDescPtr CreateTranslateOp(enum ge::Format inFormat, ge::DataType inDatatype, enum ge::Format outFormat,
ge::DataType outDatatype);

private:
ge::ComputeGraphPtr graph_;
domi::FrameworkType fmktype_;
// local fmkop flag
bool local_fmk_op_flag_;
std::string func_bin_path_;

domi::Status FindFmkNodeCluser(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map);

@@ -122,7 +70,6 @@ class ParserGraphOptimizer {
vector<ge::InControlAnchorPtr> &input_control_anchors,
vector<ge::OutControlAnchorPtr> &output_control_anchors, ge::NodePtr fusion_node);

domi::Status MakeTfProtoDef();
};
} // namespace ge
#endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_

+ 0
- 2
parser/tensorflow/iterator_fusion_pass.cc View File

@@ -32,8 +32,6 @@ Status IteratorFusionPass::Run(ge::ComputeGraphPtr graph) {
REPORT_CALL_ERROR("E19999", "New ParserGraphOptimizer failed");
return FAILED;
}

graph_optimizer->SetLocalFmkopFlag(local_fmk_op_flag_);
return graph_optimizer->FusionFmkop();
}
} // namespace ge

+ 2
- 3
parser/tensorflow/iterator_fusion_pass.h View File

@@ -23,8 +23,8 @@
namespace ge {
class IteratorFusionPass : public GraphPass {
public:
IteratorFusionPass(domi::FrameworkType type, bool local_fmk_op_flag)
: fmk_type_(type), local_fmk_op_flag_(local_fmk_op_flag) {}
IteratorFusionPass(domi::FrameworkType type)
: fmk_type_(type) {}

virtual ~IteratorFusionPass() {}

@@ -32,7 +32,6 @@ class IteratorFusionPass : public GraphPass {

private:
domi::FrameworkType fmk_type_;
bool local_fmk_op_flag_;
};
} // namespace ge



+ 1
- 1
parser/tensorflow/tensorflow_parser.cc View File

@@ -2375,7 +2375,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,
ge::parser::PassManager iterator_fusion_pass;
try {
(void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass",
new ge::IteratorFusionPass(domi::TENSORFLOW, false));
new ge::IteratorFusionPass(domi::TENSORFLOW));
} catch (std::bad_alloc &e) {
GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs.");
return INTERNAL_ERROR;


+ 64
- 64
tests/depends/graph/src/attr_util_stub.cc View File

@@ -83,61 +83,61 @@ class GeAttrValueImp {
static map<proto::AttrDef::ValueCase, GeAttrValue::ValueType> attr_val_one_type_map_;
static map<proto::AttrDef_ListValue_ListValueType, GeAttrValue::ValueType> attr_val_list_type_map_;

static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::INT val);
static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::FLOAT val);
static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::BOOL val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::STR &val);
static bool SetValue(proto::AttrDef &attr_def, int64_t val);
static bool SetValue(proto::AttrDef &attr_def, float val);
static bool SetValue(proto::AttrDef &attr_def, bool val);
static bool SetValue(proto::AttrDef &attr_def, const std::string &val);
static bool SetValue(proto::AttrDef &attr_def, const ConstGeTensorPtr &val);
static bool SetValue(proto::AttrDef &attr_def, const GeTensor &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::TENSOR_DESC &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::BYTES &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::NAMED_ATTRS &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::GRAPH &val);
static bool SetValue(proto::AttrDef &attr_def, const GeTensorDesc &val);
static bool SetValue(proto::AttrDef &attr_def, const Buffer &val);
static bool SetValue(proto::AttrDef &attr_def, const NamedAttrs &val);
static bool SetValue(proto::AttrDef &attr_def, const ComputeGraphPtr &val);
static bool SetValue(proto::AttrDef &attr_def, const vector<int64_t> &val);
static bool SetValue(proto::AttrDef &attr_def, const vector<int32_t> &val);
static bool SetValue(proto::AttrDef &attr_def, const vector<uint32_t> &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_FLOAT &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BOOL &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_STR &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<float> &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<bool> &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<std::string> &val);
static bool SetValue(proto::AttrDef &proto_attr_val, const vector<GeTensorPtr> &value);
static bool SetValue(proto::AttrDef &proto_attr_val, const vector<ConstGeTensorPtr> &value);
static bool SetValue(proto::AttrDef &attr_def, const vector<GeTensor> &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_TENSOR_DESC &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BYTES &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_NAMED_ATTRS &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_GRAPH &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::INT &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::FLOAT &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BOOL &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::STR &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::TENSOR &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<GeTensorDesc> &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<Buffer> &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<NamedAttrs> &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<ComputeGraphPtr> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, int64_t &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, float &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, bool &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, std::string &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensorPtr &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensor &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::TENSOR_DESC &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BYTES &val);
GeTensorDesc &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::NAMED_ATTRS &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::GRAPH &val);
NamedAttrs &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ComputeGraphPtr &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_INT &val);
std::vector<int64_t> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_FLOAT &val);
std::vector<float> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_BOOL &val);
std::vector<bool> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_STR &val);
std::vector<std::string> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_TENSOR &val);
std::vector<GeTensorPtr> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, vector<GeTensor> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_TENSOR_DESC &val);
std::vector<GeTensorDesc> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_BYTES &val);
std::vector<Buffer> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_NAMED_ATTRS &val);
std::vector<NamedAttrs> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_GRAPH &val);
std::vector<ComputeGraphPtr> &val);
// Value will be moved
static bool SetZeroCopyBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &&buffer);
static bool GetZeroCopyBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &buffer);
@@ -246,30 +246,30 @@ GeAttrValue GeAttrValue::Copy() const {
return GRAPH_FAILED; \
}

ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR_DESC)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::TENSOR_DESC>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::TENSOR>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::GRAPH)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::GRAPH>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>)
ATTR_VALUE_SET_GET_IMP(std::string)
ATTR_VALUE_SET_GET_IMP(vector<std::string>)
ATTR_VALUE_SET_GET_IMP(int64_t)
ATTR_VALUE_SET_GET_IMP(vector<int64_t>)
ATTR_VALUE_SET_GET_IMP(float) // lint !e524
ATTR_VALUE_SET_GET_IMP(vector<float>)
ATTR_VALUE_SET_GET_IMP(bool)
ATTR_VALUE_SET_GET_IMP(vector<bool>)
ATTR_VALUE_SET_GET_IMP(GeTensorDesc)
ATTR_VALUE_SET_GET_IMP(vector<GeTensorDesc>)
ATTR_VALUE_SET_GET_IMP(GeTensorPtr)
ATTR_VALUE_SET_GET_IMP(vector<GeTensorPtr>)
ATTR_VALUE_SET_GET_IMP(ComputeGraphPtr)
ATTR_VALUE_SET_GET_IMP(vector<ComputeGraphPtr>)
ATTR_VALUE_SET_GET_IMP(Buffer)
ATTR_VALUE_SET_GET_IMP(vector<Buffer>)
ATTR_VALUE_SET_GET_IMP(NamedAttrs)
ATTR_VALUE_SET_GET_IMP(vector<NamedAttrs>)
/*lint -e665*/
ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>)
ATTR_VALUE_SET_GET_IMP(vector<vector<float>>)
/*lint +e665*/
ATTR_VALUE_SET_GET_IMP(vector<DataType>) // lint !e665
ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665
ATTR_VALUE_SET_GET_IMP(DataType) // lint !e665

#undef ATTR_VALUE_SET_GET_IMP

@@ -569,7 +569,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeTen
return true;
}

bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::BYTES &value) {
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const Buffer &value) {
if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
return false;
}
@@ -578,7 +578,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue:
return true;
}

bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::BYTES> &value) {
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<Buffer> &value) {
if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) {
return false;
@@ -592,7 +592,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAtt
return true;
}

bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) {
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const NamedAttrs &value) {
if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) {
return false;
}
@@ -606,7 +606,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue:
return true;
}

bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::NAMED_ATTRS> &value) {
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<NamedAttrs> &value) {
if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) {
return false;
@@ -822,7 +822,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
return true;
}

bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeAttrValue::BYTES &value) {
bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, Buffer &value) {
if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
return false;
}
@@ -833,7 +833,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
}

bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
vector<GeAttrValue::BYTES> &value) {
vector<Buffer> &value) {
value.clear();
if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES,
ListValueItemCheck(bt))) {
@@ -847,7 +847,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
}

bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
GeAttrValue::NAMED_ATTRS &value) {
NamedAttrs &value) {
if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) {
return false;
}
@@ -860,7 +860,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
}

bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
vector<GeAttrValue::NAMED_ATTRS> &value) {
vector<NamedAttrs> &value) {
value.clear();
if (!AttrUtilsHelper::GetValueCheckListType(
proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) {
@@ -868,7 +868,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
}
auto &list = proto_attr_val.list();
for (const auto &item : list.na()) {
value.emplace_back(GeAttrValue::NAMED_ATTRS());
value.emplace_back(NamedAttrs());
if (value.empty()) {
return false;
}
@@ -1107,7 +1107,7 @@ ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc)
ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr)
ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr)
ATTR_UTILS_SET_IMP(Tensor, GeTensor)
ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS)
ATTR_UTILS_SET_GET_IMP(NamedAttrs, NamedAttrs)
ATTR_UTILS_SET_GET_IMP(Bytes, Buffer)
ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr)
/*lint -e665*/
@@ -1124,7 +1124,7 @@ ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector<GeTensorDesc>)
ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensorPtr>)
ATTR_UTILS_SET_IMP(ListTensor, vector<ConstGeTensorPtr>)
ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>)
ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>)
ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<NamedAttrs>)
ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>)
ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>)
ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) // lint !e665


+ 1
- 0
tests/ut/parser/CMakeLists.txt View File

@@ -314,6 +314,7 @@ set(PARSER_UT_FILES
"testcase/onnx_parser_testcase/message2operator_unittest.cc"
"testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc"
"testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc"
"testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc"
)

############ libut_parser_common.a ############


+ 62
- 0
tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc View File

@@ -0,0 +1,62 @@
#include <gtest/gtest.h>
#include <iostream>
#include "graph/utils/attr_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "ut/parser/parser_ut_utils.h"
#include "common/util.h"
#include "tensorflow/iterator_fusion_pass.h"
#include "parser/common/acl_graph_parser_util.h"
#define private public
#include "tensorflow/graph_optimizer.h"
#undef private
namespace ge {
class UtestGraphOptimizer : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};
namespace {
ComputeGraphPtr MakeGraph() {
ge::ut::GraphBuilder builder("graph");
std::string name = "graph";
std::string original_type;
original_type = "IteratorV2"; //
auto data1 = builder.AddNode(name + "_" + original_type, ge::parser::FRAMEWORKOP, 1, 1);
ge::AttrUtils::SetStr(data1->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type);
original_type = "IteratorGetNext";
auto data2 = builder.AddNode(name + "_" + original_type + "2", ge::parser::FRAMEWORKOP, 1, 2);
ge::AttrUtils::SetStr(data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type);
string nodefStr;
AttrUtils::SetZeroCopyBytes(
data2->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF,
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length()));
original_type = "IteratorGetNext";
auto data3 = builder.AddNode(name + "_" + original_type + "3", ge::parser::FRAMEWORKOP, 2, 1);
ge::AttrUtils::SetStr(data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type);
AttrUtils::SetZeroCopyBytes(
data3->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF,
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length()));

builder.AddDataEdge(data1, 0, data2, 0);
builder.AddDataEdge(data2, 0, data3, 0);
builder.AddDataEdge(data2, 1, data3, 1);
return builder.GetGraph();
}
}
TEST_F(UtestGraphOptimizer, graph_optimizer) {
ge::ComputeGraphPtr graph = MakeGraph();
ge::IteratorFusionPass iteratorFusionPass(domi::TENSORFLOW);
EXPECT_NE(iteratorFusionPass.Run(graph), ge::SUCCESS);
}
TEST_F(UtestGraphOptimizer, graph_optimizer_output) {
ge::ComputeGraphPtr graph = MakeGraph();
domi::FrameworkType type = domi::TENSORFLOW;
ge::ParserGraphOptimizer parserGraphOptimizer(graph, type);

vector<ge::InDataAnchorPtr> input_anchors;
vector<ge::OutDataAnchorPtr> output_anchors;
ge::OpDescPtr fusion_op_desc;
EXPECT_NE(parserGraphOptimizer.RebuildInputAnchors(input_anchors, fusion_op_desc), ge::SUCCESS);
EXPECT_NE(parserGraphOptimizer.RebuildOutputAnchors(output_anchors, fusion_op_desc), ge::SUCCESS);
}
}

Loading…
Cancel
Save