Merge pull request !5645 from vlne-v1/ref_demotags/v1.0.0
| @@ -20,7 +20,6 @@ from mindspore.common.tensor import Tensor | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype | from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype | ||||
| def scalar_add(x, y): | def scalar_add(x, y): | ||||
| """Implement `scalar_add`.""" | """Implement `scalar_add`.""" | ||||
| return x + y | return x + y | ||||
| @@ -117,25 +116,6 @@ def bool_or(x, y): | |||||
| return x or y | return x or y | ||||
| def vm_compare(*args): | |||||
| """Implement `vm_compare` for tensor.""" | |||||
| obj_str = args[-1] | |||||
| if obj_str == "shape": | |||||
| fn = getattr(args[0].asnumpy(), obj_str) | |||||
| return fn | |||||
| if len(args) == 2: | |||||
| fn = getattr(args[0].asnumpy(), obj_str) | |||||
| return Tensor(fn()) | |||||
| if isinstance(args[0], Tensor): | |||||
| fn = getattr(args[0].asnumpy(), obj_str) | |||||
| y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1] | |||||
| else: | |||||
| obj_str = "__r" + obj_str[2:] | |||||
| fn = getattr(args[1].asnumpy(), obj_str) | |||||
| y = args[0] | |||||
| return Tensor(np.array(fn(y))) | |||||
| def make_list(*xs): | def make_list(*xs): | ||||
| """Implement `make_list`.""" | """Implement `make_list`.""" | ||||
| return list(xs) | return list(xs) | ||||
| @@ -262,6 +262,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||||
| std::set<size_t> write_indices; | std::set<size_t> write_indices; | ||||
| std::vector<TypePtr> input_types; | std::vector<TypePtr> input_types; | ||||
| op_inputs.push_back(NewValueNode(function)); | op_inputs.push_back(NewValueNode(function)); | ||||
| auto cast_type = parse::GetMixedPrecisionTargetType(func_graph); | |||||
| // Assume, the write input of op is always the first input. We check if any write op, | // Assume, the write input of op is always the first input. We check if any write op, | ||||
| // and add cast op on other inputs to keep the same type with assigned parameter. | // and add cast op on other inputs to keep the same type with assigned parameter. | ||||
| for (size_t i = 0; i < args_spec_list.size(); ++i) { | for (size_t i = 0; i < args_spec_list.size(); ++i) { | ||||
| @@ -280,7 +281,6 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||||
| TypePtr type = args_spec_list[i]->BuildType(); | TypePtr type = args_spec_list[i]->BuildType(); | ||||
| if (type && type->isa<RefType>()) { | if (type && type->isa<RefType>()) { | ||||
| auto cast_type = parse::GetMixedPrecisionTargetType(func_graph); | |||||
| if (sig == SignatureEnumRW::kRWRead) { | if (sig == SignatureEnumRW::kRWRead) { | ||||
| auto source_tensor_type = type->cast<TensorTypePtr>(); | auto source_tensor_type = type->cast<TensorTypePtr>(); | ||||
| if (source_tensor_type != nullptr) { | if (source_tensor_type != nullptr) { | ||||
| @@ -300,8 +300,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||||
| MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but " | MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but " | ||||
| << type->ToString(); | << type->ToString(); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type " | |||||
| << args_spec_list[i]->ToString(); | |||||
| MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " abs " | |||||
| << args_spec_list[i]->ToString() << " type " << type->ToString(); | |||||
| input_types.push_back(type); | input_types.push_back(type); | ||||
| op_inputs.push_back(param); | op_inputs.push_back(param); | ||||
| } | } | ||||
| @@ -305,9 +305,6 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||||
| dic[ATTR_SHAPE] = shape; | dic[ATTR_SHAPE] = shape; | ||||
| dic[ATTR_DTYPE] = arg_slice->BuildType(); | dic[ATTR_DTYPE] = arg_slice->BuildType(); | ||||
| dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue()); | dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue()); | ||||
| } else if (abs_base->isa<AbstractRef>()) { | |||||
| auto value = abs_base->cast<AbstractRefPtr>()->ref(); | |||||
| dic = ConvertAbstractToPython(value); | |||||
| } else if (abs_base->isa<AbstractEllipsis>()) { | } else if (abs_base->isa<AbstractEllipsis>()) { | ||||
| dic[ATTR_SHAPE] = py::none(); | dic[ATTR_SHAPE] = py::none(); | ||||
| dic[ATTR_DTYPE] = py::ellipsis(); | dic[ATTR_DTYPE] = py::ellipsis(); | ||||
| @@ -23,7 +23,7 @@ namespace mindspore { | |||||
| REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { | REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { | ||||
| // Define python "MetaFuncGraph_" class | // Define python "MetaFuncGraph_" class | ||||
| (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_") | (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_") | ||||
| .def(py::init<std::string &>()); | |||||
| .def("set_signatures", &MetaFuncGraph::set_signatures, "Set primitive inputs signature."); | |||||
| // Define python "FuncGraph" class | // Define python "FuncGraph" class | ||||
| (void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph") | (void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph") | ||||
| .def(py::init()) | .def(py::init()) | ||||
| @@ -48,22 +48,9 @@ void SyncData(const py::object &arg) { | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| std::map<std::string, py::object> PrimitivePy::hook_grad_; | std::map<std::string, py::object> PrimitivePy::hook_grad_; | ||||
| static ValuePtr PyArgToValue(const py::object &arg) { | |||||
| if (py::isinstance<SignatureEnumKind>(arg) && | |||||
| py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { | |||||
| return nullptr; | |||||
| } | |||||
| return parse::data_converter::PyDataToValue(arg); | |||||
| } | |||||
| void PrimitivePy::set_signatures( | |||||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) { | |||||
| signatures_.clear(); | |||||
| for (auto &signature : signatures) { | |||||
| auto [name, rw, kind, arg_default, dtype] = signature; | |||||
| auto default_value = PyArgToValue(arg_default); | |||||
| signatures_.emplace_back(name, rw, kind, default_value, dtype); | |||||
| } | |||||
| void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) { | |||||
| signatures_ = signatures; | |||||
| set_has_signature(true); | set_has_signature(true); | ||||
| } | } | ||||
| @@ -42,9 +42,7 @@ class PrimitivePy : public Primitive { | |||||
| MS_DECLARE_PARENT(PrimitivePy, Primitive); | MS_DECLARE_PARENT(PrimitivePy, Primitive); | ||||
| py::function GetBpropFunction(); | py::function GetBpropFunction(); | ||||
| void set_signatures( | |||||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> | |||||
| signatures); | |||||
| void set_signatures(const std::vector<Signature> &signatures); | |||||
| const std::vector<Signature> &signatures() const { return signatures_; } | const std::vector<Signature> &signatures() const { return signatures_; } | ||||
| @@ -17,12 +17,26 @@ | |||||
| #include "ir/signature.h" | #include "ir/signature.h" | ||||
| #include "pybind11/operators.h" | #include "pybind11/operators.h" | ||||
| #include "pybind_api/api_register.h" | #include "pybind_api/api_register.h" | ||||
| #include "pipeline/jit/parse/data_converter.h" | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| static ValuePtr PyArgToValue(const py::object &arg) { | |||||
| if (py::isinstance<SignatureEnumKind>(arg) && | |||||
| py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { | |||||
| return nullptr; | |||||
| } | |||||
| return parse::data_converter::PyDataToValue(arg); | |||||
| } | |||||
| // Bind SignatureEnumRW as a python class. | // Bind SignatureEnumRW as a python class. | ||||
| REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { | REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { | ||||
| (void)py::class_<Signature>(*m, "Signature") | |||||
| .def(py::init([](std::string name, SignatureEnumRW rw, SignatureEnumKind kind, | |||||
| py::object arg_default, SignatureEnumDType dtype) { | |||||
| auto default_value = PyArgToValue(arg_default); | |||||
| return Signature(name, rw, kind, default_value, dtype); | |||||
| })); | |||||
| (void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic()) | (void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic()) | ||||
| .value("RW_READ", SignatureEnumRW::kRWRead) | .value("RW_READ", SignatureEnumRW::kRWRead) | ||||
| .value("RW_WRITE", SignatureEnumRW::kRWWrite) | .value("RW_WRITE", SignatureEnumRW::kRWWrite) | ||||
| @@ -393,3 +393,24 @@ class SparseTensor: | |||||
| @property | @property | ||||
| def dense_shape(self): | def dense_shape(self): | ||||
| return self.__dense_shape | return self.__dense_shape | ||||
| def _vm_compare(*args): | |||||
| """Implement `vm_compare` for tensor.""" | |||||
| obj_str = args[-1] | |||||
| if obj_str == "shape": | |||||
| fn = getattr(args[0].asnumpy(), obj_str) | |||||
| return fn | |||||
| if len(args) == 2: | |||||
| fn = getattr(args[0].asnumpy(), obj_str) | |||||
| return Tensor(fn()) | |||||
| if isinstance(args[0], Tensor): | |||||
| fn = getattr(args[0].asnumpy(), obj_str) | |||||
| y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1] | |||||
| else: | |||||
| obj_str = "__r" + obj_str[2:] | |||||
| fn = getattr(args[1].asnumpy(), obj_str) | |||||
| y = args[0] | |||||
| return Tensor(np.array(fn(y))) | |||||
| tensor_operator_registry.register('vm_compare', _vm_compare) | |||||
| @@ -34,14 +34,17 @@ from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||||
| from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry | from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry | ||||
| from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType | from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType | ||||
| from .primitive import constexpr | from .primitive import constexpr | ||||
| from .._c_expression import signature_rw, signature_kind | |||||
| from . import composite, operations, functional | |||||
| from . import signature | |||||
| __primitive__ = [ | __primitive__ = [ | ||||
| "prim_attr_register", "Primitive", "PrimitiveWithInfer", | |||||
| "signature_rw", "signature_kind" | |||||
| "prim_attr_register", "Primitive", "PrimitiveWithInfer", "signature" | |||||
| ] | ] | ||||
| __all__ = ["get_vm_impl_fn", "vm_impl_registry", | __all__ = ["get_vm_impl_fn", "vm_impl_registry", | ||||
| "op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType", | "op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType", | ||||
| "constexpr"] | "constexpr"] | ||||
| __all__.extend(__primitive__) | __all__.extend(__primitive__) | ||||
| __all__.extend(composite.__all__) | |||||
| __all__.extend(operations.__all__) | |||||
| __all__.extend(functional.__all__) | |||||
| @@ -25,9 +25,8 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, Mult | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.api import ms_function, _pynative_exec, _wrap_func | from ...common.api import ms_function, _pynative_exec, _wrap_func | ||||
| from .. import functional as F | from .. import functional as F | ||||
| from ...common.parameter import Parameter | |||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| from .. import signature as sig | |||||
| __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | ||||
| @@ -348,6 +347,8 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): | |||||
| Args: | Args: | ||||
| name (str): Operator name. | name (str): Operator name. | ||||
| read_value (bool): If the registered function not need to set value on Parameter, | |||||
| and all inputs will pass by value. Set `read_value` to True. Default: False. | |||||
| Raises: | Raises: | ||||
| ValueError: Cannot find matching fn for the given args. | ValueError: Cannot find matching fn for the given args. | ||||
| @@ -358,16 +359,15 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): | |||||
| >>> add = MultitypeFuncGraph('add') | >>> add = MultitypeFuncGraph('add') | ||||
| """ | """ | ||||
| def __init__(self, name): | |||||
| def __init__(self, name, read_value=False): | |||||
| MultitypeFuncGraph_.__init__(self, name) | MultitypeFuncGraph_.__init__(self, name) | ||||
| self.entries = list() | self.entries = list() | ||||
| if read_value: | |||||
| self.set_signatures(( | |||||
| sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),)) | |||||
| def __call__(self, *args): | def __call__(self, *args): | ||||
| def unwrap(arg): | |||||
| if isinstance(arg, Parameter): | |||||
| return arg.data | |||||
| return arg | |||||
| types = tuple(map(lambda arg: mstype.get_py_obj_dtype(unwrap(arg)), args)) | |||||
| types = tuple(map(mstype.get_py_obj_dtype, args)) | |||||
| for sigs, fn in self.entries: | for sigs, fn in self.entries: | ||||
| if len(sigs) != len(types): | if len(sigs) != len(types): | ||||
| continue | continue | ||||
| @@ -19,7 +19,7 @@ from ...composite import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| add = base.MultitypeFuncGraph('add') | |||||
| add = base.MultitypeFuncGraph('add', True) | |||||
| """`add` is a metafuncgraph object which will add two objects according to input type using ".register" decorator.""" | """`add` is a metafuncgraph object which will add two objects according to input type using ".register" decorator.""" | ||||
| @@ -19,7 +19,7 @@ from ...composite import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| div = base.MultitypeFuncGraph("div") | |||||
| div = base.MultitypeFuncGraph("div", True) | |||||
| """ | """ | ||||
| div is a metafuncgraph object which will div two objects according to input type | div is a metafuncgraph object which will div two objects according to input type | ||||
| using ".register" decorator | using ".register" decorator | ||||
| @@ -19,7 +19,7 @@ from ...composite import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| equal = base.MultitypeFuncGraph("equal") | |||||
| equal = base.MultitypeFuncGraph("equal", True) | |||||
| """ | """ | ||||
| equal is a metafuncgraph object which will determine if two objects are equal according to input type | equal is a metafuncgraph object which will determine if two objects are equal according to input type | ||||
| using ".register" decorator | using ".register" decorator | ||||
| @@ -19,7 +19,7 @@ from ...composite import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| floordiv = base.MultitypeFuncGraph("floordiv") | |||||
| floordiv = base.MultitypeFuncGraph("floordiv", True) | |||||
| """ | """ | ||||
| `floordiv` is a metafuncgraph object which will compute the floordiv of two objects | `floordiv` is a metafuncgraph object which will compute the floordiv of two objects | ||||
| using ".register" decorator. | using ".register" decorator. | ||||
| @@ -19,7 +19,7 @@ from .. import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| getitem = base.MultitypeFuncGraph('getitem') | |||||
| getitem = base.MultitypeFuncGraph('getitem', True) | |||||
| """ | """ | ||||
| getitem is a metafuncgraph object which will get item from an object according to input type | getitem is a metafuncgraph object which will get item from an object according to input type | ||||
| using ".register" decorator. | using ".register" decorator. | ||||
| @@ -19,7 +19,7 @@ from mindspore.ops import functional as F | |||||
| # greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type | # greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type | ||||
| # using ".register" decorator | # using ".register" decorator | ||||
| greater_equal = base.MultitypeFuncGraph("greater_equal") | |||||
| greater_equal = base.MultitypeFuncGraph("greater_equal", True) | |||||
| @greater_equal.register("Number", "Number") | @greater_equal.register("Number", "Number") | ||||
| @@ -19,7 +19,7 @@ from mindspore.ops import functional as F | |||||
| # greater is a metafuncgraph object which will determine if two objects are greater according to input type | # greater is a metafuncgraph object which will determine if two objects are greater according to input type | ||||
| # using ".register" decorator | # using ".register" decorator | ||||
| greater = base.MultitypeFuncGraph("greater") | |||||
| greater = base.MultitypeFuncGraph("greater", True) | |||||
| @greater.register("Number", "Number") | @greater.register("Number", "Number") | ||||
| @@ -19,7 +19,7 @@ from . import _constexpr_utils as const_utils | |||||
| from ... import functional as F | from ... import functional as F | ||||
| from ...composite import base | from ...composite import base | ||||
| in_ = base.MultitypeFuncGraph("in") | |||||
| in_ = base.MultitypeFuncGraph("in", True) | |||||
| """ | """ | ||||
| in_ is a metafuncgraph object which will determine if a in b | in_ is a metafuncgraph object which will determine if a in b | ||||
| using ".register" decorator | using ".register" decorator | ||||
| @@ -19,7 +19,7 @@ from mindspore.ops import functional as F | |||||
| # less_equal is a metagraph object which will determine if two objects are less_equal according to input type | # less_equal is a metagraph object which will determine if two objects are less_equal according to input type | ||||
| # using ".register" decorator | # using ".register" decorator | ||||
| less_equal = base.MultitypeFuncGraph("less_equal") | |||||
| less_equal = base.MultitypeFuncGraph("less_equal", True) | |||||
| @less_equal.register("Number", "Number") | @less_equal.register("Number", "Number") | ||||
| @@ -19,7 +19,7 @@ from mindspore.ops import functional as F | |||||
| # less is a metafuncgraph object which will determine if two objects are less according to input type | # less is a metafuncgraph object which will determine if two objects are less according to input type | ||||
| # using ".register" decorator | # using ".register" decorator | ||||
| less = base.MultitypeFuncGraph("less") | |||||
| less = base.MultitypeFuncGraph("less", True) | |||||
| @less.register("Number", "Number") | @less.register("Number", "Number") | ||||
| @@ -19,7 +19,7 @@ from mindspore.ops import functional as F | |||||
| # logical_not is a metagraph object which will generate function according to input type | # logical_not is a metagraph object which will generate function according to input type | ||||
| # using ".register" decorator | # using ".register" decorator | ||||
| logical_not = base.MultitypeFuncGraph("logical_not") | |||||
| logical_not = base.MultitypeFuncGraph("logical_not", True) | |||||
| @logical_not.register("Number") | @logical_not.register("Number") | ||||
| @@ -19,7 +19,7 @@ from mindspore.ops import functional as F | |||||
| # logical_and is a metagraph object which will generate function according to input type | # logical_and is a metagraph object which will generate function according to input type | ||||
| # using ".register" decorator | # using ".register" decorator | ||||
| logical_and = base.MultitypeFuncGraph("logical_and") | |||||
| logical_and = base.MultitypeFuncGraph("logical_and", True) | |||||
| @logical_and.register("Number", "Number") | @logical_and.register("Number", "Number") | ||||
| @@ -19,7 +19,7 @@ from mindspore.ops import functional as F | |||||
| # logical_or is a metagraph object which will generate function according to input type | # logical_or is a metagraph object which will generate function according to input type | ||||
| # using ".register" decorator | # using ".register" decorator | ||||
| logical_or = base.MultitypeFuncGraph("logical_or") | |||||
| logical_or = base.MultitypeFuncGraph("logical_or", True) | |||||
| @logical_or.register("Number", "Number") | @logical_or.register("Number", "Number") | ||||
| @@ -19,7 +19,7 @@ from ...composite import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| mod = base.MultitypeFuncGraph("mod") | |||||
| mod = base.MultitypeFuncGraph("mod", True) | |||||
| """ | """ | ||||
| `mod` is a metafuncgraph object which will compute the mod of two objects | `mod` is a metafuncgraph object which will compute the mod of two objects | ||||
| using ".register" decorator. | using ".register" decorator. | ||||
| @@ -19,7 +19,7 @@ from ...composite import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| mul = base.MultitypeFuncGraph("mul") | |||||
| mul = base.MultitypeFuncGraph("mul", True) | |||||
| """ | """ | ||||
| `mul` is a metafuncgraph object which will multiply two objects according to input type | `mul` is a metafuncgraph object which will multiply two objects according to input type | ||||
| using ".register" decorator. | using ".register" decorator. | ||||
| @@ -19,7 +19,7 @@ from ...composite import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| negative = base.MultitypeFuncGraph("negative") | |||||
| negative = base.MultitypeFuncGraph("negative", True) | |||||
| """ | """ | ||||
| `negative` is a metafuncgraph object which will give the negative of an object according to its input type | `negative` is a metafuncgraph object which will give the negative of an object according to its input type | ||||
| using ".register" decorator. | using ".register" decorator. | ||||
| @@ -19,7 +19,7 @@ from ...composite import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| not_equal = base.MultitypeFuncGraph("not_equal") | |||||
| not_equal = base.MultitypeFuncGraph("not_equal", True) | |||||
| """ | """ | ||||
| not_equal is a metafuncgraph object which will determine if two objects are not_equal according to input type | not_equal is a metafuncgraph object which will determine if two objects are not_equal according to input type | ||||
| using ".register" decorator | using ".register" decorator | ||||
| @@ -22,7 +22,7 @@ from ... import functional as F | |||||
| from ... import operations as P | from ... import operations as P | ||||
| ones_like_leaf = base.MultitypeFuncGraph('ones_like_leaf') | |||||
| ones_like_leaf = base.MultitypeFuncGraph('ones_like_leaf', True) | |||||
| """ | """ | ||||
| `ones_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type | `ones_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type | ||||
| using ".register" decorator. | using ".register" decorator. | ||||
| @@ -19,7 +19,7 @@ from ...composite import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| pow_ = base.MultitypeFuncGraph("pow") | |||||
| pow_ = base.MultitypeFuncGraph("pow", True) | |||||
| """ | """ | ||||
| `pow` is a metafuncgraph object which will compute the pow of two objects | `pow` is a metafuncgraph object which will compute the pow of two objects | ||||
| using ".register" decorator. | using ".register" decorator. | ||||
| @@ -19,7 +19,7 @@ from ...composite import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| sub = base.MultitypeFuncGraph("sub") | |||||
| sub = base.MultitypeFuncGraph("sub", True) | |||||
| """ | """ | ||||
| `sub` is a metafuncgraph object which will compute the subtraction of two objects | `sub` is a metafuncgraph object which will compute the subtraction of two objects | ||||
| using ".register" decorator. | using ".register" decorator. | ||||
| @@ -18,7 +18,7 @@ from mindspore.ops.composite import base | |||||
| # uadd is a metagraph object which will return operation result regarding input | # uadd is a metagraph object which will return operation result regarding input | ||||
| # using ".register" decorator | # using ".register" decorator | ||||
| uadd = base.MultitypeFuncGraph("uadd") | |||||
| uadd = base.MultitypeFuncGraph("uadd", True) | |||||
| @uadd.register("Tensor") | @uadd.register("Tensor") | ||||
| @uadd.register("Number") | @uadd.register("Number") | ||||
| @@ -19,7 +19,7 @@ from ...composite import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf') | |||||
| zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True) | |||||
| """ | """ | ||||
| `zeros_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type | `zeros_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type | ||||
| using ".register" decorator. | using ".register" decorator. | ||||
| @@ -21,7 +21,6 @@ from mindspore.common._register_for_tensor import tensor_operator_registry | |||||
| from .primitive import Primitive | from .primitive import Primitive | ||||
| from . import operations as P | from . import operations as P | ||||
| from .operations import _grad_ops | from .operations import _grad_ops | ||||
| from .._extends import builtin_operations as BP | |||||
| typeof = Primitive('typeof') | typeof = Primitive('typeof') | ||||
| hastype = Primitive('hastype') | hastype = Primitive('hastype') | ||||
| @@ -182,5 +181,6 @@ tensor_operator_registry.register('__gt__', tensor_gt) | |||||
| tensor_operator_registry.register('__ge__', tensor_ge) | tensor_operator_registry.register('__ge__', tensor_ge) | ||||
| tensor_operator_registry.register('shape', shape) | tensor_operator_registry.register('shape', shape) | ||||
| # support GE backend for no compare operators | # support GE backend for no compare operators | ||||
| tensor_operator_registry.register('vm_compare', BP.vm_compare) | |||||
| tensor_operator_registry.register('cast', cast) | tensor_operator_registry.register('cast', cast) | ||||
| __all__ = [name for name in dir() if name[0] != "_"] | |||||
| @@ -15,8 +15,7 @@ | |||||
| """Operators for gradients.""" | """Operators for gradients.""" | ||||
| from ..._c_expression import signature_rw as sig_rw | |||||
| from ..._c_expression import signature_kind as sig_kind | |||||
| from .. import signature as sig | |||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| from ..._checkparam import Validator as validator, Rel | from ..._checkparam import Validator as validator, Rel | ||||
| from .._utils import get_concat_offset | from .._utils import get_concat_offset | ||||
| @@ -1500,7 +1499,7 @@ class RefToEmbed(Primitive): | |||||
| >>> return key, self.weight | >>> return key, self.weight | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('variable', sig_rw.RW_REF, sig_kind.KIND_POSITIONAL_KEYWORD), | |||||
| sig.make_sig('variable', sig.sig_rw.RW_REF), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -28,10 +28,7 @@ import numpy as np | |||||
| from .._utils import get_concat_offset | from .._utils import get_concat_offset | ||||
| from ..operations.math_ops import _infer_shape_reduce | from ..operations.math_ops import _infer_shape_reduce | ||||
| from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | ||||
| from ..._c_expression import signature_dtype as sig_dtype | |||||
| from ..._c_expression import signature_kind as sig_kind | |||||
| from ..._c_expression import signature_rw as sig_rw | |||||
| from ..._c_expression import typing | |||||
| from .. import signature as sig | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| @@ -44,9 +41,9 @@ class _ScatterOp(PrimitiveWithInfer): | |||||
| Define Scatter operators | Define Scatter operators | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), | |||||
| ('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('updates', dtype=sig.sig_dtype.T) | |||||
| ) | ) | ||||
| def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): | def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): | ||||
| @@ -1396,7 +1393,7 @@ class Tile(PrimitiveWithInfer): | |||||
| validator.check_value_type("shape", multiples_v, [tuple], self.name) | validator.check_value_type("shape", multiples_v, [tuple], self.name) | ||||
| for i, multiple in enumerate(multiples_v): | for i, multiple in enumerate(multiples_v): | ||||
| validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name) | validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name) | ||||
| validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name) | |||||
| validator.check_value_type("x[\'dtype\']", x["dtype"], mstype.tensor_type, self.name) | |||||
| len_sub = len(multiples_v) - len(x_shp) | len_sub = len(multiples_v) - len(x_shp) | ||||
| multiples_w = None | multiples_w = None | ||||
| if len_sub == 0: | if len_sub == 0: | ||||
| @@ -18,9 +18,7 @@ | |||||
| import copy | import copy | ||||
| import numpy as np | import numpy as np | ||||
| from ... import context | from ... import context | ||||
| from ..._c_expression import signature_rw as sig_rw | |||||
| from ..._c_expression import signature_kind as sig_kind | |||||
| from ..._c_expression import signature_dtype as sig_dtype | |||||
| from .. import signature as sig | |||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| @@ -68,7 +66,7 @@ class _BinaryOp(PrimitiveWithInfer): | |||||
| Define binary operators. | Define binary operators. | ||||
| """ | """ | ||||
| __mindspore_signature__ = (sig_dtype.T, sig_dtype.T) | |||||
| __mindspore_signature__ = (sig.sig_dtype.T, sig.sig_dtype.T) | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -186,8 +184,8 @@ class AssignAdd(PrimitiveWithInfer): | |||||
| >>> net(value) | >>> net(value) | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('value', dtype=sig.sig_dtype.T) | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -237,8 +235,8 @@ class AssignSub(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('value', dtype=sig.sig_dtype.T) | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -264,8 +262,8 @@ class _Reduce(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('input_x', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), | |||||
| ('axis', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, ()), | |||||
| sig.make_sig('input_x'), | |||||
| sig.make_sig('axis', default=()) | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -22,9 +22,7 @@ from functools import reduce | |||||
| import numpy as np | import numpy as np | ||||
| from ... import context | from ... import context | ||||
| from ..._c_expression import signature_rw as sig_rw | |||||
| from ..._c_expression import signature_kind as sig_kind | |||||
| from ..._c_expression import signature_dtype as sig_dtype | |||||
| from .. import signature as sig | |||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| @@ -679,11 +677,11 @@ class FusedBatchNormEx(PrimitiveWithInfer): | |||||
| >>> output = op(input_x, scale, bias, mean, variance) | >>> output = op(input_x, scale, bias, mean, variance) | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('input_x', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), | |||||
| ('scale', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('bias', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('mean', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('variance', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| sig.make_sig('input_x', dtype=sig.sig_dtype.T2), | |||||
| sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -1722,13 +1720,11 @@ class ApplyMomentum(PrimitiveWithInfer): | |||||
| Please refer to the usage in nn.ApplyMomentum. | Please refer to the usage in nn.ApplyMomentum. | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T1), | |||||
| ('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2) | |||||
| sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accumulation', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('learning_rate', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('gradient', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('momentum', dtype=sig.sig_dtype.T2), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -3146,23 +3142,17 @@ class FusedSparseAdam(PrimitiveWithInfer): | |||||
| >>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices) | >>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices) | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('beta2_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('beta1_power', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('beta2_power', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('lr', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('beta1', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('beta2', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('epsilon', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -3285,23 +3275,17 @@ class FusedSparseLazyAdam(PrimitiveWithInfer): | |||||
| >>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices) | >>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices) | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('beta2_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('beta1_power', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('beta2_power', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('lr', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('beta1', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('beta2', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('epsilon', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -3394,11 +3378,11 @@ class FusedSparseFtrl(PrimitiveWithInfer): | |||||
| >>> output = net(grad, indices) | >>> output = net(grad, indices) | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -3492,13 +3476,13 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): | |||||
| >>> output = net(grad, indices) | >>> output = net(grad, indices) | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('lr', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('l1', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('l2', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -3754,16 +3738,15 @@ class ApplyAdaMax(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T1), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), | |||||
| ('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3), | |||||
| ('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4), | |||||
| ('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T5), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('beta1_power', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('lr', dtype=sig.sig_dtype.T2), | |||||
| sig.make_sig('beta1', dtype=sig.sig_dtype.T3), | |||||
| sig.make_sig('beta2', dtype=sig.sig_dtype.T4), | |||||
| sig.make_sig('epsilon', dtype=sig.sig_dtype.T5), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -3873,14 +3856,13 @@ class ApplyAdadelta(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum_update', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), | |||||
| ('rho', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), | |||||
| ('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum_update', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('lr', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('rho', dtype=sig.sig_dtype.T2), | |||||
| sig.make_sig('epsilon', dtype=sig.sig_dtype.T3), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -3971,10 +3953,10 @@ class ApplyAdagrad(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('lr', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -4054,10 +4036,10 @@ class ApplyAdagradV2(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('lr', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -4137,10 +4119,10 @@ class SparseApplyAdagrad(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -4224,10 +4206,10 @@ class SparseApplyAdagradV2(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -4313,12 +4295,12 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), | |||||
| ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), | |||||
| ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('lr', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('l1', dtype=sig.sig_dtype.T2), | |||||
| sig.make_sig('l2', dtype=sig.sig_dtype.T3), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -4418,13 +4400,13 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), | |||||
| ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), | |||||
| ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('lr', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('l1', dtype=sig.sig_dtype.T2), | |||||
| sig.make_sig('l2', dtype=sig.sig_dtype.T3), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('indices', dtype=sig.sig_dtype.T4), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -4508,14 +4490,13 @@ class ApplyAddSign(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), | |||||
| ('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), | |||||
| ('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T3), | |||||
| ('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('lr', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('alpha', dtype=sig.sig_dtype.T2), | |||||
| sig.make_sig('sign_decay', dtype=sig.sig_dtype.T3), | |||||
| sig.make_sig('beta', dtype=sig.sig_dtype.T3), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -4618,14 +4599,13 @@ class ApplyPowerSign(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| sig_dtype.T), | |||||
| ('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('lr', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('logbase', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('sign_decay', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('beta', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -4704,9 +4684,9 @@ class ApplyGradientDescent(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), | |||||
| ('delta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('alpha', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('delta', dtype=sig.sig_dtype.T), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -4777,11 +4757,11 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), | |||||
| ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), | |||||
| ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3), | |||||
| ('delta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('alpha', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('l1', dtype=sig.sig_dtype.T2), | |||||
| sig.make_sig('l2', dtype=sig.sig_dtype.T3), | |||||
| sig.make_sig('delta', dtype=sig.sig_dtype.T), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -5032,11 +5012,11 @@ class SparseApplyFtrl(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -5126,11 +5106,11 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||||
| sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('grad', dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -15,9 +15,7 @@ | |||||
| """Other operators.""" | """Other operators.""" | ||||
| import functools | import functools | ||||
| from ..._c_expression import signature_rw as sig_rw | |||||
| from ..._c_expression import signature_kind as sig_kind | |||||
| from ..._c_expression import signature_dtype as sig_dtype | |||||
| from .. import signature as sig | |||||
| from ..._checkparam import Validator as validator, Rel | from ..._checkparam import Validator as validator, Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| @@ -53,8 +51,8 @@ class Assign(Primitive): | |||||
| >>> net(x) | >>> net(x) | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('value', dtype=sig.sig_dtype.T) | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -14,17 +14,13 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """primitive""" | """primitive""" | ||||
| import inspect | import inspect | ||||
| import copy | import copy | ||||
| from mindspore.common.api import _wrap_func | from mindspore.common.api import _wrap_func | ||||
| from mindspore.common._register_for_tensor import tensor_operator_registry | from mindspore.common._register_for_tensor import tensor_operator_registry | ||||
| from mindspore import context | from mindspore import context | ||||
| from .._c_expression import Primitive_, real_run_op, prim_type | from .._c_expression import Primitive_, real_run_op, prim_type | ||||
| from .._c_expression import signature_rw as sig_rw | |||||
| from .._c_expression import signature_kind as sig_kind | |||||
| from .._c_expression import signature_dtype as sig_dtype | |||||
| from . import signature as sig | |||||
| class Primitive(Primitive_): | class Primitive(Primitive_): | ||||
| """ | """ | ||||
| @@ -54,24 +50,21 @@ class Primitive(Primitive_): | |||||
| self._update_parameter = False | self._update_parameter = False | ||||
| Primitive_.__init__(self, name, self) | Primitive_.__init__(self, name, self) | ||||
| if hasattr(self.__class__, '__mindspore_signature__'): | if hasattr(self.__class__, '__mindspore_signature__'): | ||||
| sig = self._fill_signature(self.__class__.__mindspore_signature__) | |||||
| self.set_signatures(sig) | |||||
| out = self._fill_signature(self.__class__.__mindspore_signature__) | |||||
| self.set_signatures(out) | |||||
| def _fill_signature(self, signatures): | def _fill_signature(self, signatures): | ||||
| """fills signature.""" | """fills signature.""" | ||||
| signatures_new = [] | signatures_new = [] | ||||
| for signature in signatures: | for signature in signatures: | ||||
| if isinstance(signature, sig_dtype): | |||||
| signatures_new.append(("argument", sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, | |||||
| sig_kind.KIND_EMPTY_DEFAULT_VALUE, signature)) | |||||
| if isinstance(signature, sig.Signature): | |||||
| signatures_new.append(signature) | |||||
| elif isinstance(signature, sig.sig_dtype): | |||||
| signatures_new.append(sig.make_sig(dtype=signature)) | |||||
| else: | else: | ||||
| if len(signature) < 3: | if len(signature) < 3: | ||||
| raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}") | raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}") | ||||
| if len(signature) == 3: | |||||
| signature += (sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T_EMPTY_DEFAULT_VALUE) | |||||
| if len(signature) == 4: | |||||
| signature += (sig_dtype.T_EMPTY_DEFAULT_VALUE,) | |||||
| signatures_new.append(signature) | |||||
| signatures_new.append(sig.make_sig(*signature)) | |||||
| return tuple(signatures_new) | return tuple(signatures_new) | ||||
| def _clone(self): | def _clone(self): | ||||
| @@ -0,0 +1,54 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """signature""" | |||||
| from .._c_expression import signature_rw as sig_rw | |||||
| from .._c_expression import signature_kind as sig_kind | |||||
| from .._c_expression import signature_dtype as sig_dtype | |||||
| from .._c_expression import Signature | |||||
| def make_sig(name="var", rw=sig_rw.RW_READ, | |||||
| kind=sig_kind.KIND_POSITIONAL_KEYWORD, | |||||
| default=sig_kind.KIND_EMPTY_DEFAULT_VALUE, | |||||
| dtype=sig_dtype.T_EMPTY_DEFAULT_VALUE): | |||||
| """ | |||||
| Make signature for one argument. | |||||
| See `ApplyMomentum` in `mindspore.ops.operation.nn_ops` as a example. | |||||
| Args: | |||||
| name (bool): Argument name. Default: "var". | |||||
| rw (:class:`mindspore.ops.signature.sig_rw`): Tag the argument attribute for write and read. Choose in | |||||
| [sig_rw.RW_READ, sig_rw.RW_WRITE, sig_rw.RW_REF]`, tag if the argument will update the input. | |||||
| `sig_rw.RW_READ` for read only argument and `sig_rw.RW_WRITE` for write only argument. `sig_rw.RW_READ` | |||||
| for the argument both need read and write. Default: sig_rw.RW_READ. | |||||
| kind (:class:`mindspore.ops.signature.kind`): Choose in `[signature_kind.KIND_POSITIONAL_KEYWORD, | |||||
| signature_kind.KIND_VAR_POSITIONAL, signature_kind.KIND_KEYWORD_ONLY, signature_kind.KIND_VAR_KEYWARD]`. | |||||
| The meaning is the same as python argument kind, please refer to the python document. | |||||
| Default: sig_kind.KIND_POSITIONAL_KEYWORD. | |||||
| default (Any): The default value of argument or `sig_kind.KIND_EMPTY_DEFAULT_VALUE` for no default value. | |||||
| Default: sig_kind.KIND_EMPTY_DEFAULT_VALUE. | |||||
| dtype (:class:`mindspore.ops.signature.sig_dtype`): Choose in `signature_dtype.T` or | |||||
| `signature_dtype.T1` to `signature_dtype.T9` or `sig_dtype.T_EMPTY_DEFAULT_VALUE` for no constraints. | |||||
| If the signature of one argument is the same as another argument, we will perform auto type convert | |||||
| between them. If any `sig_rw.RW_WRITE` argument, we will try to convert the other arguments to the | |||||
| `sig_rw.RW_WRITE` argument. Default: sig_dtype.T_EMPTY_DEFAULT_VALUE. | |||||
| Returns: | |||||
| :class:`mindspore.ops.signature.Signature`, signature for one argument. | |||||
| """ | |||||
| return Signature(name, rw, kind, default, dtype) | |||||
| @@ -136,13 +136,15 @@ class NetForCast(nn.Cell): | |||||
| super(NetForCast, self).__init__() | super(NetForCast, self).__init__() | ||||
| self.concat = P.Concat() | self.concat = P.Concat() | ||||
| self.x1 = Tensor(1.0, mstype.float32) | self.x1 = Tensor(1.0, mstype.float32) | ||||
| self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2') | |||||
| def construct(self, x0): | def construct(self, x0): | ||||
| x = self.x1 * x0 | |||||
| x = self.x1 * x0 * self.x2 | |||||
| return x | return x | ||||
| def test_cast(): | def test_cast(): | ||||
| context.set_context(save_graphs=True) | |||||
| x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01) | x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01) | ||||
| net = NetForCast() | net = NetForCast() | ||||
| net.add_flags_recursive(fp16=True) | net.add_flags_recursive(fp16=True) | ||||
| @@ -16,9 +16,7 @@ | |||||
| import functools | import functools | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| from mindspore._c_expression import signature_dtype as sig_dtype | |||||
| from mindspore._c_expression import signature_kind as sig_kind | |||||
| from mindspore._c_expression import signature_rw as sig_rw | |||||
| from mindspore.ops.signature import sig_rw, sig_dtype, make_sig | |||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| @@ -126,9 +124,9 @@ class CustomOP(PrimitiveWithInfer): | |||||
| class CustomOP2(PrimitiveWithInfer): | class CustomOP2(PrimitiveWithInfer): | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| ('p1', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('p2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('p3', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| make_sig('p1', sig_rw.RW_WRITE, dtype=sig_dtype.T), | |||||
| make_sig('p2', dtype=sig_dtype.T), | |||||
| make_sig('p3', dtype=sig_dtype.T), | |||||
| ) | ) | ||||
| @prim_attr_register | @prim_attr_register | ||||