Browse Source

!11430 add del_attr function

From: @changzherui
Reviewed-by: @kingxian,@zh_qh
Signed-off-by: @kingxian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
a58413479e
4 changed files with 30 additions and 4 deletions
  1. +6
    -0
      mindspore/ccsrc/pybind_api/ir/primitive_py.cc
  2. +2
    -0
      mindspore/ccsrc/pybind_api/ir/primitive_py.h
  3. +5
    -0
      mindspore/core/ir/primitive.h
  4. +17
    -4
      mindspore/ops/primitive.py

+ 6
- 0
mindspore/ccsrc/pybind_api/ir/primitive_py.cc View File

@@ -283,6 +283,11 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
(void)this->AddAttr(attr_name, converted_ret);
}

void PrimitivePy::DelPyAttr(const py::str &name) {
std::string attr_name = name;
(void)this->DelAttr(attr_name);
}

py::dict PrimitivePy::GetAttrDict() {
py::dict attr_dict;
for (auto &attr : attrs_) {
@@ -378,6 +383,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
.def(py::init<py::str &, py::object>())
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
.def("del_attr", &PrimitivePy::DelPyAttr, "del primitive attr")
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
.def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.")


+ 2
- 0
mindspore/ccsrc/pybind_api/ir/primitive_py.h View File

@@ -49,6 +49,8 @@ class PrimitivePy : public Primitive {

void AddPyAttr(const py::str &name, const py::object &obj);

void DelPyAttr(const py::str &name);

py::dict GetAttrDict();
void set_hook(const py::function &hook) { hook_ = hook; }
py::function hook() const { return hook_; }


+ 5
- 0
mindspore/core/ir/primitive.h View File

@@ -60,6 +60,11 @@ class Primitive : public Named {
return *this;
}

Primitive &DelAttr(const std::string &name) {
attrs_.erase(name);
return *this;
}

Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
attrs_[attr.first] = attr.second;


+ 17
- 4
mindspore/ops/primitive.py View File

@@ -103,6 +103,19 @@ class Primitive(Primitive_):
self.add_attr(name, value)
return self

def del_prim_attr(self, name):
"""
Del primitive attribute.

Args:
name (str): Attribute Name.
"""
if name in self.__dict__ and name in self.attrs:
del self.__dict__[name]
del self.attrs[name]
self.del_attr(name)
return self

def set_stage(self, stage):
"""
Add stage id to primitive attribute.
@@ -191,7 +204,7 @@ class Primitive(Primitive_):

def init_prim_io_names(self, inputs, outputs):
"""
Initializes the name of inputs and outpus of Tensor or attributes.
Initializes the name of inputs and outputs of Tensor or attributes.

Args:
inputs (list[str]): list of inputs names.
@@ -222,9 +235,9 @@ class Primitive(Primitive_):
class PrimitiveWithCheck(Primitive):
"""
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments
but used the infer method registed in c++ source codes.
but used the infer method registered in c++ source codes.

There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(),
There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(),
check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called.
If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of
the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
@@ -301,7 +314,7 @@ class PrimitiveWithInfer(Primitive):
"""
PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python.

There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(),
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer
logic of the shape and type. The infer_value() is used for constant propagation.


Loading…
Cancel
Save