From 87a63c1159044049d58dca70cf6e5893b515237f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=AC=91=E5=A4=A9?= Date: Fri, 24 Sep 2021 02:47:07 +0000 Subject: [PATCH] !375 for update libs * for update libs --- metadef | 2 +- tests/depends/graph/src/attr_util_stub.cc | 1717 ++++++--------------- tests/st/CMakeLists.txt | 23 +- tests/ut/parser/CMakeLists.txt | 23 +- 4 files changed, 482 insertions(+), 1283 deletions(-) diff --git a/metadef b/metadef index 60df4b3..ccb536e 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 60df4b39a6f639c21dd7deb220b93345451938f5 +Subproject commit ccb536ecd63934d67be51ed4f96ae8e67cef7c69 diff --git a/tests/depends/graph/src/attr_util_stub.cc b/tests/depends/graph/src/attr_util_stub.cc index 5bb3cc7..f772260 100644 --- a/tests/depends/graph/src/attr_util_stub.cc +++ b/tests/depends/graph/src/attr_util_stub.cc @@ -18,17 +18,20 @@ #include #include #include "external/graph/graph.h" -#include "utils/attr_utils.h" +#include "graph/utils/attr_utils.h" #include "framework/common/debug/ge_log.h" #include "graph/model_serialize.h" #include "graph/ge_tensor_impl.h" #include "graph/buffer_impl.h" #include "graph/op_desc_impl.h" #include "proto/ge_ir.pb.h" -#include "detail/model_serialize_imp.h" -#include "debug/ge_attr_define.h" +#include "graph/detail/model_serialize_imp.h" +#include "graph/debug/ge_attr_define.h" #include "debug/ge_log.h" #include "debug/ge_util.h" +#include "graph/utils/tensor_utils.h" +#include "graph/serialization/attr_serializer_registry.h" +#include "graph/serialization/tensor_desc_serializer.h" using std::map; using std::string; @@ -36,1454 +39,612 @@ using std::vector; using std::set; namespace ge { -NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } - -NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) - : named_attrs_(owner, proto_msg) {} // lint !e1744 - void NamedAttrs::SetName(const std::string &name) { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_name(name); - } + name_ = name; } string NamedAttrs::GetName() const { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->name(); - } - return string(); + return name_; } -GeAttrValue NamedAttrs::GetItem(const string &key) const { - GeAttrValue value; +AnyValue NamedAttrs::GetItem(const string &key) const { + AnyValue value; (void)GetAttr(key, value); return value; } -ProtoAttrMapHelper NamedAttrs::MutableAttrMap() { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr()); +ProtoAttrMap &NamedAttrs::MutableAttrMap() { + return attrs_; +} + +ConstProtoAttrMap &NamedAttrs::GetAttrMap() const { + return attrs_; +} + +bool AttrUtils::HasAttr(ConstAttrHolderAdapter &&obj, const string &name) { + if (!obj) { + return false; } - return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); + return obj->HasAttr(name); } -ConstProtoAttrMapHelper NamedAttrs::GetAttrMap() const { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr()); +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, int32_t &value) { + int64_t int64_val = 0; + if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { + return false; + } + if (int64_val > INT32_MAX) { + REPORT_INNER_ERROR("E19999", "%ld int64_t value cannot cast to int32_t", int64_val); + GELOGE(GRAPH_FAILED, "[Check][Param] %ld int64_t value cannot cast to int32_t", int64_val); + return false; } - return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); + value = static_cast(int64_val); + return true; } -class GeAttrValueImp { - public: - static map attr_val_one_type_map_; - static map attr_val_list_type_map_; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, uint32_t &value) { + int64_t int64_val = 0; + if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { + return false; + } + if (int64_val > UINT32_MAX) { + REPORT_INNER_ERROR("E19999", "%ld int64_t value cannot cast to uint32_t", int64_val); + GELOGE(GRAPH_FAILED, "[Check][Param] %ld int64_t value cannot cast to uint32_t", int64_val); + return false; + } + // 老版本中,只判断了上限,没有判断下限,因此小于0时,这里不会报错 + // 这里维持老版本的做法,在第一次上库做完后,补上小于0的判断 + value = static_cast(int64_val); + return true; +} - static bool SetValue(proto::AttrDef &attr_def, int64_t val); - static bool SetValue(proto::AttrDef &attr_def, float val); - static bool SetValue(proto::AttrDef &attr_def, bool val); - static bool SetValue(proto::AttrDef &attr_def, const std::string &val); - static bool SetValue(proto::AttrDef &attr_def, const ConstGeTensorPtr &val); - static bool SetValue(proto::AttrDef &attr_def, const GeTensor &val); - static bool SetValue(proto::AttrDef &attr_def, const GeTensorDesc &val); - static bool SetValue(proto::AttrDef &attr_def, const Buffer &val); - static bool SetValue(proto::AttrDef &attr_def, const NamedAttrs &val); - static bool SetValue(proto::AttrDef &attr_def, const ComputeGraphPtr &val); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); - static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); - static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); - static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); - static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); - static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); - static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); - static bool SetValue(proto::AttrDef &attr_def, const std::vector &val); +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { + if (org_op_desc == nullptr) { + REPORT_INNER_ERROR("E19999", "org_op_desc is null, check invalid"); + GELOGE(GRAPH_FAILED, "[Check][Param] org_op_desc is null"); + return nullptr; + } + std::shared_ptr op_def; + op_def = ComGraphMakeShared(); + if (op_def == nullptr) { + REPORT_CALL_ERROR("E19999", "create proto::OpDef failed."); + GELOGE(GRAPH_FAILED, "[Create][OpDef] proto::OpDef make shared failed"); + return nullptr; // lint !e665 + } + ModelSerializeImp imp; + (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, int64_t &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, float &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, bool &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, std::string &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensorPtr &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensor &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeTensorDesc &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - NamedAttrs &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ComputeGraphPtr &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - std::vector &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - std::vector &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - std::vector &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - std::vector &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - std::vector &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, vector &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - std::vector &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - std::vector &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - std::vector &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - std::vector &val); - // Value will be moved - static bool SetZeroCopyBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &&buffer); - static bool GetZeroCopyBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &buffer); - // Value will be moved - static bool SetZeroCopyListBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector &list_buffer); - static bool GetZeroCopyListBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector &list_buffer); + imp.SetProtobufOwner(op_def); + OpDescPtr op_desc = nullptr; + GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), + REPORT_CALL_ERROR("E19999", "UnserializeOpDesc failed"); + return op_desc, "[Call][UnserializeOpDesc] op_desc unserialize failed"); + op_desc->extAttrs_ = org_op_desc->extAttrs_; - static bool SetValue(proto::AttrDef &attr_def, const vector> &value); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector> &value); + // This function may be called by some passes of fusion engine, in this condition, do not need these attribute + if (op_desc->impl_ == nullptr) { + REPORT_INNER_ERROR("E19999", "op_desc impl is nullptr, check invalid"); + GELOGE(GRAPH_FAILED, "[Check][Param] Op desc impl is nullptr."); + return nullptr; + } + if (!op_desc->impl_->input_name_idx_.empty()) { + op_desc->impl_->input_name_idx_.clear(); + } + if (!op_desc->impl_->output_name_idx_.empty()) { + op_desc->impl_->output_name_idx_.clear(); + } + if (!op_desc->impl_->optional_input_names_.empty()) { + op_desc->impl_->optional_input_names_.clear(); + } - static bool SetValue(proto::AttrDef &attr_def, const vector> &value); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector> &value); + return op_desc; +} - static bool SetValue(proto::AttrDef &attr_def, const vector &value); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector &value); +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { + if (org_op_desc == nullptr || org_op_desc->impl_ == nullptr) { + REPORT_INNER_ERROR("E19999", "org_op_desc is null, check invalid"); + GELOGE(GRAPH_FAILED, "[Check][Param] org_op_desc is null"); + return nullptr; + } + std::shared_ptr op_def = ComGraphMakeShared(); + if (op_def == nullptr) { + REPORT_CALL_ERROR("E19999", "create proto::OpDef failed"); + GELOGE(GRAPH_FAILED, "[Create][OpDef] proto::OpDef make shared failed"); + return nullptr; + } + ModelSerializeImp imp; + (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); - static bool SetValue(proto::AttrDef &attr_def, const ge::DataType &value); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ge::DataType &value); -}; + imp.SetProtobufOwner(op_def); + OpDescPtr op_desc = nullptr; + if (!imp.UnserializeOpDesc(op_desc, *op_def)) { + REPORT_CALL_ERROR("E19999", "UnserializeOpDesc failed."); + return nullptr; + } -map GeAttrValueImp::attr_val_one_type_map_ = { - {proto::AttrDef::kI, GeAttrValue::VT_INT}, - {proto::AttrDef::kF, GeAttrValue::VT_FLOAT}, - {proto::AttrDef::kB, GeAttrValue::VT_BOOL}, - {proto::AttrDef::kS, GeAttrValue::VT_STRING}, - {proto::AttrDef::kT, GeAttrValue::VT_TENSOR}, - {proto::AttrDef::kTd, GeAttrValue::VT_TENSOR_DESC}, - {proto::AttrDef::kG, GeAttrValue::VT_GRAPH}, - {proto::AttrDef::kBt, GeAttrValue::VT_BYTES}, - {proto::AttrDef::kFunc, GeAttrValue::VT_NAMED_ATTRS}, - {proto::AttrDef::kListListInt, GeAttrValue::VT_LIST_LIST_INT}, - {proto::AttrDef::kListListFloat, GeAttrValue::VT_LIST_LIST_FLOAT}, - {proto::AttrDef::kDt, GeAttrValue::VT_DATA_TYPE}, -}; -map GeAttrValueImp::attr_val_list_type_map_ = { - {proto::AttrDef_ListValue_ListValueType_VT_LIST_INT, GeAttrValue::VT_LIST_INT}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT, GeAttrValue::VT_LIST_FLOAT}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_BOOL, GeAttrValue::VT_LIST_BOOL}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_STRING, GeAttrValue::VT_LIST_STRING}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, GeAttrValue::VT_LIST_TENSOR}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, GeAttrValue::VT_LIST_TENSOR_DESC}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, GeAttrValue::VT_LIST_GRAPH}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, GeAttrValue::VT_LIST_BYTES}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, GeAttrValue::VT_LIST_NAMED_ATTRS}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, GeAttrValue::VT_LIST_DATA_TYPE}, -}; + op_desc->extAttrs_ = org_op_desc->extAttrs_; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue::GeAttrValue() { value_.InitDefault(); } + if (op_desc->impl_ == nullptr) { + REPORT_INNER_ERROR("E19999", "op desc impl is nullptr, check invalid"); + GELOGE(GRAPH_FAILED, "[Check][Param] op desc impl is null."); + return nullptr; + } + op_desc->impl_->input_name_idx_.insert(org_op_desc->impl_->input_name_idx_.begin(), + org_op_desc->impl_->input_name_idx_.end()); + op_desc->impl_->optional_input_names_.insert(org_op_desc->impl_->optional_input_names_.begin(), + org_op_desc->impl_->optional_input_names_.end()); + op_desc->impl_->output_name_idx_.insert(org_op_desc->impl_->output_name_idx_.begin(), + org_op_desc->impl_->output_name_idx_.end()); -GeAttrValue::GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val) : value_(proto_owner, val) {} + op_desc->impl_->infer_func_ = org_op_desc->impl_->infer_func_; + op_desc->impl_->infer_format_func_ = org_op_desc->impl_->infer_format_func_; + op_desc->impl_->verifier_func_ = org_op_desc->impl_->verifier_func_; -GeAttrValue::ValueType GeAttrValue::GetValueType() const { - auto proto_msg = value_.GetProtoMsg(); - if (proto_msg != nullptr) { - auto val_case = proto_msg->value_case(); - if (val_case != proto::AttrDef::kList) { - auto it = GeAttrValueImp::attr_val_one_type_map_.find(val_case); - if (it != GeAttrValueImp::attr_val_one_type_map_.end()) { - return it->second; - } - } else { - auto it = GeAttrValueImp::attr_val_list_type_map_.find(proto_msg->list().val_type()); - if (it != GeAttrValueImp::attr_val_list_type_map_.end()) { - return it->second; - } - } - } - return GeAttrValue::VT_NONE; + return op_desc; } -bool GeAttrValue::IsEmpty() const { return GetValueType() == VT_NONE; } +template +bool SetAttrValue(AttrStore &attrs, const string &name, T &&value) { + return attrs.SetByName(name, std::forward(value)); +} -GeAttrValue GeAttrValue::Copy() const { - GeAttrValue valueRet; - auto proto_msg = value_.GetProtoMsg(); - auto proto_msg_ret = valueRet.value_.GetProtoMsg(); - if (proto_msg != nullptr && proto_msg_ret != nullptr) { - *proto_msg_ret = *proto_msg; +template +bool GetAttrValue(const AttrStore &attrs, const string &name, T &value) { + auto p = attrs.GetByName(name); + if (p == nullptr) { + return false; } - return valueRet; + value = *p; + return true; } -#define ATTR_VALUE_SET_GET_IMP(type) \ - graphStatus GeAttrValue::SetValue(const type &val) { \ - auto proto_msg = value_.GetProtoMsg(); \ - if (proto_msg) { \ - if (GeAttrValueImp::SetValue(*proto_msg, val)) { \ - return GRAPH_SUCCESS; \ - } \ - } \ - return GRAPH_FAILED; \ - } \ - \ - graphStatus GeAttrValue::GetValue(type &val) const { \ - auto proto_msg = value_.GetProtoMsg(); \ - if (proto_msg) { \ - if (GeAttrValueImp::GetValue(*proto_msg, value_.GetProtoOwner(), val)) { \ - return GRAPH_SUCCESS; \ - } \ - } \ - return GRAPH_FAILED; \ +template::type> +RT *SetAndGetAttrValue(AttrStore &attrs, const string &name, T &&value) { + if (!attrs.SetByName(name, std::forward(value))) { + return nullptr; } + return attrs.MutableGetByName(name); +} -ATTR_VALUE_SET_GET_IMP(std::string) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(int64_t) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(float) // lint !e524 -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(bool) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeTensorDesc) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeTensorPtr) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(ComputeGraphPtr) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(Buffer) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(NamedAttrs) -ATTR_VALUE_SET_GET_IMP(vector) -/*lint -e665*/ -ATTR_VALUE_SET_GET_IMP(vector>) -ATTR_VALUE_SET_GET_IMP(vector>) -/*lint +e665*/ -ATTR_VALUE_SET_GET_IMP(vector) // lint !e665 -ATTR_VALUE_SET_GET_IMP(DataType) // lint !e665 - -#undef ATTR_VALUE_SET_GET_IMP +#define SET_ATTR_FUNC(type_name, type) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY \ + bool AttrUtils::Set##type_name(AttrHolderAdapter &&obj, const string &name, const type &value) { \ + if (obj->HasAttr("test_fail")) { \ + return false; \ + } \ + return SetAttrValue(obj->MutableAttrMap(), name, value); \ + } -graphStatus GeAttrValue::MutableTensor(GeTensorPtr &tensor) { return GetValue(tensor); } +#define GET_ATTR_FUNC(type_name, type) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY \ + bool AttrUtils::Get##type_name(ConstAttrHolderAdapter &&obj, const string &name, type &value) { \ + return GetAttrValue(obj->GetAttrMap(), name, value); \ + } -graphStatus GeAttrValue::MutableListTensor(vector &list_tensor) { return GetValue(list_tensor); } +#define SET_GET_FUNC(type_name, type) \ + SET_ATTR_FUNC(type_name, type) \ + GET_ATTR_FUNC(type_name, type) -class AttrUtilsHelper { - public: - inline static bool GetValueCheckType(const proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) { - if (attr_def.value_case() != proto_case) { - GELOGW("[Check][Type] Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case); - return false; - } - return true; - } +bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value) { + return SetAttrValue(obj->MutableAttrMap(), name, value); +} +bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value) { + return GetAttrValue(obj->GetAttrMap(), name, value); +} +SET_GET_FUNC(Int, int64_t) +SET_GET_FUNC(Float, float) +SET_GET_FUNC(ListFloat, vector) +SET_GET_FUNC(Bool, bool) +SET_GET_FUNC(ListBool, vector) +SET_GET_FUNC(Str, string) +SET_GET_FUNC(ListStr, vector) +SET_GET_FUNC(TensorDesc, GeTensorDesc) +SET_GET_FUNC(ListTensorDesc, vector) +SET_GET_FUNC(NamedAttrs, NamedAttrs) +SET_GET_FUNC(ListNamedAttrs, vector) +SET_GET_FUNC(DataType, DataType) +SET_GET_FUNC(ListDataType, vector) +SET_GET_FUNC(ListListInt, vector>) +SET_GET_FUNC(ListListFloat, vector>) - inline static bool GetValueCheckListType( - const proto::AttrDef &attr_def, proto::AttrDef_ListValue_ListValueType proto_list_case, - const std::function item_check_fun) { - if (attr_def.value_case() != proto::AttrDef::kList) { - GELOGW("[Check][ListType] Check ListType Failed, value_case %u", attr_def.value_case()); - return false; - } - auto &list = attr_def.list(); - if (list.val_type() == proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE) { - return item_check_fun(attr_def); - } - if (list.val_type() != proto_list_case) { - GELOGW("[Check][ListType] Check ListType Failed, val_type %u, expected %u", list.val_type(), proto_list_case); - return false; - } - return true; - } - inline static bool SetValueCheckType(proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) { - if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto_case) { - GELOGW("[Check][Type] Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case); - return false; - } - return true; +bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value) { + return SetListInt(std::move(obj), name, std::vector(value.begin(), value.end())); +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetListInt(AttrUtils::AttrHolderAdapter &&obj, const string &name, const vector &value) { + return SetListInt(std::move(obj), name, std::vector(value.begin(), value.end())); +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value) { + return SetListInt(std::move(obj), name, std::vector(value)); +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value) { + value.clear(); + vector int64_list; + if (!GetListInt(std::move(obj), name, int64_list)) { + return false; } - inline static bool SetValueCheckAndSetListType(proto::AttrDef &attr_def, - proto::AttrDef_ListValue_ListValueType proto_list_case) { - if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto::AttrDef::kList) { - GELOGW("[Check][Type] Check Type Failed, value_case %u", attr_def.value_case()); - return false; - } - auto list = attr_def.mutable_list(); - if (list == nullptr) { - REPORT_INNER_ERROR("E19999", "attrdef list is nullptr"); - GELOGE(GRAPH_FAILED, "[Check][Param] attrdef list is nullptr"); - return false; - } - if (list->val_type() != proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE && - list->val_type() != proto_list_case) { - GELOGW("[Check][ListType] Check ListType Failed, val_type %d, expected %d", - static_cast(list->val_type()), static_cast(proto_list_case)); + for (size_t i = 0; i < int64_list.size(); ++i) { + if (int64_list[i] > INT32_MAX) { + REPORT_INNER_ERROR("E19999", "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); + GELOGE(GRAPH_FAILED, "[Check][Param] index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); return false; } - list->set_val_type(proto_list_case); - return true; } - - static bool GetAttrMapItem(const AttrHolder *obj, const string &name, const proto::AttrDef *&attr_def) { - if (obj == nullptr) { - REPORT_INNER_ERROR("E19999", "param obj is nullptr, check invalid"); - GELOGE(FAILED, "[Check][Param] %s obj is nullptr", name.c_str()); - return false; - } - auto attr_map = obj->GetAttrMap().GetProtoMsg(); - if (attr_map == nullptr) { - REPORT_CALL_ERROR("E19999", "proto msg is nullptr, check invalid."); - GELOGE(FAILED, "[Get][ProtoMsg] %s attr map is nullptr", name.c_str()); - return false; - } - auto it = attr_map->find(name); - if (it == attr_map->end()) { - return false; - } - attr_def = &it->second; - return true; + value.insert(value.begin(), int64_list.begin(), int64_list.end()); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value) { + value.clear(); + vector int64_list; + if (!GetListInt(std::move(obj), name, int64_list)) { + return false; } - inline static bool MutableAttrMapItem(AttrHolder *obj, const string &name, proto::AttrDef *&attr_def) { - if (obj == nullptr) { - REPORT_INNER_ERROR("E19999", "param obj is nullptr, check invalid."); - GELOGE(FAILED, "[Check][Param] %s obj is nullptr", name.c_str()); - return false; - } - auto attr_map = obj->MutableAttrMap().GetProtoMsg(); - if (attr_map == nullptr) { - REPORT_CALL_ERROR("E19999", "proto msg is nullptr, check invalid."); - GELOGE(FAILED, "[Get][ProtoMsg] %s attr map is nullptr", name.c_str()); + for (size_t i = 0; i < int64_list.size(); ++i) { + if (int64_list[i] > UINT32_MAX) { + REPORT_INNER_ERROR("E19999", "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); + GELOGE(GRAPH_FAILED, "[Check][Param] index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); return false; } - // Get or add - attr_def = &((*attr_map)[name]); - return true; - } -}; - -#define ATTR_VALUE_IMP_SET_ONE(ValType, proto_case, protoItem) \ - bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \ - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \ - return false; \ - } \ - proto_attr_val.set_##protoItem(value); \ - return true; \ - } - -#define ATTR_VALUE_IMP_SET_LIST(ValType, proto_list_case, protoItem) \ - bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \ - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, \ - proto::AttrDef_ListValue_ListValueType_##proto_list_case)) { \ - return false; \ - } \ - auto list = proto_attr_val.mutable_list(); \ - list->clear_##protoItem(); \ - for (const auto &item : value) { \ - list->add_##protoItem(item); \ - } \ - return true; \ - } - -ATTR_VALUE_IMP_SET_ONE(int64_t, kI, i) -ATTR_VALUE_IMP_SET_ONE(float, kF, f) -ATTR_VALUE_IMP_SET_ONE(const string &, kS, s) -ATTR_VALUE_IMP_SET_ONE(bool, kB, b) - -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_FLOAT, f) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_STRING, s) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_BOOL, b) - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensorDesc &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) { - return false; + // 老版本中,只判断了上限,没有判断下限,因此小于0时,这里不会报错 + // 这里维持老版本的做法,在第一次上库做完后,补上小于0的判断 } - if (value.impl_ == nullptr) { + value.insert(value.begin(), int64_list.begin(), int64_list.end()); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetTensor(AttrUtils::AttrHolderAdapter &&obj, const string &name, const GeTensor &value) { + // 当前GeTensor的拷贝赋值、拷贝构造函数均不是深拷贝,因此无法使用默认的方法SetAttr + if (!obj->MutableAttrMap().SetByName(name, GeTensor())) { return false; } - - auto proto_msg = value.impl_->tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { + auto tensor = obj->MutableAttrMap().MutableGetByName(name); + if (tensor == nullptr) { return false; } - *proto_attr_val.mutable_td() = *proto_msg; + TensorUtils::CopyTensor(value, *tensor); return true; } - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensorPtr &value) { + return SetTensor(std::move(obj), name, *value); +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetTensor(AttrHolderAdapter &&obj, const string &name, const ConstGeTensorPtr &value) { + return SetTensor(std::move(obj), name, *value); +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetListTensor(AttrUtils::AttrHolderAdapter &&obj, const string &name, const vector &value) { + std::vector tensors(value.size()); + if (!obj->MutableAttrMap().SetByName(name, tensors)) { return false; } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_td(); - for (const auto &item : value) { - if (item.impl_ == nullptr) { - return false; - } - auto proto_msg = item.impl_->tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { - proto_attr_val.clear_list(); - return false; - } - *list->add_td() = *proto_msg; + auto attr_tensors = obj->MutableAttrMap().MutableGetByName>(name); + if (attr_tensors == nullptr) { + return false; + } + for (size_t i = 0; i < value.size(); ++i) { + TensorUtils::CopyTensor(value[i], (*attr_tensors)[i]); } return true; } - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ConstGeTensorPtr &value) { - if (value) { - return SetValue(proto_attr_val, *value); - } else { - return SetValue(proto_attr_val, GeTensor()); - } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value) { + vector tensors(value.size()); + std::copy(value.begin(), value.end(), tensors.begin()); + return SetListTensor(std::move(obj), name, tensors); } - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensor &val) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value) { + std::vector tensors(value.size()); + if (!obj->MutableAttrMap().SetByName(name, tensors)) { return false; } - if (val.impl_ == nullptr) { + auto attr_tensors = obj->MutableAttrMap().MutableGetByName>(name); + if (attr_tensors == nullptr) { return false; } - if (val.impl_->tensor_def_.GetProtoOwner() != nullptr) { - auto proto_msg = val.impl_->tensor_def_.GetProtoMsg(); - if (proto_msg == nullptr) { - REPORT_CALL_ERROR("E19999", "Proto msg is nullptr"); - GELOGE(FAILED, "[Get][ProtoMsg] Proto msg is nullptr"); - return false; - } - *proto_attr_val.mutable_t() = *proto_msg; - } else { - auto tensor = proto_attr_val.mutable_t(); - if (tensor == nullptr) { - REPORT_INNER_ERROR("E19999", "tensor is nullptr"); - GELOGE(FAILED, "[Check][Param] tensor is nullptr"); - return false; - } - if (val.impl_ != nullptr && val.impl_->tensor_data_.impl_ != nullptr && - val.impl_->tensor_data_.impl_->tensor_descriptor_.GetProtoMsg() != nullptr) { - tensor->mutable_desc()->CopyFrom(*(val.impl_->tensor_data_.impl_->tensor_descriptor_.GetProtoMsg())); - } - tensor->set_data(val.GetData().data(), val.GetData().size()); + for (size_t i = 0; i < value.size(); ++i) { + TensorUtils::CopyTensor(*(value[i]), (*attr_tensors)[i]); } return true; } - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - vector constList(value.size()); - std::copy(value.begin(), value.end(), constList.begin()); - return SetValue(proto_attr_val, constList); +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, + std::initializer_list &&value) { + return SetListTensor(std::move(obj), name, vector(value)); } -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) { +// 所有权UT测试,不能把属性上的GeTensor给错误释放了 +// 而且这里的行为与老版本是不一样的,老版本中,即使属性的owner生命周期结束析构了,通过本接口获取的value仍然是可用的 +// 但是新接口中,owner没有转移,owner析构后,value指向的内存就被释放了,这里需要排查 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::MutableTensor(AttrHolderAdapter &&obj, const string &name, GeTensorPtr &value) { + auto tensor = obj->MutableAttrMap().MutableGetByName(name); + if (tensor == nullptr) { return false; } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_t(); - for (const auto &item : value) { - if (item == nullptr || item->impl_ == nullptr) { - REPORT_INNER_ERROR("E19999", "ConstGeTensorPtr in param value is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] AttrUtils::SetListTensor item is nullptr"); - proto_attr_val.clear_list(); - return false; - } - if (item->impl_->tensor_def_.GetProtoOwner() != nullptr) { - auto proto_msg = item->impl_->tensor_def_.GetProtoMsg(); - if (proto_msg == nullptr) { - REPORT_CALL_ERROR("E19999", "proto msg is nullptr, check invalid."); - GELOGE(FAILED, "[Get][ProtoMsg] Proto msg is nullptr"); - proto_attr_val.clear_list(); - return false; - } - *list->add_t() = *proto_msg; - } else { - auto tensor = list->add_t(); - if (tensor == nullptr) { - REPORT_INNER_ERROR("E19999", "tensor is nullptr"); - GELOGE(FAILED, "[Check][Param] tensor is nullptr"); - proto_attr_val.clear_list(); - return false; - } - if (item->impl_->tensor_data_.impl_ != nullptr && - item->impl_->tensor_data_.impl_->tensor_descriptor_.GetProtoMsg() != nullptr) { - tensor->mutable_desc()->CopyFrom(*(item->impl_->tensor_data_.impl_->tensor_descriptor_.GetProtoMsg())); - } - tensor->set_data(item->GetData().data(), item->GetData().size()); - } - } + value = std::shared_ptr(tensor, [](GeTensor *){}); return true; } -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value) { + auto tensor = obj->GetAttrMap().GetByName(name); + if (tensor == nullptr) { return false; } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_t(); - for (const auto &item : value) { - if (item.impl_ != nullptr && item.impl_->tensor_def_.GetProtoOwner() != nullptr) { - auto proto_msg = item.impl_->tensor_def_.GetProtoMsg(); - if (proto_msg == nullptr) { - REPORT_CALL_ERROR("E19999", "Proto msg is nullptr"); - GELOGE(FAILED, "[Get][ProtoMsg] Proto msg is nullptr"); - proto_attr_val.clear_list(); - return false; - } - *list->add_t() = *proto_msg; - } else { - auto tensor = list->add_t(); - if (tensor == nullptr) { - REPORT_INNER_ERROR("E19999", "tensor is nullptr"); - GELOGE(FAILED, "[Check][Param] tensor is nullptr"); - proto_attr_val.clear_list(); - return false; - } - if (item.impl_ != nullptr && item.impl_->tensor_data_.impl_ != nullptr && - item.impl_->tensor_data_.impl_->tensor_descriptor_.GetProtoMsg() != nullptr) { - tensor->mutable_desc()->CopyFrom(*(item.impl_->tensor_data_.impl_->tensor_descriptor_.GetProtoMsg())); - } - tensor->set_data(item.GetData().data(), item.GetData().size()); - } - } + value = std::shared_ptr(tensor, [](const GeTensor *){}); return true; } -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const Buffer &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector &value) { + auto tensors = obj->GetAttrMap().GetByName>(name); + if (tensors == nullptr) { return false; } - size_t val_size = value.GetSize(); - proto_attr_val.set_bt(value.GetData(), val_size); + value.resize(tensors->size()); + for (size_t i = 0; i < tensors->size(); ++i) { + value[i] = std::shared_ptr(&(*tensors)[i], [](const GeTensor *){}); + } return true; } -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector &value) { + auto tensors = obj->MutableAttrMap().MutableGetByName>(name); + if (tensors == nullptr) { return false; } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_bt(); - for (const auto &item : value) { - list->add_bt(item.GetData(), item.GetSize()); + value.resize(tensors->size()); + for (size_t i = 0; i < tensors->size(); ++i) { + value[i] = std::shared_ptr(&(*tensors)[i], [](GeTensor *){}); } return true; } -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const NamedAttrs &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetGraph(AttrUtils::AttrHolderAdapter &&obj, const string &name, const ComputeGraphPtr &value) { + proto::GraphDef *graph_def = SetAndGetAttrValue(obj->MutableAttrMap(), name, proto::GraphDef()); + if (graph_def == nullptr) { return false; } - auto proto_msg = value.named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - REPORT_CALL_ERROR("E19999", "proto msg is nullptr"); - GELOGE(FAILED, "[Get][ProtoMsg] Proto msg is nullptr"); + ModelSerializeImp imp; + if (!imp.SerializeGraph(value, graph_def)) { + REPORT_CALL_ERROR("E19999", "SerializeGraph failed when add ComputeGraph to attr %s", name.c_str()); + GELOGE(GRAPH_FAILED, "[Serialize][Graph] Failed when add ComputeGraph to attr %s", name.c_str()); + obj->MutableAttrMap().Delete(name); return false; } - *proto_attr_val.mutable_func() = *proto_msg; return true; } - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetListGraph(AttrUtils::AttrHolderAdapter &&obj, const string &name, + const vector &value) { + std::vector graphs(value.size()); + if (!obj->MutableAttrMap().SetByName(name, graphs)) { + return false; + } + auto attr_graphs = obj->MutableAttrMap().MutableGetByName>(name); + if (attr_graphs == nullptr) { return false; } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_na(); - for (const auto &item : value) { - auto proto_msg = item.named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - proto_attr_val.clear_list(); + for (size_t i = 0; i < value.size(); ++i) { + ModelSerializeImp imp; + if (!imp.SerializeGraph(value[i], &attr_graphs->at(i))) { + REPORT_CALL_ERROR("E19999", "SerializeGraph failed when add ComputeGraph to attr %s", name.c_str()); + GELOGE(GRAPH_FAILED, "[Serialize][Graph] Failed when add ComputeGraph to attr %s", name.c_str()); + obj->MutableAttrMap().Delete(name); return false; } - *list->add_na() = *proto_msg; } return true; } - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::ComputeGraphPtr &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kG)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::GetGraph(AttrUtils::ConstAttrHolderAdapter &&obj, const string &name, ComputeGraphPtr &value) { + auto attr_graph_def = obj->GetAttrMap().GetByName(name); + if (attr_graph_def == nullptr) { + return false; + } + // 这里延续了老代码实现,先拷贝构造一个ComputeGraph,然后做反序列化,感觉直接把attr_graph_def传进去应该就可以了? + // 下一步对这里做整改,直接传入attr_graph_def,避免这一次拷贝 + auto graph_def = ComGraphMakeShared(*attr_graph_def); + if (graph_def == nullptr) { + REPORT_CALL_ERROR("E19999", "create proto::GraphDef failed."); + GELOGE(GRAPH_FAILED, "[Create][GraphDef] proto::GraphDef make shared failed"); return false; } + ModelSerializeImp imp; - if (!imp.SerializeGraph(value, proto_attr_val.mutable_g())) { - REPORT_CALL_ERROR("E19999", "SerializeGraph failed"); - GELOGE(GRAPH_FAILED, "[Serialize][Graph] Failed"); - proto_attr_val.clear_g(); + imp.SetProtobufOwner(graph_def); + if (!imp.UnserializeGraph(value, *graph_def)) { + REPORT_CALL_ERROR("E19999", "UnserializeGraph failed when get attr ComputeGraph by name %s", name.c_str()); + GELOGE(GRAPH_FAILED, "[Unserialize][Graph] Failed when get attr ComputeGraph by name %s", name.c_str()); return false; } + return true; } - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::GetListGraph(AttrUtils::ConstAttrHolderAdapter &&obj, const string &name, + vector &value) { + auto graph_defs = obj->GetAttrMap().GetByName>(name); + if (graph_defs == nullptr) { return false; } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_g(); - ModelSerializeImp imp; - for (const auto &item : value) { - if (!imp.SerializeGraph(item, list->add_g())) { - REPORT_CALL_ERROR("E19999", "SerializeGraph failed."); - GELOGE(GRAPH_FAILED, "[Serialize][Graph] failed"); - proto_attr_val.clear_list(); - return false; + value.resize(graph_defs->size()); + for (size_t i = 0; i < graph_defs->size(); ++i) { + std::shared_ptr graph_def; + graph_def = ComGraphMakeShared(graph_defs->at(i)); + if (graph_def == nullptr) { + REPORT_CALL_ERROR("E19999", "create proto::GraphDef failed."); + GELOGE(GRAPH_FAILED, "[Create][GraphDef] proto::GraphDef make shared failed"); + graph_def = nullptr; + return false; // lint !e665 + } else { + ComputeGraphPtr graph = nullptr; + ModelSerializeImp imp; + imp.SetProtobufOwner(static_cast(graph_def)); + if (!imp.UnserializeGraph(graph, *graph_def)) { + REPORT_CALL_ERROR("E19999", "UnserializeGraph failed."); + GELOGE(GRAPH_FAILED, "[Unserialize][Graph] Failed"); + return false; + } // lint !e514 + value[i] = graph; } } return true; } - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector> &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetBytes(AttrUtils::AttrHolderAdapter &&obj, const string &name, const Buffer &value) { + auto buffer = SetAndGetAttrValue(obj->MutableAttrMap(), name, Buffer()); + if (buffer == nullptr) { return false; } - proto_attr_val.clear_list_list_int(); - auto list_list_int = proto_attr_val.mutable_list_list_int(); - GE_CHECK_NOTNULL_EXEC(list_list_int, return false); - for (auto &list_int : value) { - auto list_item = list_list_int->add_list_list_i(); - GE_CHECK_NOTNULL_EXEC(list_item, return false); - for (auto &int_item : list_int) { - list_item->add_list_i(int_item); - } - } + BufferUtils::CopyFrom(value, *buffer); return true; } - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector> &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kListListFloat)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::GetBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &value) { + auto buffer = obj->GetAttrMap().GetByName(name); + if (buffer == nullptr) { return false; } - proto_attr_val.clear_list_list_float(); - auto list_list_float = proto_attr_val.mutable_list_list_float(); - GE_CHECK_NOTNULL_EXEC(list_list_float, return false); - for (auto &list_float : value) { - auto list_item = list_list_float->add_list_list_f(); - GE_CHECK_NOTNULL_EXEC(list_item, return false); - for (auto &float_item : list_float) { - list_item->add_list_f(float_item); - } - } + BufferUtils::CopyFrom(*buffer, value); return true; } - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetListBytes(AttrUtils::AttrHolderAdapter &&obj, const string &name, const vector &value) { + std::vector buffers(value.size()); + auto attr_buffers = SetAndGetAttrValue(obj->MutableAttrMap(), name, buffers); + if (attr_buffers == nullptr) { return false; } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_dt(); - for (const auto &item : value) { - list->add_dt(static_cast(item)); + + for (size_t i = 0; i < value.size(); ++i) { + BufferUtils::CopyFrom(value[i], (*attr_buffers)[i]); } + return true; } - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::DataType &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::GetListBytes(AttrUtils::ConstAttrHolderAdapter &&obj, const string &name, vector &value) { + auto buffers = obj->GetAttrMap().GetByName>(name); + if (buffers == nullptr) { return false; } - proto_attr_val.set_dt(static_cast(value)); - + value.resize(buffers->size()); + for (size_t i = 0; i < buffers->size(); ++i) { + BufferUtils::CopyFrom(buffers->at(i), value[i]); + } return true; } -#define ATTR_VALUE_IMP_GET_ONE(ValType, proto_case, protoItem) \ - bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ValType value) { \ - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \ - return false; \ - } \ - value = proto_attr_val.protoItem(); \ - return true; \ - } - -#define ListValueItemCheck(protoItem) \ - [](const proto::AttrDef &proto_attr_val) { return proto_attr_val.list().protoItem##_size() > 0; } - -#define ATTR_VALUE_IMP_GET_LIST(ValType, proto_list_case, protoItem) \ - bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, vector &value) { \ - value.clear(); \ - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, \ - proto::AttrDef_ListValue_ListValueType_##proto_list_case, \ - ListValueItemCheck(protoItem))) { \ - return false; \ - } \ - auto &list = proto_attr_val.list(); \ - for (const auto &item : list.protoItem()) { \ - value.push_back(item); \ - } \ - return true; \ - } - -ATTR_VALUE_IMP_GET_ONE(int64_t &, kI, i) -ATTR_VALUE_IMP_GET_ONE(float &, kF, f) -ATTR_VALUE_IMP_GET_ONE(string &, kS, s) -ATTR_VALUE_IMP_GET_ONE(bool &, kB, b) - -ATTR_VALUE_IMP_GET_LIST(int64_t, VT_LIST_INT, i) -ATTR_VALUE_IMP_GET_LIST(float, VT_LIST_FLOAT, f) -ATTR_VALUE_IMP_GET_LIST(string, VT_LIST_STRING, s) -ATTR_VALUE_IMP_GET_LIST(bool, VT_LIST_BOOL, b) - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeTensorDesc &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) { - return false; - } - if (value.impl_ == nullptr) { - return false; - } - auto proto_msg = value.impl_->tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = proto_attr_val.td(); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - if (!AttrUtilsHelper::GetValueCheckListType( - proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, ListValueItemCheck(td))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.td()) { - value.emplace_back(GeTensorDesc()); - if (value.back().impl_ == nullptr) { - return false; - } - auto proto_msg = value.back().impl_->tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = item; - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - GeTensorPtr &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { - return false; - } - value = std::shared_ptr( - new (std::nothrow) GeTensor(proto_owner, const_cast(proto_attr_val).mutable_t())); - GE_CHK_BOOL_RET_STATUS(value != nullptr, false, "[Check][Param] value is nullptr"); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, - ListValueItemCheck(t))) { - return false; - } - auto list = const_cast(proto_attr_val).mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - for (auto &item : *(list->mutable_t())) { - std::shared_ptr temp_value = std::shared_ptr(new (std::nothrow) GeTensor(proto_owner, &item)); - if (temp_value == nullptr) { - REPORT_CALL_ERROR("E19999", "create GeTensor failed."); - GELOGE(false, "[Create][GeTensor] failed."); - return false; - } - value.push_back(temp_value); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, Buffer &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { - return false; - } - auto &proto_val = proto_attr_val.bt(); - GE_LOGI_IF(proto_val.size() == 0, "size res is 0."); - value = Buffer::CopyFrom(reinterpret_cast(proto_val.data()), proto_val.size()); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, - ListValueItemCheck(bt))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.bt()) { - value.push_back(Buffer::CopyFrom((const uint8_t *)item.data(), item.size())); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - NamedAttrs &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { - return false; - } - auto proto_msg = value.named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = proto_attr_val.func(); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType( - proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.na()) { - value.emplace_back(NamedAttrs()); - if (value.empty()) { - return false; - } - auto proto_msg = value.back().named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = item; - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ComputeGraphPtr &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kG)) { - return false; - } - ComputeGraphPtr graph = nullptr; - std::shared_ptr graph_def; - graph_def = ComGraphMakeShared(proto_attr_val.g()); - if (graph_def == nullptr) { - REPORT_CALL_ERROR("E19999", "create proto::GraphDef failed."); - GELOGE(GRAPH_FAILED, "[Create][GraphDef] proto::GraphDef make shared failed"); - graph_def = nullptr; - return false; // lint !e665 - } else { - ModelSerializeImp imp; - imp.SetProtobufOwner(graph_def); - if (!imp.UnserializeGraph(graph, *graph_def)) { - REPORT_CALL_ERROR("E19999", "UnserializeGraph failed."); - GELOGE(GRAPH_FAILED, "[Unserialize][Graph] Failed"); - return false; - } // lint !e514 - value = graph; - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, - ListValueItemCheck(g))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.g()) { - std::shared_ptr graph_def; - graph_def = ComGraphMakeShared(item); - if (graph_def == nullptr) { - REPORT_CALL_ERROR("E19999", "create proto::GraphDef failed."); - GELOGE(GRAPH_FAILED, "[Create][GraphDef] proto::GraphDef make shared failed"); - graph_def = nullptr; - return false; // lint !e665 - } else { - ComputeGraphPtr graph = nullptr; - ModelSerializeImp imp; - imp.SetProtobufOwner(graph_def); - if (!imp.UnserializeGraph(graph, *graph_def)) { - REPORT_CALL_ERROR("E19999", "UnserializeGraph failed."); - GELOGE(GRAPH_FAILED, "[Unserialize][Graph] Failed"); - return false; - } // lint !e514 - value.push_back(graph); - } - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector> &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) { - return false; - } - - auto &list_listint = proto_attr_val.list_list_int().list_list_i(); - for (auto &list_int : list_listint) { - vector list_item(list_int.list_i().size()); - if (!list_int.list_i().empty()) { - (void)std::copy(list_int.list_i().begin(), list_int.list_i().end(), list_item.begin()); - } - value.push_back(list_item); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector> &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kListListFloat)) { - return false; - } - - auto &list_list_float = proto_attr_val.list_list_float().list_list_f(); - for (auto &list_float : list_list_float) { - vector list_item(list_float.list_f().size()); - if (!list_float.list_f().empty()) { - (void)std::copy(list_float.list_f().begin(), list_float.list_f().end(), list_item.begin()); - } - value.push_back(list_item); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, - ListValueItemCheck(dt))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.dt()) { - value.emplace_back(static_cast(item)); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ge::DataType &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) { - return false; - } - value = static_cast(proto_attr_val.dt()); - return true; -} - -GE_FUNC_HOST_VISIBILITY bool GeAttrValueImp::SetZeroCopyBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - Buffer &&buffer) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { - return false; - } - if (buffer.impl_ == nullptr) { - return false; - } - auto proto_msg = buffer.impl_->data_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - proto_attr_val.set_bt(std::move(*proto_msg->mutable_bt())); - return true; -} - -bool GeAttrValueImp::GetZeroCopyBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - Buffer &buffer) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { - return false; - } - buffer = Buffer(proto_owner, &const_cast(proto_attr_val)); - return true; -} - -bool GeAttrValueImp::SetZeroCopyListBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &list_buffer) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_bt(); - for (auto &item : list_buffer) { - if (item.impl_ == nullptr) { - return false; - } - auto proto_msg = item.impl_->data_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - list->add_bt(std::move(*proto_msg->mutable_bt())); - } - return true; -} - -bool GeAttrValueImp::GetZeroCopyListBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - vector &list_buffer) { - list_buffer.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, - ListValueItemCheck(bt))) { - return false; - } - auto list = const_cast(proto_attr_val).mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - for (auto &item : *(list->mutable_bt())) { - list_buffer.emplace_back(Buffer(proto_owner, &item)); - } - return true; -} - -bool AttrUtils::HasAttr(ConstAttrHolderAdapter &&obj, const string &name) { - if (!obj) { - return false; - } - return obj->HasAttr(name); -} - -#define ATTR_UTILS_SET_IMP(FuncName, Type) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Set##FuncName( \ - AttrHolderAdapter &&obj, const string &name, const Type &value) { \ - if (obj->HasAttr("test_fail")) { \ - return false; \ - } \ - \ - proto::AttrDef *proto_attr_val = nullptr; \ - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ - return false; \ - } \ - if (!GeAttrValueImp::SetValue(*proto_attr_val, value)) { \ - GELOGW("[Set][Value] Set" #FuncName " failed key %s", name.c_str()); \ - return false; \ - } \ - return true; \ - } - -#define ATTR_UTILS_GET_IMP(FuncName, Type) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Get##FuncName(ConstAttrHolderAdapter &&obj, \ - const string &name, Type &value) { \ - const proto::AttrDef *proto_attr_val = nullptr; \ - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ - return false; \ - } \ - if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value)) { \ - GELOGW("[Get][Value] Get" #FuncName " failed key %s", name.c_str()); \ - return false; \ - } \ - return true; \ - } - -#define ATTR_UTILS_SET_GET_IMP(FuncName, Type) \ - ATTR_UTILS_SET_IMP(FuncName, Type) \ - ATTR_UTILS_GET_IMP(FuncName, Type) - -ATTR_UTILS_SET_GET_IMP(Int, int64_t) -ATTR_UTILS_SET_GET_IMP(Float, float) -ATTR_UTILS_SET_GET_IMP(Bool, bool) -ATTR_UTILS_SET_GET_IMP(Str, string) -ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc) -ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) -ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) -ATTR_UTILS_SET_IMP(Tensor, GeTensor) -ATTR_UTILS_SET_GET_IMP(NamedAttrs, NamedAttrs) -ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) -ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) -/*lint -e665*/ -ATTR_UTILS_SET_GET_IMP(ListListInt, vector>) -/*lint +e665*/ -ATTR_UTILS_SET_GET_IMP(ListInt, vector) -ATTR_UTILS_SET_IMP(ListInt, vector) -ATTR_UTILS_SET_IMP(ListInt, vector) -ATTR_UTILS_SET_GET_IMP(ListFloat, vector) -ATTR_UTILS_SET_GET_IMP(ListListFloat, vector>) -ATTR_UTILS_SET_GET_IMP(ListBool, vector) -ATTR_UTILS_SET_GET_IMP(ListStr, vector) -ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector) -ATTR_UTILS_SET_IMP(ListTensor, vector) -ATTR_UTILS_SET_IMP(ListTensor, vector) -ATTR_UTILS_SET_IMP(ListTensor, vector) -ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector) -ATTR_UTILS_SET_GET_IMP(ListBytes, vector) -ATTR_UTILS_SET_GET_IMP(ListGraph, vector) -ATTR_UTILS_SET_GET_IMP(ListDataType, vector) // lint !e665 -ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665 - -bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, - std::initializer_list &&value) { - return SetListTensor(std::move(obj), name, vector(value)); -} - -bool AttrUtils::GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value) { - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - GeTensorPtr tensor; - if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) { - return false; - } - value = tensor; - return true; -} - -bool AttrUtils::GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector &value) { - value.clear(); - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - vector tensor; - if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) { - return false; - } - value.insert(value.begin(), tensor.begin(), tensor.end()); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::MutableTensor(AttrHolderAdapter &&obj, - const string &name, GeTensorPtr &value) { - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value); -} - -bool AttrUtils::MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector &value) { - value.clear(); - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value); -} - -bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value) { - proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::SetValue(*proto_attr_val, value); +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer) { + // Value will be shared + return SetAttrValue(obj->MutableAttrMap(), name, std::move(buffer)); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, - int32_t &value) { - int64_t int64_val = 0; - if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { - return false; - } - if (int64_val > INT32_MAX) { - REPORT_INNER_ERROR("E19999", "%ld int64_t value cannot cast to int32_t", int64_val); - GELOGE(GRAPH_FAILED, "[Check][Param] %ld int64_t value cannot cast to int32_t", int64_val); - return false; - } - value = static_cast(int64_val); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, - uint32_t &value) { - int64_t int64_val = 0; - if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { - return false; - } - if (int64_val > UINT32_MAX) { - REPORT_INNER_ERROR("E19999", "%ld int64_t value cannot cast to uint32_t", int64_val); - GELOGE(GRAPH_FAILED, "[Check][Param] %ld int64_t value cannot cast to uint32_t", int64_val); - return false; - } - value = static_cast(int64_val); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, - const string &name, vector &value) { - value.clear(); - vector int64_list; - if (!GetListInt(std::move(obj), name, int64_list)) { - return false; - } - - for (size_t i = 0; i < int64_list.size(); ++i) { - if (int64_list[i] > INT32_MAX) { - REPORT_INNER_ERROR("E19999", "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); - GELOGE(GRAPH_FAILED, "[Check][Param] index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); - return false; - } - } - value.insert(value.begin(), int64_list.begin(), int64_list.end()); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, - const string &name, vector &value) { - value.clear(); - vector int64_list; - if (!GetListInt(std::move(obj), name, int64_list)) { - return false; - } - - for (size_t i = 0; i < int64_list.size(); ++i) { - if (int64_list[i] > UINT32_MAX) { - REPORT_INNER_ERROR("E19999", "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); - GELOGE(GRAPH_FAILED, "[Check][Param] index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); - return false; - } - } - value.insert(value.begin(), int64_list.begin(), int64_list.end()); - return true; -} - -bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value) { - if (obj) { - vector bytes_vals; - for (auto &item : value) { - ModelSerialize serialize; - auto buffer = serialize.SerializeOpDesc(item); - if (buffer.GetSize() == 0) { - return false; - } - bytes_vals.push_back(buffer); - } - return SetZeroCopyListBytes(std::move(obj), name, bytes_vals); - } - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, - const string &name, - const vector &value) { - if (obj) { - vector bytes_vals; - for (auto &item : value) { - ModelSerialize serialize; - auto buffer = serialize.SerializeOpDesc(item); - if (buffer.GetSize() == 0) { - return false; - } - bytes_vals.push_back(buffer); - } - return SetZeroCopyListBytes(std::move(obj), name, bytes_vals); - } - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(ConstAttrHolderAdapter &&obj, - const string &name, - vector &value) { - value.clear(); - - vector bytes_vals; - if (!GetZeroCopyListBytes(std::move(obj), name, bytes_vals)) { - return false; - } - for (const auto &item : bytes_vals) { - ModelSerialize serialize; - auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732 - value.push_back(op_desc); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetZeroCopyBytes(AttrHolderAdapter &&obj, - const string &name, Buffer &&buffer) { - // Value will be moved - proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::SetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), std::move(buffer)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, - const string &name, Buffer &buffer) { - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), buffer); +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +bool AttrUtils::GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer) { + // Value will be shared + return GetAttrValue(obj->GetAttrMap(), name, buffer); } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector &list_buffer) { - // Value will be moved - proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::SetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer); + // Value will be shared + return SetAttrValue(obj->MutableAttrMap(), name, list_buffer); } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &list_buffer) { - list_buffer.clear(); - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer); + // Value will be shared + return GetAttrValue>(obj->GetAttrMap(), name, list_buffer); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { - if (org_op_desc == nullptr) { - REPORT_INNER_ERROR("E19999", "org_op_desc is null, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] org_op_desc is null"); - return nullptr; - } - std::shared_ptr op_def; - op_def = ComGraphMakeShared(); - if (op_def == nullptr) { - REPORT_CALL_ERROR("E19999", "create proto::OpDef failed."); - GELOGE(GRAPH_FAILED, "[Create][OpDef] proto::OpDef make shared failed"); - return nullptr; // lint !e665 - } - ModelSerializeImp imp; - (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); - - imp.SetProtobufOwner(op_def); - OpDescPtr op_desc = nullptr; - GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), - REPORT_CALL_ERROR("E19999", "UnserializeOpDesc failed"); - return op_desc, "[Call][UnserializeOpDesc] op_desc unserialize failed"); - op_desc->extAttrs_ = org_op_desc->extAttrs_; - - // This function may be called by some passes of fusion engine, in this condition, do not need these attribute - if (op_desc->impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "Op desc is nullptr."); - return nullptr; - } - if (!op_desc->impl_->input_name_idx_.empty()) { - op_desc->impl_->input_name_idx_.clear(); - } - if (!op_desc->impl_->output_name_idx_.empty()) { - op_desc->impl_->output_name_idx_.clear(); - } - if (!op_desc->impl_->optional_input_names_.empty()) { - op_desc->impl_->optional_input_names_.clear(); +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +std::map AttrUtils::GetAllAttrs(ConstAttrHolderAdapter &&obj) { + auto holder = obj.get(); + if (holder == nullptr) { + std::map empty; + return empty; } - - return op_desc; + return holder->GetAllAttrs(); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { - if (org_op_desc == nullptr || org_op_desc->impl_ == nullptr) { - REPORT_INNER_ERROR("E19999", "org_op_desc is null, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] org_op_desc is null"); - return nullptr; - } - std::shared_ptr op_def = ComGraphMakeShared(); - if (op_def == nullptr) { - REPORT_CALL_ERROR("E19999", "create proto::OpDef failed"); - GELOGE(GRAPH_FAILED, "[Create][OpDef] proto::OpDef make shared failed"); - return nullptr; - } - ModelSerializeImp imp; - (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); - imp.SetProtobufOwner(op_def); - OpDescPtr op_desc = nullptr; - GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), - REPORT_CALL_ERROR("E19999", "UnserializeOpDesc failed."); - return op_desc, "[Unserialize][OpDesc] failed"); +std::string AttrUtils::GetAttrsStrAfterRid(ConstAttrHolderAdapter &&obj, const set &un_compute_attrs) { - op_desc->extAttrs_ = org_op_desc->extAttrs_; - - if (op_desc->impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "op desc is null."); - return nullptr; - } - op_desc->impl_->input_name_idx_.insert(org_op_desc->impl_->input_name_idx_.begin(), - org_op_desc->impl_->input_name_idx_.end()); - op_desc->impl_->optional_input_names_.insert(org_op_desc->impl_->optional_input_names_.begin(), - org_op_desc->impl_->optional_input_names_.end()); - op_desc->impl_->output_name_idx_.insert(org_op_desc->impl_->output_name_idx_.begin(), - org_op_desc->impl_->output_name_idx_.end()); - - op_desc->impl_->infer_func_ = org_op_desc->impl_->infer_func_; - op_desc->impl_->infer_format_func_ = org_op_desc->impl_->infer_format_func_; - op_desc->impl_->verifier_func_ = org_op_desc->impl_->verifier_func_; - - return op_desc; -} -std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) { - auto holder = obj.get(); - if (holder == nullptr) { + std::map attr_map = GetAllAttrs(std::move(obj)); + if (attr_map.empty()) { return ""; } - auto attrs_map = holder->GetAttrMap(); - if (attrs_map.GetProtoMsg() == nullptr) { - return ""; - } - std::map ordered_attrs; - for (auto &attr : *(attrs_map.GetProtoMsg())) { - if (attr.second.has_t()) { - // print tensor desc message as an ordered string. - auto tensor_def = attr.second.t(); - string ordered_tensor_desc; - (void)google::protobuf::TextFormat::PrintToString(tensor_def.desc(), &ordered_tensor_desc); - ordered_attrs[attr.first] = ordered_tensor_desc + tensor_def.data(); - } else if (attr.second.has_td()) { - // print tensor desc message as an ordered string. - string ordered_attr; - (void)google::protobuf::TextFormat::PrintToString(attr.second, &ordered_attr); - ordered_attrs[attr.first] = ordered_attr; - } else { - ordered_attrs[attr.first] = attr.second.SerializeAsString(); + for (auto &attr : attr_map) { + proto::AttrDef attr_def; + auto *serializer = AttrSerializerRegistry::GetInstance().GetSerializer(attr.second.GetValueTypeId()); + if (serializer == nullptr || serializer->Serialize(attr.second, attr_def) != GRAPH_SUCCESS) { + ordered_attrs[attr.first] = ""; + continue; } + + ordered_attrs[attr.first] = attr_def.SerializeAsString(); } std::stringstream ss; for (auto &attr : ordered_attrs) { + if (un_compute_attrs.find(attr.first) != un_compute_attrs.end()) { + continue; + } ss << attr.first << ":" << attr.second << ";"; } return ss.str(); } +std::string AttrUtils::GetAllAttrsStr(ConstAttrHolderAdapter &&obj) { -std::string AttrUtils::GetAttrsStrAfterRid(AttrUtils::ConstAttrHolderAdapter &&obj, - const set &un_compute_attrs) { - auto holder = obj.get(); - if (holder == nullptr) { - return ""; - } - auto attrs_map = holder->GetAttrMap(); - if (attrs_map.GetProtoMsg() == nullptr) { + std::map attr_map = GetAllAttrs(std::move(obj)); + if (attr_map.empty()) { return ""; } - std::map ordered_attrs; - for (auto &attr : *(attrs_map.GetProtoMsg())) { - ordered_attrs[attr.first] = attr.second.SerializeAsString(); + for (auto &attr : attr_map) { + proto::AttrDef attr_def; + auto *serializer = AttrSerializerRegistry::GetInstance().GetSerializer(attr.second.GetValueTypeId()); + if (serializer == nullptr || serializer->Serialize(attr.second, attr_def) != GRAPH_SUCCESS) { + ordered_attrs[attr.first] = ""; + continue; + } + + if (attr_def.has_t()) { + // print tensor desc message as an ordered string. + std::string ordered_tensor_desc; + (void)google::protobuf::TextFormat::PrintToString(attr_def.t().desc(), &ordered_tensor_desc); + ordered_attrs[attr.first] = ordered_tensor_desc + attr_def.t().data(); + } else if (attr_def.has_td()) { + // print tensor desc message as an ordered string. + string ordered_attr; + (void)google::protobuf::TextFormat::PrintToString(attr_def.td(), &ordered_attr); + ordered_attrs[attr.first] = ordered_attr; + } else { + ordered_attrs[attr.first] = attr_def.SerializeAsString(); + } } std::stringstream ss; for (auto &attr : ordered_attrs) { - if (un_compute_attrs.find(attr.first) != un_compute_attrs.end()) { - continue; - } ss << attr.first << ":" << attr.second << ";"; } - return ss.str(); } } // namespace ge diff --git a/tests/st/CMakeLists.txt b/tests/st/CMakeLists.txt index a835c25..5f7bf5a 100644 --- a/tests/st/CMakeLists.txt +++ b/tests/st/CMakeLists.txt @@ -79,6 +79,23 @@ set(MATEDEF_SRC_FILES "${PARSER_DIR}/metadef/graph/detail/attributes_holder.cc" "${PARSER_DIR}/metadef/graph/format_refiner.cc" "${PARSER_DIR}/metadef/graph/ge_attr_define.cc" + "${PARSER_DIR}/metadef/graph/any_value.cc" + "${PARSER_DIR}/metadef/graph/attr_store.cc" + "${PARSER_DIR}/metadef/graph/serialization/attr_serializer_registry.cc" + "${PARSER_DIR}/metadef/graph/serialization/attr_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/bool_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/buffer_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/data_type_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/float_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/graph_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/int_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/list_list_float_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/list_list_int_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/list_value_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/named_attrs_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/string_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/tensor_desc_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/tensor_serializer.cc" "${PARSER_DIR}/metadef/graph/ge_tensor.cc" "${PARSER_DIR}/metadef/graph/gnode.cc" "${PARSER_DIR}/metadef/graph/graph.cc" @@ -119,6 +136,7 @@ include_directories(${PARSER_DIR}/metadef/inc) include_directories(${PARSER_DIR}/metadef/inc/graph) include_directories(${PARSER_DIR}/metadef/inc/external) include_directories(${PARSER_DIR}/metadef/inc/external/graph) +include_directories(${PARSER_DIR}/metadef/register/op_tiling) include_directories(${PARSER_DIR}/metadef/graph) include_directories(${PARSER_DIR}/metadef/third_party) include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc) @@ -165,8 +183,8 @@ set(REGISTER_SRC_FILES "${PARSER_DIR}/metadef/register/infer_data_slice_registry.cc" "${PARSER_DIR}/metadef/register/ops_kernel_builder_registry.cc" "${PARSER_DIR}/metadef/register/op_kernel_registry.cpp" - "${PARSER_DIR}/metadef/register/op_tiling.cpp" - "${PARSER_DIR}/metadef/register/op_tiling_registry.cpp" + "${PARSER_DIR}/metadef/register/op_tiling/op_tiling.cc" + "${PARSER_DIR}/metadef/register/op_tiling/op_tiling_registry.cc" "${PARSER_DIR}/metadef/register/register.cpp" "${PARSER_DIR}/metadef/register/register_format_transfer.cc" "${PARSER_DIR}/metadef/register/register_pass.cpp" @@ -184,6 +202,7 @@ include_directories(${CMAKE_CURRENT_LIST_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) include_directories(${PARSER_DIR}/metadef) include_directories(${PARSER_DIR}/metadef/graph) +include_directories(${PARSER_DIR}/metadef/register) include_directories(${PARSER_DIR}/metadef/inc) include_directories(${PARSER_DIR}/metadef/inc/external) include_directories(${PARSER_DIR}/metadef/inc/register) diff --git a/tests/ut/parser/CMakeLists.txt b/tests/ut/parser/CMakeLists.txt index 8deb30c..58da242 100644 --- a/tests/ut/parser/CMakeLists.txt +++ b/tests/ut/parser/CMakeLists.txt @@ -81,6 +81,23 @@ set(MATEDEF_SRC_FILES "${PARSER_DIR}/metadef/graph/detail/attributes_holder.cc" "${PARSER_DIR}/metadef/graph/format_refiner.cc" "${PARSER_DIR}/metadef/graph/ge_attr_define.cc" + "${PARSER_DIR}/metadef/graph/any_value.cc" + "${PARSER_DIR}/metadef/graph/attr_store.cc" + "${PARSER_DIR}/metadef/graph/serialization/attr_serializer_registry.cc" + "${PARSER_DIR}/metadef/graph/serialization/attr_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/bool_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/buffer_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/data_type_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/float_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/graph_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/int_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/list_list_float_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/list_list_int_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/list_value_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/named_attrs_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/string_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/tensor_desc_serializer.cc" + "${PARSER_DIR}/metadef/graph/serialization/tensor_serializer.cc" "${PARSER_DIR}/metadef/graph/ge_tensor.cc" "${PARSER_DIR}/metadef/graph/gnode.cc" "${PARSER_DIR}/metadef/graph/graph.cc" @@ -122,6 +139,7 @@ include_directories(${PARSER_DIR}/metadef/inc/graph) include_directories(${PARSER_DIR}/metadef/inc/external) include_directories(${PARSER_DIR}/metadef/inc/external/graph) include_directories(${PARSER_DIR}/metadef/graph) +include_directories(${PARSER_DIR}/metadef/register) include_directories(${PARSER_DIR}/metadef/third_party) include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc) include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external) @@ -167,8 +185,8 @@ set(REGISTER_SRC_FILES "${PARSER_DIR}/metadef/register/infer_data_slice_registry.cc" "${PARSER_DIR}/metadef/register/ops_kernel_builder_registry.cc" "${PARSER_DIR}/metadef/register/op_kernel_registry.cpp" - "${PARSER_DIR}/metadef/register/op_tiling.cpp" - "${PARSER_DIR}/metadef/register/op_tiling_registry.cpp" + "${PARSER_DIR}/metadef/register/op_tiling/op_tiling.cc" + "${PARSER_DIR}/metadef/register/op_tiling/op_tiling_registry.cc" "${PARSER_DIR}/metadef/register/register.cpp" "${PARSER_DIR}/metadef/register/register_format_transfer.cc" "${PARSER_DIR}/metadef/register/register_pass.cpp" @@ -186,6 +204,7 @@ include_directories(${CMAKE_CURRENT_LIST_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) include_directories(${PARSER_DIR}/metadef) include_directories(${PARSER_DIR}/metadef/graph) +include_directories(${PARSER_DIR}/metadef/register) include_directories(${PARSER_DIR}/metadef/inc) include_directories(${PARSER_DIR}/metadef/inc/external) include_directories(${PARSER_DIR}/metadef/inc/register)