|
|
|
@@ -384,9 +384,11 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): |
|
|
|
>>> @add.register("Tensor", "Tensor") |
|
|
|
... def add_tensor(x, y): |
|
|
|
... return tensor_add(x, y) |
|
|
|
>>> add(1, 2) |
|
|
|
>>> ourput = add(1, 2) |
|
|
|
>>> print(output) |
|
|
|
3 |
|
|
|
>>> add(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) |
|
|
|
>>> output = add(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) |
|
|
|
>>> print(output) |
|
|
|
Tensor(shape=[], dtype=Float32, 3) |
|
|
|
""" |
|
|
|
|
|
|
|
@@ -470,11 +472,13 @@ class HyperMap(HyperMap_): |
|
|
|
... return F.square(x) |
|
|
|
>>> |
|
|
|
>>> common_map = HyperMap() |
|
|
|
>>> common_map(square, nest_tensor_list) |
|
|
|
>>> output = common_map(square, nest_tensor_list) |
|
|
|
>>> print(output) |
|
|
|
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)), |
|
|
|
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16)) |
|
|
|
>>> square_map = HyperMap(square) |
|
|
|
>>> square_map(nest_tensor_list) |
|
|
|
>>> output = square_map(nest_tensor_list) |
|
|
|
>>> print(output) |
|
|
|
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)), |
|
|
|
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16)) |
|
|
|
""" |
|
|
|
@@ -531,10 +535,12 @@ class Map(Map_): |
|
|
|
... return F.square(x) |
|
|
|
>>> |
|
|
|
>>> common_map = Map() |
|
|
|
>>> common_map(square, tensor_list) |
|
|
|
>>> output = common_map(square, tensor_list) |
|
|
|
>>> print(output) |
|
|
|
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9)) |
|
|
|
>>> square_map = Map(square) |
|
|
|
>>> square_map(tensor_list) |
|
|
|
>>> output = square_map(tensor_list) |
|
|
|
>>> print(output) |
|
|
|
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9)) |
|
|
|
""" |
|
|
|
|
|
|
|
|