|
|
|
@@ -341,22 +341,42 @@ class GradOperation(GradOperation_): |
|
|
|
|
|
|
|
class MultitypeFuncGraph(MultitypeFuncGraph_): |
|
|
|
""" |
|
|
|
Generate multiply graph. |
|
|
|
Generate overloaded functions. |
|
|
|
|
|
|
|
MultitypeFuncGraph is a class used to generate graphs for function with different type as input. |
|
|
|
MultitypeFuncGraph is a class used to generate overloaded functions with different type as inputs. |
|
|
|
Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator |
|
|
|
for the function to be registed. And the object can be called with different type of inputs, |
|
|
|
and work with `HyperMap` and `Map`. |
|
|
|
|
|
|
|
Args: |
|
|
|
name (str): Operator name. |
|
|
|
read_value (bool): If the registered function not need to set value on Parameter, |
|
|
|
and all inputs will pass by value. Set `read_value` to True. Default: False. |
|
|
|
and all inputs will pass by value, set `read_value` to True. Default: False. |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: Cannot find matching fn for the given args. |
|
|
|
ValueError: Cannot find matching functions for the given args. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> # `add` is a metagraph object which will add two objects according to |
|
|
|
>>> # input type using ".register" decorator. |
|
|
|
>>> from mindspore import Tensor |
|
|
|
>>> from mindspore.ops import Primitive, operations as P |
|
|
|
>>> from mindspore import dtype as mstype |
|
|
|
>>> |
|
|
|
>>> scala_add = Primitive('scala_add') |
|
|
|
>>> tensor_add = P.TensorAdd() |
|
|
|
>>> |
|
|
|
>>> add = MultitypeFuncGraph('add') |
|
|
|
>>> @add.register("Number", "Number") |
|
|
|
... def add_scala(x, y): |
|
|
|
... return scala_add(x, y) |
|
|
|
>>> @add.register("Tensor", "Tensor") |
|
|
|
... def add_tensor(x, y): |
|
|
|
... return tensor_add(x, y) |
|
|
|
>>> add(1, 2) |
|
|
|
3 |
|
|
|
>>> add(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) |
|
|
|
Tensor(shape=[], dtype=Float32, 3) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, name, read_value=False): |
|
|
|
@@ -378,9 +398,25 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): |
|
|
|
raise ValueError("Cannot find fn match given args.") |
|
|
|
|
|
|
|
def register(self, *type_names): |
|
|
|
"""Register a function for the given type string.""" |
|
|
|
""" |
|
|
|
Register a function for the given type string. |
|
|
|
|
|
|
|
Args: |
|
|
|
type_names (Union[str, :class:`mindspore.dtype`]): Inputs type names or types list. |
|
|
|
|
|
|
|
Return: |
|
|
|
decorator, a decorator to register the function to run, when called under the |
|
|
|
types described in `type_names`. |
|
|
|
""" |
|
|
|
def deco(fn): |
|
|
|
types = tuple(map(mstype.typing.str_to_type, type_names)) |
|
|
|
def convert_type(type_input): |
|
|
|
if isinstance(type_input, str): |
|
|
|
return mstype.typing.str_to_type(type_input) |
|
|
|
if not isinstance(type_input, mstype.Type): |
|
|
|
raise TypeError(f"MultitypeFuncGraph register only support str or {mstype.Type}") |
|
|
|
return type_input |
|
|
|
|
|
|
|
types = tuple(map(convert_type, type_names)) |
|
|
|
self.register_fn(type_names, fn) |
|
|
|
self.entries.append((types, fn)) |
|
|
|
return fn |
|
|
|
@@ -391,11 +427,12 @@ class HyperMap(HyperMap_): |
|
|
|
""" |
|
|
|
Hypermap will apply the set operation on input sequences. |
|
|
|
|
|
|
|
Which will apply the operations of every elements of the sequence. |
|
|
|
Apply the operations to every elements of the sequence or nested sequence. Different |
|
|
|
from `Map`, the `HyperMap` supports to apply on nested structure. |
|
|
|
|
|
|
|
Args: |
|
|
|
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`, |
|
|
|
the operations should be putted in the first input of the instance. |
|
|
|
the operations should be put in the first input of the instance. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, |
|
|
|
@@ -405,8 +442,28 @@ class HyperMap(HyperMap_): |
|
|
|
If `ops` is not `None`, the first input is the operation, and the other is inputs. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
sequence, the output will be same type and same length of sequence from input and the value of each element |
|
|
|
is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`. |
|
|
|
Sequence or nested sequence, the sequence of output after applying the function. |
|
|
|
e.g. `operation(args[0][i], args[1][i])`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore import dtype as mstype |
|
|
|
>>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)), |
|
|
|
... (Tensor(3, mstype.float32), Tensor(4, mstype.float32))) |
|
|
|
>>> # square all the tensor in the nested list |
|
|
|
>>> |
|
|
|
>>> square = MultitypeFuncGraph('square') |
|
|
|
>>> @square.register("Tensor") |
|
|
|
... def square_tensor(x): |
|
|
|
... return F.square(x) |
|
|
|
>>> |
|
|
|
>>> common_map = HyperMap() |
|
|
|
>>> common_map(square, nest_tensor_list) |
|
|
|
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)), |
|
|
|
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16)) |
|
|
|
>>> square_map = HyperMap(square) |
|
|
|
>>> square_map(nest_tensor_list) |
|
|
|
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)), |
|
|
|
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16)) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, ops=None): |
|
|
|
@@ -434,11 +491,11 @@ class Map(Map_): |
|
|
|
""" |
|
|
|
Map will apply the set operation on input sequences. |
|
|
|
|
|
|
|
Which will apply the operations of every elements of the sequence. |
|
|
|
Apply the operations to every elements of the sequence. |
|
|
|
|
|
|
|
Args: |
|
|
|
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`, |
|
|
|
the operations should be putted in the first input of the instance. |
|
|
|
the operations should be put in the first input of the instance. Default: None |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, |
|
|
|
@@ -448,8 +505,24 @@ class Map(Map_): |
|
|
|
If `ops` is not `None`, the first input is the operation, and the other is inputs. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
sequence, the output will be same type and same length of sequence from input and the value of each element |
|
|
|
is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`. |
|
|
|
Sequence, the sequence of output after applying the function. e.g. `operation(args[0][i], args[1][i])`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore import dtype as mstype |
|
|
|
>>> tensor_list = (Tensor(1, mstype.float32), Tensor(2, mstype.float32), Tensor(3, mstype.float32)) |
|
|
|
>>> # square all the tensor in the list |
|
|
|
>>> |
|
|
|
>>> square = MultitypeFuncGraph('square') |
|
|
|
>>> @square.register("Tensor") |
|
|
|
>>> def square_tensor(x): |
|
|
|
... return F.square(x) |
|
|
|
>>> |
|
|
|
>>> common_map = Map() |
|
|
|
>>> common_map(square, tensor_list) |
|
|
|
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9)) |
|
|
|
>>> square_map = Map(square) |
|
|
|
>>> square_map(tensor_list) |
|
|
|
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9)) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, ops=None): |
|
|
|
|