Browse Source

check parameter types of uniform

tags/v1.0.0
peixu_ren 5 years ago
parent
commit
64aff52ffc
2 changed files with 11 additions and 0 deletions
  1. +7
    -0
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  2. +4
    -0
      mindspore/ops/composite/random_ops.py

+ 7
- 0
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -131,6 +131,13 @@ def is_same_type(inst, type_):
return inst == type_


@constexpr
def check_valid_type(data_type, value_type, name):
if not data_type in value_type:
raise TypeError(
f"For {name}, valid type include {value_type}, {data_type} is invalid")


def slice_expand(input_slices, shape):
"""
Converts slice to indices.


+ 4
- 0
mindspore/ops/composite/random_ops.py View File

@@ -92,6 +92,9 @@ def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32):
If dtype is int32, only one number is allowed.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
Must be non-negative. Default: 0.
dtype (mindspore.dtype): type of the Uniform distribution. If it is int32, it generates numbers from discrete
uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only
supports these two data types. Default: mstype.float32.

Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of minval and maxval.
@@ -112,6 +115,7 @@ def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32):
"""
minval_dtype = F.dtype(minval)
maxval_dtype = F.dtype(maxval)
const_utils.check_valid_type(dtype, [mstype.int32, mstype.float32], 'uniform')
const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform")
const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform")
const_utils.check_non_negative("seed", seed, "uniform")


Loading…
Cancel
Save