diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h index 6a6bfd3dcf..7000c15c37 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h @@ -214,8 +214,10 @@ enum PrimType { PrimType_ResizeGrad = 187, PrimType_Splice = 188, PrimType_LogSoftmax = 189, + PrimType_Call = 190, + PrimType_Custom = 191, PrimType_MIN = PrimType_NONE, - PrimType_MAX = PrimType_LogSoftmax + 1 + PrimType_MAX = PrimType_Custom + 1 }; void RegInfer(int prim_type, InferShape func); diff --git a/mindspore/core/ops/custom.cc b/mindspore/core/ops/custom.cc index dbcd5f528b..e01cf2f17a 100644 --- a/mindspore/core/ops/custom.cc +++ b/mindspore/core/ops/custom.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,19 +15,41 @@ */ #include "ops/custom.h" -#include "utils/check_convert_utils.h" -#include "abstract/primitive_infer_map.h" -#include "ops/op_utils.h" +#include +#include namespace mindspore { namespace ops { -void Custom::Init(const std::vector &custom) { this->set_custom(custom); } +void Custom::Init(const std::string &type, const std::map> &attrs) { + this->set_type(type); + this->set_attr(attrs); +} + +void Custom::set_type(const std::string &type) { this->AddAttr(kType, MakeValue(type)); } -void Custom::set_custom(const std::vector &custom) { this->AddAttr(kCustom, MakeValue(custom)); } +std::string Custom::get_type() const { + auto value_ptr = this->GetAttr(kType); + return GetValue(value_ptr); +} + +void Custom::set_attr(const std::map> &attrs) { + ValuePtrList value_ptr_list; + for (const auto &attr : attrs) { + value_ptr_list.emplace_back(MakeValue(attr.first)); + value_ptr_list.emplace_back(MakeValue>(attr.second)); + } + this->AddAttr(kAttr, MakeValue(value_ptr_list)); +} -std::vector Custom::get_custom() const { - auto value_ptr = this->GetAttr(kCustom); - return GetValue>(value_ptr); +std::map> Custom::get_attr() const { + std::map> attrs; + auto value_ptr_list = GetValue(this->GetAttr(kAttr)); + for (size_t i = 0; i < value_ptr_list.size(); i += 2) { + auto key = GetValue(value_ptr_list[i]); + auto value = GetValue>(value_ptr_list[i + 1]); + attrs[key] = value; + } + return attrs; } REGISTER_PRIMITIVE_C(kNameCustom, Custom); } // namespace ops diff --git a/mindspore/core/ops/custom.h b/mindspore/core/ops/custom.h index 9773f3c7dc..52a5204945 100644 --- a/mindspore/core/ops/custom.h +++ b/mindspore/core/ops/custom.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,12 +16,15 @@ #ifndef MINDSPORE_CORE_OPS_CUSTOM_H_ #define MINDSPORE_CORE_OPS_CUSTOM_H_ -#include - +#include +#include #include +#include +#include +#include #include "ops/primitive_c.h" -#include "abstract/abstract_value.h" -#include "utils/check_convert_utils.h" +#include "ops/op_utils.h" +#include "ir/anf.h" namespace mindspore { namespace ops { @@ -29,11 +32,13 @@ constexpr auto kNameCustom = "Custom"; class Custom : public PrimitiveC { public: Custom() : PrimitiveC(kNameCustom) {} - ~Custom() = default; + ~Custom() override = default; MS_DECLARE_PARENT(Custom, PrimitiveC); - void Init(const std::vector &custom); - void set_custom(const std::vector &custom); - std::vector get_custom() const; + void Init(const std::string &type, const std::map> &attrs); + void set_type(const std::string &type); + std::string get_type() const; + void set_attr(const std::map> &attrs); + std::map> get_attr() const; }; using PrimCustomPtr = std::shared_ptr; diff --git a/mindspore/core/ops/op_utils.h b/mindspore/core/ops/op_utils.h index 0d3f5923ea..a1bf7bbb1e 100644 --- a/mindspore/core/ops/op_utils.h +++ b/mindspore/core/ops/op_utils.h @@ -31,6 +31,7 @@ constexpr auto kActivation = "activation"; constexpr auto kActivationType = "activation_type"; constexpr auto kAddress = "address"; constexpr auto kAlignCorners = "align_corners"; +constexpr auto kAttr = "attr"; constexpr auto kAspectRatios = "aspect_ratios"; constexpr auto kAxes = "axes"; constexpr auto kAxis = "axis"; diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index b959d66e6d..62f1ad4141 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -208,6 +208,7 @@ union PrimitiveType { Splice, LogSoftmax, Call, + Custom, } table Abs { @@ -1103,3 +1104,8 @@ table LogSoftmax { table Call { } + +table Custom { + type: string; + attr: [Attribute]; +} diff --git a/mindspore/lite/schema/ops_types.fbs b/mindspore/lite/schema/ops_types.fbs index 6d9e6ebfed..7ef27613c3 100644 --- a/mindspore/lite/schema/ops_types.fbs +++ b/mindspore/lite/schema/ops_types.fbs @@ -142,3 +142,8 @@ table Vec { table Vec2D { data: [Vec]; } + +table Attribute { + name: string; + data: [ubyte]; +} diff --git a/mindspore/lite/src/ops/ops_def.cc b/mindspore/lite/src/ops/ops_def.cc index 4efb906c84..5b17ecef5b 100644 --- a/mindspore/lite/src/ops/ops_def.cc +++ b/mindspore/lite/src/ops/ops_def.cc @@ -207,6 +207,7 @@ OP_TYPE(ResizeGrad) OP_TYPE(Splice) OP_TYPE(LogSoftmax) OP_TYPE(Call) +OP_TYPE(Custom) OP_TYPE_DEF_END(PrimitiveType) OP_SCHEMA_DEF(Abs) @@ -1102,3 +1103,8 @@ OP_SCHEMA_DEF_END(LogSoftmax) OP_SCHEMA_DEF(Call) OP_SCHEMA_DEF_END(Call) + +OP_SCHEMA_DEF_ONLY(Custom) +OP_ATTR_ONLY(type, string) +OP_ATTR_ONLY(attr, [Attribute]) +OP_SCHEMA_DEF_ONLY_END(Custom) diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index d2808e3732..eaddf86232 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -975,6 +975,39 @@ RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator); RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator); RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator); RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator); + +schema::PrimitiveT *CustomPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + auto *schema_op = new (std::nothrow) schema::CustomT(); + if (schema_op == nullptr) { + return nullptr; + } + if (ms_primc->GetAttr("type") != nullptr) { + schema_op->type = ms_primc->get_type(); + } + if (ms_primc->GetAttr("attr") != nullptr) { + auto attr_map = ms_primc->get_attr(); + for (const auto &attr_item : attr_map) { + auto *attr = new (std::nothrow) schema::AttributeT(); + if (attr == nullptr) { + return nullptr; + } + attr->name = attr_item.first; + attr->data = attr_item.second; + schema_op->attr.emplace_back(attr); + } + } + + auto *prim = new (std::nothrow) schema::PrimitiveT(); + if (prim == nullptr) { + return nullptr; + } + prim->value.value = schema_op; + prim->value.type = schema::PrimitiveType_Custom; + return prim; +} + +RegistryMSOps g_CustomPrimitiveCreatorRegistry("Custom", CustomPrimitiveCreator); } // namespace lite } // namespace mindspore