| @@ -45,8 +45,9 @@ class Embedding(Cell): | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The element of | |||
| the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero | |||
| if larger than vocab_size. | |||
| the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero | |||
| if larger than vocab_size. | |||
| Outputs: | |||
| Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`. | |||
| @@ -17,6 +17,7 @@ | |||
| from .. import operations as P | |||
| from ..operations import _grad_ops as G | |||
| from ..operations import _inner_ops as inner | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..functional import broadcast_gradient_args | |||
| from .. import functional as F | |||
| @@ -341,7 +342,7 @@ def get_bprop_sparse_gather_v2(self): | |||
| return bprop | |||
| @bprop_getters.register(P.Range) | |||
| @bprop_getters.register(inner.Range) | |||
| def get_bprop_range(self): | |||
| """Generate bprop for Range""" | |||
| @@ -23,7 +23,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Diag, DiagPart, DType, ExpandDims, Eye, | |||
| Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, | |||
| IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | |||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, | |||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, | |||
| SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | |||
| ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | |||
| Shape, Size, Slice, Split, EmbeddingLookup, | |||
| @@ -75,7 +75,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl | |||
| ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, | |||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | |||
| from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, | |||
| CheckValid, MakeRefKey, Partial, Depend, CheckBprop, ConfusionMatrix) | |||
| CheckValid, MakeRefKey, Partial, Depend, CheckBprop) | |||
| from . import _quant_ops | |||
| from ._quant_ops import * | |||
| from .thor_ops import * | |||
| @@ -303,13 +303,12 @@ __all__ = [ | |||
| "Atan", | |||
| "Atanh", | |||
| "BasicLSTMCell", | |||
| "ConfusionMatrix", | |||
| "BroadcastTo", | |||
| "Range", | |||
| "DataFormatDimMap", | |||
| "ApproximateEqual", | |||
| "InplaceUpdate", | |||
| "InTopK", | |||
| "DataFormatDimMap" | |||
| ] | |||
| __all__.extend(_quant_ops.__all__) | |||
| @@ -15,9 +15,10 @@ | |||
| """Inner operators.""" | |||
| from ..._checkparam import Rel | |||
| from ..._checkparam import Validator as validator | |||
| from ...common import dtype as mstype | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||
| class ExtractImagePatches(PrimitiveWithInfer): | |||
| @@ -96,3 +97,61 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||
| """infer dtype""" | |||
| validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) | |||
| return input_x | |||
| class Range(PrimitiveWithInfer): | |||
| r""" | |||
| Creates a sequence of numbers. | |||
| Set `input_x` as :math:`x_i` for each element, `output` as follows: | |||
| .. math:: | |||
| \text{output}(x_i) = x_i * \text{delta} + \text{start} | |||
| Args: | |||
| start (float): If `limit` is `None`, the value acts as limit in the range and first entry | |||
| defaults to `0`. Otherwise, it acts as first entry in the range. | |||
| limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start` | |||
| while set the first entry of the range to `0`. It can not be equal to `start`. | |||
| delta (float): Increment of the range. It can not be equal to zero. Default: 1.0. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32. | |||
| Outputs: | |||
| Tensor, has the same shape and dtype as `input_x`. | |||
| Examples: | |||
| >>> range = P.Range(1.0, 8.0, 2.0) | |||
| >>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32) | |||
| >>> range(x) | |||
| [3, 5, 7, 5] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, start, limit=None, delta=1.0): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | |||
| self.delta = validator.check_value_type("delta", delta, [float], self.name) | |||
| validator.check_value_type("start", start, [float], self.name) | |||
| if limit is None: | |||
| self.start = 0.0 | |||
| self.limit = start | |||
| self.add_prim_attr("start", self.start) | |||
| self.add_prim_attr("limit", self.limit) | |||
| else: | |||
| validator.check_value_type("limit", limit, [float], self.name) | |||
| validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name) | |||
| if self.delta == 0.0: | |||
| raise ValueError("The input of `delta` can not be equal to zero.") | |||
| if self.delta > 0.0 and self.start > self.limit: | |||
| raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, " | |||
| f"but got start:{self.start}, limit:{self.limit}") | |||
| if self.delta < 0.0 and self.start < self.limit: | |||
| raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, " | |||
| f"but got start:{self.start}, limit:{self.limit}") | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name) | |||
| return x_dtype | |||
| @@ -556,64 +556,6 @@ class SparseGatherV2(GatherV2): | |||
| """ | |||
| class Range(PrimitiveWithInfer): | |||
| r""" | |||
| Creates a sequence of numbers. | |||
| Set `input_x` as :math:`x_i` for each element, `output` as follows: | |||
| .. math:: | |||
| \text{output}(x_i) = x_i * \text{delta} + \text{start} | |||
| Args: | |||
| start (float): If `limit` is `None`, the value acts as limit in the range and first entry | |||
| defaults to `0`. Otherwise, it acts as first entry in the range. | |||
| limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start` | |||
| while set the first entry of the range to `0`. It can not be equal to `start`. | |||
| delta (float): Increment of the range. It can not be equal to zero. Default: 1.0. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32. | |||
| Outputs: | |||
| Tensor, has the same shape and dtype as `input_x`. | |||
| Examples: | |||
| >>> range = P.Range(1.0, 8.0, 2.0) | |||
| >>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32) | |||
| >>> range(x) | |||
| [3, 5, 7, 5] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, start, limit=None, delta=1.0): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | |||
| self.delta = validator.check_value_type("delta", delta, [float], self.name) | |||
| validator.check_value_type("start", start, [float], self.name) | |||
| if limit is None: | |||
| self.start = 0.0 | |||
| self.limit = start | |||
| self.add_prim_attr("start", self.start) | |||
| self.add_prim_attr("limit", self.limit) | |||
| else: | |||
| validator.check_value_type("limit", limit, [float], self.name) | |||
| validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name) | |||
| if self.delta == 0.0: | |||
| raise ValueError("The input of `delta` can not be equal to zero.") | |||
| if self.delta > 0.0 and self.start > self.limit: | |||
| raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, " | |||
| f"but got start:{self.start}, limit:{self.limit}") | |||
| if self.delta < 0.0 and self.start < self.limit: | |||
| raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, " | |||
| f"but got start:{self.start}, limit:{self.limit}") | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name) | |||
| return x_dtype | |||
| class EmbeddingLookup(PrimitiveWithInfer): | |||
| """ | |||
| Returns a slice of input tensor based on the specified indices. | |||
| @@ -25,6 +25,7 @@ from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.nn import Cell | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.ops import prim_attr_register | |||
| from mindspore.ops.primitive import PrimitiveWithInfer | |||
| import mindspore.context as context | |||
| @@ -286,19 +287,10 @@ class SpaceToBatchNDNet(Cell): | |||
| return self.space_to_batch_nd(x) | |||
| class ConfusionMatrixNet(Cell): | |||
| def __init__(self): | |||
| super(ConfusionMatrixNet, self).__init__() | |||
| self.confusion_matrix = P.ConfusionMatrix(4, "int32") | |||
| def construct(self, x, y): | |||
| return self.confusion_matrix(x, y) | |||
| class RangeNet(Cell): | |||
| def __init__(self): | |||
| super(RangeNet, self).__init__() | |||
| self.range_ops = P.Range(1.0, 8.0, 2.0) | |||
| self.range_ops = inner.Range(1.0, 8.0, 2.0) | |||
| def construct(self, x): | |||
| return self.range_ops(x) | |||
| @@ -344,9 +336,6 @@ test_case_array_ops = [ | |||
| ('BatchToSpaceNDNet', { | |||
| 'block': BatchToSpaceNDNet(), | |||
| 'desc_inputs': [Tensor(np.random.rand(4, 1, 1, 1).astype(np.float16))]}), | |||
| ('ConfusionMatrixNet', { | |||
| 'block': ConfusionMatrixNet(), | |||
| 'desc_inputs': [Tensor([0, 1, 1, 3], ms.int32), Tensor([0, 1, 1, 3], ms.int32)]}), | |||
| ('RangeNet', { | |||
| 'block': RangeNet(), | |||
| 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}), | |||
| @@ -25,6 +25,7 @@ from mindspore.common import dtype as mstype | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from ..ut_filter import non_graph_engine | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| @@ -1051,7 +1052,7 @@ test_case_nn_ops = [ | |||
| 'desc_inputs': [[3, 1, 2], Tensor(np.array([0, 1]).astype(np.int32))], | |||
| 'desc_bprop': [[2, 1, 2]]}), | |||
| ('Range', { | |||
| 'block': P.Range(1.0, 5.0), | |||
| 'block': inner.Range(1.0, 5.0), | |||
| 'desc_inputs': [Tensor(np.ones([10]).astype(np.float32))], | |||
| 'desc_bprop': [[10]]}), | |||
| ('UnsortedSegmentSum', { | |||
| @@ -1454,7 +1455,7 @@ test_case_array_ops = [ | |||
| 'desc_inputs': [(Tensor(np.array([1], np.float32)), | |||
| Tensor(np.array([1], np.float32)), | |||
| Tensor(np.array([1], np.float32)))], | |||
| 'desc_bprop': [[3,]]}), | |||
| 'desc_bprop': [[3, ]]}), | |||
| ('Pack_0', { | |||
| 'block': NetForPackInput(P.Pack()), | |||
| 'desc_inputs': [[2, 2], [2, 2], [2, 2]], | |||
| @@ -1527,7 +1528,7 @@ test_case_array_ops = [ | |||
| Tensor(np.array([0, 1, 1]).astype(np.int32))], | |||
| 'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}), | |||
| ('BroadcastTo', { | |||
| 'block': P.BroadcastTo((2,3)), | |||
| 'block': P.BroadcastTo((2, 3)), | |||
| 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))], | |||
| 'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}), | |||
| ('InTopK', { | |||