Browse Source

modify custom

pull/15631/head
liuyu 4 years ago
parent
commit
4e07e5bf18
8 changed files with 99 additions and 19 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h
  2. +31
    -9
      mindspore/core/ops/custom.cc
  3. +14
    -9
      mindspore/core/ops/custom.h
  4. +1
    -0
      mindspore/core/ops/op_utils.h
  5. +6
    -0
      mindspore/lite/schema/ops.fbs
  6. +5
    -0
      mindspore/lite/schema/ops_types.fbs
  7. +6
    -0
      mindspore/lite/src/ops/ops_def.cc
  8. +33
    -0
      mindspore/lite/src/ops/ops_utils.cc

+ 3
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h View File

@@ -214,8 +214,10 @@ enum PrimType {
PrimType_ResizeGrad = 187, PrimType_ResizeGrad = 187,
PrimType_Splice = 188, PrimType_Splice = 188,
PrimType_LogSoftmax = 189, PrimType_LogSoftmax = 189,
PrimType_Call = 190,
PrimType_Custom = 191,
PrimType_MIN = PrimType_NONE, PrimType_MIN = PrimType_NONE,
PrimType_MAX = PrimType_LogSoftmax + 1
PrimType_MAX = PrimType_Custom + 1
}; };


void RegInfer(int prim_type, InferShape func); void RegInfer(int prim_type, InferShape func);


+ 31
- 9
mindspore/core/ops/custom.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -15,19 +15,41 @@
*/ */


#include "ops/custom.h" #include "ops/custom.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "ops/op_utils.h"
#include <memory>
#include <map>


namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
void Custom::Init(const std::vector<int64_t> &custom) { this->set_custom(custom); }
void Custom::Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs) {
this->set_type(type);
this->set_attr(attrs);
}

void Custom::set_type(const std::string &type) { this->AddAttr(kType, MakeValue(type)); }


void Custom::set_custom(const std::vector<int64_t> &custom) { this->AddAttr(kCustom, MakeValue(custom)); }
std::string Custom::get_type() const {
auto value_ptr = this->GetAttr(kType);
return GetValue<std::string>(value_ptr);
}

void Custom::set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs) {
ValuePtrList value_ptr_list;
for (const auto &attr : attrs) {
value_ptr_list.emplace_back(MakeValue<std::string>(attr.first));
value_ptr_list.emplace_back(MakeValue<std::vector<uint8_t>>(attr.second));
}
this->AddAttr(kAttr, MakeValue(value_ptr_list));
}


std::vector<int64_t> Custom::get_custom() const {
auto value_ptr = this->GetAttr(kCustom);
return GetValue<std::vector<int64_t>>(value_ptr);
std::map<std::string, std::vector<uint8_t>> Custom::get_attr() const {
std::map<std::string, std::vector<uint8_t>> attrs;
auto value_ptr_list = GetValue<ValuePtrList>(this->GetAttr(kAttr));
for (size_t i = 0; i < value_ptr_list.size(); i += 2) {
auto key = GetValue<std::string>(value_ptr_list[i]);
auto value = GetValue<std::vector<uint8_t>>(value_ptr_list[i + 1]);
attrs[key] = value;
}
return attrs;
} }
REGISTER_PRIMITIVE_C(kNameCustom, Custom); REGISTER_PRIMITIVE_C(kNameCustom, Custom);
} // namespace ops } // namespace ops


+ 14
- 9
mindspore/core/ops/custom.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -16,12 +16,15 @@


#ifndef MINDSPORE_CORE_OPS_CUSTOM_H_ #ifndef MINDSPORE_CORE_OPS_CUSTOM_H_
#define MINDSPORE_CORE_OPS_CUSTOM_H_ #define MINDSPORE_CORE_OPS_CUSTOM_H_
#include <memory>
#include <string>
#include <utility>
#include <vector> #include <vector>
#include <map>
#include <memory>
#include <unordered_map>
#include "ops/primitive_c.h" #include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "ir/anf.h"


namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
@@ -29,11 +32,13 @@ constexpr auto kNameCustom = "Custom";
class Custom : public PrimitiveC { class Custom : public PrimitiveC {
public: public:
Custom() : PrimitiveC(kNameCustom) {} Custom() : PrimitiveC(kNameCustom) {}
~Custom() = default;
~Custom() override = default;
MS_DECLARE_PARENT(Custom, PrimitiveC); MS_DECLARE_PARENT(Custom, PrimitiveC);
void Init(const std::vector<int64_t> &custom);
void set_custom(const std::vector<int64_t> &custom);
std::vector<int64_t> get_custom() const;
void Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs);
void set_type(const std::string &type);
std::string get_type() const;
void set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs);
std::map<std::string, std::vector<uint8_t>> get_attr() const;
}; };


using PrimCustomPtr = std::shared_ptr<Custom>; using PrimCustomPtr = std::shared_ptr<Custom>;


+ 1
- 0
mindspore/core/ops/op_utils.h View File

@@ -31,6 +31,7 @@ constexpr auto kActivation = "activation";
constexpr auto kActivationType = "activation_type"; constexpr auto kActivationType = "activation_type";
constexpr auto kAddress = "address"; constexpr auto kAddress = "address";
constexpr auto kAlignCorners = "align_corners"; constexpr auto kAlignCorners = "align_corners";
constexpr auto kAttr = "attr";
constexpr auto kAspectRatios = "aspect_ratios"; constexpr auto kAspectRatios = "aspect_ratios";
constexpr auto kAxes = "axes"; constexpr auto kAxes = "axes";
constexpr auto kAxis = "axis"; constexpr auto kAxis = "axis";


+ 6
- 0
mindspore/lite/schema/ops.fbs View File

@@ -208,6 +208,7 @@ union PrimitiveType {
Splice, Splice,
LogSoftmax, LogSoftmax,
Call, Call,
Custom,
} }


table Abs { table Abs {
@@ -1103,3 +1104,8 @@ table LogSoftmax {


table Call { table Call {
} }

table Custom {
type: string;
attr: [Attribute];
}

+ 5
- 0
mindspore/lite/schema/ops_types.fbs View File

@@ -142,3 +142,8 @@ table Vec {
table Vec2D { table Vec2D {
data: [Vec]; data: [Vec];
} }

table Attribute {
name: string;
data: [ubyte];
}

+ 6
- 0
mindspore/lite/src/ops/ops_def.cc View File

@@ -207,6 +207,7 @@ OP_TYPE(ResizeGrad)
OP_TYPE(Splice) OP_TYPE(Splice)
OP_TYPE(LogSoftmax) OP_TYPE(LogSoftmax)
OP_TYPE(Call) OP_TYPE(Call)
OP_TYPE(Custom)
OP_TYPE_DEF_END(PrimitiveType) OP_TYPE_DEF_END(PrimitiveType)


OP_SCHEMA_DEF(Abs) OP_SCHEMA_DEF(Abs)
@@ -1102,3 +1103,8 @@ OP_SCHEMA_DEF_END(LogSoftmax)


OP_SCHEMA_DEF(Call) OP_SCHEMA_DEF(Call)
OP_SCHEMA_DEF_END(Call) OP_SCHEMA_DEF_END(Call)

OP_SCHEMA_DEF_ONLY(Custom)
OP_ATTR_ONLY(type, string)
OP_ATTR_ONLY(attr, [Attribute])
OP_SCHEMA_DEF_ONLY_END(Custom)

+ 33
- 0
mindspore/lite/src/ops/ops_utils.cc View File

@@ -975,6 +975,39 @@ RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator);
RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator); RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator);
RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator); RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator);
RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator); RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator);

schema::PrimitiveT *CustomPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node);
auto *schema_op = new (std::nothrow) schema::CustomT();
if (schema_op == nullptr) {
return nullptr;
}
if (ms_primc->GetAttr("type") != nullptr) {
schema_op->type = ms_primc->get_type();
}
if (ms_primc->GetAttr("attr") != nullptr) {
auto attr_map = ms_primc->get_attr();
for (const auto &attr_item : attr_map) {
auto *attr = new (std::nothrow) schema::AttributeT();
if (attr == nullptr) {
return nullptr;
}
attr->name = attr_item.first;
attr->data = attr_item.second;
schema_op->attr.emplace_back(attr);
}
}

auto *prim = new (std::nothrow) schema::PrimitiveT();
if (prim == nullptr) {
return nullptr;
}
prim->value.value = schema_op;
prim->value.type = schema::PrimitiveType_Custom;
return prim;
}

RegistryMSOps g_CustomPrimitiveCreatorRegistry("Custom", CustomPrimitiveCreator);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore




Loading…
Cancel
Save