Browse Source

[ME] delete reduant function in check_parameter

tags/v1.1.0
chenzomi 5 years ago
parent
commit
acadb694aa
28 changed files with 307 additions and 367 deletions
  1. +24
    -67
      mindspore/_checkparam.py
  2. +3
    -3
      mindspore/common/parameter.py
  3. +6
    -16
      mindspore/context.py
  4. +2
    -2
      mindspore/nn/cell.py
  5. +2
    -2
      mindspore/nn/layer/basic.py
  6. +11
    -11
      mindspore/nn/layer/conv.py
  7. +2
    -2
      mindspore/nn/layer/image.py
  8. +4
    -4
      mindspore/nn/layer/pooling.py
  9. +5
    -5
      mindspore/nn/layer/quant.py
  10. +25
    -25
      mindspore/ops/operations/_grad_ops.py
  11. +3
    -3
      mindspore/ops/operations/_inner_ops.py
  12. +14
    -19
      mindspore/ops/operations/_quant_ops.py
  13. +1
    -1
      mindspore/ops/operations/_thor_ops.py
  14. +19
    -20
      mindspore/ops/operations/array_ops.py
  15. +2
    -2
      mindspore/ops/operations/comm_ops.py
  16. +1
    -1
      mindspore/ops/operations/control_ops.py
  17. +2
    -2
      mindspore/ops/operations/debug_ops.py
  18. +1
    -2
      mindspore/ops/operations/inner_ops.py
  19. +16
    -16
      mindspore/ops/operations/math_ops.py
  20. +128
    -128
      mindspore/ops/operations/nn_ops.py
  21. +13
    -13
      mindspore/ops/operations/other_ops.py
  22. +2
    -2
      mindspore/ops/operations/random_ops.py
  23. +1
    -1
      mindspore/ops/primitive.py
  24. +5
    -5
      mindspore/train/quant/quant.py
  25. +3
    -3
      mindspore/train/summary/_summary_adapter.py
  26. +3
    -3
      mindspore/train/summary/summary_record.py
  27. +2
    -2
      tests/ut/python/nn/test_checkparameter.py
  28. +7
    -7
      tests/ut/python/nn/test_parameter.py

+ 24
- 67
mindspore/_checkparam.py View File

@@ -97,7 +97,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
Check argument integer. Check argument integer.


Usage: Usage:
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
""" """
rel_fn = Rel.get_fns(rel) rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
@@ -166,12 +166,12 @@ class Validator:
return arg_value return arg_value


@staticmethod @staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name=None):
def check_int(arg_value, value, rel, arg_name=None, prim_name=None):
""" """
Checks input integer value `arg_value` compare to `value`. Checks input integer value `arg_value` compare to `value`.


Usage: Usage:
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
""" """
return check_number(arg_value, value, rel, int, arg_name, prim_name) return check_number(arg_value, value, rel, int, arg_name, prim_name)


@@ -187,6 +187,16 @@ class Validator:
""" """
return check_is_number(arg_value, int, arg_name, prim_name) return check_is_number(arg_value, int, arg_name, prim_name)


@staticmethod
def check_equal_int(arg_value, value, arg_name=None, prim_name=None):
"""
Checks input integer value `arg_value` compare to `value`.

Usage:
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
"""
return check_number(arg_value, value, Rel.EQ, int, arg_name, prim_name)

@staticmethod @staticmethod
def check_positive_int(arg_value, arg_name=None, prim_name=None): def check_positive_int(arg_value, arg_name=None, prim_name=None):
""" """
@@ -365,6 +375,17 @@ class Validator:
raise ValueError(f'{msg_prefix} `{arg_name}` should be str and must be in `{valid_values}`,' raise ValueError(f'{msg_prefix} `{arg_name}` should be str and must be in `{valid_values}`,'
f' but got `{arg_value}`.') f' but got `{arg_value}`.')


@staticmethod
def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
if reg is None:
# Named string regular expression
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
if re.match(reg, target, flag) is None:
prim_name = f'in `{prim_name}`' if prim_name else ""
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format(
target, prim_name, reg, flag))
return True

@staticmethod @staticmethod
def check_pad_value_by_mode(pad_mode, padding, prim_name): def check_pad_value_by_mode(pad_mode, padding, prim_name):
"""Validates value of padding according to pad_mode""" """Validates value of padding according to pad_mode"""
@@ -530,13 +551,6 @@ class Validator:
f'{tuple(exp_shape)}, but got {shape}.') f'{tuple(exp_shape)}, but got {shape}.')




def check_int_zero_one(input_param):
"""Judge whether it is 0 or 1."""
if input_param in (0, 1):
return input_param
raise ValueError("The data must be 0 or 1.")


def check_input_format(input_param): def check_input_format(input_param):
"""Judge input format.""" """Judge input format."""
if input_param == "NCHW": if input_param == "NCHW":
@@ -544,27 +558,6 @@ def check_input_format(input_param):
raise ValueError("The data format must be NCHW.") raise ValueError("The data format must be NCHW.")




def check_padding(padding):
"""Check padding."""
if padding >= 0:
return padding
raise ValueError("The padding must be at least 0,"" but got padding {}.".format(padding))


def check_padmode(mode):
"""Check padmode."""
if mode in ("same", "valid", "pad"):
return mode
raise ValueError("The pad mode must be same or valid or pad,"" but got mode {}.".format(mode))


def check_tensor_supported_type(dtype):
"""Check tensor dtype."""
if dtype in (mstype.int32, mstype.float32):
return dtype
raise ValueError("The dtype must be mstype.int32 or mstype.float32, but got mstype {}.".format(dtype))


def _expand_tuple(n_dimensions): def _expand_tuple(n_dimensions):
"""To expand a number to tuple.""" """To expand a number to tuple."""


@@ -673,42 +666,6 @@ def check_typename(arg_name, arg_type, valid_types):
f' but got {get_typename(arg_type)}.') f' but got {get_typename(arg_type)}.')




def check_shape(arg_name, arg_value):
"""Check shape."""
# First, check if shape is a tuple
if not isinstance(arg_value, tuple):
raise TypeError(f'The type of `{arg_name}` should be one of {tuple.__name__},'
f' but got {type(arg_value).__name__}.')

# Second, wrap arg_value with numpy array so that it can be checked through numpy api
arg_value = np.array(arg_value)

# shape can not be ()
if arg_value.size == 0:
raise ValueError('Shape can not be empty.')

# shape's dimension should be 1
if arg_value.ndim != 1:
raise ValueError('Shape of tensor should be 1-dim vector, but got {}-dim.'.format(arg_value.ndim))

# Thirdly, check each element's type of the shape
valid_types = (int, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64)
for dim_size in arg_value:
if not isinstance(dim_size, valid_types) or dim_size <= 0:
raise ValueError('Every dimension size of the tensor shape should be a positive integer,'
' but got {}.'.format(dim_size))


def _check_str_by_regular(target, reg=None, flag=re.ASCII):
if reg is None:
# Named string regular expression
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
if re.match(reg, target, flag) is None:
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
return True


def args_type_check(*type_args, **type_kwargs): def args_type_check(*type_args, **type_kwargs):
"""Check whether input data type is correct.""" """Check whether input data type is correct."""




+ 3
- 3
mindspore/common/parameter.py View File

@@ -19,7 +19,7 @@ from .._c_expression import ParamInfo
from . import dtype as mstype from . import dtype as mstype
from .initializer import initializer, Initializer from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor from .tensor import Tensor, MetaTensor
from .._checkparam import _check_str_by_regular
from .._checkparam import Validator
from ..parallel._tensor import _get_slice_index from ..parallel._tensor import _get_slice_index
from ..parallel._auto_parallel_context import auto_parallel_context from ..parallel._auto_parallel_context import auto_parallel_context
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
@@ -263,7 +263,7 @@ class Parameter(MetaTensor):
Returns: Returns:
Parameter, a new parameter. Parameter, a new parameter.
""" """
_check_str_by_regular(prefix)
Validator.check_str_by_regular(prefix)
x = copy(self) x = copy(self)
# pylint: disable=protected-access # pylint: disable=protected-access
x._param_info = self._param_info.clone() x._param_info = self._param_info.clone()
@@ -446,7 +446,7 @@ class ParameterTuple(tuple):
Returns: Returns:
Tuple, the new Parameter tuple. Tuple, the new Parameter tuple.
""" """
_check_str_by_regular(prefix)
Validator.check_str_by_regular(prefix)
new = [] new = []
for x in self: for x in self:
x1 = x.clone(prefix, init) x1 = x.clone(prefix, init)


+ 6
- 16
mindspore/context.py View File

@@ -23,7 +23,7 @@ from collections import namedtuple
from types import FunctionType from types import FunctionType
from mindspore import log as logger from mindspore import log as logger
from mindspore._c_expression import MSContext, ms_ctx_param from mindspore._c_expression import MSContext, ms_ctx_param
from mindspore._checkparam import args_type_check
from mindspore._checkparam import args_type_check, Validator
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context _reset_auto_parallel_context
from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
@@ -35,9 +35,9 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut


GRAPH_MODE = 0 GRAPH_MODE = 0
PYNATIVE_MODE = 1 PYNATIVE_MODE = 1
# The max memory size of graph plus variable.
_DEVICE_APP_MEMORY_SIZE = 31
_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable.
_re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
_k_context = None


def _make_directory(path): def _make_directory(path):
"""Make directory.""" """Make directory."""
@@ -223,7 +223,7 @@ class _Context:


def set_variable_memory_max_size(self, variable_memory_max_size): def set_variable_memory_max_size(self, variable_memory_max_size):
"""set values of variable_memory_max_size and graph_memory_max_size""" """set values of variable_memory_max_size and graph_memory_max_size"""
if not _check_input_format(variable_memory_max_size):
if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern):
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
raise ValueError("Context param variable_memory_max_size should be less than 31GB.") raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
@@ -235,7 +235,7 @@ class _Context:
self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_) self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_)


def set_max_device_memory(self, max_device_memory): def set_max_device_memory(self, max_device_memory):
if not _check_input_format(max_device_memory):
if not Validator.check_str_by_regular(max_device_memory, _re_pattern):
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
max_device_memory_value = float(max_device_memory[:-2]) max_device_memory_value = float(max_device_memory[:-2])
if max_device_memory_value == 0: if max_device_memory_value == 0:
@@ -294,16 +294,6 @@ class _Context:
thread_info.debug_runtime = enable thread_info.debug_runtime = enable




def _check_input_format(x):
import re
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
result = re.match(pattern, x)
return result is not None


_k_context = None


def _context(): def _context():
""" """
Get the global _context, if context is not created, create a new one. Get the global _context, if context is not created, create a new one.


+ 2
- 2
mindspore/nn/cell.py View File

@@ -23,7 +23,7 @@ from mindspore import log as logger
from .. import context from .. import context
from ..common import dtype as mstype from ..common import dtype as mstype
from ..common.api import _executor, _pynative_exec from ..common.api import _executor, _pynative_exec
from .._checkparam import _check_str_by_regular
from .._checkparam import Validator
from ..common.parameter import Parameter, ParameterTuple from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend, Cell_ from .._c_expression import init_backend, Cell_
from ..ops.primitive import Primitive from ..ops.primitive import Primitive
@@ -715,7 +715,7 @@ class Cell(Cell_):
recurse (bool): Whether contains the parameters of subcells. Default: True. recurse (bool): Whether contains the parameters of subcells. Default: True.
""" """


_check_str_by_regular(prefix)
Validator.check_str_by_regular(prefix)
for name, param in self.parameters_and_names(expand=recurse): for name, param in self.parameters_and_names(expand=recurse):
if prefix != '': if prefix != '':
param.is_init = False param.is_init = False


+ 2
- 2
mindspore/nn/layer/basic.py View File

@@ -549,7 +549,7 @@ class Unfold(Cell):


@constexpr @constexpr
def _get_matrix_diag_assist(x_shape, x_dtype): def _get_matrix_diag_assist(x_shape, x_dtype):
Validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist")
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist")
base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1) base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],)) assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],))
return Tensor(assist, x_dtype) return Tensor(assist, x_dtype)
@@ -557,7 +557,7 @@ def _get_matrix_diag_assist(x_shape, x_dtype):


@constexpr @constexpr
def _get_matrix_diag_part_assist(x_shape, x_dtype): def _get_matrix_diag_part_assist(x_shape, x_dtype):
Validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist")
Validator.check_int(len(x_shape), 2, Rel.GE, "x rank", "_get_matrix_diag_part_assist")
base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1) base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape) assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape)
return Tensor(assist, x_dtype) return Tensor(assist, x_dtype)


+ 11
- 11
mindspore/nn/layer/conv.py View File

@@ -239,8 +239,8 @@ class Conv2d(_Conv):
"""Initialize depthwise conv2d op""" """Initialize depthwise conv2d op"""
if context.get_context("device_target") == "Ascend" and self.group > 1: if context.get_context("device_target") == "Ascend" and self.group > 1:
self.dilation = self._dilation self.dilation = self._dilation
Validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
Validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
Validator.check_equal_int(self.group, self.in_channels, 'group')
Validator.check_equal_int(self.group, self.out_channels, 'group')
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1, self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
pad_mode=self.pad_mode, pad_mode=self.pad_mode,
@@ -384,10 +384,10 @@ class Conv1d(_Conv):
Validator.check_value_type("stride", stride, [int], self.cls_name) Validator.check_value_type("stride", stride, [int], self.cls_name)
Validator.check_value_type("padding", padding, [int], self.cls_name) Validator.check_value_type("padding", padding, [int], self.cls_name)
Validator.check_value_type("dilation", dilation, [int], self.cls_name) Validator.check_value_type("dilation", dilation, [int], self.cls_name)
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
Validator.check_int(kernel_size, 1, Rel.GE, 'kernel_size', self.cls_name)
Validator.check_int(stride, 1, Rel.GE, 'stride', self.cls_name)
Validator.check_non_negative_int(padding, 'padding', self.cls_name) Validator.check_non_negative_int(padding, 'padding', self.cls_name)
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
Validator.check_int(dilation, 1, Rel.GE, 'dilation', self.cls_name)
kernel_size = (1, kernel_size) kernel_size = (1, kernel_size)
stride = (1, stride) stride = (1, stride)
dilation = (1, dilation) dilation = (1, dilation)
@@ -395,7 +395,7 @@ class Conv1d(_Conv):
get_dtype = P.DType() get_dtype = P.DType()
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
weight_init_shape = get_shape(weight_init) weight_init_shape = get_shape(weight_init)
Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name)
Validator.check_equal_int(len(weight_init_shape), 3, 'weight_init_shape', self.cls_name)
weight_init_dtype = get_dtype(weight_init) weight_init_dtype = get_dtype(weight_init)
weight_init_value = weight_init.asnumpy() weight_init_value = weight_init.asnumpy()
weight_init_value = np.expand_dims(weight_init_value, 2) weight_init_value = np.expand_dims(weight_init_value, 2)
@@ -539,7 +539,7 @@ class Conv2dTranspose(_Conv):
dilation = twice(dilation) dilation = twice(dilation)
Validator.check_value_type('padding', padding, (int, tuple), self.cls_name) Validator.check_value_type('padding', padding, (int, tuple), self.cls_name)
if isinstance(padding, tuple): if isinstance(padding, tuple):
Validator.check_integer('padding size', len(padding), 4, Rel.EQ, self.cls_name)
Validator.check_equal_int(len(padding), 4, 'padding size', self.cls_name)
# out_channels and in_channels swap. # out_channels and in_channels swap.
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel, # cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
# then Conv2dTranspose's out_channel refers to Conv2DBackpropInput's in_channel. # then Conv2dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
@@ -703,10 +703,10 @@ class Conv1dTranspose(_Conv):
Validator.check_value_type("stride", stride, [int], self.cls_name) Validator.check_value_type("stride", stride, [int], self.cls_name)
Validator.check_value_type("padding", padding, [int], self.cls_name) Validator.check_value_type("padding", padding, [int], self.cls_name)
Validator.check_value_type("dilation", dilation, [int], self.cls_name) Validator.check_value_type("dilation", dilation, [int], self.cls_name)
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
Validator.check_int(kernel_size, 1, Rel.GE, 'kernel_size', self.cls_name)
Validator.check_int(stride, 1, Rel.GE, 'stride', self.cls_name)
Validator.check_non_negative_int(padding, 'padding', self.cls_name) Validator.check_non_negative_int(padding, 'padding', self.cls_name)
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
Validator.check_int(dilation, 1, Rel.GE, 'dilation', self.cls_name)
kernel_size = (1, kernel_size) kernel_size = (1, kernel_size)
stride = (1, stride) stride = (1, stride)
dilation = (1, dilation) dilation = (1, dilation)
@@ -714,7 +714,7 @@ class Conv1dTranspose(_Conv):
get_dtype = P.DType() get_dtype = P.DType()
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
weight_init_shape = get_shape(weight_init) weight_init_shape = get_shape(weight_init)
Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name)
Validator.check_equal_int(len(weight_init_shape), 3, 'weight_init_shape', self.cls_name)
weight_init_dtype = get_dtype(weight_init) weight_init_dtype = get_dtype(weight_init)
weight_init_value = weight_init.asnumpy() weight_init_value = weight_init.asnumpy()
weight_init_value = np.expand_dims(weight_init_value, 2) weight_init_value = np.expand_dims(weight_init_value, 2)


+ 2
- 2
mindspore/nn/layer/image.py View File

@@ -220,7 +220,7 @@ class SSIM(Cell):
validator.check_value_type('max_val', max_val, [int, float], self.cls_name) validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val self.max_val = max_val
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name)
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
@@ -298,7 +298,7 @@ class MSSSIM(Cell):
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val self.max_val = max_val
validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name) validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name)
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name)
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)


+ 4
- 4
mindspore/nn/layer/pooling.py View File

@@ -190,8 +190,8 @@ class MaxPool1d(_PoolNd):
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name) validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name) validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name) self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name)
validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name)
self.kernel_size = (1, kernel_size) self.kernel_size = (1, kernel_size)
self.stride = (1, stride) self.stride = (1, stride)
self.max_pool = P.MaxPool(ksize=self.kernel_size, self.max_pool = P.MaxPool(ksize=self.kernel_size,
@@ -349,8 +349,8 @@ class AvgPool1d(_PoolNd):
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name) validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name) validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name) self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name)
validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name)
self.kernel_size = (1, kernel_size) self.kernel_size = (1, kernel_size)
self.stride = (1, stride) self.stride = (1, stride)
self.avg_pool = P.AvgPool(ksize=self.kernel_size, self.avg_pool = P.AvgPool(ksize=self.kernel_size,


+ 5
- 5
mindspore/nn/layer/quant.py View File

@@ -323,7 +323,7 @@ class FakeQuantWithMinMax(Cell):
Validator.check_type("min_init", min_init, [int, float]) Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float]) Validator.check_type("max_init", max_init, [int, float])
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init self.min_init = min_init
self.max_init = max_init self.max_init = max_init
self.num_bits = num_bits self.num_bits = num_bits
@@ -489,8 +489,8 @@ class Conv2dBnFoldQuant(Cell):


# initialize convolution op and Parameter # initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1: if context.get_context('device_target') == "Ascend" and group > 1:
Validator.check_integer('group', group, in_channels, Rel.EQ)
Validator.check_integer('group', group, out_channels, Rel.EQ)
Validator.check_equal_int(group, in_channels, 'group')
Validator.check_equal_int(group, out_channels, 'group')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
pad_mode=pad_mode, pad_mode=pad_mode,
@@ -674,8 +674,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
self.bias = None self.bias = None
# initialize convolution op and Parameter # initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1: if context.get_context('device_target') == "Ascend" and group > 1:
Validator.check_integer('group', group, in_channels, Rel.EQ)
Validator.check_integer('group', group, out_channels, Rel.EQ)
Validator.check_equal_int(group, in_channels, 'group')
Validator.check_equal_int(group, out_channels, 'group')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
pad_mode=pad_mode, pad_mode=pad_mode,


+ 25
- 25
mindspore/ops/operations/_grad_ops.py View File

@@ -931,19 +931,19 @@ class LSTMGradData(PrimitiveWithInfer):
def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape, def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
hx_shape, cx_shape, reserve_shape, state_shape): hx_shape, cx_shape, reserve_shape, state_shape):
# dhy and dcy should be same shape # dhy and dcy should be same shape
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name)
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name)
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name)
validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)


validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name)
validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)


# dy: (seq_len, batch_size, hidden_size * num_directions) # dy: (seq_len, batch_size, hidden_size * num_directions)
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name)
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name)
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name)
validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)


# (seq_len, batch_size, input_size) # (seq_len, batch_size, input_size)
dx_shape = (y_shape[0], y_shape[1], self.input_size) dx_shape = (y_shape[0], y_shape[1], self.input_size)
@@ -1015,19 +1015,19 @@ class LSTMGrad(PrimitiveWithInfer):
def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape, def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape,
dcy_shape, reserve_shape): dcy_shape, reserve_shape):
# dhy and dcy should be same shape # dhy and dcy should be same shape
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name)
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name)
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name)
validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)


validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name)
validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)


# dy: (seq_len, batch_size, hidden_size * num_directions) # dy: (seq_len, batch_size, hidden_size * num_directions)
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name)
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name)
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name)
validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)


# (seq_len, batch_size, input_size) # (seq_len, batch_size, input_size)
dx_shape = (y_shape[0], y_shape[1], self.input_size) dx_shape = (y_shape[0], y_shape[1], self.input_size)
@@ -1069,7 +1069,7 @@ class DynamicRNNGrad(PrimitiveWithInfer):


def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape, def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape): c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 3, "x_shape", self.name)
num_step, batch_size, input_size = x_shape num_step, batch_size, input_size = x_shape
hidden_size = w_shape[-1] // 4 hidden_size = w_shape[-1] // 4
if w_shape[-1] % 4 != 0: if w_shape[-1] % 4 != 0:
@@ -1575,7 +1575,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):


def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
# dhy and dcy should be same shape # dhy and dcy should be same shape
validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(c_shape), 2, "c rank", self.name)
validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name) validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name) validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name) validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name)
@@ -1624,7 +1624,7 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
self.add_prim_attr("io_format", "HWCN") self.add_prim_attr("io_format", "HWCN")


def infer_shape(self, x_shape, h_shape, dgate_shape): def infer_shape(self, x_shape, h_shape, dgate_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name) validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name)
validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name) validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name)
validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name) validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
@@ -1656,8 +1656,8 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
self.add_prim_attr("io_format", "ND") self.add_prim_attr("io_format", "ND")


def infer_shape(self, dgate_shape, w_shape): def infer_shape(self, dgate_shape, w_shape):
validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)
validator.check_equal_int(len(w_shape), 2, "w rank", self.name)
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
batch_size = dgate_shape[0] batch_size = dgate_shape[0]
hidden_size = dgate_shape[1] // 4 hidden_size = dgate_shape[1] // 4


+ 3
- 3
mindspore/ops/operations/_inner_ops.py View File

@@ -347,7 +347,7 @@ class MatrixDiag(PrimitiveWithInfer):
return x_dtype return x_dtype


def infer_shape(self, x_shape, assist_shape): def infer_shape(self, x_shape, assist_shape):
validator.check_integer("assist rank", len(assist_shape), 2, Rel.GE, self.name)
validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name)
validator.check('rank of x', len(x_shape)+1, validator.check('rank of x', len(x_shape)+1,
'rank of assist', len(assist_shape), Rel.LE, self.name) 'rank of assist', len(assist_shape), Rel.LE, self.name)
validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
@@ -395,7 +395,7 @@ class MatrixDiagPart(PrimitiveWithInfer):
return x_dtype return x_dtype


def infer_shape(self, x_shape, assist_shape): def infer_shape(self, x_shape, assist_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)


if assist_shape[-2] < assist_shape[-1]: if assist_shape[-2] < assist_shape[-1]:
@@ -438,7 +438,7 @@ class MatrixSetDiag(PrimitiveWithInfer):
return x_dtype return x_dtype


def infer_shape(self, x_shape, diagonal_shape, assist_shape): def infer_shape(self, x_shape, diagonal_shape, assist_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)


if x_shape[-2] < x_shape[-1]: if x_shape[-2] < x_shape[-1]:


+ 14
- 19
mindspore/ops/operations/_quant_ops.py View File

@@ -81,11 +81,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
outputs=['min_up', 'max_up']) outputs=['min_up', 'max_up'])


def infer_shape(self, x_shape, min_shape, max_shape): def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
validator.check("min shape", min_shape, "max shape", validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name) max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return min_shape, max_shape return min_shape, max_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
@@ -147,11 +146,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank: if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
if not self.is_ascend: if not self.is_ascend:
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
validator.check("min shape", min_shape, "max shape", validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name) max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return min_shape, max_shape return min_shape, max_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
@@ -228,9 +226,9 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
outputs=['out']) outputs=['out'])


def infer_shape(self, x_shape, min_shape, max_shape): def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return x_shape return x_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
@@ -284,8 +282,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
x_shape, Rel.EQ, self.name) x_shape, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape", validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name) max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return dout_shape return dout_shape


def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
@@ -375,14 +372,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank: if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
if not self.is_ascend: if not self.is_ascend:
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
if len(x_shape) == 1: if len(x_shape) == 1:
self.channel_axis = 0 self.channel_axis = 0
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer(
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer(
"max shape", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name)
validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name)
return x_shape return x_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
@@ -501,7 +496,7 @@ class BatchNormFold(PrimitiveWithInfer):
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return mean_shape, mean_shape, mean_shape, mean_shape return mean_shape, mean_shape, mean_shape, mean_shape


def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
@@ -548,7 +543,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
"batch_std shape", batch_std_shape, Rel.EQ, self.name) "batch_std shape", batch_std_shape, Rel.EQ, self.name)
validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
"input channel", x_shape[self.channel_axis], Rel.EQ, self.name) "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return x_shape return x_shape


def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
@@ -723,7 +718,7 @@ class BatchNormFold2(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
Rel.EQ, self.name) Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return x_shape return x_shape


def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
@@ -771,7 +766,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis], validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
Rel.EQ, self.name) Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape


def infer_dtype(self, dout_type, x_type, gamma_type, def infer_dtype(self, dout_type, x_type, gamma_type,


+ 1
- 1
mindspore/ops/operations/_thor_ops.py View File

@@ -520,7 +520,7 @@ class Im2Col(PrimitiveWithInfer):
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
kernel_size_h = self.kernel_size[0] kernel_size_h = self.kernel_size[0]
kernel_size_w = self.kernel_size[1] kernel_size_w = self.kernel_size[1]
stride_h = self.stride[2] stride_h = self.stride[2]


+ 19
- 20
mindspore/ops/operations/array_ops.py View File

@@ -583,8 +583,8 @@ class Transpose(PrimitiveWithInfer):


tmp = list(p_value) tmp = list(p_value)
for i, dim in enumerate(p_value): for i, dim in enumerate(p_value):
validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE, self.name)
validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT, self.name)
validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name)
validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name)
tmp.remove(dim) tmp.remove(dim)
if dim in tmp: if dim in tmp:
raise ValueError('The value of perm is wrong.') raise ValueError('The value of perm is wrong.')
@@ -725,8 +725,8 @@ class Padding(PrimitiveWithInfer):
def __infer__(self, x): def __infer__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape']) x_shape = list(x['shape'])
validator.check_integer("rank of x", len(x_shape), 1, Rel.GT, self.name)
validator.check_integer("last dim of x", x_shape[-1], 1, Rel.EQ, self.name)
validator.check_int(len(x_shape), 1, Rel.GT, "rank of x", self.name)
validator.check_int(x_shape[-1], 1, Rel.EQ, "last dim of x", self.name)
out_shape = x_shape out_shape = x_shape
out_shape[-1] = self.pad_dim_size out_shape[-1] = self.pad_dim_size
out = {'shape': out_shape, out = {'shape': out_shape,
@@ -1575,7 +1575,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
valid_type = [mstype.float16, mstype.float32, mstype.int32] valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0], validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value'] num_segments_v = num_segments['value']
@@ -1628,7 +1628,7 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
valid_type = [mstype.float16, mstype.float32, mstype.int32] valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0], validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value'] num_segments_v = num_segments['value']
@@ -1730,7 +1730,7 @@ class ParallelConcat(PrimitiveWithInfer):
x_shp = values['shape'] x_shp = values['shape']
x_type = values['dtype'] x_type = values['dtype']


validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name)
validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name)


args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
@@ -1738,7 +1738,7 @@ class ParallelConcat(PrimitiveWithInfer):
first_elem = x_shp[0] first_elem = x_shp[0]
for i, elem in enumerate(x_shp[1:]): for i, elem in enumerate(x_shp[1:]):
j = i + 1 j = i + 1
validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name)
validator.check_equal_int(elem[0], 1, f'x_shp[{j}][0]', self.name)
validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name) validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name)


ret_shp = x_shp[0].copy() ret_shp = x_shp[0].copy()
@@ -1755,7 +1755,7 @@ class ParallelConcat(PrimitiveWithInfer):
def _get_pack_shape(x_shape, x_type, axis, prim_name): def _get_pack_shape(x_shape, x_type, axis, prim_name):
"""for pack output shape""" """for pack output shape"""
validator.check_value_type("shape", x_shape, [tuple, list], prim_name) validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GE, prim_name)
validator.check_int(len(x_shape), 1, Rel.GE, "len of input_x", prim_name)
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name) validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
rank_base = len(x_shape[0]) rank_base = len(x_shape[0])
N = len(x_shape) N = len(x_shape)
@@ -1871,8 +1871,8 @@ class Unpack(PrimitiveWithInfer):
validator.check_positive_int(output_num, "output_num", self.name) validator.check_positive_int(output_num, "output_num", self.name)
self.add_prim_attr('num', output_num) self.add_prim_attr('num', output_num)
output_valid_check = x_shape[self.axis] - output_num output_valid_check = x_shape[self.axis] - output_num
validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ,
self.name)
validator.check_int(output_valid_check, 0, Rel.EQ,
"The dimension which to unpack divides output_num", self.name)
out_shapes = [] out_shapes = []
out_dtypes = [] out_dtypes = []
out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:] out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
@@ -2523,7 +2523,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
"""Initialize ResizeNearestNeighbor""" """Initialize ResizeNearestNeighbor"""
validator.check_value_type("size", size, [tuple, list], self.name) validator.check_value_type("size", size, [tuple, list], self.name)
validator.check_value_type("align_corners", align_corners, [bool], self.name) validator.check_value_type("align_corners", align_corners, [bool], self.name)
validator.check_integer("length of size", len(size), 2, Rel.EQ, self.name)
validator.check_equal_int(len(size), 2, "length of size", self.name)
for i, value in enumerate(size): for i, value in enumerate(size):
validator.check_non_negative_int(value, f'{i}th value of size', self.name) validator.check_non_negative_int(value, f'{i}th value of size', self.name)
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
@@ -3134,9 +3134,8 @@ class DepthToSpace(PrimitiveWithInfer):
for i in range(2): for i in range(2):
out_shape[i + 2] *= self.block_size out_shape[i + 2] *= self.block_size


validator.check_integer('x_shape[1] % (block_size*block_size)',
x_shape[1] % (self.block_size * self.block_size),
0, Rel.EQ, self.name)
validator.check_int(x_shape[1] % (self.block_size * self.block_size),
0, Rel.EQ, 'x_shape[1] % (block_size*block_size)', self.name)
out_shape[1] //= self.block_size * self.block_size out_shape[1] //= self.block_size * self.block_size
return out_shape return out_shape


@@ -3205,7 +3204,7 @@ class SpaceToBatch(PrimitiveWithInfer):
return x_dtype return x_dtype


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 4, 'rank of input_x', self.name)
out_shape = copy.deepcopy(x_shape) out_shape = copy.deepcopy(x_shape)
for i in range(2): for i in range(2):
padded = out_shape[i + 2] + self.paddings[i][0] + self.paddings[i][1] padded = out_shape[i + 2] + self.paddings[i][0] + self.paddings[i][1]
@@ -3367,7 +3366,7 @@ class SpaceToBatchND(PrimitiveWithInfer):


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
validator.check_equal_int(x_rank, 4, 'x_shape rank', self.name)
out_shape = copy.deepcopy(x_shape) out_shape = copy.deepcopy(x_shape)


block_shape_prod = 1 block_shape_prod = 1
@@ -3460,7 +3459,7 @@ class BatchToSpaceND(PrimitiveWithInfer):


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
validator.check_int(x_rank, 4, Rel.EQ, 'x_shape rank', self.name)
out_shape = copy.deepcopy(x_shape) out_shape = copy.deepcopy(x_shape)


block_shape_prod = 1 block_shape_prod = 1
@@ -3607,11 +3606,11 @@ class Meshgrid(PrimitiveWithInfer):


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_value_type("shape", x_shape, [tuple, list], self.name) validator.check_value_type("shape", x_shape, [tuple, list], self.name)
validator.check_integer("len of input_x", len(x_shape), 2, Rel.GE, self.name)
validator.check_int(len(x_shape), 2, Rel.GE, "len of input_x", self.name)
n = len(x_shape) n = len(x_shape)
shape_0 = [] shape_0 = []
for s in x_shape: for s in x_shape:
validator.check_integer('each_input_rank', len(s), 1, Rel.EQ, self.name)
validator.check_int(len(s), 1, Rel.EQ, 'each_input_rank', self.name)
shape_0.append(s[0]) shape_0.append(s[0])
if self.indexing == "xy": if self.indexing == "xy":
shape_0[0], shape_0[1] = shape_0[1], shape_0[0] shape_0[0], shape_0[1] = shape_0[1], shape_0[0]


+ 2
- 2
mindspore/ops/operations/comm_ops.py View File

@@ -204,7 +204,7 @@ class _HostAllGather(PrimitiveWithInfer):
if group is None: if group is None:
raise ValueError(f"For '{self.name}' group must be set.") raise ValueError(f"For '{self.name}' group must be set.")
validator.check_value_type('group', group, (tuple, list), self.name) validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
for r in group: for r in group:
validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
validator.check_value_type("rank_id", r, (int,), self.name) validator.check_value_type("rank_id", r, (int,), self.name)
@@ -313,7 +313,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
raise ValueError(f"For '{self.name}' group must be set.") raise ValueError(f"For '{self.name}' group must be set.")
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
validator.check_value_type('group', group, (tuple, list), self.name) validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
for r in group: for r in group:
validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
validator.check_value_type("rank_id", r, (int,), self.name) validator.check_value_type("rank_id", r, (int,), self.name)


+ 1
- 1
mindspore/ops/operations/control_ops.py View File

@@ -126,7 +126,7 @@ class GeSwitch(PrimitiveWithInfer):
raise NotImplementedError raise NotImplementedError


def infer_shape(self, data, pred): def infer_shape(self, data, pred):
validator.check_integer("pred rank", len(pred), 0, Rel.EQ, self.name)
validator.check_equal_int(len(pred), 0, "pred rank", self.name)
return (data, data) return (data, data)


def infer_dtype(self, data_type, pred_type): def infer_dtype(self, data_type, pred_type):


+ 2
- 2
mindspore/ops/operations/debug_ops.py View File

@@ -374,9 +374,9 @@ class Assert(PrimitiveWithInfer):


def infer_shape(self, condition, inputs): def infer_shape(self, condition, inputs):
condition_len = len(condition) condition_len = len(condition)
validator.check_integer("condition's rank", condition_len, 1, Rel.LE, self.name)
validator.check_int(condition_len, 1, Rel.LE, "condition's rank", self.name)
if condition_len == 1: if condition_len == 1:
validator.check_integer("condition[0]", condition[0], 1, Rel.EQ, self.name)
validator.check_equal_int(condition[0], 1, "condition[0]", self.name)
return [1] return [1]


def infer_dtype(self, condition, inputs): def infer_dtype(self, condition, inputs):


+ 1
- 2
mindspore/ops/operations/inner_ops.py View File

@@ -17,7 +17,6 @@


import numbers import numbers
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common.dtype import tensor, dtype_to_pytype from ...common.dtype import tensor, dtype_to_pytype
from ..primitive import prim_attr_register, PrimitiveWithInfer from ..primitive import prim_attr_register, PrimitiveWithInfer


@@ -43,7 +42,7 @@ class ScalarCast(PrimitiveWithInfer):
pass pass


def __infer__(self, x, t): def __infer__(self, x, t):
validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name)
validator.check_equal_int(len(x['shape']), 0, 'x shape', self.name)
value, to = x['value'], t['value'] value, to = x['value'], t['value']
if value is not None: if value is not None:
validator.check_value_type("value", value, [numbers.Number, bool], self.name) validator.check_value_type("value", value, [numbers.Number, bool], self.name)


+ 16
- 16
mindspore/ops/operations/math_ops.py View File

@@ -827,7 +827,7 @@ class AddN(PrimitiveWithInfer):


def infer_shape(self, inputs): def infer_shape(self, inputs):
cls_name = self.name cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
self.add_prim_attr('n', len(inputs)) self.add_prim_attr('n', len(inputs))
shp0 = inputs[0] shp0 = inputs[0]
for i, shp in enumerate(inputs): for i, shp in enumerate(inputs):
@@ -837,7 +837,7 @@ class AddN(PrimitiveWithInfer):
def infer_dtype(self, inputs): def infer_dtype(self, inputs):
cls_name = self.name cls_name = self.name
validator.check_value_type("inputs", inputs, [tuple, list], cls_name) validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
args = {} args = {}
contains_undetermined = False contains_undetermined = False
for i, dtype in enumerate(inputs): for i, dtype in enumerate(inputs):
@@ -910,7 +910,7 @@ class AccumulateNV2(PrimitiveWithInfer):


def infer_shape(self, inputs): def infer_shape(self, inputs):
cls_name = self.name cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
self.add_prim_attr('n', len(inputs)) self.add_prim_attr('n', len(inputs))
shp0 = inputs[0] shp0 = inputs[0]
for i, shp in enumerate(inputs): for i, shp in enumerate(inputs):
@@ -920,7 +920,7 @@ class AccumulateNV2(PrimitiveWithInfer):
def infer_dtype(self, inputs): def infer_dtype(self, inputs):
cls_name = self.name cls_name = self.name
validator.check_value_type("inputs", inputs, [tuple, list], cls_name) validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
args = {} args = {}
for i, dtype in enumerate(inputs): for i, dtype in enumerate(inputs):
args[f"inputs[{i}]"] = dtype args[f"inputs[{i}]"] = dtype
@@ -1488,7 +1488,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, nbins, dtype='int32'): def __init__(self, nbins, dtype='int32'):
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name) self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
validator.check_integer("nbins", nbins, 1, Rel.GE, self.name)
validator.check_int(nbins, 1, Rel.GE, "nbins", self.name)
valid_values = ['int32', 'int64'] valid_values = ['int32', 'int64']
self.dtype = validator.check_string(dtype, valid_values, "dtype", self.name) self.dtype = validator.check_string(dtype, valid_values, "dtype", self.name)
self.init_prim_io_names(inputs=['x', 'range'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'range'], outputs=['y'])
@@ -2810,8 +2810,8 @@ class NPUGetFloatStatus(PrimitiveWithInfer):


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
cls_name = self.name cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
validator.check_equal_int(len(x_shape), 1, "len(x_shape)", cls_name)
validator.check_equal_int(x_shape[0], 8, "x_shape[0]", cls_name)
return [8] return [8]


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
@@ -2853,8 +2853,8 @@ class NPUClearFloatStatus(PrimitiveWithInfer):


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
cls_name = self.name cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
validator.check_equal_int(len(x_shape), 1, "len(x_shape)", cls_name)
validator.check_equal_int(x_shape[0], 8, "x_shape[0]", cls_name)
return [8] return [8]


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
@@ -3023,9 +3023,9 @@ class NMSWithMask(PrimitiveWithInfer):


def infer_shape(self, bboxes_shape): def infer_shape(self, bboxes_shape):
cls_name = self.name cls_name = self.name
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
validator.check_equal_int(len(bboxes_shape), 2, "bboxes rank", cls_name)
validator.check_positive_int(bboxes_shape[0], "bboxes.shape[0]", cls_name) validator.check_positive_int(bboxes_shape[0], "bboxes.shape[0]", cls_name)
validator.check_integer("bboxes.shape[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
validator.check_equal_int(bboxes_shape[1], 5, "bboxes.shape[1]", cls_name)
num = bboxes_shape[0] num = bboxes_shape[0]
return (bboxes_shape, (num,), (num,)) return (bboxes_shape, (num,), (num,))


@@ -3572,11 +3572,11 @@ class IFMR(PrimitiveWithInfer):
validator.check_value_type("offset_flag", with_offset, [bool], self.name) validator.check_value_type("offset_flag", with_offset, [bool], self.name)


def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape): def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape):
validator.check_integer("dims of data_min", len(data_min_shape), 1, Rel.EQ, self.name)
validator.check_integer("data_min[0]", data_min_shape[0], 1, Rel.EQ, self.name)
validator.check_integer("dims of data_max", len(data_max_shape), 1, Rel.EQ, self.name)
validator.check_integer("data_max[0]", data_max_shape[0], 1, Rel.EQ, self.name)
validator.check_integer("dims of cumsum", len(cumsum_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name)
validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name)
validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name)
validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name)
validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name)
return (1,), (1,) return (1,), (1,)


def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):


+ 128
- 128
mindspore/ops/operations/nn_ops.py View File

@@ -98,7 +98,7 @@ class Flatten(PrimitiveWithInfer):
pass pass


def infer_shape(self, input_x): def infer_shape(self, input_x):
validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name)
validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:]) prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:])
return input_x[0], prod return input_x[0], prod


@@ -146,7 +146,7 @@ class Softmax(PrimitiveWithInfer):
validator.check_value_type("item of axis", item, [int], self.name) validator.check_value_type("item of axis", item, [int], self.name)


def infer_shape(self, logits): def infer_shape(self, logits):
validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name)
validator.check_int(len(self.axis), 1, Rel.GE, "length of axis", self.name)
rank = len(logits) rank = len(logits)
for axis_v in self.axis: for axis_v in self.axis:
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
@@ -636,7 +636,7 @@ class FusedBatchNorm(Primitive):
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name)
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True self._update_parameter = True
@@ -709,17 +709,17 @@ class FusedBatchNormEx(PrimitiveWithInfer):
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve']) outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve'])
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name)
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True self._update_parameter = True
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")


def infer_shape(self, input_x, scale, bias, mean, variance): def infer_shape(self, input_x, scale, bias, mean, variance):
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name)
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
return (input_x, scale, scale, scale, scale, scale) return (input_x, scale, scale, scale, scale, scale)
@@ -757,7 +757,7 @@ class BNTrainingReduce(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum']) self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum'])


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
return ([x_shape[1]], [x_shape[1]]) return ([x_shape[1]], [x_shape[1]])


def infer_dtype(self, x_type): def infer_dtype(self, x_type):
@@ -822,13 +822,13 @@ class BNTrainingUpdate(PrimitiveWithInfer):
self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate') self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate')


def infer_shape(self, x, sum, square_sum, scale, b, mean, variance): def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name)
validator.check_integer("sum rank", len(sum), 1, Rel.EQ, self.name)
validator.check_integer("square_sum rank", len(square_sum), 1, Rel.EQ, self.name)
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
validator.check_integer("b rank", len(b), 1, Rel.EQ, self.name)
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
validator.check_integer("variance rank", len(variance), 1, Rel.EQ, self.name)
validator.check_equal_int(len(x), 4, "x rank", self.name)
validator.check_equal_int(len(sum), 1, "sum rank", self.name)
validator.check_equal_int(len(square_sum), 1, "square_sum rank", self.name)
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
validator.check_equal_int(len(b), 1, "b rank", self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check_equal_int(len(variance), 1, "variance rank", self.name)
validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name) validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name) validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name)
validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name) validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name)
@@ -904,11 +904,11 @@ class BatchNorm(PrimitiveWithInfer):
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])


def infer_shape(self, input_x, scale, bias, mean, variance): def infer_shape(self, input_x, scale, bias, mean, variance):
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name)
if not self.is_training: if not self.is_training:
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
return (input_x, scale, scale, scale, scale) return (input_x, scale, scale, scale, scale)
@@ -1010,7 +1010,7 @@ class Conv2D(PrimitiveWithInfer):
if isinstance(pad, int): if isinstance(pad, int):
pad = (pad,) * 4 pad = (pad,) * 4
else: else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
self.padding = pad self.padding = pad
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)


@@ -1020,15 +1020,15 @@ class Conv2D(PrimitiveWithInfer):
for item in pad: for item in pad:
validator.check_non_negative_int(item, 'pad item', self.name) validator.check_non_negative_int(item, 'pad item', self.name)


self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.group = validator.check_positive_int(group, 'group', self.name) self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('offset_a', 0) self.add_prim_attr('offset_a', 0)


def infer_shape(self, x_shape, w_shape, b_shape=None): def infer_shape(self, x_shape, w_shape, b_shape=None):
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(w_shape), 4, "weight rank", self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
@@ -1150,7 +1150,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
if isinstance(pad, int): if isinstance(pad, int):
pad = (pad,) * 4 pad = (pad,) * 4
else: else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
self.padding = pad self.padding = pad
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0): if pad_mode != 'pad' and pad != (0, 0, 0, 0):
@@ -1158,15 +1158,15 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
if self.pad_mode == 'pad': if self.pad_mode == 'pad':
for item in pad: for item in pad:
validator.check_non_negative_int(item, 'pad item', self.name) validator.check_non_negative_int(item, 'pad item', self.name)
self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name)
self.mode = validator.check_equal_int(mode, 3, "mode", self.name)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name) self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name)
self.group = validator.check_positive_int(group, "group", self.name) self.group = validator.check_positive_int(group, "group", self.name)
self.add_prim_attr('offset_a', 0) self.add_prim_attr('offset_a', 0)


def infer_shape(self, x_shape, w_shape, b_shape=None): def infer_shape(self, x_shape, w_shape, b_shape=None):
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(w_shape), 4, "weight rank", self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)


@@ -1250,7 +1250,7 @@ class _Pool(PrimitiveWithInfer):
self.add_prim_attr("strides", self.strides) self.add_prim_attr("strides", self.strides)


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
batch, channel, input_h, input_w = x_shape batch, channel, input_h, input_w = x_shape
if self.is_maxpoolwithargmax: if self.is_maxpoolwithargmax:
_, kernel_h, kernel_w, _ = self.ksize _, kernel_h, kernel_w, _ = self.ksize
@@ -1536,7 +1536,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
if isinstance(pad, int): if isinstance(pad, int):
pad = (pad,) * 4 pad = (pad,) * 4
else: else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
self.padding = pad self.padding = pad
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0): if pad_mode != 'pad' and pad != (0, 0, 0, 0):
@@ -1547,7 +1547,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):


pad_mode = pad_mode.upper() pad_mode = pad_mode.upper()
self.add_prim_attr('pad_mode', pad_mode) self.add_prim_attr('pad_mode', pad_mode)
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.group = validator.check_positive_int(group, 'group', self.name) self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
if pad_list: if pad_list:
@@ -1624,8 +1624,8 @@ class BiasAdd(PrimitiveWithInfer):
self.add_prim_attr('data_format', 'NCHW') self.add_prim_attr('data_format', 'NCHW')


def infer_shape(self, x_shape, b_shape): def infer_shape(self, x_shape, b_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
validator.check_integer("bias rank", len(b_shape), 1, Rel.EQ, self.name)
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
validator.check_equal_int(len(b_shape), 1, "bias rank", self.name)
validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_shape[1], Rel.EQ, self.name) validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_shape[1], Rel.EQ, self.name)
return x_shape return x_shape


@@ -2007,10 +2007,10 @@ class RNNTLoss(PrimitiveWithInfer):
outputs=['costs', 'grads']) outputs=['costs', 'grads'])


def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape): def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape):
validator.check_integer('acts_rank', len(acts_shape), 4, Rel.EQ, self.name)
validator.check_integer('labels_rank', len(labels_shape), 2, Rel.EQ, self.name)
validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name)
validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(acts_shape), 4, 'acts_rank', self.name)
validator.check_equal_int(len(labels_shape), 2, 'labels_rank', self.name)
validator.check_equal_int(len(input_length_shape), 1, 'input_length_rank', self.name)
validator.check_equal_int(len(label_length_shape), 1, 'label_length_rank', self.name)
validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name) validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name)
validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
@@ -2080,11 +2080,11 @@ class SGD(PrimitiveWithInfer):
def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape, def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
accum_shape, momentum_shape, stat_shape): accum_shape, momentum_shape, stat_shape):
validator.check_positive_int(len(parameters_shape), "parameters rank", self.name) validator.check_positive_int(len(parameters_shape), "parameters rank", self.name)
validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name)
validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name)
validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name)
validator.check_int(len(learning_rate_shape), 0, Rel.GE, f'learning rate rank', self.name)
validator.check_positive_int(len(accum_shape), "accumulation rank", self.name) validator.check_positive_int(len(accum_shape), "accumulation rank", self.name)
validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name)
validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name)
validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name)
validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name)
validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name) validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name)
return parameters_shape return parameters_shape


@@ -2780,17 +2780,17 @@ class LSTM(PrimitiveWithInfer):


def infer_shape(self, x_shape, h_shape, c_shape, w_shape): def infer_shape(self, x_shape, h_shape, c_shape, w_shape):
# (seq, batch_size, feature) # (seq, batch_size, feature)
validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name)
validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 3, "x rank", self.name)
validator.check_equal_int(x_shape[2], self.input_size, "x[2]", self.name)


# h and c should be same shape # h and c should be same shape
validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name)
validator.check_equal_int(len(h_shape), 3, "h rank", self.name)
validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name) validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name)


# (num_layers * num_directions, batch, hidden_size) # (num_layers * num_directions, batch, hidden_size)
validator.check_integer("h[0]", h_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
validator.check_integer("h[1]", h_shape[1], x_shape[1], Rel.EQ, self.name)
validator.check_integer("h[2]", h_shape[2], self.hidden_size, Rel.EQ, self.name)
validator.check_int(h_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h[0]", self.name)
validator.check_equal_int(h_shape[1], x_shape[1], "h[1]", self.name)
validator.check_int(h_shape[2], self.hidden_size, Rel.EQ, "h[2]", self.name)


y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions) y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions)


@@ -2918,7 +2918,7 @@ class Pad(PrimitiveWithInfer):


def infer_shape(self, x): def infer_shape(self, x):
paddings = np.array(self.paddings) paddings = np.array(self.paddings)
validator.check_integer('paddings.shape', paddings.size, len(x) * 2, Rel.EQ, self.name)
validator.check_int(paddings.size, len(x) * 2, Rel.EQ, 'paddings.shape', self.name)
if not np.all(paddings >= 0): if not np.all(paddings >= 0):
raise ValueError('All elements of paddings must be >= 0.') raise ValueError('All elements of paddings must be >= 0.')
y_shape = () y_shape = ()
@@ -2992,7 +2992,7 @@ class MirrorPad(PrimitiveWithInfer):
x_shape = list(input_x['shape']) x_shape = list(input_x['shape'])
paddings_value = paddings['value'].asnumpy() paddings_value = paddings['value'].asnumpy()
paddings_size = paddings_value.size paddings_size = paddings_value.size
validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ, self.name)
validator.check_int(paddings_size, len(x_shape) * 2, Rel.EQ, 'paddings.shape', self.name)
if not np.all(paddings_value >= 0): if not np.all(paddings_value >= 0):
raise ValueError('All elements of paddings must be >= 0.') raise ValueError('All elements of paddings must be >= 0.')
adjust = 0 adjust = 0
@@ -3276,7 +3276,7 @@ class FusedSparseAdam(PrimitiveWithInfer):
beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape): beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]: if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]:
raise ValueError(f"For '{self.name}', the shape of updates should be [] or " raise ValueError(f"For '{self.name}', the shape of updates should be [] or "
@@ -3409,7 +3409,7 @@ class FusedSparseLazyAdam(PrimitiveWithInfer):
beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape): beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]: if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]:
raise ValueError(f"For '{self.name}', the shape of updates should be [] or " raise ValueError(f"For '{self.name}', the shape of updates should be [] or "
@@ -3513,7 +3513,7 @@ class FusedSparseFtrl(PrimitiveWithInfer):
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if len(var_shape) > 1: if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return [1], [1], [1] return [1], [1], [1]


@@ -3602,7 +3602,7 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)


def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
return [1], [1] return [1], [1]


def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
@@ -3869,25 +3869,25 @@ class ApplyAdaMax(PrimitiveWithInfer):
validator.check("v_shape", v_shape, "var_shape", var_shape, Rel.EQ, self.name) validator.check("v_shape", v_shape, "var_shape", var_shape, Rel.EQ, self.name)
validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name) validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name)
beta1_power_shp_len = len(beta1_power_shape) beta1_power_shp_len = len(beta1_power_shape)
validator.check_integer("beta1 power's rank", beta1_power_shp_len, 1, Rel.LE, self.name)
validator.check_int(beta1_power_shp_len, 1, Rel.LE, "beta1 power's rank", self.name)
if beta1_power_shp_len == 1: if beta1_power_shp_len == 1:
validator.check_integer("beta1_power_shape[0]", beta1_power_shape[0], 1, Rel.EQ, self.name)
validator.check_int(beta1_power_shape[0], 1, Rel.EQ, "beta1_power_shape[0]", self.name)
lr_shp_len = len(lr_shape) lr_shp_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name)
validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shp_len == 1: if lr_shp_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
beta1_shp_len = len(beta1_shape) beta1_shp_len = len(beta1_shape)
validator.check_integer("beta1's rank", beta1_shp_len, 1, Rel.LE, self.name)
validator.check_int(beta1_shp_len, 1, Rel.LE, "beta1's rank", self.name)
if beta1_shp_len == 1: if beta1_shp_len == 1:
validator.check_integer("beta1_shape[0]", beta1_shape[0], 1, Rel.EQ, self.name)
validator.check_int(beta1_shape[0], 1, Rel.EQ, "beta1_shape[0]", self.name)
beta2_shp_len = len(beta2_shape) beta2_shp_len = len(beta2_shape)
validator.check_integer("beta2's rank", beta2_shp_len, 1, Rel.LE, self.name)
validator.check_int(beta2_shp_len, 1, Rel.LE, "beta2's rank", self.name)
if beta2_shp_len == 1: if beta2_shp_len == 1:
validator.check_integer("beta2_shape[0]", beta2_shape[0], 1, Rel.EQ, self.name)
validator.check_int(beta2_shape[0], 1, Rel.EQ, "beta2_shape[0]", self.name)
epsilon_shp_len = len(epsilon_shape) epsilon_shp_len = len(epsilon_shape)
validator.check_integer("epsilon's rank", epsilon_shp_len, 1, Rel.LE, self.name)
validator.check_int(epsilon_shp_len, 1, Rel.LE, "epsilon's rank", self.name)
if epsilon_shp_len == 1: if epsilon_shp_len == 1:
validator.check_integer("epsilon_shape[0]", epsilon_shape[0], 1, Rel.EQ, self.name)
validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name)
return var_shape, m_shape, v_shape return var_shape, m_shape, v_shape


def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype, def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype,
@@ -3985,17 +3985,17 @@ class ApplyAdadelta(PrimitiveWithInfer):
validator.check("accum_update_shape", accum_update_shape, "var_shape", var_shape, Rel.EQ, self.name) validator.check("accum_update_shape", accum_update_shape, "var_shape", var_shape, Rel.EQ, self.name)
validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name) validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name)
lr_shp_len = len(lr_shape) lr_shp_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name)
validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shp_len == 1: if lr_shp_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
rho_shp_len = len(rho_shape) rho_shp_len = len(rho_shape)
validator.check_integer("rho's rank", rho_shp_len, 1, Rel.LE, self.name)
validator.check_int(rho_shp_len, 1, Rel.LE, "rho's rank", self.name)
if rho_shp_len == 1: if rho_shp_len == 1:
validator.check_integer("rho_shape[0]", rho_shape[0], 1, Rel.EQ, self.name)
validator.check_int(rho_shape[0], 1, Rel.EQ, "rho_shape[0]", self.name)
epsilon_shp_len = len(epsilon_shape) epsilon_shp_len = len(epsilon_shape)
validator.check_integer("lepsilon's rank", epsilon_shp_len, 1, Rel.LE, self.name)
validator.check_int(epsilon_shp_len, 1, Rel.LE, "lepsilon's rank", self.name)
if epsilon_shp_len == 1: if epsilon_shp_len == 1:
validator.check_integer("epsilon_shape[0]", epsilon_shape[0], 1, Rel.EQ, self.name)
validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name)
return var_shape, accum_shape, accum_update_shape return var_shape, accum_shape, accum_update_shape


def infer_dtype(self, var_dtype, accum_dtype, accum_update_dtype, lr_dtype, rho_dtype, def infer_dtype(self, var_dtype, accum_dtype, accum_update_dtype, lr_dtype, rho_dtype,
@@ -4077,9 +4077,9 @@ class ApplyAdagrad(PrimitiveWithInfer):
validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name) validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name)
validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name) validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name)
lr_shp_len = len(lr_shape) lr_shp_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name)
validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shp_len == 1: if lr_shp_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
return var_shape, accum_shape return var_shape, accum_shape


def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
@@ -4161,9 +4161,9 @@ class ApplyAdagradV2(PrimitiveWithInfer):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name) validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name)
lr_shp_len = len(lr_shape) lr_shp_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name)
validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shp_len == 1: if lr_shp_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
return var_shape, accum_shape return var_shape, accum_shape


def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
@@ -4249,7 +4249,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name)
if len(var_shape) > 1: if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape, accum_shape return var_shape, accum_shape


@@ -4338,7 +4338,7 @@ class SparseApplyAdagradV2(PrimitiveWithInfer):
validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name)
if len(var_shape) > 1: if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape, accum_shape return var_shape, accum_shape


@@ -4428,17 +4428,17 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name) validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name)
validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name) validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name)
lr_shp_len = len(lr_shape) lr_shp_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name)
validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shp_len == 1: if lr_shp_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
l1_shp_len = len(l1_shape) l1_shp_len = len(l1_shape)
validator.check_integer("l1's rank", l1_shp_len, 1, Rel.LE, self.name)
validator.check_int(l1_shp_len, 1, Rel.LE, "l1's rank", self.name)
if l1_shp_len == 1: if l1_shp_len == 1:
validator.check_integer("l1_shape[0]", l1_shape[0], 1, Rel.EQ, self.name)
validator.check_int(l1_shape[0], 1, Rel.EQ, "l1_shape[0]", self.name)
l2_shp_len = len(l2_shape) l2_shp_len = len(l2_shape)
validator.check_integer("l2's rank", l2_shp_len, 1, Rel.LE, self.name)
validator.check_int(l2_shp_len, 1, Rel.LE, "l2's rank", self.name)
if l2_shp_len == 1: if l2_shp_len == 1:
validator.check_integer("l2_shape[0]", l2_shape[0], 1, Rel.EQ, self.name)
validator.check_int(l2_shape[0], 1, Rel.EQ, "l2_shape[0]", self.name)
return var_shape, accum_shape return var_shape, accum_shape


def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype):
@@ -4532,7 +4532,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck):
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)


def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)


def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
@@ -4623,21 +4623,21 @@ class ApplyAddSign(PrimitiveWithInfer):
validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name) validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name)
validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name) validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name)
lr_shape_len = len(lr_shape) lr_shape_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shape_len, 1, Rel.LE, self.name)
validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shape_len == 1: if lr_shape_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
alpha_shape_len = len(alpha_shape) alpha_shape_len = len(alpha_shape)
validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name)
validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name)
if alpha_shape_len == 1: if alpha_shape_len == 1:
validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name)
validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name)
sign_decay_shape_len = len(sign_decay_shape) sign_decay_shape_len = len(sign_decay_shape)
validator.check_integer("sign_decay's rank", sign_decay_shape_len, 1, Rel.LE, self.name)
validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name)
if sign_decay_shape_len == 1: if sign_decay_shape_len == 1:
validator.check_integer("sign_decay_shape[0]", sign_decay_shape[0], 1, Rel.EQ, self.name)
validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name)
beta_shape_len = len(beta_shape) beta_shape_len = len(beta_shape)
validator.check_integer("beta's rank", beta_shape_len, 1, Rel.LE, self.name)
validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name)
if beta_shape_len == 1: if beta_shape_len == 1:
validator.check_integer("beta_shape[0]", beta_shape[0], 1, Rel.EQ, self.name)
validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name)
return var_shape, m_shape return var_shape, m_shape


def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, beta_dtype, grad_dtype): def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, beta_dtype, grad_dtype):
@@ -4732,21 +4732,21 @@ class ApplyPowerSign(PrimitiveWithInfer):
validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name) validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name)
validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name) validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name)
lr_shape_len = len(lr_shape) lr_shape_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shape_len, 1, Rel.LE, self.name)
validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shape_len == 1: if lr_shape_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
logbase_shape_len = len(logbase_shape) logbase_shape_len = len(logbase_shape)
validator.check_integer("logbase's rank", logbase_shape_len, 1, Rel.LE, self.name)
validator.check_int(logbase_shape_len, 1, Rel.LE, "logbase's rank", self.name)
if logbase_shape_len == 1: if logbase_shape_len == 1:
validator.check_integer("logbase_shape[0]", logbase_shape[0], 1, Rel.EQ, self.name)
validator.check_int(logbase_shape[0], 1, Rel.EQ, "logbase_shape[0]", self.name)
sign_decay_shape_len = len(sign_decay_shape) sign_decay_shape_len = len(sign_decay_shape)
validator.check_integer("sign_decay's rank", sign_decay_shape_len, 1, Rel.LE, self.name)
validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name)
if sign_decay_shape_len == 1: if sign_decay_shape_len == 1:
validator.check_integer("sign_decay_shape[0]", sign_decay_shape[0], 1, Rel.EQ, self.name)
validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name)
beta_shape_len = len(beta_shape) beta_shape_len = len(beta_shape)
validator.check_integer("beta's rank", beta_shape_len, 1, Rel.LE, self.name)
validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name)
if beta_shape_len == 1: if beta_shape_len == 1:
validator.check_integer("beta_shape[0]", beta_shape[0], 1, Rel.EQ, self.name)
validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name)
return var_shape, m_shape return var_shape, m_shape


def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, beta_dtype, grad_dtype): def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, beta_dtype, grad_dtype):
@@ -4812,9 +4812,9 @@ class ApplyGradientDescent(PrimitiveWithInfer):
def infer_shape(self, var_shape, alpha_shape, delta_shape): def infer_shape(self, var_shape, alpha_shape, delta_shape):
validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name) validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name)
alpha_shape_len = len(alpha_shape) alpha_shape_len = len(alpha_shape)
validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name)
validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name)
if alpha_shape_len == 1: if alpha_shape_len == 1:
validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name)
validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name)
return var_shape return var_shape


def infer_dtype(self, var_dtype, alpha_dtype, delta_dtype): def infer_dtype(self, var_dtype, alpha_dtype, delta_dtype):
@@ -4887,17 +4887,17 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer):
def infer_shape(self, var_shape, alpha_shape, l1_shape, l2_shape, delta_shape): def infer_shape(self, var_shape, alpha_shape, l1_shape, l2_shape, delta_shape):
validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name) validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name)
alpha_shape_len = len(alpha_shape) alpha_shape_len = len(alpha_shape)
validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name)
validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name)
if alpha_shape_len == 1: if alpha_shape_len == 1:
validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name)
validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name)
l1_shape_len = len(l1_shape) l1_shape_len = len(l1_shape)
validator.check_integer("l1's rank", l1_shape_len, 1, Rel.LE, self.name)
validator.check_int(l1_shape_len, 1, Rel.LE, "l1's rank", self.name)
if l1_shape_len == 1: if l1_shape_len == 1:
validator.check_integer("l1_shape[0]", l1_shape[0], 1, Rel.EQ, self.name)
validator.check_int(l1_shape[0], 1, Rel.EQ, "l1_shape[0]", self.name)
l2_shape_len = len(l2_shape) l2_shape_len = len(l2_shape)
validator.check_integer("l2's rank", l2_shape_len, 1, Rel.LE, self.name)
validator.check_int(l2_shape_len, 1, Rel.LE, "l2's rank", self.name)
if l2_shape_len == 1: if l2_shape_len == 1:
validator.check_integer("l2_shape[0]", l2_shape[0], 1, Rel.EQ, self.name)
validator.check_int(l2_shape[0], 1, Rel.EQ, "l2_shape[0]", self.name)
return var_shape return var_shape


def infer_dtype(self, var_dtype, alpha_dtype, l1_dtype, l2_dtype, delta_dtype): def infer_dtype(self, var_dtype, alpha_dtype, l1_dtype, l2_dtype, delta_dtype):
@@ -4965,13 +4965,13 @@ class LARSUpdate(PrimitiveWithInfer):
validator.check("norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape, Rel.EQ, validator.check("norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape, Rel.EQ,
self.name) self.name)
shp_len = len(weight_decay_shape) shp_len = len(weight_decay_shape)
validator.check_integer("weight decay's rank", shp_len, 1, Rel.LE, self.name)
validator.check_int(shp_len, 1, Rel.LE, "weight decay's rank", self.name)
if shp_len == 1: if shp_len == 1:
validator.check_integer("weight_decay_shape[0]", weight_decay_shape[0], 1, Rel.EQ, self.name)
validator.check_int(weight_decay_shape[0], 1, Rel.EQ, "weight_decay_shape[0]", self.name)
shp_len = len(learning_rate_shape) shp_len = len(learning_rate_shape)
validator.check_integer("learning rate's rank", shp_len, 1, Rel.LE, self.name)
validator.check_int(shp_len, 1, Rel.LE, "learning rate's rank", self.name)
if shp_len == 1: if shp_len == 1:
validator.check_integer("learning_rate_shape[0]", learning_rate_shape[0], 1, Rel.EQ, self.name)
validator.check_int(learning_rate_shape[0], 1, Rel.EQ, "learning_rate_shape[0]", self.name)
return weight_shape return weight_shape


def infer_dtype(self, weight_dtype, gradient_dtype, norm_weight_dtype, norm_gradient_dtype, def infer_dtype(self, weight_dtype, gradient_dtype, norm_weight_dtype, norm_gradient_dtype,
@@ -5155,7 +5155,7 @@ class SparseApplyFtrl(PrimitiveWithCheck):
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if len(var_shape) > 1: if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)


def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
@@ -5251,7 +5251,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if len(var_shape) > 1: if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape, accum_shape, linear_shape return var_shape, accum_shape, linear_shape


@@ -5288,7 +5288,7 @@ class Dropout(PrimitiveWithInfer):
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x_shape", self.name)
mask_shape = x_shape mask_shape = x_shape
return x_shape, mask_shape return x_shape, mask_shape


@@ -5352,11 +5352,11 @@ class CTCLoss(PrimitiveWithInfer):
self.ignore_longer_outputs_than_inputs_ = ignore_longer_outputs_than_inputs self.ignore_longer_outputs_than_inputs_ = ignore_longer_outputs_than_inputs


def infer_shape(self, inputs, labels_indices, labels_values, sequence_length): def infer_shape(self, inputs, labels_indices, labels_values, sequence_length):
validator.check_integer("inputs rank", len(inputs), 3, Rel.EQ, self.name)
validator.check_integer("labels_indices rank", len(labels_indices), 2, Rel.EQ, self.name)
validator.check_integer("labels_indices dim one", labels_indices[1], 2, Rel.EQ, self.name)
validator.check_integer("labels_values rank", len(labels_values), 1, Rel.EQ, self.name)
validator.check_integer("sequence_length rank", len(sequence_length), 1, Rel.EQ, self.name)
validator.check_int(len(inputs), 3, Rel.EQ, "inputs rank", self.name)
validator.check_int(len(labels_indices), 2, Rel.EQ, "labels_indices rank", self.name)
validator.check_int(labels_indices[1], 2, Rel.EQ, "labels_indices dim one", self.name)
validator.check_int(len(labels_values), 1, Rel.EQ, "labels_values rank", self.name)
validator.check_int(len(sequence_length), 1, Rel.EQ, "sequence_length rank", self.name)
validator.check('labels_indices size', labels_indices[0], 'labels_values size', validator.check('labels_indices size', labels_indices[0], 'labels_values size',
labels_values[0], Rel.EQ, self.name) labels_values[0], Rel.EQ, self.name)
validator.check('inputs batch_size', inputs[1], 'sequence_length batch_size', validator.check('inputs batch_size', inputs[1], 'sequence_length batch_size',
@@ -5422,8 +5422,8 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name) self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name)


def infer_shape(self, inputs_shape, sequence_length_shape): def infer_shape(self, inputs_shape, sequence_length_shape):
validator.check_integer("inputs rank", len(inputs_shape), 3, Rel.EQ, self.name)
validator.check_integer("sequence_length rank", len(sequence_length_shape), 1, Rel.EQ, self.name)
validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name)
validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name)
validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size', validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size',
sequence_length_shape[0], Rel.EQ, self.name) sequence_length_shape[0], Rel.EQ, self.name)
total_decoded_outputs = -1 total_decoded_outputs = -1
@@ -5517,11 +5517,11 @@ class BasicLSTMCell(PrimitiveWithInfer):
self.add_prim_attr("io_format", "ND") self.add_prim_attr("io_format", "ND")


def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
validator.check_integer("h rank", len(h_shape), 2, Rel.EQ, self.name)
validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
validator.check_int(len(x_shape), 2, Rel.EQ, "x rank", self.name)
validator.check_int(len(h_shape), 2, Rel.EQ, "h rank", self.name)
validator.check_int(len(c_shape), 2, Rel.EQ, "c rank", self.name)
validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name)
validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name)
validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name) validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name)
@@ -5637,11 +5637,11 @@ class DynamicRNN(PrimitiveWithInfer):
self.add_prim_attr("io_format", "ND") self.add_prim_attr("io_format", "ND")


def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape): def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ, self.name)
validator.check_integer("c_shape", len(c_shape), 3, Rel.EQ, self.name)
validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name)
validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name)
validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name)
validator.check_int(len(h_shape), 3, Rel.EQ, "h_shape", self.name)
validator.check_int(len(c_shape), 3, Rel.EQ, "c_shape", self.name)
if seq_shape is not None: if seq_shape is not None:
raise ValueError(f"For {self.name}, seq_shape should be None.") raise ValueError(f"For {self.name}, seq_shape should be None.")


@@ -5654,7 +5654,7 @@ class DynamicRNN(PrimitiveWithInfer):
validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size", validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
input_size + hidden_size, Rel.EQ, self.name) input_size + hidden_size, Rel.EQ, self.name)
validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name) validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[0]", h_shape[0], 1, Rel.EQ, self.name)
validator.check_int(h_shape[0], 1, Rel.EQ, "h_shape[0]", self.name)
validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name) validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name)
validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name) validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name)
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name) validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
@@ -5754,5 +5754,5 @@ class LRN(PrimitiveWithInfer):
return x_dtype return x_dtype


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name)
validator.check_int(len(x_shape), 4, Rel.EQ, "x_shape", self.name)
return x_shape return x_shape

+ 13
- 13
mindspore/ops/operations/other_ops.py View File

@@ -98,16 +98,16 @@ class BoundingBoxEncode(PrimitiveWithInfer):
validator.check_value_type("means[%d]" % i, value, [float], self.name) validator.check_value_type("means[%d]" % i, value, [float], self.name)
for i, value in enumerate(stds): for i, value in enumerate(stds):
validator.check_value_type("stds[%d]" % i, value, [float], self.name) validator.check_value_type("stds[%d]" % i, value, [float], self.name)
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
validator.check_equal_int(len(means), 4, "means len", self.name)
validator.check_equal_int(len(stds), 4, "stds len", self.name)


def infer_shape(self, anchor_box, groundtruth_box): def infer_shape(self, anchor_box, groundtruth_box):
validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ, validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ,
self.name) self.name)
validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name)
validator.check("groundtruth_box rank", len(groundtruth_box), "", 2, Rel.EQ, self.name) validator.check("groundtruth_box rank", len(groundtruth_box), "", 2, Rel.EQ, self.name)
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name)
validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name)
validator.check_equal_int(groundtruth_box[1], 4, 'groundtruth_box shape[1]', self.name)
return anchor_box return anchor_box


def infer_dtype(self, anchor_box, groundtruth_box): def infer_dtype(self, anchor_box, groundtruth_box):
@@ -153,18 +153,18 @@ class BoundingBoxDecode(PrimitiveWithInfer):
for i, value in enumerate(stds): for i, value in enumerate(stds):
validator.check_value_type("stds[%d]" % i, value, [float], self.name) validator.check_value_type("stds[%d]" % i, value, [float], self.name)
validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name) validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name)
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
validator.check_equal_int(len(means), 4, "means len", self.name)
validator.check_equal_int(len(stds), 4, "stds len", self.name)
if max_shape is not None: if max_shape is not None:
validator.check_value_type('max_shape', max_shape, [tuple], self.name) validator.check_value_type('max_shape', max_shape, [tuple], self.name)
validator.check_integer("max_shape len", len(max_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(max_shape), 2, "max_shape len", self.name)


def infer_shape(self, anchor_box, deltas): def infer_shape(self, anchor_box, deltas):
validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name) validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name)
validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name)
validator.check("deltas rank", len(deltas), "", 2, Rel.EQ, self.name) validator.check("deltas rank", len(deltas), "", 2, Rel.EQ, self.name)
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name)
validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name)
validator.check_equal_int(deltas[1], 4, 'deltas shape[1]', self.name)
return anchor_box return anchor_box


def infer_dtype(self, anchor_box, deltas): def infer_dtype(self, anchor_box, deltas):
@@ -272,10 +272,10 @@ class IOU(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap']) self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])


def infer_shape(self, anchor_boxes, gt_boxes): def infer_shape(self, anchor_boxes, gt_boxes):
validator.check_integer('gt_boxes shape[1]', gt_boxes[1], 4, Rel.EQ, self.name)
validator.check_integer('anchor_boxes shape[1]', anchor_boxes[1], 4, Rel.EQ, self.name)
validator.check_integer('anchor_boxes rank', len(anchor_boxes), 2, Rel.EQ, self.name)
validator.check_integer('gt_boxes rank', len(gt_boxes), 2, Rel.EQ, self.name)
validator.check_equal_int(gt_boxes[1], 4, 'gt_boxes shape[1]', self.name)
validator.check_equal_int(anchor_boxes[1], 4, 'anchor_boxes shape[1]', self.name)
validator.check_equal_int(len(anchor_boxes), 2, 'anchor_boxes rank', self.name)
validator.check_equal_int(len(gt_boxes), 2, 'gt_boxes rank', self.name)
iou = [gt_boxes[0], anchor_boxes[0]] iou = [gt_boxes[0], anchor_boxes[0]]
return iou return iou




+ 2
- 2
mindspore/ops/operations/random_ops.py View File

@@ -356,8 +356,8 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
Validator.check_value_type('seed2', seed2, [int], self.name) Validator.check_value_type('seed2', seed2, [int], self.name)


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
Validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
Validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name)
Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name)
Validator.check_int(len(x_shape), 5, Rel.LE, "input_x rank", self.name)
return ([self.count, len(x_shape)], [self.count]) return ([self.count, len(x_shape)], [self.count])


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):


+ 1
- 1
mindspore/ops/primitive.py View File

@@ -227,7 +227,7 @@ class PrimitiveWithCheck(Primitive):
>>> def __init__(self): >>> def __init__(self):
>>> pass >>> pass
>>> def check_shape(self, input_x): >>> def check_shape(self, input_x):
>>> validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name)
>>> validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
>>> >>>
>>> def check_dtype(self, input_x): >>> def check_dtype(self, input_x):
>>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name) >>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name)


+ 5
- 5
mindspore/train/quant/quant.py View File

@@ -89,12 +89,12 @@ class ConvertToQuantNetwork:


def __init__(self, **kwargs): def __init__(self, **kwargs):
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay")
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay")
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold") self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn")
self.weight_bits = Validator.check_non_negative_int(kwargs["num_bits"][0], "weights bit")
self.act_bits = Validator.check_int(kwargs["num_bits"][-1], 0, Rel.GE, "activations bit")
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel") self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel") self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric") self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")


+ 3
- 3
mindspore/train/summary/_summary_adapter.py View File

@@ -21,7 +21,7 @@ from PIL import Image


from mindspore import log as logger from mindspore import log as logger


from ..._checkparam import _check_str_by_regular
from ..._checkparam import Validator
from ..anf_ir_pb2 import DataType, ModelProto from ..anf_ir_pb2 import DataType, ModelProto
from ..summary_pb2 import Event from ..summary_pb2 import Event


@@ -47,8 +47,8 @@ def get_event_file_name(prefix, suffix):
Returns: Returns:
String, the name of event log file. String, the name of event log file.
""" """
_check_str_by_regular(prefix)
_check_str_by_regular(suffix)
Validator.check_str_by_regular(prefix)
Validator.check_str_by_regular(suffix)
file_name = "" file_name = ""
time_second = str(int(time.time())) time_second = str(int(time.time()))
hostname = platform.node() hostname = platform.node()


+ 3
- 3
mindspore/train/summary/summary_record.py View File

@@ -21,7 +21,7 @@ import threading
from mindspore import log as logger from mindspore import log as logger


from ..._c_expression import Tensor from ..._c_expression import Tensor
from ..._checkparam import _check_str_by_regular
from ..._checkparam import Validator
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory from .._utils import _check_lineage_value, _check_to_numpy, _make_directory
from ._summary_adapter import get_event_file_name, package_graph_event from ._summary_adapter import get_event_file_name, package_graph_event
from ._writer_pool import WriterPool from ._writer_pool import WriterPool
@@ -103,8 +103,8 @@ class SummaryRecord:
self._closed, self._event_writer = False, None self._closed, self._event_writer = False, None
self._mode, self._data_pool = 'train', _dictlist() self._mode, self._data_pool = 'train', _dictlist()


_check_str_by_regular(file_prefix)
_check_str_by_regular(file_suffix)
Validator.check_str_by_regular(file_prefix)
Validator.check_str_by_regular(file_suffix)


self.log_path = _make_directory(log_dir) self.log_path = _make_directory(log_dir)




+ 2
- 2
tests/ut/python/nn/test_checkparameter.py View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" test checkparameter """
""" test check parameter """
import pytest import pytest
import numpy as np import numpy as np
from mindspore._checkparam import twice, Validator
from mindspore._checkparam import Validator, twice


kernel_size = 5 kernel_size = 5
kernel_size1 = twice(kernel_size) kernel_size1 = twice(kernel_size)


+ 7
- 7
tests/ut/python/nn/test_parameter.py View File

@@ -18,7 +18,7 @@ import numpy as np
import pytest import pytest


from mindspore import context, Tensor, Parameter, ParameterTuple, nn from mindspore import context, Tensor, Parameter, ParameterTuple, nn
from mindspore._checkparam import _check_str_by_regular
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer


@@ -124,15 +124,15 @@ def test_check_str_by_regular():
str4 = ".12_sf.asdf" str4 = ".12_sf.asdf"
str5 = "12_sf.a$sdf." str5 = "12_sf.a$sdf."
str6 = "12+sf.asdf" str6 = "12+sf.asdf"
_check_str_by_regular(str1)
_check_str_by_regular(str2)
_check_str_by_regular(str3)
Validator.check_str_by_regular(str1)
Validator.check_str_by_regular(str2)
Validator.check_str_by_regular(str3)
with pytest.raises(ValueError): with pytest.raises(ValueError):
_check_str_by_regular(str4)
Validator.check_str_by_regular(str4)
with pytest.raises(ValueError): with pytest.raises(ValueError):
_check_str_by_regular(str5)
Validator.check_str_by_regular(str5)
with pytest.raises(ValueError): with pytest.raises(ValueError):
_check_str_by_regular(str6)
Validator.check_str_by_regular(str6)


def test_parameter_compute(): def test_parameter_compute():
para_1 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test1') para_1 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test1')


Loading…
Cancel
Save