Browse Source

!7723 fix bugs of op ArgMinWithValue, LRN and AvgPool1d

Merge pull request !7723 from lihongkang/v2_master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2e6794fcec
3 changed files with 8 additions and 3 deletions
  1. +1
    -1
      mindspore/nn/layer/pooling.py
  2. +2
    -1
      mindspore/ops/operations/array_ops.py
  3. +5
    -1
      mindspore/ops/operations/nn_ops.py

+ 1
- 1
mindspore/nn/layer/pooling.py View File

@@ -334,7 +334,7 @@ class AvgPool1d(_PoolNd):
Tensor of shape :math:`(N, C_{out}, L_{out})`. Tensor of shape :math:`(N, C_{out}, L_{out})`.


Examples: Examples:
>>> pool = nn.AvgPool1d(kernel_size=6, strides=1)
>>> pool = nn.AvgPool1d(kernel_size=6, stride=1)
>>> x = Tensor(np.random.randint(0, 10, [1, 3, 6]), mindspore.float32) >>> x = Tensor(np.random.randint(0, 10, [1, 3, 6]), mindspore.float32)
>>> output = pool(x) >>> output = pool(x)
>>> output.shape >>> output.shape


+ 2
- 1
mindspore/ops/operations/array_ops.py View File

@@ -1376,8 +1376,9 @@ class ArgMinWithValue(PrimitiveWithInfer):
- output_x (Tensor) - The minimum value of input tensor, with the same shape as index. - output_x (Tensor) - The minimum value of input tensor, with the same shape as index.


Examples: Examples:
>>> input_x = Tensor(np.random.rand(5))
>>> input_x = Tensor(np.random.rand(5), mindspore.float32)
>>> index, output = P.ArgMinWithValue()(input_x) >>> index, output = P.ArgMinWithValue()(input_x)
0 0.0496291
""" """


@prim_attr_register @prim_attr_register


+ 5
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -5740,9 +5740,13 @@ class LRN(PrimitiveWithInfer):
Tensor, with the same shape and data type as the input tensor. Tensor, with the same shape and data type as the input tensor.


Examples: Examples:
>>> x = Tensor(np.random.rand(1, 10, 4, 4)), mindspore.float32)
>>> x = Tensor(np.random.rand(1, 2, 2, 2), mindspore.float32)
>>> lrn = P.LRN() >>> lrn = P.LRN()
>>> lrn(x) >>> lrn(x)
[[[[0.18990143 0.59475636]
[0.6291904 0.1371534 ]]
[[0.6258911 0.4964315 ]
[0.3141494 0.43636137]]]]
""" """
@prim_attr_register @prim_attr_register
def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5, norm_region="ACROSS_CHANNELS"): def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5, norm_region="ACROSS_CHANNELS"):


Loading…
Cancel
Save