| @@ -13,10 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """math""" | |||
| import math | |||
| import numpy as np | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common._decorator import deprecated | |||
| from mindspore.ops.primitive import constexpr | |||
| @@ -25,7 +23,6 @@ from ..cell import Cell | |||
| from ...common import dtype as mstype | |||
| from ..._checkparam import Validator as validator | |||
| __all__ = ['ReduceLogSumExp', | |||
| 'Range', | |||
| 'LGamma', | |||
| @@ -140,37 +137,15 @@ class Range(Cell): | |||
| def __init__(self, start, limit=None, delta=1): | |||
| super(Range, self).__init__() | |||
| validator.check_value_type("start", start, [int, float], self.cls_name) | |||
| validator.check_value_type("delta", delta, [int, float], self.cls_name) | |||
| if delta == 0: | |||
| raise ValueError("The input of `delta` can not be equal to zero.") | |||
| if limit is not None: | |||
| validator.check_value_type("limit", limit, [int, float], self.cls_name) | |||
| if isinstance(start, int) and isinstance(limit, int) and isinstance(delta, int): | |||
| self.dtype = mstype.int32 | |||
| else: | |||
| self.dtype = mstype.float32 | |||
| else: | |||
| if isinstance(start, int) and isinstance(delta, int): | |||
| self.dtype = mstype.int32 | |||
| else: | |||
| self.dtype = mstype.float32 | |||
| if isinstance(start, int): | |||
| start = float(start) | |||
| if isinstance(limit, int): | |||
| limit = float(limit) | |||
| if isinstance(delta, int): | |||
| delta = float(delta) | |||
| self.range_x = inner.Range(start, limit, delta) | |||
| if limit is None: | |||
| length_input = math.ceil(start / delta) | |||
| data = np.arange(start, limit, delta) | |||
| if data.dtype == np.float: | |||
| self.ms_dtype = mstype.float32 | |||
| else: | |||
| length_input = math.ceil((limit - start) / delta) | |||
| self.input_tensor = Tensor(list(range(length_input)), self.dtype) | |||
| self.ms_dtype = mstype.int32 | |||
| self.result_tensor = Tensor(data, dtype=self.ms_dtype) | |||
| def construct(self): | |||
| range_out = self.range_x(self.input_tensor) | |||
| return range_out | |||
| return self.result_tensor | |||
| class LGamma(Cell): | |||
| @@ -16,6 +16,7 @@ | |||
| """Inner operators.""" | |||
| import numpy as np | |||
| from mindspore.common import Tensor | |||
| from ..._checkparam import Rel | |||
| from ..._checkparam import Validator as validator | |||
| from ... import context | |||
| @@ -25,6 +26,7 @@ from ..operations.math_ops import _infer_shape_reduce | |||
| from ...communication.management import GlobalComm | |||
| from .. import signature as sig | |||
| class ExtractImagePatches(PrimitiveWithInfer): | |||
| """ | |||
| Extracts patches from images. | |||
| @@ -164,6 +166,9 @@ class Range(PrimitiveWithInfer): | |||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name) | |||
| return x_dtype | |||
| def infer_value(self, x_value): | |||
| return Tensor(np.arange(self.start, self.limit, self.delta), dtype=x_value.dtype) | |||
| class Quant(PrimitiveWithInfer): | |||
| r""" | |||
| @@ -408,6 +413,7 @@ class Send(PrimitiveWithInfer): | |||
| >>> net = Net() | |||
| >>> output = net(input_) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): | |||
| self.rank = dest_rank | |||
| @@ -464,6 +470,7 @@ class Receive(PrimitiveWithInfer): | |||
| >>> net = Net() | |||
| >>> output = net() | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): | |||
| self.rank = src_rank | |||
| @@ -2391,6 +2391,7 @@ class Pack(PrimitiveWithInfer): | |||
| Same as operator Stack. Pack will be deprecated in the future. | |||
| Please use Stack instead. | |||
| """ | |||
| @deprecated("1.1", "Stack", True) | |||
| @prim_attr_register | |||
| def __init__(self, axis=0): | |||
| @@ -2469,6 +2470,7 @@ class Unpack(PrimitiveWithInfer): | |||
| Same as operator Unstack. Unpack will be deprecated in the future. | |||
| Please use Unstack instead. | |||
| """ | |||
| @deprecated("1.1", "Unstack", True) | |||
| @prim_attr_register | |||
| def __init__(self, axis=0): | |||
| @@ -3491,7 +3493,6 @@ class ScatterUpdate(_ScatterOp_Dynamic): | |||
| self.add_prim_attr('side_effect_mem', True) | |||
| class ScatterNdUpdate(_ScatterNdOp): | |||
| r""" | |||
| Updates tensor values by using input indices and value. | |||
| @@ -5250,3 +5251,11 @@ class Range(PrimitiveWithCheck): | |||
| valid_dtypes = [mstype.int32, mstype.float32] | |||
| inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name) | |||
| def infer_value(self, start_value, limit_value, delat_value): | |||
| if start_value is not None and limit_value is not None and delat_value is not None: | |||
| start = np.asscalar(start_value.asnumpy()) | |||
| limit = np.asscalar(limit_value.asnumpy()) | |||
| delat = np.asscalar(delat_value.asnumpy()) | |||
| return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype) | |||
| return None | |||
| @@ -15,7 +15,7 @@ | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore import context, Tensor, Parameter | |||
| from mindspore.nn import Cell, Momentum | |||
| from mindspore.ops import operations as P | |||
| @@ -48,18 +48,25 @@ class Net(Cell): | |||
| def __init__(self, weight, start, limit, delta, strategy1=None, strategy2=None, strategy3=None): | |||
| super().__init__() | |||
| self.mul = P.Mul().shard(strategy1) | |||
| self.range = nn.Range(start, limit, delta) | |||
| self.range.range_x.shard(strategy2) | |||
| if isinstance(start, float): | |||
| self.type = mstype.float32 | |||
| else: | |||
| self.type = mstype.int32 | |||
| self.start = Tensor(start, self.type) | |||
| self.limit = Tensor(limit, self.type) | |||
| self.delta = Tensor(delta, self.type) | |||
| self.range = P.Range() | |||
| self.range.shard(strategy2) | |||
| self.mul2 = P.Mul().shard(strategy3) | |||
| self.weight = Parameter(weight, "w") | |||
| def construct(self, x, b): | |||
| r_out = self.range() | |||
| r_out = self.range(self.start, self.limit, self.delta) | |||
| out = self.mul(x, self.weight) | |||
| out = self.mul2(out, r_out) | |||
| return out | |||
| dev_num = 4 | |||
| _x = Tensor(np.ones([64 // dev_num, 8]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([8]), dtype=ms.float32) | |||
| @@ -98,5 +105,5 @@ def test_range2(): | |||
| def test_range3(): | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=dev_num, global_rank=2) | |||
| net = Net(_w1, 4.0, None, 0.5) | |||
| net = Net(_w1, 0.0, 4.0, 0.5) | |||
| compile_net(net) | |||