| @@ -13,10 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """math""" | """math""" | ||||
| import math | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common._decorator import deprecated | from mindspore.common._decorator import deprecated | ||||
| from mindspore.ops.primitive import constexpr | from mindspore.ops.primitive import constexpr | ||||
| @@ -25,7 +23,6 @@ from ..cell import Cell | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| __all__ = ['ReduceLogSumExp', | __all__ = ['ReduceLogSumExp', | ||||
| 'Range', | 'Range', | ||||
| 'LGamma', | 'LGamma', | ||||
| @@ -140,37 +137,15 @@ class Range(Cell): | |||||
| def __init__(self, start, limit=None, delta=1): | def __init__(self, start, limit=None, delta=1): | ||||
| super(Range, self).__init__() | super(Range, self).__init__() | ||||
| validator.check_value_type("start", start, [int, float], self.cls_name) | |||||
| validator.check_value_type("delta", delta, [int, float], self.cls_name) | |||||
| if delta == 0: | |||||
| raise ValueError("The input of `delta` can not be equal to zero.") | |||||
| if limit is not None: | |||||
| validator.check_value_type("limit", limit, [int, float], self.cls_name) | |||||
| if isinstance(start, int) and isinstance(limit, int) and isinstance(delta, int): | |||||
| self.dtype = mstype.int32 | |||||
| else: | |||||
| self.dtype = mstype.float32 | |||||
| else: | |||||
| if isinstance(start, int) and isinstance(delta, int): | |||||
| self.dtype = mstype.int32 | |||||
| else: | |||||
| self.dtype = mstype.float32 | |||||
| if isinstance(start, int): | |||||
| start = float(start) | |||||
| if isinstance(limit, int): | |||||
| limit = float(limit) | |||||
| if isinstance(delta, int): | |||||
| delta = float(delta) | |||||
| self.range_x = inner.Range(start, limit, delta) | |||||
| if limit is None: | |||||
| length_input = math.ceil(start / delta) | |||||
| data = np.arange(start, limit, delta) | |||||
| if data.dtype == np.float: | |||||
| self.ms_dtype = mstype.float32 | |||||
| else: | else: | ||||
| length_input = math.ceil((limit - start) / delta) | |||||
| self.input_tensor = Tensor(list(range(length_input)), self.dtype) | |||||
| self.ms_dtype = mstype.int32 | |||||
| self.result_tensor = Tensor(data, dtype=self.ms_dtype) | |||||
| def construct(self): | def construct(self): | ||||
| range_out = self.range_x(self.input_tensor) | |||||
| return range_out | |||||
| return self.result_tensor | |||||
| class LGamma(Cell): | class LGamma(Cell): | ||||
| @@ -16,6 +16,7 @@ | |||||
| """Inner operators.""" | """Inner operators.""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.common import Tensor | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ... import context | from ... import context | ||||
| @@ -25,6 +26,7 @@ from ..operations.math_ops import _infer_shape_reduce | |||||
| from ...communication.management import GlobalComm | from ...communication.management import GlobalComm | ||||
| from .. import signature as sig | from .. import signature as sig | ||||
| class ExtractImagePatches(PrimitiveWithInfer): | class ExtractImagePatches(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Extracts patches from images. | Extracts patches from images. | ||||
| @@ -164,6 +166,9 @@ class Range(PrimitiveWithInfer): | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name) | validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name) | ||||
| return x_dtype | return x_dtype | ||||
| def infer_value(self, x_value): | |||||
| return Tensor(np.arange(self.start, self.limit, self.delta), dtype=x_value.dtype) | |||||
| class Quant(PrimitiveWithInfer): | class Quant(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| @@ -408,6 +413,7 @@ class Send(PrimitiveWithInfer): | |||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> output = net(input_) | >>> output = net(input_) | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): | def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): | ||||
| self.rank = dest_rank | self.rank = dest_rank | ||||
| @@ -464,6 +470,7 @@ class Receive(PrimitiveWithInfer): | |||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> output = net() | >>> output = net() | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): | def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): | ||||
| self.rank = src_rank | self.rank = src_rank | ||||
| @@ -2391,6 +2391,7 @@ class Pack(PrimitiveWithInfer): | |||||
| Same as operator Stack. Pack will be deprecated in the future. | Same as operator Stack. Pack will be deprecated in the future. | ||||
| Please use Stack instead. | Please use Stack instead. | ||||
| """ | """ | ||||
| @deprecated("1.1", "Stack", True) | @deprecated("1.1", "Stack", True) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, axis=0): | def __init__(self, axis=0): | ||||
| @@ -2469,6 +2470,7 @@ class Unpack(PrimitiveWithInfer): | |||||
| Same as operator Unstack. Unpack will be deprecated in the future. | Same as operator Unstack. Unpack will be deprecated in the future. | ||||
| Please use Unstack instead. | Please use Unstack instead. | ||||
| """ | """ | ||||
| @deprecated("1.1", "Unstack", True) | @deprecated("1.1", "Unstack", True) | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, axis=0): | def __init__(self, axis=0): | ||||
| @@ -3491,7 +3493,6 @@ class ScatterUpdate(_ScatterOp_Dynamic): | |||||
| self.add_prim_attr('side_effect_mem', True) | self.add_prim_attr('side_effect_mem', True) | ||||
| class ScatterNdUpdate(_ScatterNdOp): | class ScatterNdUpdate(_ScatterNdOp): | ||||
| r""" | r""" | ||||
| Updates tensor values by using input indices and value. | Updates tensor values by using input indices and value. | ||||
| @@ -5250,3 +5251,11 @@ class Range(PrimitiveWithCheck): | |||||
| valid_dtypes = [mstype.int32, mstype.float32] | valid_dtypes = [mstype.int32, mstype.float32] | ||||
| inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype} | inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype} | ||||
| validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name) | validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name) | ||||
| def infer_value(self, start_value, limit_value, delat_value): | |||||
| if start_value is not None and limit_value is not None and delat_value is not None: | |||||
| start = np.asscalar(start_value.asnumpy()) | |||||
| limit = np.asscalar(limit_value.asnumpy()) | |||||
| delat = np.asscalar(delat_value.asnumpy()) | |||||
| return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype) | |||||
| return None | |||||
| @@ -15,7 +15,7 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore import context, Tensor, Parameter | from mindspore import context, Tensor, Parameter | ||||
| from mindspore.nn import Cell, Momentum | from mindspore.nn import Cell, Momentum | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| @@ -48,18 +48,25 @@ class Net(Cell): | |||||
| def __init__(self, weight, start, limit, delta, strategy1=None, strategy2=None, strategy3=None): | def __init__(self, weight, start, limit, delta, strategy1=None, strategy2=None, strategy3=None): | ||||
| super().__init__() | super().__init__() | ||||
| self.mul = P.Mul().shard(strategy1) | self.mul = P.Mul().shard(strategy1) | ||||
| self.range = nn.Range(start, limit, delta) | |||||
| self.range.range_x.shard(strategy2) | |||||
| if isinstance(start, float): | |||||
| self.type = mstype.float32 | |||||
| else: | |||||
| self.type = mstype.int32 | |||||
| self.start = Tensor(start, self.type) | |||||
| self.limit = Tensor(limit, self.type) | |||||
| self.delta = Tensor(delta, self.type) | |||||
| self.range = P.Range() | |||||
| self.range.shard(strategy2) | |||||
| self.mul2 = P.Mul().shard(strategy3) | self.mul2 = P.Mul().shard(strategy3) | ||||
| self.weight = Parameter(weight, "w") | self.weight = Parameter(weight, "w") | ||||
| def construct(self, x, b): | def construct(self, x, b): | ||||
| r_out = self.range() | |||||
| r_out = self.range(self.start, self.limit, self.delta) | |||||
| out = self.mul(x, self.weight) | out = self.mul(x, self.weight) | ||||
| out = self.mul2(out, r_out) | out = self.mul2(out, r_out) | ||||
| return out | return out | ||||
| dev_num = 4 | dev_num = 4 | ||||
| _x = Tensor(np.ones([64 // dev_num, 8]), dtype=ms.float32) | _x = Tensor(np.ones([64 // dev_num, 8]), dtype=ms.float32) | ||||
| _b = Tensor(np.ones([8]), dtype=ms.float32) | _b = Tensor(np.ones([8]), dtype=ms.float32) | ||||
| @@ -98,5 +105,5 @@ def test_range2(): | |||||
| def test_range3(): | def test_range3(): | ||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=dev_num, global_rank=2) | context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=dev_num, global_rank=2) | ||||
| net = Net(_w1, 4.0, None, 0.5) | |||||
| net = Net(_w1, 0.0, 4.0, 0.5) | |||||
| compile_net(net) | compile_net(net) | ||||