| @@ -214,8 +214,10 @@ enum PrimType { | |||||
| PrimType_ResizeGrad = 187, | PrimType_ResizeGrad = 187, | ||||
| PrimType_Splice = 188, | PrimType_Splice = 188, | ||||
| PrimType_LogSoftmax = 189, | PrimType_LogSoftmax = 189, | ||||
| PrimType_Call = 190, | |||||
| PrimType_Custom = 191, | |||||
| PrimType_MIN = PrimType_NONE, | PrimType_MIN = PrimType_NONE, | ||||
| PrimType_MAX = PrimType_LogSoftmax + 1 | |||||
| PrimType_MAX = PrimType_Custom + 1 | |||||
| }; | }; | ||||
| void RegInfer(int prim_type, InferShape func); | void RegInfer(int prim_type, InferShape func); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -15,19 +15,41 @@ | |||||
| */ | */ | ||||
| #include "ops/custom.h" | #include "ops/custom.h" | ||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include <memory> | |||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| void Custom::Init(const std::vector<int64_t> &custom) { this->set_custom(custom); } | |||||
| void Custom::Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs) { | |||||
| this->set_type(type); | |||||
| this->set_attr(attrs); | |||||
| } | |||||
| void Custom::set_type(const std::string &type) { this->AddAttr(kType, MakeValue(type)); } | |||||
| void Custom::set_custom(const std::vector<int64_t> &custom) { this->AddAttr(kCustom, MakeValue(custom)); } | |||||
| std::string Custom::get_type() const { | |||||
| auto value_ptr = this->GetAttr(kType); | |||||
| return GetValue<std::string>(value_ptr); | |||||
| } | |||||
| void Custom::set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs) { | |||||
| ValuePtrList value_ptr_list; | |||||
| for (const auto &attr : attrs) { | |||||
| value_ptr_list.emplace_back(MakeValue<std::string>(attr.first)); | |||||
| value_ptr_list.emplace_back(MakeValue<std::vector<uint8_t>>(attr.second)); | |||||
| } | |||||
| this->AddAttr(kAttr, MakeValue(value_ptr_list)); | |||||
| } | |||||
| std::vector<int64_t> Custom::get_custom() const { | |||||
| auto value_ptr = this->GetAttr(kCustom); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| std::map<std::string, std::vector<uint8_t>> Custom::get_attr() const { | |||||
| std::map<std::string, std::vector<uint8_t>> attrs; | |||||
| auto value_ptr_list = GetValue<ValuePtrList>(this->GetAttr(kAttr)); | |||||
| for (size_t i = 0; i < value_ptr_list.size(); i += 2) { | |||||
| auto key = GetValue<std::string>(value_ptr_list[i]); | |||||
| auto value = GetValue<std::vector<uint8_t>>(value_ptr_list[i + 1]); | |||||
| attrs[key] = value; | |||||
| } | |||||
| return attrs; | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameCustom, Custom); | REGISTER_PRIMITIVE_C(kNameCustom, Custom); | ||||
| } // namespace ops | } // namespace ops | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -16,12 +16,15 @@ | |||||
| #ifndef MINDSPORE_CORE_OPS_CUSTOM_H_ | #ifndef MINDSPORE_CORE_OPS_CUSTOM_H_ | ||||
| #define MINDSPORE_CORE_OPS_CUSTOM_H_ | #define MINDSPORE_CORE_OPS_CUSTOM_H_ | ||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include "ops/primitive_c.h" | #include "ops/primitive_c.h" | ||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "ir/anf.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| @@ -29,11 +32,13 @@ constexpr auto kNameCustom = "Custom"; | |||||
| class Custom : public PrimitiveC { | class Custom : public PrimitiveC { | ||||
| public: | public: | ||||
| Custom() : PrimitiveC(kNameCustom) {} | Custom() : PrimitiveC(kNameCustom) {} | ||||
| ~Custom() = default; | |||||
| ~Custom() override = default; | |||||
| MS_DECLARE_PARENT(Custom, PrimitiveC); | MS_DECLARE_PARENT(Custom, PrimitiveC); | ||||
| void Init(const std::vector<int64_t> &custom); | |||||
| void set_custom(const std::vector<int64_t> &custom); | |||||
| std::vector<int64_t> get_custom() const; | |||||
| void Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs); | |||||
| void set_type(const std::string &type); | |||||
| std::string get_type() const; | |||||
| void set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs); | |||||
| std::map<std::string, std::vector<uint8_t>> get_attr() const; | |||||
| }; | }; | ||||
| using PrimCustomPtr = std::shared_ptr<Custom>; | using PrimCustomPtr = std::shared_ptr<Custom>; | ||||
| @@ -31,6 +31,7 @@ constexpr auto kActivation = "activation"; | |||||
| constexpr auto kActivationType = "activation_type"; | constexpr auto kActivationType = "activation_type"; | ||||
| constexpr auto kAddress = "address"; | constexpr auto kAddress = "address"; | ||||
| constexpr auto kAlignCorners = "align_corners"; | constexpr auto kAlignCorners = "align_corners"; | ||||
| constexpr auto kAttr = "attr"; | |||||
| constexpr auto kAspectRatios = "aspect_ratios"; | constexpr auto kAspectRatios = "aspect_ratios"; | ||||
| constexpr auto kAxes = "axes"; | constexpr auto kAxes = "axes"; | ||||
| constexpr auto kAxis = "axis"; | constexpr auto kAxis = "axis"; | ||||
| @@ -208,6 +208,7 @@ union PrimitiveType { | |||||
| Splice, | Splice, | ||||
| LogSoftmax, | LogSoftmax, | ||||
| Call, | Call, | ||||
| Custom, | |||||
| } | } | ||||
| table Abs { | table Abs { | ||||
| @@ -1103,3 +1104,8 @@ table LogSoftmax { | |||||
| table Call { | table Call { | ||||
| } | } | ||||
| table Custom { | |||||
| type: string; | |||||
| attr: [Attribute]; | |||||
| } | |||||
| @@ -142,3 +142,8 @@ table Vec { | |||||
| table Vec2D { | table Vec2D { | ||||
| data: [Vec]; | data: [Vec]; | ||||
| } | } | ||||
| table Attribute { | |||||
| name: string; | |||||
| data: [ubyte]; | |||||
| } | |||||
| @@ -207,6 +207,7 @@ OP_TYPE(ResizeGrad) | |||||
| OP_TYPE(Splice) | OP_TYPE(Splice) | ||||
| OP_TYPE(LogSoftmax) | OP_TYPE(LogSoftmax) | ||||
| OP_TYPE(Call) | OP_TYPE(Call) | ||||
| OP_TYPE(Custom) | |||||
| OP_TYPE_DEF_END(PrimitiveType) | OP_TYPE_DEF_END(PrimitiveType) | ||||
| OP_SCHEMA_DEF(Abs) | OP_SCHEMA_DEF(Abs) | ||||
| @@ -1102,3 +1103,8 @@ OP_SCHEMA_DEF_END(LogSoftmax) | |||||
| OP_SCHEMA_DEF(Call) | OP_SCHEMA_DEF(Call) | ||||
| OP_SCHEMA_DEF_END(Call) | OP_SCHEMA_DEF_END(Call) | ||||
| OP_SCHEMA_DEF_ONLY(Custom) | |||||
| OP_ATTR_ONLY(type, string) | |||||
| OP_ATTR_ONLY(attr, [Attribute]) | |||||
| OP_SCHEMA_DEF_ONLY_END(Custom) | |||||
| @@ -975,6 +975,39 @@ RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator); | |||||
| RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator); | RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator); | ||||
| RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator); | RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator); | ||||
| RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator); | RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator); | ||||
| schema::PrimitiveT *CustomPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node); | |||||
| auto *schema_op = new (std::nothrow) schema::CustomT(); | |||||
| if (schema_op == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| if (ms_primc->GetAttr("type") != nullptr) { | |||||
| schema_op->type = ms_primc->get_type(); | |||||
| } | |||||
| if (ms_primc->GetAttr("attr") != nullptr) { | |||||
| auto attr_map = ms_primc->get_attr(); | |||||
| for (const auto &attr_item : attr_map) { | |||||
| auto *attr = new (std::nothrow) schema::AttributeT(); | |||||
| if (attr == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| attr->name = attr_item.first; | |||||
| attr->data = attr_item.second; | |||||
| schema_op->attr.emplace_back(attr); | |||||
| } | |||||
| } | |||||
| auto *prim = new (std::nothrow) schema::PrimitiveT(); | |||||
| if (prim == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| prim->value.value = schema_op; | |||||
| prim->value.type = schema::PrimitiveType_Custom; | |||||
| return prim; | |||||
| } | |||||
| RegistryMSOps g_CustomPrimitiveCreatorRegistry("Custom", CustomPrimitiveCreator); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||