From aae06487772daf52b3d091f9951fe9376d57f001 Mon Sep 17 00:00:00 2001 From: lihongkang <[lihongkang1@huawei.com]> Date: Thu, 17 Sep 2020 20:04:30 +0800 Subject: [PATCH] fix bugs --- mindspore/nn/layer/math.py | 4 +- mindspore/ops/_op_impl/tbe/gather_v2.py | 44 ++++++++------------ mindspore/ops/_op_impl/tbe/scatter_update.py | 3 +- mindspore/ops/operations/array_ops.py | 21 +++++++++- mindspore/ops/operations/math_ops.py | 2 + mindspore/ops/operations/nn_ops.py | 24 +++++++---- 6 files changed, 57 insertions(+), 41 deletions(-) diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index af73f7ea80..f83f49552e 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -59,6 +59,7 @@ class ReduceLogSumExp(Cell): >>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32)) >>> op = nn.ReduceLogSumExp(keep_dims=True) >>> output = op(input_x, 1) + (3, 1, 5, 6) """ def __init__(self, axis, keep_dims=False): @@ -217,9 +218,10 @@ class LGamma(Cell): Tensor, has the same shape and dtype as the `input_x`. Examples: - >>> input_x = Tensor(np.array(2, 3, 4).astype(np.float32)) + >>> input_x = Tensor(np.array([2, 3, 4]).astype(np.float32)) >>> op = nn.LGamma() >>> output = op(input_x) + [3.5762787e-07 6.9314754e-01 1.7917603e+00] """ def __init__(self): diff --git a/mindspore/ops/_op_impl/tbe/gather_v2.py b/mindspore/ops/_op_impl/tbe/gather_v2.py index 72ba17d942..df9eb882e4 100644 --- a/mindspore/ops/_op_impl/tbe/gather_v2.py +++ b/mindspore/ops/_op_impl/tbe/gather_v2.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""AddN op""" +"""GatherV2 op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType gather_v2_op_info = TBERegOp("GatherV2") \ @@ -23,40 +23,30 @@ gather_v2_op_info = TBERegOp("GatherV2") \ .compute_cost(10) \ .kernel_name("gather_v2_d") \ .partial_flag(True) \ - .attr("axis", "optional", "int", "all") \ + .attr("axis", "required", "int", "all") \ .input(0, "x", False, "required", "all") \ .input(1, "indices", False, "required", "all") \ .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ - .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ - .dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \ - .dtype_format(DataType.I8_5HD, DataType.I64_5HD, DataType.I8_5HD) \ - .dtype_format(DataType.I8_FracZ, DataType.I32_FracZ, DataType.I8_FracZ) \ - .dtype_format(DataType.I8_FracZ, DataType.I64_FracZ, DataType.I8_FracZ) \ .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ - .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ - .dtype_format(DataType.U8_5HD, DataType.I32_5HD, DataType.U8_5HD) \ - .dtype_format(DataType.U8_5HD, DataType.I64_5HD, DataType.U8_5HD) \ - .dtype_format(DataType.U8_FracZ, DataType.I32_FracZ, DataType.U8_FracZ) \ - .dtype_format(DataType.U8_FracZ, DataType.I64_FracZ, DataType.U8_FracZ) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ - .dtype_format(DataType.I32_5HD, DataType.I64_5HD, DataType.I32_5HD) \ - .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ - .dtype_format(DataType.I32_FracZ, DataType.I64_FracZ, DataType.I32_FracZ) \ - .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \ .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.I32_5HD, DataType.F16_5HD) \ - .dtype_format(DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \ - .dtype_format(DataType.F16_FracZ, DataType.I32_FracZ, DataType.F16_FracZ) \ - .dtype_format(DataType.F16_FracZ, DataType.I64_FracZ, DataType.F16_FracZ) \ - .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \ - .dtype_format(DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \ - .dtype_format(DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \ - .dtype_format(DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \ + .dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/scatter_update.py b/mindspore/ops/_op_impl/tbe/scatter_update.py index 244b8ab21f..ebe2220cd0 100644 --- a/mindspore/ops/_op_impl/tbe/scatter_update.py +++ b/mindspore/ops/_op_impl/tbe/scatter_update.py @@ -26,13 +26,12 @@ scatter_update_op_info = TBERegOp("ScatterUpdate") \ .attr("use_locking", "optional", "bool", "all") \ .input(0, "var", False, "required", "all") \ .input(1, "indices", False, "required", "all") \ - .input(1, "updates", False, "required", "all") \ + .input(2, "updates", False, "required", "all") \ .output(0, "var", False, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ .get_op_info() diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 638d74f364..8a9ab4a925 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -336,6 +336,7 @@ class IsInstance(PrimitiveWithInfer): Examples: >>> a = 1 >>> result = P.IsInstance()(a, mindspore.int32) + True """ @prim_attr_register @@ -634,6 +635,9 @@ class GatherV2(PrimitiveWithCheck): >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32) >>> axis = 1 >>> out = P.GatherV2()(input_params, input_indices, axis) + [[2.0, 7.0], + [4.0, 54.0], + [2.0, 55.0]] """ @prim_attr_register @@ -940,6 +944,8 @@ class OnesLike(PrimitiveWithInfer): >>> oneslike = P.OnesLike() >>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)) >>> output = oneslike(x) + [[1, 1], + [1, 1]] """ @prim_attr_register @@ -970,6 +976,8 @@ class ZerosLike(PrimitiveWithInfer): >>> zeroslike = P.ZerosLike() >>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32)) >>> output = zeroslike(x) + [[0.0, 0.0], + [0.0, 0.0]] """ @prim_attr_register @@ -1628,6 +1636,10 @@ class Concat(PrimitiveWithInfer): >>> data2 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)) >>> op = P.Concat() >>> output = op((data1, data2)) + [[0, 1], + [2, 1], + [0, 1], + [2, 1]] """ @prim_attr_register @@ -2502,6 +2514,7 @@ class GatherNd(PrimitiveWithInfer): >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> op = P.GatherNd() >>> output = op(input_x, indices) + [-0.1, 0.5] """ @prim_attr_register @@ -2525,10 +2538,10 @@ class TensorScatterUpdate(PrimitiveWithInfer): Update tensor value using given values, along with the input indices. Inputs: - - **input_x** (Tensor) - The target tensor. + - **input_x** (Tensor) - The target tensor. The dimension of input_x must be equal to indices.shape[-1]. - **indices** (Tensor) - The index of input tensor whose data type is int32. - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, - and update.shape = indices.shape + input_x.shape[1:]. + and update.shape = indices.shape[:-1] + input_x.shape[indices.shape[-1]:]. Outputs: Tensor, has the same shape and type as `input_x`. @@ -2539,6 +2552,8 @@ class TensorScatterUpdate(PrimitiveWithInfer): >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) >>> op = P.TensorScatterUpdate() >>> output = op(input_x, indices, update) + [[1.0, 0.3, 3.6], + [0.4, 2.2, -3.2]] """ @prim_attr_register @@ -2591,6 +2606,8 @@ class ScatterUpdate(_ScatterOp): >>> updates = Tensor(np_updates, mindspore.float32) >>> op = P.ScatterUpdate() >>> output = op(input_x, indices, updates) + [[2.0, 1.2, 1.0], + [3.0, 1.2, 1.0]] """ @prim_attr_register diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index c2a516af3b..1f5430e576 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -3487,7 +3487,9 @@ class Eps(PrimitiveWithInfer): Tensor, has the same type and shape as `input_x`, but filled with `input_x` dtype minimum val. Examples: + >>> input_x = Tensor([4, 1, 2, 3], mindspore.float32) >>> out = P.Eps()(input_x) + [1.52587891e-05, 1.52587891e-05, 1.52587891e-05, 1.52587891e-05] """ @prim_attr_register diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index c728bce89a..d54c6333da 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2306,7 +2306,7 @@ class DropoutGenMask(Primitive): Inputs: - **shape** (tuple[int]) - The shape of target mask. - - **keep_prob** (Tensor) - The keep rate, between 0 and 1, e.g. keep_prob = 0.9, + - **keep_prob** (Tensor) - The keep rate, greater than 0 and less equal than 1, e.g. keep_prob = 0.9, means dropping out 10% of input units. Outputs: @@ -2314,9 +2314,10 @@ class DropoutGenMask(Primitive): Examples: >>> dropout_gen_mask = P.DropoutGenMask() - >>> shape = (20, 16, 50) + >>> shape = (2, 4, 5) >>> keep_prob = Tensor(0.5, mindspore.float32) >>> mask = dropout_gen_mask(shape, keep_prob) + [249, 11, 134, 133, 143, 246, 89, 52, 169, 15, 94, 63, 146, 103, 7, 101] """ @prim_attr_register @@ -2338,7 +2339,7 @@ class DropoutDoMask(PrimitiveWithInfer): - **mask** (Tensor) - The mask to be applied on `input_x`, which is the output of `DropoutGenMask`. And the shape of `input_x` must be the same as the value of `DropoutGenMask`'s input `shape`. If input wrong `mask`, the output of `DropoutDoMask` are unpredictable. - - **keep_prob** (Tensor) - The keep rate, between 0 and 1, e.g. keep_prob = 0.9, + - **keep_prob** (Tensor) - The keep rate, greater than 0 and less equal than 1, e.g. keep_prob = 0.9, means dropping out 10% of input units. The value of `keep_prob` is the same as the input `keep_prob` of `DropoutGenMask`. @@ -2346,14 +2347,18 @@ class DropoutDoMask(PrimitiveWithInfer): Tensor, the value that applied dropout on. Examples: - >>> x = Tensor(np.ones([20, 16, 50]), mindspore.float32) - >>> shape = (20, 16, 50) + >>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32) + >>> shape = (2, 2, 3) >>> keep_prob = Tensor(0.5, mindspore.float32) >>> dropout_gen_mask = P.DropoutGenMask() >>> dropout_do_mask = P.DropoutDoMask() >>> mask = dropout_gen_mask(shape, keep_prob) >>> output = dropout_do_mask(x, mask, keep_prob) - >>> assert output.shape == (20, 16, 50) + >>> assert output.shape == (2, 2, 3) + [[[2.0, 0.0, 0.0], + [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], + [2.0, 2.0, 2.0]]] """ @prim_attr_register @@ -2401,11 +2406,11 @@ class ResizeBilinear(PrimitiveWithInfer): rescale by `new_height / height`. Default: False. Inputs: - - **input** (Tensor) - Image to be resized. Tensor of shape `(N_i, ..., N_n, height, width)`, - with data type of float32 or float16. + - **input** (Tensor) - Image to be resized. Input images must be a 4-D tensor with shape + [batch, channels, height, width], with data type of float32 or float16. Outputs: - Tensor, resized image. Tensor of shape `(N_i, ..., N_n, new_height, new_width)` in `float32`. + Tensor, resized image. 4-D with shape [batch, channels, new_height, new_width] in `float32`. Examples: >>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32) @@ -2419,6 +2424,7 @@ class ResizeBilinear(PrimitiveWithInfer): pass def infer_shape(self, input_shape): + validator.check("input shape rank", len(input_shape), "", 4, Rel.EQ, self.name) input_shape = list(input_shape) batch, channel, _, _ = input_shape out_shape = [batch, channel]