Browse Source

!2929 fix LRN

Merge pull request !2929 from jiangjinsheng/issue_fix4
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
2042bb8911
2 changed files with 4 additions and 1 deletions
  1. +3
    -1
      mindspore/ops/operations/array_ops.py
  2. +1
    -0
      mindspore/ops/operations/nn_ops.py

+ 3
- 1
mindspore/ops/operations/array_ops.py View File

@@ -988,7 +988,7 @@ class InvertPermutation(PrimitiveWithInfer):
values can not be negative.

Inputs:
- **input_x** (Union(tuple[int]) - The input tuple is constructed by multiple
- **input_x** (Union(tuple[int], list[int]) - The input is constructed by multiple
integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices.
The values must include 0. There can be no duplicate values or negative values.
Only constant value is allowed.
@@ -1016,6 +1016,8 @@ class InvertPermutation(PrimitiveWithInfer):
validator.check_value_type("shape", x_shp, [tuple, list], self.name)
if mstype.issubclass_(x['dtype'], mstype.tensor):
raise ValueError(f'For \'{self.name}\' the input value must be non-Tensor.')
for i, value in enumerate(x_value):
validator.check_value_type("input[%d]" % i, value, [int], self.name)
z = [x_value[i] for i in range(len(x_value))]
z.sort()



+ 1
- 0
mindspore/ops/operations/nn_ops.py View File

@@ -4974,6 +4974,7 @@ class LRN(PrimitiveWithInfer):
validator.check_value_type("beta", beta, [float], self.name)
validator.check_value_type("norm_region", norm_region, [str], self.name)
validator.check_string('norm_region', norm_region, ['ACROSS_CHANNELS'], self.name)
validator.check_integer("depth_radius", depth_radius, 0, Rel.GE, self.name)

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name)


Loading…
Cancel
Save