Browse Source

!7981 fix bugs of op Dropout, Slice, StridedSlice and BatchMatMul etc.

Merge pull request !7981 from lihongkang/v2_master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
5e8858e561
4 changed files with 30 additions and 8 deletions
  1. +7
    -5
      mindspore/nn/layer/basic.py
  2. +2
    -0
      mindspore/nn/layer/image.py
  3. +3
    -3
      mindspore/ops/operations/array_ops.py
  4. +18
    -0
      mindspore/ops/operations/math_ops.py

+ 7
- 5
mindspore/nn/layer/basic.py View File

@@ -75,11 +75,12 @@ class Dropout(Cell):
Examples: Examples:
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32) >>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> net = nn.Dropout(keep_prob=0.8) >>> net = nn.Dropout(keep_prob=0.8)
>>> net.set_train()
>>> net(x) >>> net(x)
[[[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]]]
[[[0., 1.25, 0.],
[1.25, 1.25, 1.25]],
[[1.25, 1.25, 1.25],
[1.25, 1.25, 1.25]]]
""" """


def __init__(self, keep_prob=0.5, dtype=mstype.float32): def __init__(self, keep_prob=0.5, dtype=mstype.float32):
@@ -287,7 +288,8 @@ class ClipByNorm(Cell):
>>> net = nn.ClipByNorm() >>> net = nn.ClipByNorm()
>>> input = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
>>> clip_norm = Tensor(np.array([100]).astype(np.float32)) >>> clip_norm = Tensor(np.array([100]).astype(np.float32))
>>> net(input, clip_norm)
>>> net(input, clip_norm).shape
(4, 16)


""" """




+ 2
- 0
mindspore/nn/layer/image.py View File

@@ -447,6 +447,8 @@ class CentralCrop(Cell):
>>> net = nn.CentralCrop(central_fraction=0.5) >>> net = nn.CentralCrop(central_fraction=0.5)
>>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32) >>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32)
>>> output = net(image) >>> output = net(image)
>>> output.shape
(4, 3, 2, 2)
""" """


def __init__(self, central_fraction): def __init__(self, central_fraction):


+ 3
- 3
mindspore/ops/operations/array_ops.py View File

@@ -1941,7 +1941,7 @@ class Slice(PrimitiveWithInfer):
""" """
Slices a tensor in the specified shape. Slices a tensor in the specified shape.


Args:
Inputs:
x (Tensor): The target tensor. x (Tensor): The target tensor.
begin (tuple): The beginning of the slice. Only constant value is allowed. begin (tuple): The beginning of the slice. Only constant value is allowed.
size (tuple): The size of the slice. Only constant value is allowed. size (tuple): The size of the slice. Only constant value is allowed.
@@ -2262,8 +2262,8 @@ class StridedSlice(PrimitiveWithInfer):
validator.check_value_type("strides", strides_v, [tuple], self.name) validator.check_value_type("strides", strides_v, [tuple], self.name)


if tuple(filter(lambda x: not isinstance(x, int), begin_v + end_v + strides_v)): if tuple(filter(lambda x: not isinstance(x, int), begin_v + end_v + strides_v)):
raise ValueError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, "
f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.")
raise TypeError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, "
f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.")


if tuple(filter(lambda x: x == 0, strides_v)): if tuple(filter(lambda x: x == 0, strides_v)):
raise ValueError(f"For '{self.name}', the strides cannot contain 0, but got strides: {strides_v}.") raise ValueError(f"For '{self.name}', the strides cannot contain 0, but got strides: {strides_v}.")


+ 18
- 0
mindspore/ops/operations/math_ops.py View File

@@ -724,11 +724,29 @@ class BatchMatMul(MatMul):
>>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32) >>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> batmatmul = P.BatchMatMul() >>> batmatmul = P.BatchMatMul()
>>> output = batmatmul(input_x, input_y) >>> output = batmatmul(input_x, input_y)
[[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]

[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]]
>>> >>>
>>> input_x = Tensor(np.ones(shape=[2, 4, 3, 1]), mindspore.float32) >>> input_x = Tensor(np.ones(shape=[2, 4, 3, 1]), mindspore.float32)
>>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32) >>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> batmatmul = P.BatchMatMul(transpose_a=True) >>> batmatmul = P.BatchMatMul(transpose_a=True)
>>> output = batmatmul(input_x, input_y) >>> output = batmatmul(input_x, input_y)
[[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]

[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]]
""" """


@prim_attr_register @prim_attr_register


Loading…
Cancel
Save