Browse Source

fix Randperm expample and infer check

pull/15805/head
yanzhenxiang2020 4 years ago
parent
commit
8e00f48deb
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      mindspore/ops/operations/inner_ops.py

+ 4
- 4
mindspore/ops/operations/inner_ops.py View File

@@ -77,8 +77,7 @@ class Randperm(PrimitiveWithInfer):
dtype (mindspore.dtype): The type of output. Default: mindspore.int32. dtype (mindspore.dtype): The type of output. Default: mindspore.int32.


Inputs: Inputs:
- **n** (Tensor[int]) - The input tensor with shape: (1,) and the number must be in (0, `max_length`].
Default: 1.
- **n** (Tensor[int32]) - The input tensor with shape: (1,) and the number must be in [0, `max_length`].


Outputs: Outputs:
- **output** (Tensor) - The output Tensor with shape: (`max_length`,) and type: `dtype`. - **output** (Tensor) - The output Tensor with shape: (`max_length`,) and type: `dtype`.
@@ -87,6 +86,7 @@ class Randperm(PrimitiveWithInfer):
TypeError: If neither `max_length` nor `pad` is an int. TypeError: If neither `max_length` nor `pad` is an int.
TypeError: If `n` is not a Tensor. TypeError: If `n` is not a Tensor.
TypeError: If `n` has non-Int elements. TypeError: If `n` has non-Int elements.
TypeError: If `n` has negative elements.


Supported Platforms: Supported Platforms:
``Ascend`` ``Ascend``
@@ -96,7 +96,7 @@ class Randperm(PrimitiveWithInfer):
>>> n = Tensor([20], dtype=mindspore.int32) >>> n = Tensor([20], dtype=mindspore.int32)
>>> output = randperm(n) >>> output = randperm(n)
>>> print(output) >>> print(output)
[15 6 11 19 14 16 9 5 13 18 4 10 8 0 17 2 14 1 12 3 7
[15 6 11 19 14 16 9 5 13 18 4 10 8 0 17 2 1 12 3 7
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1] -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
""" """


@@ -105,7 +105,7 @@ class Randperm(PrimitiveWithInfer):
"""Initialize Randperm""" """Initialize Randperm"""
validator.check_value_type("pad", pad, [int], self.name) validator.check_value_type("pad", pad, [int], self.name)
validator.check_value_type("max_length", max_length, [int], self.name) validator.check_value_type("max_length", max_length, [int], self.name)
validator.check_int(max_length, 1, Rel.GE, "1", self.name)
validator.check_int(max_length, 1, Rel.GE, "max_length", self.name)
self.dtype = dtype self.dtype = dtype
self.max_length = max_length self.max_length = max_length
self.init_prim_io_names(inputs=[], outputs=['output']) self.init_prim_io_names(inputs=[], outputs=['output'])


Loading…
Cancel
Save