From 8c6475fd0be4fdcf7cf83ca15178d8e0e9185e01 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Tue, 8 Sep 2020 20:56:17 +0800 Subject: [PATCH] add composite op doc --- .../operator/composite/multitype_funcgraph.cc | 33 +++--- .../operator/composite/multitype_funcgraph.h | 1 - mindspore/ccsrc/frontend/optimizer/clean.cc | 4 +- mindspore/common/tensor.py | 2 +- mindspore/nn/layer/container.py | 4 +- mindspore/ops/composite/base.py | 101 +++++++++++++++--- .../pynative_mode/ops/test_multitype.py | 25 +++++ 7 files changed, 133 insertions(+), 37 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc index 1768bbd90f..7ddb47e948 100644 --- a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc @@ -27,6 +27,7 @@ #include "utils/ms_context.h" #include "pybind_api/api_register.h" #include "ir/signature.h" +#include "ir/dtype.h" #include "debug/trace.h" namespace mindspore { @@ -57,31 +58,27 @@ void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function & fn_cache_py_[types] = py_fn; } -void MultitypeFuncGraph::Register(const std::vector &types_name, const py::function &py_fn) { +void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { TypePtrList types; - for (auto &type_name : types_name) { - auto type_ptr = StringToType(type_name); - if (type_ptr == nullptr) { - MS_LOG(EXCEPTION) << type_name << " convert from string error "; + for (size_t it = 0; it < tuple.size(); ++it) { + py::object type_in = tuple[it]; + TypePtr type_ptr = nullptr; + if (py::isinstance(type_in)) { + auto type_name = type_in.cast(); + type_ptr = StringToType(type_name); + if (type_ptr == nullptr) { + MS_LOG(EXCEPTION) << type_name << " convert from string error "; + } + } else if (py::isinstance(type_in)) { + type_ptr = type_in.cast(); + } else { + MS_LOG(EXCEPTION) << "Register must be string or `mindspore.dtype.Type`"; } types.push_back(type_ptr); } Register(types, py_fn); } -void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { - std::vector types_name; - for (size_t it = 0; it < tuple.size(); ++it) { - py::object name_py = tuple[it]; - if (py::isinstance(name_py)) { - types_name.push_back(name_py.cast()); - continue; - } - MS_LOG(EXCEPTION) << "Register must be string"; - } - Register(types_name, py_fn); -} - // Return Exact match if exists, else return non ambiguous sub class match // Return py::none() if matching is ambiguous const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h index 9bcfdb2ee2..15d8449cd7 100644 --- a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h @@ -44,7 +44,6 @@ class MultitypeFuncGraph : public MetaFuncGraph { // Register a method which specialize based on types vectors; virtual void Register(const TypePtrList &types, specialize_fn s_fn); virtual void Register(const TypePtrList &types, const py::function &py_fn); - virtual void Register(const std::vector &types_name, const py::function &py_fn); virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc index 1a39c3cc47..1059fc2f63 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.cc +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -396,7 +396,9 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " << ret->ToString(); node->set_abstract(ret); - changed = true; + if (ret->cast()->size() > 0) { + changed = true; + } } } return changed; diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index e03e0b5b02..643e2873d6 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -293,7 +293,7 @@ class RowTensor: The dense tensor dense represented by an RowTensor slices has `dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`. - RowTensor can only be used in the `Cell`'s contruct method. + RowTensor can only be used in the `Cell`'s construct method. It is not supported in pynative mode at the moment. diff --git a/mindspore/nn/layer/container.py b/mindspore/nn/layer/container.py index 62a50d1255..994fe173be 100644 --- a/mindspore/nn/layer/container.py +++ b/mindspore/nn/layer/container.py @@ -46,7 +46,6 @@ class _CellListBase(): by iterator or subscript , it will be interpretated as a list of cells. """ def __init__(self): - super(_CellListBase, self).__init__() self.__cell_as_list__ = True @abstractmethod @@ -177,7 +176,8 @@ class CellList(_CellListBase, Cell): (2): ReLU<> > """ def __init__(self, *args): - super(CellList, self).__init__() + _CellListBase.__init__(self) + Cell.__init__(self) if len(args) == 1: self.extend(args[0]) diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 99c37c6988..557ce5f189 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -341,22 +341,42 @@ class GradOperation(GradOperation_): class MultitypeFuncGraph(MultitypeFuncGraph_): """ - Generate multiply graph. + Generate overloaded functions. - MultitypeFuncGraph is a class used to generate graphs for function with different type as input. + MultitypeFuncGraph is a class used to generate overloaded functions with different type as inputs. + Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator + for the function to be registed. And the object can be called with different type of inputs, + and work with `HyperMap` and `Map`. Args: name (str): Operator name. read_value (bool): If the registered function not need to set value on Parameter, - and all inputs will pass by value. Set `read_value` to True. Default: False. + and all inputs will pass by value, set `read_value` to True. Default: False. Raises: - ValueError: Cannot find matching fn for the given args. + ValueError: Cannot find matching functions for the given args. Examples: >>> # `add` is a metagraph object which will add two objects according to >>> # input type using ".register" decorator. + >>> from mindspore import Tensor + >>> from mindspore.ops import Primitive, operations as P + >>> from mindspore import dtype as mstype + >>> + >>> scala_add = Primitive('scala_add') + >>> tensor_add = P.TensorAdd() + >>> >>> add = MultitypeFuncGraph('add') + >>> @add.register("Number", "Number") + ... def add_scala(x, y): + ... return scala_add(x, y) + >>> @add.register("Tensor", "Tensor") + ... def add_tensor(x, y): + ... return tensor_add(x, y) + >>> add(1, 2) + 3 + >>> add(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) + Tensor(shape=[], dtype=Float32, 3) """ def __init__(self, name, read_value=False): @@ -378,9 +398,25 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): raise ValueError("Cannot find fn match given args.") def register(self, *type_names): - """Register a function for the given type string.""" + """ + Register a function for the given type string. + + Args: + type_names (Union[str, :class:`mindspore.dtype`]): Inputs type names or types list. + + Return: + decorator, a decorator to register the function to run, when called under the + types described in `type_names`. + """ def deco(fn): - types = tuple(map(mstype.typing.str_to_type, type_names)) + def convert_type(type_input): + if isinstance(type_input, str): + return mstype.typing.str_to_type(type_input) + if not isinstance(type_input, mstype.Type): + raise TypeError(f"MultitypeFuncGraph register only support str or {mstype.Type}") + return type_input + + types = tuple(map(convert_type, type_names)) self.register_fn(type_names, fn) self.entries.append((types, fn)) return fn @@ -391,11 +427,12 @@ class HyperMap(HyperMap_): """ Hypermap will apply the set operation on input sequences. - Which will apply the operations of every elements of the sequence. + Apply the operations to every elements of the sequence or nested sequence. Different + from `Map`, the `HyperMap` supports to apply on nested structure. Args: ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`, - the operations should be putted in the first input of the instance. + the operations should be put in the first input of the instance. Inputs: - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, @@ -405,8 +442,28 @@ class HyperMap(HyperMap_): If `ops` is not `None`, the first input is the operation, and the other is inputs. Outputs: - sequence, the output will be same type and same length of sequence from input and the value of each element - is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`. + Sequence or nested sequence, the sequence of output after applying the function. + e.g. `operation(args[0][i], args[1][i])`. + + Examples: + >>> from mindspore import dtype as mstype + >>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)), + ... (Tensor(3, mstype.float32), Tensor(4, mstype.float32))) + >>> # square all the tensor in the nested list + >>> + >>> square = MultitypeFuncGraph('square') + >>> @square.register("Tensor") + ... def square_tensor(x): + ... return F.square(x) + >>> + >>> common_map = HyperMap() + >>> common_map(square, nest_tensor_list) + ((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)), + (Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16)) + >>> square_map = HyperMap(square) + >>> square_map(nest_tensor_list) + ((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)), + (Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16)) """ def __init__(self, ops=None): @@ -434,11 +491,11 @@ class Map(Map_): """ Map will apply the set operation on input sequences. - Which will apply the operations of every elements of the sequence. + Apply the operations to every elements of the sequence. Args: ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`, - the operations should be putted in the first input of the instance. + the operations should be put in the first input of the instance. Default: None Inputs: - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, @@ -448,8 +505,24 @@ class Map(Map_): If `ops` is not `None`, the first input is the operation, and the other is inputs. Outputs: - sequence, the output will be same type and same length of sequence from input and the value of each element - is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`. + Sequence, the sequence of output after applying the function. e.g. `operation(args[0][i], args[1][i])`. + + Examples: + >>> from mindspore import dtype as mstype + >>> tensor_list = (Tensor(1, mstype.float32), Tensor(2, mstype.float32), Tensor(3, mstype.float32)) + >>> # square all the tensor in the list + >>> + >>> square = MultitypeFuncGraph('square') + >>> @square.register("Tensor") + >>> def square_tensor(x): + ... return F.square(x) + >>> + >>> common_map = Map() + >>> common_map(square, tensor_list) + (Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9)) + >>> square_map = Map(square) + >>> square_map(tensor_list) + (Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9)) """ def __init__(self, ops=None): diff --git a/tests/ut/python/pynative_mode/ops/test_multitype.py b/tests/ut/python/pynative_mode/ops/test_multitype.py index 24d7edcc0b..67da12b088 100644 --- a/tests/ut/python/pynative_mode/ops/test_multitype.py +++ b/tests/ut/python/pynative_mode/ops/test_multitype.py @@ -21,6 +21,7 @@ from mindspore.common.parameter import Parameter from mindspore.ops import Primitive from mindspore.ops import composite as C from mindspore.ops import operations as P +from mindspore import dtype as mstype from ...ut_filter import non_graph_engine tensor_add = P.TensorAdd() @@ -62,3 +63,27 @@ def test_multitype_tuple(): def test_multitype_scalar(): mainf(1, 2) + + +add2 = C.MultitypeFuncGraph('add2') +@add2.register(mstype.number, mstype.number) +def add_scala2(x, y): + return scala_add(x, y) + + +@add2.register(mstype.tensor, mstype.tensor) +def add_tensor2(x, y): + return tensor_add(x, y) + + +@ms_function +def mainf2(x, y): + return add2(x, y) + + +@non_graph_engine +def test_multitype_tensor_by_type(): + tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) + tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) + out = mainf2(tensor1, tensor2) + print(out)