Browse Source

!6501 fix bugs of op DropoutDoMask, Zeroslike, LGamma, GatherV2 and TensorScatterUpdate etc.

Merge pull request !6501 from lihongkang/v2_master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
53cc0230b8
6 changed files with 57 additions and 41 deletions
  1. +3
    -1
      mindspore/nn/layer/math.py
  2. +17
    -27
      mindspore/ops/_op_impl/tbe/gather_v2.py
  3. +1
    -2
      mindspore/ops/_op_impl/tbe/scatter_update.py
  4. +19
    -2
      mindspore/ops/operations/array_ops.py
  5. +2
    -0
      mindspore/ops/operations/math_ops.py
  6. +15
    -9
      mindspore/ops/operations/nn_ops.py

+ 3
- 1
mindspore/nn/layer/math.py View File

@@ -59,6 +59,7 @@ class ReduceLogSumExp(Cell):
>>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
>>> op = nn.ReduceLogSumExp(keep_dims=True)
>>> output = op(input_x, 1)
(3, 1, 5, 6)
"""

def __init__(self, axis, keep_dims=False):
@@ -217,9 +218,10 @@ class LGamma(Cell):
Tensor, has the same shape and dtype as the `input_x`.

Examples:
>>> input_x = Tensor(np.array(2, 3, 4).astype(np.float32))
>>> input_x = Tensor(np.array([2, 3, 4]).astype(np.float32))
>>> op = nn.LGamma()
>>> output = op(input_x)
[3.5762787e-07 6.9314754e-01 1.7917603e+00]
"""

def __init__(self):


+ 17
- 27
mindspore/ops/_op_impl/tbe/gather_v2.py View File

@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================

"""AddN op"""
"""GatherV2 op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

gather_v2_op_info = TBERegOp("GatherV2") \
@@ -23,40 +23,30 @@ gather_v2_op_info = TBERegOp("GatherV2") \
.compute_cost(10) \
.kernel_name("gather_v2_d") \
.partial_flag(True) \
.attr("axis", "optional", "int", "all") \
.attr("axis", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_5HD, DataType.I64_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_FracZ, DataType.I32_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.I8_FracZ, DataType.I64_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.I32_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_5HD, DataType.I64_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_FracZ, DataType.I32_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.U8_FracZ, DataType.I64_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_5HD, DataType.I64_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.I32_FracZ, DataType.I64_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.I32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.I32_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracZ, DataType.I64_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
.get_op_info()




+ 1
- 2
mindspore/ops/_op_impl/tbe/scatter_update.py View File

@@ -26,13 +26,12 @@ scatter_update_op_info = TBERegOp("ScatterUpdate") \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(1, "updates", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()




+ 19
- 2
mindspore/ops/operations/array_ops.py View File

@@ -336,6 +336,7 @@ class IsInstance(PrimitiveWithInfer):
Examples:
>>> a = 1
>>> result = P.IsInstance()(a, mindspore.int32)
True
"""

@prim_attr_register
@@ -634,6 +635,9 @@ class GatherV2(PrimitiveWithCheck):
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
>>> axis = 1
>>> out = P.GatherV2()(input_params, input_indices, axis)
[[2.0, 7.0],
[4.0, 54.0],
[2.0, 55.0]]
"""

@prim_attr_register
@@ -940,6 +944,8 @@ class OnesLike(PrimitiveWithInfer):
>>> oneslike = P.OnesLike()
>>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
>>> output = oneslike(x)
[[1, 1],
[1, 1]]
"""

@prim_attr_register
@@ -970,6 +976,8 @@ class ZerosLike(PrimitiveWithInfer):
>>> zeroslike = P.ZerosLike()
>>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
>>> output = zeroslike(x)
[[0.0, 0.0],
[0.0, 0.0]]
"""

@prim_attr_register
@@ -1628,6 +1636,10 @@ class Concat(PrimitiveWithInfer):
>>> data2 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
>>> op = P.Concat()
>>> output = op((data1, data2))
[[0, 1],
[2, 1],
[0, 1],
[2, 1]]
"""

@prim_attr_register
@@ -2502,6 +2514,7 @@ class GatherNd(PrimitiveWithInfer):
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> op = P.GatherNd()
>>> output = op(input_x, indices)
[-0.1, 0.5]
"""

@prim_attr_register
@@ -2525,10 +2538,10 @@ class TensorScatterUpdate(PrimitiveWithInfer):
Update tensor value using given values, along with the input indices.

Inputs:
- **input_x** (Tensor) - The target tensor.
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be equal to indices.shape[-1].
- **indices** (Tensor) - The index of input tensor whose data type is int32.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:].
and update.shape = indices.shape[:-1] + input_x.shape[indices.shape[-1]:].

Outputs:
Tensor, has the same shape and type as `input_x`.
@@ -2539,6 +2552,8 @@ class TensorScatterUpdate(PrimitiveWithInfer):
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.TensorScatterUpdate()
>>> output = op(input_x, indices, update)
[[1.0, 0.3, 3.6],
[0.4, 2.2, -3.2]]
"""

@prim_attr_register
@@ -2591,6 +2606,8 @@ class ScatterUpdate(_ScatterOp):
>>> updates = Tensor(np_updates, mindspore.float32)
>>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, updates)
[[2.0, 1.2, 1.0],
[3.0, 1.2, 1.0]]
"""

@prim_attr_register


+ 2
- 0
mindspore/ops/operations/math_ops.py View File

@@ -3487,7 +3487,9 @@ class Eps(PrimitiveWithInfer):
Tensor, has the same type and shape as `input_x`, but filled with `input_x` dtype minimum val.

Examples:
>>> input_x = Tensor([4, 1, 2, 3], mindspore.float32)
>>> out = P.Eps()(input_x)
[1.52587891e-05, 1.52587891e-05, 1.52587891e-05, 1.52587891e-05]
"""

@prim_attr_register


+ 15
- 9
mindspore/ops/operations/nn_ops.py View File

@@ -2306,7 +2306,7 @@ class DropoutGenMask(Primitive):

Inputs:
- **shape** (tuple[int]) - The shape of target mask.
- **keep_prob** (Tensor) - The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
- **keep_prob** (Tensor) - The keep rate, greater than 0 and less equal than 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units.

Outputs:
@@ -2314,9 +2314,10 @@ class DropoutGenMask(Primitive):

Examples:
>>> dropout_gen_mask = P.DropoutGenMask()
>>> shape = (20, 16, 50)
>>> shape = (2, 4, 5)
>>> keep_prob = Tensor(0.5, mindspore.float32)
>>> mask = dropout_gen_mask(shape, keep_prob)
[249, 11, 134, 133, 143, 246, 89, 52, 169, 15, 94, 63, 146, 103, 7, 101]
"""

@prim_attr_register
@@ -2338,7 +2339,7 @@ class DropoutDoMask(PrimitiveWithInfer):
- **mask** (Tensor) - The mask to be applied on `input_x`, which is the output of `DropoutGenMask`. And the
shape of `input_x` must be the same as the value of `DropoutGenMask`'s input `shape`. If input wrong `mask`,
the output of `DropoutDoMask` are unpredictable.
- **keep_prob** (Tensor) - The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
- **keep_prob** (Tensor) - The keep rate, greater than 0 and less equal than 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units. The value of `keep_prob` is the same as the input `keep_prob` of
`DropoutGenMask`.

@@ -2346,14 +2347,18 @@ class DropoutDoMask(PrimitiveWithInfer):
Tensor, the value that applied dropout on.

Examples:
>>> x = Tensor(np.ones([20, 16, 50]), mindspore.float32)
>>> shape = (20, 16, 50)
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> shape = (2, 2, 3)
>>> keep_prob = Tensor(0.5, mindspore.float32)
>>> dropout_gen_mask = P.DropoutGenMask()
>>> dropout_do_mask = P.DropoutDoMask()
>>> mask = dropout_gen_mask(shape, keep_prob)
>>> output = dropout_do_mask(x, mask, keep_prob)
>>> assert output.shape == (20, 16, 50)
>>> assert output.shape == (2, 2, 3)
[[[2.0, 0.0, 0.0],
[0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0],
[2.0, 2.0, 2.0]]]
"""

@prim_attr_register
@@ -2401,11 +2406,11 @@ class ResizeBilinear(PrimitiveWithInfer):
rescale by `new_height / height`. Default: False.

Inputs:
- **input** (Tensor) - Image to be resized. Tensor of shape `(N_i, ..., N_n, height, width)`,
with data type of float32 or float16.
- **input** (Tensor) - Image to be resized. Input images must be a 4-D tensor with shape
[batch, channels, height, width], with data type of float32 or float16.

Outputs:
Tensor, resized image. Tensor of shape `(N_i, ..., N_n, new_height, new_width)` in `float32`.
Tensor, resized image. 4-D with shape [batch, channels, new_height, new_width] in `float32`.

Examples:
>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)
@@ -2419,6 +2424,7 @@ class ResizeBilinear(PrimitiveWithInfer):
pass

def infer_shape(self, input_shape):
validator.check("input shape rank", len(input_shape), "", 4, Rel.EQ, self.name)
input_shape = list(input_shape)
batch, channel, _, _ = input_shape
out_shape = [batch, channel]


Loading…
Cancel
Save