Browse Source

[ME] change `check_lega_float_value` to `check_is_float` and add `check_is_int`

tags/v1.1.0
chenzomi 5 years ago
parent
commit
cabb387545
11 changed files with 98 additions and 74 deletions
  1. +42
    -10
      mindspore/_checkparam.py
  2. +6
    -6
      mindspore/nn/dynamic_lr.py
  3. +4
    -4
      mindspore/nn/layer/conv.py
  4. +5
    -5
      mindspore/nn/learning_rate_schedule.py
  5. +1
    -2
      mindspore/nn/probability/distribution/distribution.py
  6. +9
    -16
      mindspore/ops/operations/_quant_ops.py
  7. +1
    -1
      mindspore/ops/operations/_thor_ops.py
  8. +12
    -12
      mindspore/ops/operations/array_ops.py
  9. +1
    -1
      mindspore/ops/operations/math_ops.py
  10. +6
    -6
      mindspore/ops/operations/nn_ops.py
  11. +11
    -11
      mindspore/ops/operations/random_ops.py

+ 42
- 10
mindspore/_checkparam.py View File

@@ -111,6 +111,24 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
return arg_value


def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
"""
Checks input value is float type or not.

Usage:
- number = check_is_number(number, int)
- number = check_is_number(number, int, "bias")
- number = check_is_number(number, int, "bias", "bias_class")
"""
prim_name = f'in \'{prim_name}\'' if prim_name else ''
arg_name = f'\'{prim_name}\'' if arg_name else 'Input value'
if isinstance(arg_value, arg_type):
if math.isinf(arg_value) or math.isnan(arg_value):
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
return arg_value
raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`')


class Validator:
"""validator for checking input parameters"""

@@ -140,6 +158,18 @@ class Validator:
f' with type `{type(arg_value).__name__}`.')
return arg_value

@staticmethod
def check_is_int(arg_value, arg_name=None, prim_name=None):
"""
Checks input value is float type or not.

Usage:
- number = check_is_int(number, int)
- number = check_is_int(number, int, "bias")
- number = check_is_int(number, int, "bias", "bias_class")
"""
check_is_number(arg_value, int, arg_name, prim_name)

@staticmethod
def check_positive_int(arg_value, arg_name=None, prim_name=None):
"""
@@ -184,6 +214,18 @@ class Validator:
"""
return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name)

@staticmethod
def check_is_float(arg_value, arg_name=None, prim_name=None):
"""
Checks input value is float type or not.

Usage:
- number = check_is_float(number, int)
- number = check_is_float(number, int, "bias")
- number = check_is_float(number, int, "bias", "bias_class")
"""
check_is_number(arg_value, float, arg_name, prim_name)

@staticmethod
def check_positive_float(arg_value, arg_name=None, prim_name=None):
"""
@@ -453,16 +495,6 @@ class Validator:
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')

@staticmethod
def check_float_legal_value(arg_name, arg_value, prim_name):
"""Checks whether a legal value of float type"""
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
if isinstance(arg_value, float):
if math.isinf(arg_value) or math.isnan(arg_value):
raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.")
return arg_value
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")

@staticmethod
def check_reduce_shape(ori_shape, shape, axis, prim_name):
"""Checks whether shape is ori_shape reduced on axis"""


+ 6
- 6
mindspore/nn/dynamic_lr.py View File

@@ -53,7 +53,7 @@ def piecewise_constant_lr(milestone, learning_rates):
last_item = 0
for i, item in enumerate(milestone):
validator.check_positive_int(item, f'milestone[{i}]')
validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None)
validator.check_is_float(learning_rates[i], f'learning_rates[{i}]')
if item < last_item:
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
lr += [learning_rates[i]] * (item - last_item)
@@ -67,9 +67,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
validator.check_positive_int(decay_epoch, 'decay_epoch')
validator.check_positive_float(learning_rate, 'learning_rate')
validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_is_float(learning_rate, 'learning_rate')
validator.check_positive_float(decay_rate, 'decay_rate')
validator.check_float_legal_value('decay_rate', decay_rate, None)
validator.check_is_float(decay_rate, 'decay_rate')
validator.check_value_type('is_stair', is_stair, [bool], None)


@@ -235,7 +235,7 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
raise TypeError("min_lr must be float.")
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_positive_float(max_lr, 'max_lr')
validator.check_float_legal_value('max_lr', max_lr, None)
validator.check_is_float(max_lr, 'max_lr')
validator.check_positive_int(total_step, 'total_step')
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
validator.check_positive_int(decay_epoch, 'decay_epoch')
@@ -300,12 +300,12 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
"""
validator.check_positive_float(learning_rate, 'learning_rate')
validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_is_float(learning_rate, 'learning_rate')
if not isinstance(end_learning_rate, float):
raise TypeError("end_learning_rate must be float.")
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_positive_float(power, 'power')
validator.check_float_legal_value('power', power, None)
validator.check_is_float(power, 'power')
validator.check_positive_int(total_step, 'total_step')
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
validator.check_positive_int(decay_epoch, 'decay_epoch')


+ 4
- 4
mindspore/nn/layer/conv.py View File

@@ -55,11 +55,11 @@ class _Conv(Cell):
self.weight_init = weight_init
self.bias_init = bias_init
if isinstance(padding, int):
Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name)
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
self.padding = padding
elif isinstance(padding, tuple):
for pad in padding:
Validator.check_integer('padding item', pad, 0, Rel.GE, self.cls_name)
Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
self.padding = padding
else:
raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding)))
@@ -386,7 +386,7 @@ class Conv1d(_Conv):
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name)
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
kernel_size = (1, kernel_size)
stride = (1, stride)
@@ -705,7 +705,7 @@ class Conv1dTranspose(_Conv):
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name)
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
kernel_size = (1, kernel_size)
stride = (1, stride)


+ 5
- 5
mindspore/nn/learning_rate_schedule.py View File

@@ -46,9 +46,9 @@ class LearningRateSchedule(Cell):
def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name):
validator.check_positive_int(decay_steps, 'decay_steps', cls_name)
validator.check_positive_float(learning_rate, 'learning_rate', cls_name)
validator.check_float_legal_value('learning_rate', learning_rate, cls_name)
validator.check_is_float(learning_rate, 'learning_rate', cls_name)
validator.check_positive_float(decay_rate, 'decay_rate', cls_name)
validator.check_float_legal_value('decay_rate', decay_rate, cls_name)
validator.check_is_float(decay_rate, 'decay_rate', cls_name)
validator.check_value_type('is_stair', is_stair, [bool], cls_name)


@@ -256,7 +256,7 @@ class CosineDecayLR(LearningRateSchedule):
raise TypeError("min_lr must be float.")
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_positive_float(max_lr, 'max_lr', self.cls_name)
validator.check_float_legal_value('max_lr', max_lr, self.cls_name)
validator.check_is_float(max_lr, 'max_lr', self.cls_name)
validator.check_positive_int(decay_steps, "decay_steps", self.cls_name)
if min_lr >= max_lr:
raise ValueError('`max_lr` should be greater than `min_lr`.')
@@ -319,7 +319,7 @@ class PolynomialDecayLR(LearningRateSchedule):
def __init__(self, learning_rate, end_learning_rate, decay_steps, power, update_decay_steps=False):
super(PolynomialDecayLR, self).__init__()
validator.check_positive_float(learning_rate, 'learning_rate')
validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_is_float(learning_rate, 'learning_rate')
if not isinstance(end_learning_rate, float):
raise TypeError("end_learning_rate must be float.")
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT,
@@ -327,7 +327,7 @@ class PolynomialDecayLR(LearningRateSchedule):
validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name)
validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name)
validator.check_positive_float(power, 'power', self.cls_name)
validator.check_float_legal_value('power', power, self.cls_name)
validator.check_is_float(power, 'power', self.cls_name)

self.decay_steps = decay_steps
self.start_learning_rate = learning_rate


+ 1
- 2
mindspore/nn/probability/distribution/distribution.py View File

@@ -17,7 +17,6 @@ from mindspore import context
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common import get_seed
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
raise_not_implemented_util
@@ -64,7 +63,7 @@ class Distribution(Cell):
if seed is None:
seed = 0
validator.check_value_type('name', name, [str], type(self).__name__)
validator.check_integer('seed', seed, 0, Rel.GE, name)
validator.check_non_negative_int(seed, 'seed', name)

self._name = name
self._seed = seed


+ 9
- 16
mindspore/ops/operations/_quant_ops.py View File

@@ -141,7 +141,7 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
if self.is_ascend:
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
else:
self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name)
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])

@@ -226,10 +226,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_integer(
'quant_delay', quant_delay, 0, Rel.GE, self.name)
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['out'])

@@ -275,8 +273,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")

self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
@@ -371,14 +368,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_integer(
'quant_delay', quant_delay, 0, Rel.GE, self.name)
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
if self.is_ascend:
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
else:
self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name)
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])

def infer_shape(self, x_shape, min_shape, max_shape):
@@ -433,16 +428,14 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")

self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
self.init_prim_io_names(
inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])



+ 1
- 1
mindspore/ops/operations/_thor_ops.py View File

@@ -516,7 +516,7 @@ class Im2Col(PrimitiveWithInfer):
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
if self.pad_mode == 'pad':
validator.check_integer('pad', self.pad, 0, Rel.GE, self.name)
validator.check_non_negative_int(self.pad, 'pad', self.name)
self.add_prim_attr('data_format', "NCHW")

def infer_shape(self, x_shape):


+ 12
- 12
mindspore/ops/operations/array_ops.py View File

@@ -763,7 +763,7 @@ class Split(PrimitiveWithInfer):
x_shape = list(x['shape'])
dim = len(x_shape)
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
validator.check_integer("output_num", self.output_num, 0, Rel.GT, self.name)
validator.check_positive_int(self.output_num, "output_num", self.name)
output_valid_check = x_shape[self.axis] % self.output_num
if output_valid_check != 0:
raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by"
@@ -846,7 +846,7 @@ class TruncatedNormal(PrimitiveWithInfer):
shape_value = shape['value']
validator.check_value_type("shape", shape_value, [tuple], self.name)
for i, value in enumerate(shape_value):
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT, self.name)
validator.check_positive_int(value, f'{i}th value of shape', self.name)
out = {'shape': shape_value,
'dtype': mstype.tensor_type(self.dtype),
'value': None}
@@ -2180,13 +2180,13 @@ class StridedSlice(PrimitiveWithInfer):
shrink_axis_mask=0):
"""Initialize StrideSlice"""
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
validator.check_integer('begin_mask', begin_mask, 0, Rel.GE, self.name)
validator.check_integer('end_mask', end_mask, 0, Rel.GE, self.name)
validator.check_integer('ellipsis_mask', ellipsis_mask, 0, Rel.GE, self.name)
validator.check_non_negative_int(begin_mask, 'begin_mask', self.name)
validator.check_non_negative_int(end_mask, 'end_mask', self.name)
validator.check_non_negative_int(ellipsis_mask, 'ellipsis_mask', self.name)
if len(tuple(filter(lambda x: x == '1', bin(ellipsis_mask)[-1:1:-1]))) > 1:
raise ValueError(f"For '{self.name}', only support one ellipsis in the index, but got {end_mask}.")
validator.check_integer('new_axis_mask', new_axis_mask, 0, Rel.GE, self.name)
validator.check_integer('shrink_axis_mask', shrink_axis_mask, 0, Rel.GE, self.name)
validator.check_non_negative_int(new_axis_mask, 'new_axis_mask', self.name)
validator.check_non_negative_int(shrink_axis_mask, 'shrink_axis_mask', self.name)

def __infer__(self, x, begin, end, strides):
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
@@ -2507,7 +2507,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
validator.check_value_type("align_corners", align_corners, [bool], self.name)
validator.check_integer("length of size", len(size), 2, Rel.EQ, self.name)
for i, value in enumerate(size):
validator.check_integer(f'{i}th value of size', value, 0, Rel.GE, self.name)
validator.check_non_negative_int(value, f'{i}th value of size', self.name)
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])

def infer_shape(self, x):
@@ -3176,7 +3176,7 @@ class SpaceToBatch(PrimitiveWithInfer):
self.block_size = block_size
validator.check('paddings shape', np.array(paddings).shape, '', (2, 2), Rel.EQ, self.name)
for elem in itertools.chain(*paddings):
validator.check_integer('paddings element', elem, 0, Rel.GE, self.name)
validator.check_non_negative_int(elem, 'paddings element', self.name)
validator.check_value_type('paddings element', elem, [int], self.name)
self.paddings = paddings

@@ -3248,7 +3248,7 @@ class BatchToSpace(PrimitiveWithInfer):
validator.check_value_type('crops type', crops, [list, tuple], self.name)
validator.check('crops shape', np.array(crops).shape, '', (2, 2))
for elem in itertools.chain(*crops):
validator.check_integer('crops element', elem, 0, Rel.GE, self.name)
validator.check_non_negative_int(elem, 'crops element', self.name)
validator.check_value_type('crops element', elem, [int], self.name)
self.crops = crops

@@ -3333,7 +3333,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
validator.check('paddings length', len(paddings), '', 2, Rel.EQ, self.name)
validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name)
for elem in itertools.chain(*paddings):
validator.check_integer('paddings element', elem, 0, Rel.GE, self.name)
validator.check_non_negative_int(elem, 'paddings element', self.name)
validator.check_value_type('paddings element', elem, [int], self.name)
self.paddings = paddings
block_shape_append = [1] + list(self.block_shape)
@@ -3426,7 +3426,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
validator.check('crops length', len(crops), '', 2, Rel.EQ, self.name)
validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name)
for elem in itertools.chain(*crops):
validator.check_integer('crops element', elem, 0, Rel.GE, self.name)
validator.check_non_negative_int(elem, 'crops element', self.name)
validator.check_value_type('crops element', elem, [int], self.name)
self.crops = crops
block_shape_append = [1] + list(self.block_shape)


+ 1
- 1
mindspore/ops/operations/math_ops.py View File

@@ -3019,7 +3019,7 @@ class NMSWithMask(PrimitiveWithInfer):
def infer_shape(self, bboxes_shape):
cls_name = self.name
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
validator.check_integer("bboxes.shape[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
validator.check_positive_int(bboxes_shape[0], "bboxes.shape[0]", cls_name)
validator.check_integer("bboxes.shape[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
num = bboxes_shape[0]
return (bboxes_shape, (num,), (num,))


+ 6
- 6
mindspore/ops/operations/nn_ops.py View File

@@ -1001,7 +1001,7 @@ class Conv2D(PrimitiveWithInfer):
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
if self.pad_mode == 'pad':
for item in pad:
validator.check_integer('pad item', item, 0, Rel.GE, self.name)
validator.check_non_negative_int(item, 'pad item', self.name)

self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
self.add_prim_attr('data_format', "NCHW")
@@ -1139,7 +1139,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
if self.pad_mode == 'pad':
for item in pad:
validator.check_integer('pad item', item, 0, Rel.GE, self.name)
validator.check_non_negative_int(item, 'pad item', self.name)
self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name)
self.add_prim_attr('data_format', "NCHW")
self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name)
@@ -1525,7 +1525,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
if self.pad_mode == 'pad':
for item in pad:
validator.check_integer('pad item', item, 0, Rel.GE, self.name)
validator.check_non_negative_int(item, 'pad item', self.name)

pad_mode = pad_mode.upper()
self.add_prim_attr('pad_mode', pad_mode)
@@ -1534,7 +1534,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
self.add_prim_attr('data_format', "NCHW")
if pad_list:
for x in pad_list:
validator.check_integer('element of pad_list', x, 0, Rel.GE, self.name)
validator.check_non_negative_int(x, 'element of pad_list', self.name)
self.pad_list = pad_list

def __infer__(self, doutput, w, x_size):
@@ -2568,7 +2568,7 @@ class OneHot(PrimitiveWithInfer):
indices_shp = indices['shape']
validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name)
depth_val = depth['value']
validator.check_integer("depth", depth_val, 0, Rel.GE, self.name)
validator.check_non_negative_int(depth_val, "depth", self.name)
# create new dimension at end if self.axis is -1
_ = indices_shp.insert(self.axis, depth_val) if self.axis >= 0 else indices_shp.append(depth_val)

@@ -5722,7 +5722,7 @@ class LRN(PrimitiveWithInfer):
validator.check_value_type("beta", beta, [float], self.name)
validator.check_value_type("norm_region", norm_region, [str], self.name)
validator.check_string(norm_region, ['ACROSS_CHANNELS'], 'norm_region', self.name)
validator.check_integer("depth_radius", depth_radius, 0, Rel.GE, self.name)
validator.check_non_negative_int(depth_radius, "depth_radius", self.name)

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name)


+ 11
- 11
mindspore/ops/operations/random_ops.py View File

@@ -44,8 +44,8 @@ class StandardNormal(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0):
"""Initialize StandardNormal"""
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)

def __infer__(self, shape):
shape_v = shape["value"]
@@ -141,8 +141,8 @@ class Gamma(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0):
"""Initialize Gamma"""
self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)

def __infer__(self, shape, alpha, beta):
shape_v = shape["value"]
@@ -193,8 +193,8 @@ class Poisson(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0):
"""Initialize Poisson"""
self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)

def __infer__(self, shape, mean):
shape_v = shape["value"]
@@ -249,8 +249,8 @@ class UniformInt(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0):
"""Initialize UniformInt"""
self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output'])
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)

def __infer__(self, shape, minval, maxval):
shape_v = shape["value"]
@@ -296,8 +296,8 @@ class UniformReal(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0):
"""Initialize UniformReal"""
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)

def __infer__(self, shape):
shape_v = shape["value"]
@@ -449,7 +449,7 @@ class Multinomial(PrimitiveWithInfer):
def __init__(self, seed=0):
"""init"""
Validator.check_value_type("seed", seed, [int], self.name)
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
Validator.check_non_negative_int(seed, "seed", self.name)
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])

def __infer__(self, inputs, num_samples):


Loading…
Cancel
Save