| @@ -45,8 +45,9 @@ class Embedding(Cell): | |||||
| Inputs: | Inputs: | ||||
| - **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The element of | - **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The element of | ||||
| the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero | |||||
| if larger than vocab_size. | |||||
| the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero | |||||
| if larger than vocab_size. | |||||
| Outputs: | Outputs: | ||||
| Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`. | Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`. | ||||
| @@ -17,6 +17,7 @@ | |||||
| from .. import operations as P | from .. import operations as P | ||||
| from ..operations import _grad_ops as G | from ..operations import _grad_ops as G | ||||
| from ..operations import _inner_ops as inner | |||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from ..functional import broadcast_gradient_args | from ..functional import broadcast_gradient_args | ||||
| from .. import functional as F | from .. import functional as F | ||||
| @@ -341,7 +342,7 @@ def get_bprop_sparse_gather_v2(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.Range) | |||||
| @bprop_getters.register(inner.Range) | |||||
| def get_bprop_range(self): | def get_bprop_range(self): | ||||
| """Generate bprop for Range""" | """Generate bprop for Range""" | ||||
| @@ -23,7 +23,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Diag, DiagPart, DType, ExpandDims, Eye, | Diag, DiagPart, DType, ExpandDims, Eye, | ||||
| Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, | Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, | ||||
| IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | ||||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, | |||||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, | |||||
| SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | ||||
| ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ||||
| Shape, Size, Slice, Split, EmbeddingLookup, | Shape, Size, Slice, Split, EmbeddingLookup, | ||||
| @@ -75,7 +75,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl | |||||
| ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, | ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, | ||||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | ||||
| from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, | from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, | ||||
| CheckValid, MakeRefKey, Partial, Depend, CheckBprop, ConfusionMatrix) | |||||
| CheckValid, MakeRefKey, Partial, Depend, CheckBprop) | |||||
| from . import _quant_ops | from . import _quant_ops | ||||
| from ._quant_ops import * | from ._quant_ops import * | ||||
| from .thor_ops import * | from .thor_ops import * | ||||
| @@ -303,13 +303,12 @@ __all__ = [ | |||||
| "Atan", | "Atan", | ||||
| "Atanh", | "Atanh", | ||||
| "BasicLSTMCell", | "BasicLSTMCell", | ||||
| "ConfusionMatrix", | |||||
| "BroadcastTo", | "BroadcastTo", | ||||
| "Range", | |||||
| "DataFormatDimMap", | "DataFormatDimMap", | ||||
| "ApproximateEqual", | "ApproximateEqual", | ||||
| "InplaceUpdate", | "InplaceUpdate", | ||||
| "InTopK", | "InTopK", | ||||
| "DataFormatDimMap" | |||||
| ] | ] | ||||
| __all__.extend(_quant_ops.__all__) | __all__.extend(_quant_ops.__all__) | ||||
| @@ -15,9 +15,10 @@ | |||||
| """Inner operators.""" | """Inner operators.""" | ||||
| from ..._checkparam import Rel | |||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||||
| class ExtractImagePatches(PrimitiveWithInfer): | class ExtractImagePatches(PrimitiveWithInfer): | ||||
| @@ -96,3 +97,61 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||||
| """infer dtype""" | """infer dtype""" | ||||
| validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) | validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) | ||||
| return input_x | return input_x | ||||
| class Range(PrimitiveWithInfer): | |||||
| r""" | |||||
| Creates a sequence of numbers. | |||||
| Set `input_x` as :math:`x_i` for each element, `output` as follows: | |||||
| .. math:: | |||||
| \text{output}(x_i) = x_i * \text{delta} + \text{start} | |||||
| Args: | |||||
| start (float): If `limit` is `None`, the value acts as limit in the range and first entry | |||||
| defaults to `0`. Otherwise, it acts as first entry in the range. | |||||
| limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start` | |||||
| while set the first entry of the range to `0`. It can not be equal to `start`. | |||||
| delta (float): Increment of the range. It can not be equal to zero. Default: 1.0. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32. | |||||
| Outputs: | |||||
| Tensor, has the same shape and dtype as `input_x`. | |||||
| Examples: | |||||
| >>> range = P.Range(1.0, 8.0, 2.0) | |||||
| >>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32) | |||||
| >>> range(x) | |||||
| [3, 5, 7, 5] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, start, limit=None, delta=1.0): | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | |||||
| self.delta = validator.check_value_type("delta", delta, [float], self.name) | |||||
| validator.check_value_type("start", start, [float], self.name) | |||||
| if limit is None: | |||||
| self.start = 0.0 | |||||
| self.limit = start | |||||
| self.add_prim_attr("start", self.start) | |||||
| self.add_prim_attr("limit", self.limit) | |||||
| else: | |||||
| validator.check_value_type("limit", limit, [float], self.name) | |||||
| validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name) | |||||
| if self.delta == 0.0: | |||||
| raise ValueError("The input of `delta` can not be equal to zero.") | |||||
| if self.delta > 0.0 and self.start > self.limit: | |||||
| raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, " | |||||
| f"but got start:{self.start}, limit:{self.limit}") | |||||
| if self.delta < 0.0 and self.start < self.limit: | |||||
| raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, " | |||||
| f"but got start:{self.start}, limit:{self.limit}") | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name) | |||||
| return x_dtype | |||||
| @@ -556,64 +556,6 @@ class SparseGatherV2(GatherV2): | |||||
| """ | """ | ||||
| class Range(PrimitiveWithInfer): | |||||
| r""" | |||||
| Creates a sequence of numbers. | |||||
| Set `input_x` as :math:`x_i` for each element, `output` as follows: | |||||
| .. math:: | |||||
| \text{output}(x_i) = x_i * \text{delta} + \text{start} | |||||
| Args: | |||||
| start (float): If `limit` is `None`, the value acts as limit in the range and first entry | |||||
| defaults to `0`. Otherwise, it acts as first entry in the range. | |||||
| limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start` | |||||
| while set the first entry of the range to `0`. It can not be equal to `start`. | |||||
| delta (float): Increment of the range. It can not be equal to zero. Default: 1.0. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32. | |||||
| Outputs: | |||||
| Tensor, has the same shape and dtype as `input_x`. | |||||
| Examples: | |||||
| >>> range = P.Range(1.0, 8.0, 2.0) | |||||
| >>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32) | |||||
| >>> range(x) | |||||
| [3, 5, 7, 5] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, start, limit=None, delta=1.0): | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | |||||
| self.delta = validator.check_value_type("delta", delta, [float], self.name) | |||||
| validator.check_value_type("start", start, [float], self.name) | |||||
| if limit is None: | |||||
| self.start = 0.0 | |||||
| self.limit = start | |||||
| self.add_prim_attr("start", self.start) | |||||
| self.add_prim_attr("limit", self.limit) | |||||
| else: | |||||
| validator.check_value_type("limit", limit, [float], self.name) | |||||
| validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name) | |||||
| if self.delta == 0.0: | |||||
| raise ValueError("The input of `delta` can not be equal to zero.") | |||||
| if self.delta > 0.0 and self.start > self.limit: | |||||
| raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, " | |||||
| f"but got start:{self.start}, limit:{self.limit}") | |||||
| if self.delta < 0.0 and self.start < self.limit: | |||||
| raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, " | |||||
| f"but got start:{self.start}, limit:{self.limit}") | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name) | |||||
| return x_dtype | |||||
| class EmbeddingLookup(PrimitiveWithInfer): | class EmbeddingLookup(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Returns a slice of input tensor based on the specified indices. | Returns a slice of input tensor based on the specified indices. | ||||
| @@ -25,6 +25,7 @@ from mindspore import Tensor | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| from mindspore.ops import prim_attr_register | from mindspore.ops import prim_attr_register | ||||
| from mindspore.ops.primitive import PrimitiveWithInfer | from mindspore.ops.primitive import PrimitiveWithInfer | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| @@ -286,19 +287,10 @@ class SpaceToBatchNDNet(Cell): | |||||
| return self.space_to_batch_nd(x) | return self.space_to_batch_nd(x) | ||||
| class ConfusionMatrixNet(Cell): | |||||
| def __init__(self): | |||||
| super(ConfusionMatrixNet, self).__init__() | |||||
| self.confusion_matrix = P.ConfusionMatrix(4, "int32") | |||||
| def construct(self, x, y): | |||||
| return self.confusion_matrix(x, y) | |||||
| class RangeNet(Cell): | class RangeNet(Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(RangeNet, self).__init__() | super(RangeNet, self).__init__() | ||||
| self.range_ops = P.Range(1.0, 8.0, 2.0) | |||||
| self.range_ops = inner.Range(1.0, 8.0, 2.0) | |||||
| def construct(self, x): | def construct(self, x): | ||||
| return self.range_ops(x) | return self.range_ops(x) | ||||
| @@ -344,9 +336,6 @@ test_case_array_ops = [ | |||||
| ('BatchToSpaceNDNet', { | ('BatchToSpaceNDNet', { | ||||
| 'block': BatchToSpaceNDNet(), | 'block': BatchToSpaceNDNet(), | ||||
| 'desc_inputs': [Tensor(np.random.rand(4, 1, 1, 1).astype(np.float16))]}), | 'desc_inputs': [Tensor(np.random.rand(4, 1, 1, 1).astype(np.float16))]}), | ||||
| ('ConfusionMatrixNet', { | |||||
| 'block': ConfusionMatrixNet(), | |||||
| 'desc_inputs': [Tensor([0, 1, 1, 3], ms.int32), Tensor([0, 1, 1, 3], ms.int32)]}), | |||||
| ('RangeNet', { | ('RangeNet', { | ||||
| 'block': RangeNet(), | 'block': RangeNet(), | ||||
| 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}), | 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}), | ||||
| @@ -25,6 +25,7 @@ from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import _grad_ops as G | from mindspore.ops.operations import _grad_ops as G | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| from ..ut_filter import non_graph_engine | from ..ut_filter import non_graph_engine | ||||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | from ....mindspore_test_framework.mindspore_test import mindspore_test | ||||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | from ....mindspore_test_framework.pipeline.forward.compile_forward \ | ||||
| @@ -1051,7 +1052,7 @@ test_case_nn_ops = [ | |||||
| 'desc_inputs': [[3, 1, 2], Tensor(np.array([0, 1]).astype(np.int32))], | 'desc_inputs': [[3, 1, 2], Tensor(np.array([0, 1]).astype(np.int32))], | ||||
| 'desc_bprop': [[2, 1, 2]]}), | 'desc_bprop': [[2, 1, 2]]}), | ||||
| ('Range', { | ('Range', { | ||||
| 'block': P.Range(1.0, 5.0), | |||||
| 'block': inner.Range(1.0, 5.0), | |||||
| 'desc_inputs': [Tensor(np.ones([10]).astype(np.float32))], | 'desc_inputs': [Tensor(np.ones([10]).astype(np.float32))], | ||||
| 'desc_bprop': [[10]]}), | 'desc_bprop': [[10]]}), | ||||
| ('UnsortedSegmentSum', { | ('UnsortedSegmentSum', { | ||||
| @@ -1454,7 +1455,7 @@ test_case_array_ops = [ | |||||
| 'desc_inputs': [(Tensor(np.array([1], np.float32)), | 'desc_inputs': [(Tensor(np.array([1], np.float32)), | ||||
| Tensor(np.array([1], np.float32)), | Tensor(np.array([1], np.float32)), | ||||
| Tensor(np.array([1], np.float32)))], | Tensor(np.array([1], np.float32)))], | ||||
| 'desc_bprop': [[3,]]}), | |||||
| 'desc_bprop': [[3, ]]}), | |||||
| ('Pack_0', { | ('Pack_0', { | ||||
| 'block': NetForPackInput(P.Pack()), | 'block': NetForPackInput(P.Pack()), | ||||
| 'desc_inputs': [[2, 2], [2, 2], [2, 2]], | 'desc_inputs': [[2, 2], [2, 2], [2, 2]], | ||||
| @@ -1527,7 +1528,7 @@ test_case_array_ops = [ | |||||
| Tensor(np.array([0, 1, 1]).astype(np.int32))], | Tensor(np.array([0, 1, 1]).astype(np.int32))], | ||||
| 'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}), | 'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}), | ||||
| ('BroadcastTo', { | ('BroadcastTo', { | ||||
| 'block': P.BroadcastTo((2,3)), | |||||
| 'block': P.BroadcastTo((2, 3)), | |||||
| 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))], | 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))], | ||||
| 'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}), | 'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}), | ||||
| ('InTopK', { | ('InTopK', { | ||||