Browse Source

use true types instead of GeAttrValue::

pull/372/head
盛楠 CLAY-panjw 4 years ago
parent
commit
bc207ebc1b
3 changed files with 89 additions and 89 deletions
  1. +1
    -1
      parser/tensorflow/graph_functiondef.cc
  2. +24
    -24
      parser/tensorflow/graph_optimizer.cc
  3. +64
    -64
      tests/depends/graph/src/attr_util_stub.cc

+ 1
- 1
parser/tensorflow/graph_functiondef.cc View File

@@ -424,7 +424,7 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph
}

// Analysis of nodedef of original tensorflow
ge::GeAttrValue::BYTES nodedef_bytes;
ge::Buffer nodedef_bytes;
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetBytes(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, nodedef_bytes),
PARAM_INVALID, "Get type attr nodedef failed.");
domi::tensorflow::NodeDef node_def_;


+ 24
- 24
parser/tensorflow/graph_optimizer.cc View File

@@ -515,10 +515,10 @@ Status SetNodedefProto(domi::tensorflow::NodeDef &proto, ge::NodePtr n, string o
return SUCCESS;
}

typedef Status (*PIOListHandle)(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
typedef Status (*PIOListHandle)(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc);

Status GatherV2IOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status GatherV2IOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
int tparams;
GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "Tparams", tparams)),
@@ -546,7 +546,7 @@ Status GatherV2IOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LI
return SUCCESS;
}

Status ConstIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status ConstIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
int dtype;
GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "dtype", dtype)), return PARAM_INVALID, "Get dtype error.");
@@ -556,7 +556,7 @@ Status ConstIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_
return SUCCESS;
}

Status MaxMinIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status MaxMinIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
int attrT;
GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", attrT)),
@@ -574,7 +574,7 @@ Status MaxMinIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST
return SUCCESS;
}

Status CastIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status CastIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
int srcT;
int dstT;
@@ -592,7 +592,7 @@ Status CastIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_I
return SUCCESS;
}

Status AddIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, ge::OpDescPtr &opDesc) {
Status AddIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list, ge::OpDescPtr &opDesc) {
int type;
GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", type)),
REPORT_CALL_ERROR("E19999", "Get Attr:T from op:%s(%s) failed",
@@ -607,7 +607,7 @@ Status AddIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_IN
return SUCCESS;
}

Status LessIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status LessIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
int dtype;
GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", dtype)),
@@ -622,7 +622,7 @@ Status LessIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_I
return SUCCESS;
}

Status MulIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, ge::OpDescPtr &opDesc) {
Status MulIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list, ge::OpDescPtr &opDesc) {
int dataType;
GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, ge::ATTR_NAME_T, dataType)),
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ge::ATTR_NAME_T.c_str(),
@@ -638,7 +638,7 @@ Status MulIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_IN
return SUCCESS;
}

Status RealDivIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status RealDivIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
int t;
GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)),
@@ -654,7 +654,7 @@ Status RealDivIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIS
return SUCCESS;
}

Status SelectIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status SelectIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
int t;
GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)),
@@ -671,7 +671,7 @@ Status SelectIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST
return SUCCESS;
}

Status SqrtIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status SqrtIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
int dataType;
GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, ge::ATTR_NAME_T, dataType)),
@@ -687,7 +687,7 @@ Status SqrtIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_I
return SUCCESS;
}

Status TruncatedNormalIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status TruncatedNormalIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
int t;
int dtype;
@@ -707,7 +707,7 @@ Status TruncatedNormalIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrVa
return SUCCESS;
}

Status PackIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status PackIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
int t;
int n;
@@ -729,7 +729,7 @@ Status PackIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_I
return SUCCESS;
}

Status DropOutGenMaskIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status DropOutGenMaskIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
input_list.push_back(domi::tensorflow::DT_INT64);
input_list.push_back(domi::tensorflow::DT_FLOAT);
@@ -738,7 +738,7 @@ Status DropOutGenMaskIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrVal
return SUCCESS;
}

Status ExpandDimsIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status ExpandDimsIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
int dataType;
int dimType;
@@ -759,7 +759,7 @@ Status ExpandDimsIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::
return SUCCESS;
}

Status SqueezeIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list,
Status SqueezeIOList(std::vector<int64_t> &input_list, std::vector<int64_t> &output_list,
ge::OpDescPtr &opDesc) {
// Set - TENSORFLOW_IN_DATATYPE/TENSORFLOW_OUT_DATATYPE
int dataType;
@@ -785,7 +785,7 @@ Status SqueezeIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIS
return SUCCESS;
}

Status TopKV2IOList(ge::GeAttrValue::LIST_INT &inputList, ge::GeAttrValue::LIST_INT &outputList,
Status TopKV2IOList(std::vector<int64_t> &inputList, std::vector<int64_t> &outputList,
ge::OpDescPtr &opDesc) {
int t;
GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)),
@@ -829,8 +829,8 @@ Status CreateNodeDefBytes(ge::NodePtr n, string originalType, map<string, PIOLis
GELOGI("n->GetName() = %s.\n", n->GetName().c_str());
// Set - NodeDef PROTO
domi::tensorflow::NodeDef proto;
ge::GeAttrValue::LIST_INT inputList;
ge::GeAttrValue::LIST_INT outputList;
std::vector<int64_t> inputList;
std::vector<int64_t> outputList;
ret = SetNodedefProto(proto, n, originalType);
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "SetNodedefProto failed.");

@@ -891,7 +891,7 @@ Status CreateNodeDefBytes(ge::NodePtr n, string originalType, map<string, PIOLis
return PARAM_INVALID);

// Set - ATTR_NAME_FRAMEWORK_NODE_DEF
ge::GeAttrValue::BYTES nodeDefBytes;
ge::Buffer nodeDefBytes;
(void)ge::AttrUtils::SetZeroCopyBytes(
opDesc, ge::ATTR_NAME_FRAMEWORK_NODE_DEF,
nodeDefBytes.CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length()));
@@ -1279,7 +1279,7 @@ Status CreateFuncDefBytes(ge::NodePtr n, string original_type, string func_bin_p

GELOGI("len =%d\n", len);

ge::GeAttrValue::BYTES funcDefBytes;
ge::Buffer funcDefBytes;
funcDefBytes = ge::Buffer::CopyFrom((std::uint8_t *)buf, len);
(void)ge::AttrUtils::SetBytes(opDesc, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes);
GELOGI("funcDefBytes.GetSize() =%zu", funcDefBytes.GetSize());
@@ -1453,7 +1453,7 @@ Status CollectNodeFuncs(vector<ge::NodePtr> &nodes, FunctionDefLibrary *library)
GE_CHECK_NOTNULL(node);
OpDescPtr opDef = node->GetOpDesc();
string funcdefStr;
ge::GeAttrValue::BYTES funcDefBytes;
ge::Buffer funcDefBytes;

GE_IF_BOOL_EXEC(
AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib;
@@ -1648,7 +1648,7 @@ Status ParserGraphOptimizer::LinkInnerAnchor(unordered_map<string, ge::NodePtr>
// rebuild output anchor
Status ParserGraphOptimizer::RebuildOutputAnchors(vector<ge::OutDataAnchorPtr> &output_anchors,
ge::OpDescPtr fusion_op_desc) {
ge::GeAttrValue::LIST_INT output_list;
std::vector<int64_t> output_list;
GE_CHECK_NOTNULL(fusion_op_desc);

// create input desc
@@ -1679,7 +1679,7 @@ Status ParserGraphOptimizer::RebuildOutputAnchors(vector<ge::OutDataAnchorPtr> &
// rebuild input desc
Status ParserGraphOptimizer::RebuildInputAnchors(vector<ge::InDataAnchorPtr> &input_anchors,
ge::OpDescPtr fusion_op_desc) {
ge::GeAttrValue::LIST_INT input_list;
std::vector<int64_t> input_list;
GE_CHECK_NOTNULL(fusion_op_desc);
// add input desc
for (auto in_anchor : input_anchors) {


+ 64
- 64
tests/depends/graph/src/attr_util_stub.cc View File

@@ -83,61 +83,61 @@ class GeAttrValueImp {
static map<proto::AttrDef::ValueCase, GeAttrValue::ValueType> attr_val_one_type_map_;
static map<proto::AttrDef_ListValue_ListValueType, GeAttrValue::ValueType> attr_val_list_type_map_;

static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::INT val);
static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::FLOAT val);
static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::BOOL val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::STR &val);
static bool SetValue(proto::AttrDef &attr_def, int64_t val);
static bool SetValue(proto::AttrDef &attr_def, float val);
static bool SetValue(proto::AttrDef &attr_def, bool val);
static bool SetValue(proto::AttrDef &attr_def, const std::string &val);
static bool SetValue(proto::AttrDef &attr_def, const ConstGeTensorPtr &val);
static bool SetValue(proto::AttrDef &attr_def, const GeTensor &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::TENSOR_DESC &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::BYTES &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::NAMED_ATTRS &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::GRAPH &val);
static bool SetValue(proto::AttrDef &attr_def, const GeTensorDesc &val);
static bool SetValue(proto::AttrDef &attr_def, const Buffer &val);
static bool SetValue(proto::AttrDef &attr_def, const NamedAttrs &val);
static bool SetValue(proto::AttrDef &attr_def, const ComputeGraphPtr &val);
static bool SetValue(proto::AttrDef &attr_def, const vector<int64_t> &val);
static bool SetValue(proto::AttrDef &attr_def, const vector<int32_t> &val);
static bool SetValue(proto::AttrDef &attr_def, const vector<uint32_t> &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_FLOAT &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BOOL &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_STR &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<float> &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<bool> &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<std::string> &val);
static bool SetValue(proto::AttrDef &proto_attr_val, const vector<GeTensorPtr> &value);
static bool SetValue(proto::AttrDef &proto_attr_val, const vector<ConstGeTensorPtr> &value);
static bool SetValue(proto::AttrDef &attr_def, const vector<GeTensor> &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_TENSOR_DESC &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BYTES &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_NAMED_ATTRS &val);
static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_GRAPH &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::INT &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::FLOAT &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BOOL &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::STR &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::TENSOR &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<GeTensorDesc> &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<Buffer> &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<NamedAttrs> &val);
static bool SetValue(proto::AttrDef &attr_def, const std::vector<ComputeGraphPtr> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, int64_t &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, float &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, bool &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, std::string &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensorPtr &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensor &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::TENSOR_DESC &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BYTES &val);
GeTensorDesc &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::NAMED_ATTRS &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::GRAPH &val);
NamedAttrs &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ComputeGraphPtr &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_INT &val);
std::vector<int64_t> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_FLOAT &val);
std::vector<float> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_BOOL &val);
std::vector<bool> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_STR &val);
std::vector<std::string> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_TENSOR &val);
std::vector<GeTensorPtr> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, vector<GeTensor> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_TENSOR_DESC &val);
std::vector<GeTensorDesc> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_BYTES &val);
std::vector<Buffer> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_NAMED_ATTRS &val);
std::vector<NamedAttrs> &val);
static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
GeAttrValue::LIST_GRAPH &val);
std::vector<ComputeGraphPtr> &val);
// Value will be moved
static bool SetZeroCopyBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &&buffer);
static bool GetZeroCopyBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &buffer);
@@ -246,30 +246,30 @@ GeAttrValue GeAttrValue::Copy() const {
return GRAPH_FAILED; \
}

ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR_DESC)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::TENSOR_DESC>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::TENSOR>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::GRAPH)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::GRAPH>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>)
ATTR_VALUE_SET_GET_IMP(std::string)
ATTR_VALUE_SET_GET_IMP(vector<std::string>)
ATTR_VALUE_SET_GET_IMP(int64_t)
ATTR_VALUE_SET_GET_IMP(vector<int64_t>)
ATTR_VALUE_SET_GET_IMP(float) // lint !e524
ATTR_VALUE_SET_GET_IMP(vector<float>)
ATTR_VALUE_SET_GET_IMP(bool)
ATTR_VALUE_SET_GET_IMP(vector<bool>)
ATTR_VALUE_SET_GET_IMP(GeTensorDesc)
ATTR_VALUE_SET_GET_IMP(vector<GeTensorDesc>)
ATTR_VALUE_SET_GET_IMP(GeTensorPtr)
ATTR_VALUE_SET_GET_IMP(vector<GeTensorPtr>)
ATTR_VALUE_SET_GET_IMP(ComputeGraphPtr)
ATTR_VALUE_SET_GET_IMP(vector<ComputeGraphPtr>)
ATTR_VALUE_SET_GET_IMP(Buffer)
ATTR_VALUE_SET_GET_IMP(vector<Buffer>)
ATTR_VALUE_SET_GET_IMP(NamedAttrs)
ATTR_VALUE_SET_GET_IMP(vector<NamedAttrs>)
/*lint -e665*/
ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>)
ATTR_VALUE_SET_GET_IMP(vector<vector<float>>)
/*lint +e665*/
ATTR_VALUE_SET_GET_IMP(vector<DataType>) // lint !e665
ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665
ATTR_VALUE_SET_GET_IMP(DataType) // lint !e665

#undef ATTR_VALUE_SET_GET_IMP

@@ -569,7 +569,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeTen
return true;
}

bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::BYTES &value) {
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const Buffer &value) {
if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
return false;
}
@@ -578,7 +578,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue:
return true;
}

bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::BYTES> &value) {
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<Buffer> &value) {
if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) {
return false;
@@ -592,7 +592,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAtt
return true;
}

bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) {
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const NamedAttrs &value) {
if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) {
return false;
}
@@ -606,7 +606,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue:
return true;
}

bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::NAMED_ATTRS> &value) {
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<NamedAttrs> &value) {
if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) {
return false;
@@ -822,7 +822,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
return true;
}

bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeAttrValue::BYTES &value) {
bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, Buffer &value) {
if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
return false;
}
@@ -833,7 +833,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
}

bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
vector<GeAttrValue::BYTES> &value) {
vector<Buffer> &value) {
value.clear();
if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES,
ListValueItemCheck(bt))) {
@@ -847,7 +847,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
}

bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
GeAttrValue::NAMED_ATTRS &value) {
NamedAttrs &value) {
if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) {
return false;
}
@@ -860,7 +860,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
}

bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
vector<GeAttrValue::NAMED_ATTRS> &value) {
vector<NamedAttrs> &value) {
value.clear();
if (!AttrUtilsHelper::GetValueCheckListType(
proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) {
@@ -868,7 +868,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
}
auto &list = proto_attr_val.list();
for (const auto &item : list.na()) {
value.emplace_back(GeAttrValue::NAMED_ATTRS());
value.emplace_back(NamedAttrs());
if (value.empty()) {
return false;
}
@@ -1107,7 +1107,7 @@ ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc)
ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr)
ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr)
ATTR_UTILS_SET_IMP(Tensor, GeTensor)
ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS)
ATTR_UTILS_SET_GET_IMP(NamedAttrs, NamedAttrs)
ATTR_UTILS_SET_GET_IMP(Bytes, Buffer)
ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr)
/*lint -e665*/
@@ -1124,7 +1124,7 @@ ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector<GeTensorDesc>)
ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensorPtr>)
ATTR_UTILS_SET_IMP(ListTensor, vector<ConstGeTensorPtr>)
ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>)
ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>)
ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<NamedAttrs>)
ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>)
ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>)
ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) // lint !e665


Loading…
Cancel
Save