Browse Source

!11430 add del_attr function

From: @changzherui
Reviewed-by: @kingxian,@zh_qh
Signed-off-by: @kingxian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
a58413479e
4 changed files with 30 additions and 4 deletions
  1. +6
    -0
      mindspore/ccsrc/pybind_api/ir/primitive_py.cc
  2. +2
    -0
      mindspore/ccsrc/pybind_api/ir/primitive_py.h
  3. +5
    -0
      mindspore/core/ir/primitive.h
  4. +17
    -4
      mindspore/ops/primitive.py

+ 6
- 0
mindspore/ccsrc/pybind_api/ir/primitive_py.cc View File

@@ -283,6 +283,11 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
(void)this->AddAttr(attr_name, converted_ret); (void)this->AddAttr(attr_name, converted_ret);
} }


void PrimitivePy::DelPyAttr(const py::str &name) {
std::string attr_name = name;
(void)this->DelAttr(attr_name);
}

py::dict PrimitivePy::GetAttrDict() { py::dict PrimitivePy::GetAttrDict() {
py::dict attr_dict; py::dict attr_dict;
for (auto &attr : attrs_) { for (auto &attr : attrs_) {
@@ -378,6 +383,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
.def(py::init<py::str &, py::object>()) .def(py::init<py::str &, py::object>())
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
.def("del_attr", &PrimitivePy::DelPyAttr, "del primitive attr")
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
.def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.") .def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.")


+ 2
- 0
mindspore/ccsrc/pybind_api/ir/primitive_py.h View File

@@ -49,6 +49,8 @@ class PrimitivePy : public Primitive {


void AddPyAttr(const py::str &name, const py::object &obj); void AddPyAttr(const py::str &name, const py::object &obj);


void DelPyAttr(const py::str &name);

py::dict GetAttrDict(); py::dict GetAttrDict();
void set_hook(const py::function &hook) { hook_ = hook; } void set_hook(const py::function &hook) { hook_ = hook; }
py::function hook() const { return hook_; } py::function hook() const { return hook_; }


+ 5
- 0
mindspore/core/ir/primitive.h View File

@@ -60,6 +60,11 @@ class Primitive : public Named {
return *this; return *this;
} }


Primitive &DelAttr(const std::string &name) {
attrs_.erase(name);
return *this;
}

Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) { Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) { for (auto &attr : attrs) {
attrs_[attr.first] = attr.second; attrs_[attr.first] = attr.second;


+ 17
- 4
mindspore/ops/primitive.py View File

@@ -103,6 +103,19 @@ class Primitive(Primitive_):
self.add_attr(name, value) self.add_attr(name, value)
return self return self


def del_prim_attr(self, name):
"""
Del primitive attribute.

Args:
name (str): Attribute Name.
"""
if name in self.__dict__ and name in self.attrs:
del self.__dict__[name]
del self.attrs[name]
self.del_attr(name)
return self

def set_stage(self, stage): def set_stage(self, stage):
""" """
Add stage id to primitive attribute. Add stage id to primitive attribute.
@@ -191,7 +204,7 @@ class Primitive(Primitive_):


def init_prim_io_names(self, inputs, outputs): def init_prim_io_names(self, inputs, outputs):
""" """
Initializes the name of inputs and outpus of Tensor or attributes.
Initializes the name of inputs and outputs of Tensor or attributes.


Args: Args:
inputs (list[str]): list of inputs names. inputs (list[str]): list of inputs names.
@@ -222,9 +235,9 @@ class Primitive(Primitive_):
class PrimitiveWithCheck(Primitive): class PrimitiveWithCheck(Primitive):
""" """
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments
but used the infer method registed in c++ source codes.
but used the infer method registered in c++ source codes.


There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(),
There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(),
check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called. check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called.
If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of
the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation. the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
@@ -301,7 +314,7 @@ class PrimitiveWithInfer(Primitive):
""" """
PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python. PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python.


There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(),
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer
logic of the shape and type. The infer_value() is used for constant propagation. logic of the shape and type. The infer_value() is used for constant propagation.


Loading…
Cancel
Save