Merge pull request !7388 from chenzhongming/zomi_mastertags/v1.1.0
| @@ -97,7 +97,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N | |||
| Check argument integer. | |||
| Usage: | |||
| - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 | |||
| - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0 | |||
| """ | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) | |||
| @@ -166,12 +166,12 @@ class Validator: | |||
| return arg_value | |||
| @staticmethod | |||
| def check_integer(arg_name, arg_value, value, rel, prim_name=None): | |||
| def check_int(arg_value, value, rel, arg_name=None, prim_name=None): | |||
| """ | |||
| Checks input integer value `arg_value` compare to `value`. | |||
| Usage: | |||
| - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 | |||
| - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0 | |||
| """ | |||
| return check_number(arg_value, value, rel, int, arg_name, prim_name) | |||
| @@ -187,6 +187,16 @@ class Validator: | |||
| """ | |||
| return check_is_number(arg_value, int, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_equal_int(arg_value, value, arg_name=None, prim_name=None): | |||
| """ | |||
| Checks input integer value `arg_value` compare to `value`. | |||
| Usage: | |||
| - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0 | |||
| """ | |||
| return check_number(arg_value, value, Rel.EQ, int, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_positive_int(arg_value, arg_name=None, prim_name=None): | |||
| """ | |||
| @@ -365,6 +375,17 @@ class Validator: | |||
| raise ValueError(f'{msg_prefix} `{arg_name}` should be str and must be in `{valid_values}`,' | |||
| f' but got `{arg_value}`.') | |||
| @staticmethod | |||
| def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): | |||
| if reg is None: | |||
| # Named string regular expression | |||
| reg = r"^\w+[0-9a-zA-Z\_\.]*$" | |||
| if re.match(reg, target, flag) is None: | |||
| prim_name = f'in `{prim_name}`' if prim_name else "" | |||
| raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format( | |||
| target, prim_name, reg, flag)) | |||
| return True | |||
| @staticmethod | |||
| def check_pad_value_by_mode(pad_mode, padding, prim_name): | |||
| """Validates value of padding according to pad_mode""" | |||
| @@ -530,13 +551,6 @@ class Validator: | |||
| f'{tuple(exp_shape)}, but got {shape}.') | |||
| def check_int_zero_one(input_param): | |||
| """Judge whether it is 0 or 1.""" | |||
| if input_param in (0, 1): | |||
| return input_param | |||
| raise ValueError("The data must be 0 or 1.") | |||
| def check_input_format(input_param): | |||
| """Judge input format.""" | |||
| if input_param == "NCHW": | |||
| @@ -544,27 +558,6 @@ def check_input_format(input_param): | |||
| raise ValueError("The data format must be NCHW.") | |||
| def check_padding(padding): | |||
| """Check padding.""" | |||
| if padding >= 0: | |||
| return padding | |||
| raise ValueError("The padding must be at least 0,"" but got padding {}.".format(padding)) | |||
| def check_padmode(mode): | |||
| """Check padmode.""" | |||
| if mode in ("same", "valid", "pad"): | |||
| return mode | |||
| raise ValueError("The pad mode must be same or valid or pad,"" but got mode {}.".format(mode)) | |||
| def check_tensor_supported_type(dtype): | |||
| """Check tensor dtype.""" | |||
| if dtype in (mstype.int32, mstype.float32): | |||
| return dtype | |||
| raise ValueError("The dtype must be mstype.int32 or mstype.float32, but got mstype {}.".format(dtype)) | |||
| def _expand_tuple(n_dimensions): | |||
| """To expand a number to tuple.""" | |||
| @@ -673,42 +666,6 @@ def check_typename(arg_name, arg_type, valid_types): | |||
| f' but got {get_typename(arg_type)}.') | |||
| def check_shape(arg_name, arg_value): | |||
| """Check shape.""" | |||
| # First, check if shape is a tuple | |||
| if not isinstance(arg_value, tuple): | |||
| raise TypeError(f'The type of `{arg_name}` should be one of {tuple.__name__},' | |||
| f' but got {type(arg_value).__name__}.') | |||
| # Second, wrap arg_value with numpy array so that it can be checked through numpy api | |||
| arg_value = np.array(arg_value) | |||
| # shape can not be () | |||
| if arg_value.size == 0: | |||
| raise ValueError('Shape can not be empty.') | |||
| # shape's dimension should be 1 | |||
| if arg_value.ndim != 1: | |||
| raise ValueError('Shape of tensor should be 1-dim vector, but got {}-dim.'.format(arg_value.ndim)) | |||
| # Thirdly, check each element's type of the shape | |||
| valid_types = (int, np.int8, np.int16, np.int32, np.int64, | |||
| np.uint8, np.uint16, np.uint32, np.uint64) | |||
| for dim_size in arg_value: | |||
| if not isinstance(dim_size, valid_types) or dim_size <= 0: | |||
| raise ValueError('Every dimension size of the tensor shape should be a positive integer,' | |||
| ' but got {}.'.format(dim_size)) | |||
| def _check_str_by_regular(target, reg=None, flag=re.ASCII): | |||
| if reg is None: | |||
| # Named string regular expression | |||
| reg = r"^\w+[0-9a-zA-Z\_\.]*$" | |||
| if re.match(reg, target, flag) is None: | |||
| raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag)) | |||
| return True | |||
| def args_type_check(*type_args, **type_kwargs): | |||
| """Check whether input data type is correct.""" | |||
| @@ -19,7 +19,7 @@ from .._c_expression import ParamInfo | |||
| from . import dtype as mstype | |||
| from .initializer import initializer, Initializer | |||
| from .tensor import Tensor, MetaTensor | |||
| from .._checkparam import _check_str_by_regular | |||
| from .._checkparam import Validator | |||
| from ..parallel._tensor import _get_slice_index | |||
| from ..parallel._auto_parallel_context import auto_parallel_context | |||
| from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched | |||
| @@ -263,7 +263,7 @@ class Parameter(MetaTensor): | |||
| Returns: | |||
| Parameter, a new parameter. | |||
| """ | |||
| _check_str_by_regular(prefix) | |||
| Validator.check_str_by_regular(prefix) | |||
| x = copy(self) | |||
| # pylint: disable=protected-access | |||
| x._param_info = self._param_info.clone() | |||
| @@ -446,7 +446,7 @@ class ParameterTuple(tuple): | |||
| Returns: | |||
| Tuple, the new Parameter tuple. | |||
| """ | |||
| _check_str_by_regular(prefix) | |||
| Validator.check_str_by_regular(prefix) | |||
| new = [] | |||
| for x in self: | |||
| x1 = x.clone(prefix, init) | |||
| @@ -23,7 +23,7 @@ from collections import namedtuple | |||
| from types import FunctionType | |||
| from mindspore import log as logger | |||
| from mindspore._c_expression import MSContext, ms_ctx_param | |||
| from mindspore._checkparam import args_type_check | |||
| from mindspore._checkparam import args_type_check, Validator | |||
| from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ | |||
| _reset_auto_parallel_context | |||
| from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context | |||
| @@ -35,9 +35,9 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut | |||
| GRAPH_MODE = 0 | |||
| PYNATIVE_MODE = 1 | |||
| # The max memory size of graph plus variable. | |||
| _DEVICE_APP_MEMORY_SIZE = 31 | |||
| _DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable. | |||
| _re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' | |||
| _k_context = None | |||
| def _make_directory(path): | |||
| """Make directory.""" | |||
| @@ -223,7 +223,7 @@ class _Context: | |||
| def set_variable_memory_max_size(self, variable_memory_max_size): | |||
| """set values of variable_memory_max_size and graph_memory_max_size""" | |||
| if not _check_input_format(variable_memory_max_size): | |||
| if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern): | |||
| raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") | |||
| if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: | |||
| raise ValueError("Context param variable_memory_max_size should be less than 31GB.") | |||
| @@ -235,7 +235,7 @@ class _Context: | |||
| self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_) | |||
| def set_max_device_memory(self, max_device_memory): | |||
| if not _check_input_format(max_device_memory): | |||
| if not Validator.check_str_by_regular(max_device_memory, _re_pattern): | |||
| raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") | |||
| max_device_memory_value = float(max_device_memory[:-2]) | |||
| if max_device_memory_value == 0: | |||
| @@ -294,16 +294,6 @@ class _Context: | |||
| thread_info.debug_runtime = enable | |||
| def _check_input_format(x): | |||
| import re | |||
| pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' | |||
| result = re.match(pattern, x) | |||
| return result is not None | |||
| _k_context = None | |||
| def _context(): | |||
| """ | |||
| Get the global _context, if context is not created, create a new one. | |||
| @@ -23,7 +23,7 @@ from mindspore import log as logger | |||
| from .. import context | |||
| from ..common import dtype as mstype | |||
| from ..common.api import _executor, _pynative_exec | |||
| from .._checkparam import _check_str_by_regular | |||
| from .._checkparam import Validator | |||
| from ..common.parameter import Parameter, ParameterTuple | |||
| from .._c_expression import init_backend, Cell_ | |||
| from ..ops.primitive import Primitive | |||
| @@ -715,7 +715,7 @@ class Cell(Cell_): | |||
| recurse (bool): Whether contains the parameters of subcells. Default: True. | |||
| """ | |||
| _check_str_by_regular(prefix) | |||
| Validator.check_str_by_regular(prefix) | |||
| for name, param in self.parameters_and_names(expand=recurse): | |||
| if prefix != '': | |||
| param.is_init = False | |||
| @@ -549,7 +549,7 @@ class Unfold(Cell): | |||
| @constexpr | |||
| def _get_matrix_diag_assist(x_shape, x_dtype): | |||
| Validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist") | |||
| Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist") | |||
| base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1) | |||
| assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],)) | |||
| return Tensor(assist, x_dtype) | |||
| @@ -557,7 +557,7 @@ def _get_matrix_diag_assist(x_shape, x_dtype): | |||
| @constexpr | |||
| def _get_matrix_diag_part_assist(x_shape, x_dtype): | |||
| Validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist") | |||
| Validator.check_int(len(x_shape), 2, Rel.GE, "x rank", "_get_matrix_diag_part_assist") | |||
| base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1) | |||
| assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape) | |||
| return Tensor(assist, x_dtype) | |||
| @@ -239,8 +239,8 @@ class Conv2d(_Conv): | |||
| """Initialize depthwise conv2d op""" | |||
| if context.get_context("device_target") == "Ascend" and self.group > 1: | |||
| self.dilation = self._dilation | |||
| Validator.check_integer('group', self.group, self.in_channels, Rel.EQ) | |||
| Validator.check_integer('group', self.group, self.out_channels, Rel.EQ) | |||
| Validator.check_equal_int(self.group, self.in_channels, 'group') | |||
| Validator.check_equal_int(self.group, self.out_channels, 'group') | |||
| self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=self.pad_mode, | |||
| @@ -384,10 +384,10 @@ class Conv1d(_Conv): | |||
| Validator.check_value_type("stride", stride, [int], self.cls_name) | |||
| Validator.check_value_type("padding", padding, [int], self.cls_name) | |||
| Validator.check_value_type("dilation", dilation, [int], self.cls_name) | |||
| Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) | |||
| Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) | |||
| Validator.check_int(kernel_size, 1, Rel.GE, 'kernel_size', self.cls_name) | |||
| Validator.check_int(stride, 1, Rel.GE, 'stride', self.cls_name) | |||
| Validator.check_non_negative_int(padding, 'padding', self.cls_name) | |||
| Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) | |||
| Validator.check_int(dilation, 1, Rel.GE, 'dilation', self.cls_name) | |||
| kernel_size = (1, kernel_size) | |||
| stride = (1, stride) | |||
| dilation = (1, dilation) | |||
| @@ -395,7 +395,7 @@ class Conv1d(_Conv): | |||
| get_dtype = P.DType() | |||
| if isinstance(weight_init, Tensor): | |||
| weight_init_shape = get_shape(weight_init) | |||
| Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name) | |||
| Validator.check_equal_int(len(weight_init_shape), 3, 'weight_init_shape', self.cls_name) | |||
| weight_init_dtype = get_dtype(weight_init) | |||
| weight_init_value = weight_init.asnumpy() | |||
| weight_init_value = np.expand_dims(weight_init_value, 2) | |||
| @@ -539,7 +539,7 @@ class Conv2dTranspose(_Conv): | |||
| dilation = twice(dilation) | |||
| Validator.check_value_type('padding', padding, (int, tuple), self.cls_name) | |||
| if isinstance(padding, tuple): | |||
| Validator.check_integer('padding size', len(padding), 4, Rel.EQ, self.cls_name) | |||
| Validator.check_equal_int(len(padding), 4, 'padding size', self.cls_name) | |||
| # out_channels and in_channels swap. | |||
| # cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel, | |||
| # then Conv2dTranspose's out_channel refers to Conv2DBackpropInput's in_channel. | |||
| @@ -703,10 +703,10 @@ class Conv1dTranspose(_Conv): | |||
| Validator.check_value_type("stride", stride, [int], self.cls_name) | |||
| Validator.check_value_type("padding", padding, [int], self.cls_name) | |||
| Validator.check_value_type("dilation", dilation, [int], self.cls_name) | |||
| Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) | |||
| Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) | |||
| Validator.check_int(kernel_size, 1, Rel.GE, 'kernel_size', self.cls_name) | |||
| Validator.check_int(stride, 1, Rel.GE, 'stride', self.cls_name) | |||
| Validator.check_non_negative_int(padding, 'padding', self.cls_name) | |||
| Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) | |||
| Validator.check_int(dilation, 1, Rel.GE, 'dilation', self.cls_name) | |||
| kernel_size = (1, kernel_size) | |||
| stride = (1, stride) | |||
| dilation = (1, dilation) | |||
| @@ -714,7 +714,7 @@ class Conv1dTranspose(_Conv): | |||
| get_dtype = P.DType() | |||
| if isinstance(weight_init, Tensor): | |||
| weight_init_shape = get_shape(weight_init) | |||
| Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name) | |||
| Validator.check_equal_int(len(weight_init_shape), 3, 'weight_init_shape', self.cls_name) | |||
| weight_init_dtype = get_dtype(weight_init) | |||
| weight_init_value = weight_init.asnumpy() | |||
| weight_init_value = np.expand_dims(weight_init_value, 2) | |||
| @@ -220,7 +220,7 @@ class SSIM(Cell): | |||
| validator.check_value_type('max_val', max_val, [int, float], self.cls_name) | |||
| validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) | |||
| self.max_val = max_val | |||
| self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) | |||
| self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name) | |||
| self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) | |||
| self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) | |||
| self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) | |||
| @@ -298,7 +298,7 @@ class MSSSIM(Cell): | |||
| validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) | |||
| self.max_val = max_val | |||
| validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name) | |||
| self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) | |||
| self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name) | |||
| self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) | |||
| self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) | |||
| self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) | |||
| @@ -190,8 +190,8 @@ class MaxPool1d(_PoolNd): | |||
| validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name) | |||
| validator.check_value_type('stride', stride, [int], self.cls_name) | |||
| self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name) | |||
| validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name) | |||
| validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name) | |||
| validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name) | |||
| validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name) | |||
| self.kernel_size = (1, kernel_size) | |||
| self.stride = (1, stride) | |||
| self.max_pool = P.MaxPool(ksize=self.kernel_size, | |||
| @@ -349,8 +349,8 @@ class AvgPool1d(_PoolNd): | |||
| validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name) | |||
| validator.check_value_type('stride', stride, [int], self.cls_name) | |||
| self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name) | |||
| validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name) | |||
| validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name) | |||
| validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name) | |||
| validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name) | |||
| self.kernel_size = (1, kernel_size) | |||
| self.stride = (1, stride) | |||
| self.avg_pool = P.AvgPool(ksize=self.kernel_size, | |||
| @@ -323,7 +323,7 @@ class FakeQuantWithMinMax(Cell): | |||
| Validator.check_type("min_init", min_init, [int, float]) | |||
| Validator.check_type("max_init", max_init, [int, float]) | |||
| Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) | |||
| Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE) | |||
| Validator.check_non_negative_int(quant_delay, 'quant_delay') | |||
| self.min_init = min_init | |||
| self.max_init = max_init | |||
| self.num_bits = num_bits | |||
| @@ -489,8 +489,8 @@ class Conv2dBnFoldQuant(Cell): | |||
| # initialize convolution op and Parameter | |||
| if context.get_context('device_target') == "Ascend" and group > 1: | |||
| Validator.check_integer('group', group, in_channels, Rel.EQ) | |||
| Validator.check_integer('group', group, out_channels, Rel.EQ) | |||
| Validator.check_equal_int(group, in_channels, 'group') | |||
| Validator.check_equal_int(group, out_channels, 'group') | |||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| @@ -674,8 +674,8 @@ class Conv2dBnWithoutFoldQuant(Cell): | |||
| self.bias = None | |||
| # initialize convolution op and Parameter | |||
| if context.get_context('device_target') == "Ascend" and group > 1: | |||
| Validator.check_integer('group', group, in_channels, Rel.EQ) | |||
| Validator.check_integer('group', group, out_channels, Rel.EQ) | |||
| Validator.check_equal_int(group, in_channels, 'group') | |||
| Validator.check_equal_int(group, out_channels, 'group') | |||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| @@ -931,19 +931,19 @@ class LSTMGradData(PrimitiveWithInfer): | |||
| def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape, | |||
| hx_shape, cx_shape, reserve_shape, state_shape): | |||
| # dhy and dcy should be same shape | |||
| validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name) | |||
| validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name) | |||
| validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name) | |||
| validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name) | |||
| validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name) | |||
| validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name) | |||
| validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name) | |||
| validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name) | |||
| validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name) | |||
| validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name) | |||
| validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) | |||
| validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name) | |||
| validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name) | |||
| validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name) | |||
| # dy: (seq_len, batch_size, hidden_size * num_directions) | |||
| validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name) | |||
| validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name) | |||
| validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name) | |||
| validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name) | |||
| validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name) | |||
| # (seq_len, batch_size, input_size) | |||
| dx_shape = (y_shape[0], y_shape[1], self.input_size) | |||
| @@ -1015,19 +1015,19 @@ class LSTMGrad(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape, | |||
| dcy_shape, reserve_shape): | |||
| # dhy and dcy should be same shape | |||
| validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name) | |||
| validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name) | |||
| validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name) | |||
| validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name) | |||
| validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name) | |||
| validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name) | |||
| validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name) | |||
| validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name) | |||
| validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name) | |||
| validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name) | |||
| validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) | |||
| validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name) | |||
| validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name) | |||
| validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name) | |||
| # dy: (seq_len, batch_size, hidden_size * num_directions) | |||
| validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name) | |||
| validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name) | |||
| validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name) | |||
| validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name) | |||
| validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name) | |||
| # (seq_len, batch_size, input_size) | |||
| dx_shape = (y_shape[0], y_shape[1], self.input_size) | |||
| @@ -1069,7 +1069,7 @@ class DynamicRNNGrad(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape, | |||
| c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape): | |||
| validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(x_shape), 3, "x_shape", self.name) | |||
| num_step, batch_size, input_size = x_shape | |||
| hidden_size = w_shape[-1] // 4 | |||
| if w_shape[-1] % 4 != 0: | |||
| @@ -1575,7 +1575,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer): | |||
| def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): | |||
| # dhy and dcy should be same shape | |||
| validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(c_shape), 2, "c rank", self.name) | |||
| validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| @@ -1624,7 +1624,7 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer): | |||
| self.add_prim_attr("io_format", "HWCN") | |||
| def infer_shape(self, x_shape, h_shape, dgate_shape): | |||
| validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(x_shape), 2, "x rank", self.name) | |||
| validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name) | |||
| validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name) | |||
| validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name) | |||
| @@ -1656,8 +1656,8 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer): | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, dgate_shape, w_shape): | |||
| validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name) | |||
| validator.check_equal_int(len(w_shape), 2, "w rank", self.name) | |||
| validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) | |||
| batch_size = dgate_shape[0] | |||
| hidden_size = dgate_shape[1] // 4 | |||
| @@ -347,7 +347,7 @@ class MatrixDiag(PrimitiveWithInfer): | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, assist_shape): | |||
| validator.check_integer("assist rank", len(assist_shape), 2, Rel.GE, self.name) | |||
| validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name) | |||
| validator.check('rank of x', len(x_shape)+1, | |||
| 'rank of assist', len(assist_shape), Rel.LE, self.name) | |||
| validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', | |||
| @@ -395,7 +395,7 @@ class MatrixDiagPart(PrimitiveWithInfer): | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, assist_shape): | |||
| validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) | |||
| validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) | |||
| validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) | |||
| if assist_shape[-2] < assist_shape[-1]: | |||
| @@ -438,7 +438,7 @@ class MatrixSetDiag(PrimitiveWithInfer): | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, diagonal_shape, assist_shape): | |||
| validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) | |||
| validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) | |||
| validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) | |||
| if x_shape[-2] < x_shape[-1]: | |||
| @@ -81,11 +81,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer): | |||
| outputs=['min_up', 'max_up']) | |||
| def infer_shape(self, x_shape, min_shape, max_shape): | |||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | |||
| validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) | |||
| validator.check("min shape", min_shape, "max shape", | |||
| max_shape, Rel.EQ, self.name) | |||
| validator.check_integer("min shape", len( | |||
| min_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(min_shape), 1, "min shape", self.name) | |||
| return min_shape, max_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| @@ -147,11 +146,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): | |||
| if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank: | |||
| raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") | |||
| if not self.is_ascend: | |||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | |||
| validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) | |||
| validator.check("min shape", min_shape, "max shape", | |||
| max_shape, Rel.EQ, self.name) | |||
| validator.check_integer("min shape", len( | |||
| min_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(min_shape), 1, "min shape", self.name) | |||
| return min_shape, max_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| @@ -228,9 +226,9 @@ class FakeQuantPerLayer(PrimitiveWithInfer): | |||
| outputs=['out']) | |||
| def infer_shape(self, x_shape, min_shape, max_shape): | |||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | |||
| validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) | |||
| validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) | |||
| validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(min_shape), 1, "min shape", self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| @@ -284,8 +282,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): | |||
| x_shape, Rel.EQ, self.name) | |||
| validator.check("min shape", min_shape, "max shape", | |||
| max_shape, Rel.EQ, self.name) | |||
| validator.check_integer("min shape", len( | |||
| min_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(min_shape), 1, "min shape", self.name) | |||
| return dout_shape | |||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | |||
| @@ -375,14 +372,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer): | |||
| if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank: | |||
| raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") | |||
| if not self.is_ascend: | |||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | |||
| validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) | |||
| if len(x_shape) == 1: | |||
| self.channel_axis = 0 | |||
| validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) | |||
| validator.check_integer( | |||
| "min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||
| validator.check_integer( | |||
| "max shape", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||
| validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name) | |||
| validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| @@ -501,7 +496,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): | |||
| validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) | |||
| validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) | |||
| validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) | |||
| return mean_shape, mean_shape, mean_shape, mean_shape | |||
| def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): | |||
| @@ -548,7 +543,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer): | |||
| "batch_std shape", batch_std_shape, Rel.EQ, self.name) | |||
| validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], | |||
| "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) | |||
| validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) | |||
| return x_shape | |||
| def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, | |||
| @@ -723,7 +718,7 @@ class BatchNormFold2(PrimitiveWithInfer): | |||
| validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], | |||
| Rel.EQ, self.name) | |||
| validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, | |||
| @@ -771,7 +766,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer): | |||
| validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) | |||
| validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis], | |||
| Rel.EQ, self.name) | |||
| validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) | |||
| return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape | |||
| def infer_dtype(self, dout_type, x_type, gamma_type, | |||
| @@ -520,7 +520,7 @@ class Im2Col(PrimitiveWithInfer): | |||
| self.add_prim_attr('data_format', "NCHW") | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(x_shape), 4, "x rank", self.name) | |||
| kernel_size_h = self.kernel_size[0] | |||
| kernel_size_w = self.kernel_size[1] | |||
| stride_h = self.stride[2] | |||
| @@ -583,8 +583,8 @@ class Transpose(PrimitiveWithInfer): | |||
| tmp = list(p_value) | |||
| for i, dim in enumerate(p_value): | |||
| validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE, self.name) | |||
| validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT, self.name) | |||
| validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name) | |||
| validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name) | |||
| tmp.remove(dim) | |||
| if dim in tmp: | |||
| raise ValueError('The value of perm is wrong.') | |||
| @@ -725,8 +725,8 @@ class Padding(PrimitiveWithInfer): | |||
| def __infer__(self, x): | |||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | |||
| x_shape = list(x['shape']) | |||
| validator.check_integer("rank of x", len(x_shape), 1, Rel.GT, self.name) | |||
| validator.check_integer("last dim of x", x_shape[-1], 1, Rel.EQ, self.name) | |||
| validator.check_int(len(x_shape), 1, Rel.GT, "rank of x", self.name) | |||
| validator.check_int(x_shape[-1], 1, Rel.EQ, "last dim of x", self.name) | |||
| out_shape = x_shape | |||
| out_shape[-1] = self.pad_dim_size | |||
| out = {'shape': out_shape, | |||
| @@ -1575,7 +1575,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer): | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | |||
| validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) | |||
| validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||
| validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | |||
| validator.check(f'first shape of input_x', x_shape[0], | |||
| 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | |||
| num_segments_v = num_segments['value'] | |||
| @@ -1628,7 +1628,7 @@ class UnsortedSegmentProd(PrimitiveWithInfer): | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | |||
| validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) | |||
| validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||
| validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | |||
| validator.check(f'first shape of input_x', x_shape[0], | |||
| 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | |||
| num_segments_v = num_segments['value'] | |||
| @@ -1730,7 +1730,7 @@ class ParallelConcat(PrimitiveWithInfer): | |||
| x_shp = values['shape'] | |||
| x_type = values['dtype'] | |||
| validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name) | |||
| validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name) | |||
| args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} | |||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||
| @@ -1738,7 +1738,7 @@ class ParallelConcat(PrimitiveWithInfer): | |||
| first_elem = x_shp[0] | |||
| for i, elem in enumerate(x_shp[1:]): | |||
| j = i + 1 | |||
| validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(elem[0], 1, f'x_shp[{j}][0]', self.name) | |||
| validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name) | |||
| ret_shp = x_shp[0].copy() | |||
| @@ -1755,7 +1755,7 @@ class ParallelConcat(PrimitiveWithInfer): | |||
| def _get_pack_shape(x_shape, x_type, axis, prim_name): | |||
| """for pack output shape""" | |||
| validator.check_value_type("shape", x_shape, [tuple, list], prim_name) | |||
| validator.check_integer("len of input_x", len(x_shape), 1, Rel.GE, prim_name) | |||
| validator.check_int(len(x_shape), 1, Rel.GE, "len of input_x", prim_name) | |||
| validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name) | |||
| rank_base = len(x_shape[0]) | |||
| N = len(x_shape) | |||
| @@ -1871,8 +1871,8 @@ class Unpack(PrimitiveWithInfer): | |||
| validator.check_positive_int(output_num, "output_num", self.name) | |||
| self.add_prim_attr('num', output_num) | |||
| output_valid_check = x_shape[self.axis] - output_num | |||
| validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ, | |||
| self.name) | |||
| validator.check_int(output_valid_check, 0, Rel.EQ, | |||
| "The dimension which to unpack divides output_num", self.name) | |||
| out_shapes = [] | |||
| out_dtypes = [] | |||
| out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:] | |||
| @@ -2523,7 +2523,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): | |||
| """Initialize ResizeNearestNeighbor""" | |||
| validator.check_value_type("size", size, [tuple, list], self.name) | |||
| validator.check_value_type("align_corners", align_corners, [bool], self.name) | |||
| validator.check_integer("length of size", len(size), 2, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(size), 2, "length of size", self.name) | |||
| for i, value in enumerate(size): | |||
| validator.check_non_negative_int(value, f'{i}th value of size', self.name) | |||
| self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) | |||
| @@ -3134,9 +3134,8 @@ class DepthToSpace(PrimitiveWithInfer): | |||
| for i in range(2): | |||
| out_shape[i + 2] *= self.block_size | |||
| validator.check_integer('x_shape[1] % (block_size*block_size)', | |||
| x_shape[1] % (self.block_size * self.block_size), | |||
| 0, Rel.EQ, self.name) | |||
| validator.check_int(x_shape[1] % (self.block_size * self.block_size), | |||
| 0, Rel.EQ, 'x_shape[1] % (block_size*block_size)', self.name) | |||
| out_shape[1] //= self.block_size * self.block_size | |||
| return out_shape | |||
| @@ -3205,7 +3204,7 @@ class SpaceToBatch(PrimitiveWithInfer): | |||
| return x_dtype | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(x_shape), 4, 'rank of input_x', self.name) | |||
| out_shape = copy.deepcopy(x_shape) | |||
| for i in range(2): | |||
| padded = out_shape[i + 2] + self.paddings[i][0] + self.paddings[i][1] | |||
| @@ -3367,7 +3366,7 @@ class SpaceToBatchND(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| x_rank = len(x_shape) | |||
| validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(x_rank, 4, 'x_shape rank', self.name) | |||
| out_shape = copy.deepcopy(x_shape) | |||
| block_shape_prod = 1 | |||
| @@ -3460,7 +3459,7 @@ class BatchToSpaceND(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| x_rank = len(x_shape) | |||
| validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name) | |||
| validator.check_int(x_rank, 4, Rel.EQ, 'x_shape rank', self.name) | |||
| out_shape = copy.deepcopy(x_shape) | |||
| block_shape_prod = 1 | |||
| @@ -3607,11 +3606,11 @@ class Meshgrid(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| validator.check_value_type("shape", x_shape, [tuple, list], self.name) | |||
| validator.check_integer("len of input_x", len(x_shape), 2, Rel.GE, self.name) | |||
| validator.check_int(len(x_shape), 2, Rel.GE, "len of input_x", self.name) | |||
| n = len(x_shape) | |||
| shape_0 = [] | |||
| for s in x_shape: | |||
| validator.check_integer('each_input_rank', len(s), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(s), 1, Rel.EQ, 'each_input_rank', self.name) | |||
| shape_0.append(s[0]) | |||
| if self.indexing == "xy": | |||
| shape_0[0], shape_0[1] = shape_0[1], shape_0[0] | |||
| @@ -204,7 +204,7 @@ class _HostAllGather(PrimitiveWithInfer): | |||
| if group is None: | |||
| raise ValueError(f"For '{self.name}' group must be set.") | |||
| validator.check_value_type('group', group, (tuple, list), self.name) | |||
| validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | |||
| validator.check_int(len(group), 2, Rel.GE, "group size", self.name) | |||
| for r in group: | |||
| validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) | |||
| validator.check_value_type("rank_id", r, (int,), self.name) | |||
| @@ -313,7 +313,7 @@ class _HostReduceScatter(PrimitiveWithInfer): | |||
| raise ValueError(f"For '{self.name}' group must be set.") | |||
| validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) | |||
| validator.check_value_type('group', group, (tuple, list), self.name) | |||
| validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | |||
| validator.check_int(len(group), 2, Rel.GE, "group size", self.name) | |||
| for r in group: | |||
| validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) | |||
| validator.check_value_type("rank_id", r, (int,), self.name) | |||
| @@ -126,7 +126,7 @@ class GeSwitch(PrimitiveWithInfer): | |||
| raise NotImplementedError | |||
| def infer_shape(self, data, pred): | |||
| validator.check_integer("pred rank", len(pred), 0, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(pred), 0, "pred rank", self.name) | |||
| return (data, data) | |||
| def infer_dtype(self, data_type, pred_type): | |||
| @@ -374,9 +374,9 @@ class Assert(PrimitiveWithInfer): | |||
| def infer_shape(self, condition, inputs): | |||
| condition_len = len(condition) | |||
| validator.check_integer("condition's rank", condition_len, 1, Rel.LE, self.name) | |||
| validator.check_int(condition_len, 1, Rel.LE, "condition's rank", self.name) | |||
| if condition_len == 1: | |||
| validator.check_integer("condition[0]", condition[0], 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(condition[0], 1, "condition[0]", self.name) | |||
| return [1] | |||
| def infer_dtype(self, condition, inputs): | |||
| @@ -17,7 +17,6 @@ | |||
| import numbers | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common.dtype import tensor, dtype_to_pytype | |||
| from ..primitive import prim_attr_register, PrimitiveWithInfer | |||
| @@ -43,7 +42,7 @@ class ScalarCast(PrimitiveWithInfer): | |||
| pass | |||
| def __infer__(self, x, t): | |||
| validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(x['shape']), 0, 'x shape', self.name) | |||
| value, to = x['value'], t['value'] | |||
| if value is not None: | |||
| validator.check_value_type("value", value, [numbers.Number, bool], self.name) | |||
| @@ -827,7 +827,7 @@ class AddN(PrimitiveWithInfer): | |||
| def infer_shape(self, inputs): | |||
| cls_name = self.name | |||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | |||
| validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name) | |||
| self.add_prim_attr('n', len(inputs)) | |||
| shp0 = inputs[0] | |||
| for i, shp in enumerate(inputs): | |||
| @@ -837,7 +837,7 @@ class AddN(PrimitiveWithInfer): | |||
| def infer_dtype(self, inputs): | |||
| cls_name = self.name | |||
| validator.check_value_type("inputs", inputs, [tuple, list], cls_name) | |||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | |||
| validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name) | |||
| args = {} | |||
| contains_undetermined = False | |||
| for i, dtype in enumerate(inputs): | |||
| @@ -910,7 +910,7 @@ class AccumulateNV2(PrimitiveWithInfer): | |||
| def infer_shape(self, inputs): | |||
| cls_name = self.name | |||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | |||
| validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name) | |||
| self.add_prim_attr('n', len(inputs)) | |||
| shp0 = inputs[0] | |||
| for i, shp in enumerate(inputs): | |||
| @@ -920,7 +920,7 @@ class AccumulateNV2(PrimitiveWithInfer): | |||
| def infer_dtype(self, inputs): | |||
| cls_name = self.name | |||
| validator.check_value_type("inputs", inputs, [tuple, list], cls_name) | |||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | |||
| validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name) | |||
| args = {} | |||
| for i, dtype in enumerate(inputs): | |||
| args[f"inputs[{i}]"] = dtype | |||
| @@ -1488,7 +1488,7 @@ class HistogramFixedWidth(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, nbins, dtype='int32'): | |||
| self.nbins = validator.check_value_type("nbins", nbins, [int], self.name) | |||
| validator.check_integer("nbins", nbins, 1, Rel.GE, self.name) | |||
| validator.check_int(nbins, 1, Rel.GE, "nbins", self.name) | |||
| valid_values = ['int32', 'int64'] | |||
| self.dtype = validator.check_string(dtype, valid_values, "dtype", self.name) | |||
| self.init_prim_io_names(inputs=['x', 'range'], outputs=['y']) | |||
| @@ -2810,8 +2810,8 @@ class NPUGetFloatStatus(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| cls_name = self.name | |||
| validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) | |||
| validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) | |||
| validator.check_equal_int(len(x_shape), 1, "len(x_shape)", cls_name) | |||
| validator.check_equal_int(x_shape[0], 8, "x_shape[0]", cls_name) | |||
| return [8] | |||
| def infer_dtype(self, x_dtype): | |||
| @@ -2853,8 +2853,8 @@ class NPUClearFloatStatus(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| cls_name = self.name | |||
| validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) | |||
| validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) | |||
| validator.check_equal_int(len(x_shape), 1, "len(x_shape)", cls_name) | |||
| validator.check_equal_int(x_shape[0], 8, "x_shape[0]", cls_name) | |||
| return [8] | |||
| def infer_dtype(self, x_dtype): | |||
| @@ -3023,9 +3023,9 @@ class NMSWithMask(PrimitiveWithInfer): | |||
| def infer_shape(self, bboxes_shape): | |||
| cls_name = self.name | |||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | |||
| validator.check_equal_int(len(bboxes_shape), 2, "bboxes rank", cls_name) | |||
| validator.check_positive_int(bboxes_shape[0], "bboxes.shape[0]", cls_name) | |||
| validator.check_integer("bboxes.shape[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | |||
| validator.check_equal_int(bboxes_shape[1], 5, "bboxes.shape[1]", cls_name) | |||
| num = bboxes_shape[0] | |||
| return (bboxes_shape, (num,), (num,)) | |||
| @@ -3572,11 +3572,11 @@ class IFMR(PrimitiveWithInfer): | |||
| validator.check_value_type("offset_flag", with_offset, [bool], self.name) | |||
| def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape): | |||
| validator.check_integer("dims of data_min", len(data_min_shape), 1, Rel.EQ, self.name) | |||
| validator.check_integer("data_min[0]", data_min_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_integer("dims of data_max", len(data_max_shape), 1, Rel.EQ, self.name) | |||
| validator.check_integer("data_max[0]", data_max_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_integer("dims of cumsum", len(cumsum_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name) | |||
| validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name) | |||
| validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name) | |||
| validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name) | |||
| validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name) | |||
| return (1,), (1,) | |||
| def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): | |||
| @@ -98,7 +98,7 @@ class Flatten(PrimitiveWithInfer): | |||
| pass | |||
| def infer_shape(self, input_x): | |||
| validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name) | |||
| validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name) | |||
| prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:]) | |||
| return input_x[0], prod | |||
| @@ -146,7 +146,7 @@ class Softmax(PrimitiveWithInfer): | |||
| validator.check_value_type("item of axis", item, [int], self.name) | |||
| def infer_shape(self, logits): | |||
| validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name) | |||
| validator.check_int(len(self.axis), 1, Rel.GE, "length of axis", self.name) | |||
| rank = len(logits) | |||
| for axis_v in self.axis: | |||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||
| @@ -636,7 +636,7 @@ class FusedBatchNorm(Primitive): | |||
| def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): | |||
| self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], | |||
| outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) | |||
| self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) | |||
| self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name) | |||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||
| self._update_parameter = True | |||
| @@ -709,17 +709,17 @@ class FusedBatchNormEx(PrimitiveWithInfer): | |||
| def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): | |||
| self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], | |||
| outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve']) | |||
| self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) | |||
| self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name) | |||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||
| self._update_parameter = True | |||
| self.add_prim_attr('data_format', "NCHW") | |||
| def infer_shape(self, input_x, scale, bias, mean, variance): | |||
| validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(scale), 1, "scale rank", self.name) | |||
| validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) | |||
| validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) | |||
| validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(mean), 1, "mean rank", self.name) | |||
| validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) | |||
| validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) | |||
| return (input_x, scale, scale, scale, scale, scale) | |||
| @@ -757,7 +757,7 @@ class BNTrainingReduce(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum']) | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(x_shape), 4, "x rank", self.name) | |||
| return ([x_shape[1]], [x_shape[1]]) | |||
| def infer_dtype(self, x_type): | |||
| @@ -822,13 +822,13 @@ class BNTrainingUpdate(PrimitiveWithInfer): | |||
| self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate') | |||
| def infer_shape(self, x, sum, square_sum, scale, b, mean, variance): | |||
| validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name) | |||
| validator.check_integer("sum rank", len(sum), 1, Rel.EQ, self.name) | |||
| validator.check_integer("square_sum rank", len(square_sum), 1, Rel.EQ, self.name) | |||
| validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) | |||
| validator.check_integer("b rank", len(b), 1, Rel.EQ, self.name) | |||
| validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) | |||
| validator.check_integer("variance rank", len(variance), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(x), 4, "x rank", self.name) | |||
| validator.check_equal_int(len(sum), 1, "sum rank", self.name) | |||
| validator.check_equal_int(len(square_sum), 1, "square_sum rank", self.name) | |||
| validator.check_equal_int(len(scale), 1, "scale rank", self.name) | |||
| validator.check_equal_int(len(b), 1, "b rank", self.name) | |||
| validator.check_equal_int(len(mean), 1, "mean rank", self.name) | |||
| validator.check_equal_int(len(variance), 1, "variance rank", self.name) | |||
| validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name) | |||
| validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name) | |||
| validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name) | |||
| @@ -904,11 +904,11 @@ class BatchNorm(PrimitiveWithInfer): | |||
| outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) | |||
| def infer_shape(self, input_x, scale, bias, mean, variance): | |||
| validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(scale), 1, "scale rank", self.name) | |||
| validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) | |||
| validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) | |||
| if not self.is_training: | |||
| validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(mean), 1, "mean rank", self.name) | |||
| validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) | |||
| validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) | |||
| return (input_x, scale, scale, scale, scale) | |||
| @@ -1010,7 +1010,7 @@ class Conv2D(PrimitiveWithInfer): | |||
| if isinstance(pad, int): | |||
| pad = (pad,) * 4 | |||
| else: | |||
| validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(pad), 4, 'pad size', self.name) | |||
| self.padding = pad | |||
| self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) | |||
| @@ -1020,15 +1020,15 @@ class Conv2D(PrimitiveWithInfer): | |||
| for item in pad: | |||
| validator.check_non_negative_int(item, 'pad item', self.name) | |||
| self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) | |||
| self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) | |||
| self.add_prim_attr('data_format', "NCHW") | |||
| self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) | |||
| self.group = validator.check_positive_int(group, 'group', self.name) | |||
| self.add_prim_attr('offset_a', 0) | |||
| def infer_shape(self, x_shape, w_shape, b_shape=None): | |||
| validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) | |||
| validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(w_shape), 4, "weight rank", self.name) | |||
| validator.check_equal_int(len(x_shape), 4, "x rank", self.name) | |||
| validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) | |||
| validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) | |||
| validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) | |||
| @@ -1150,7 +1150,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): | |||
| if isinstance(pad, int): | |||
| pad = (pad,) * 4 | |||
| else: | |||
| validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(pad), 4, 'pad size', self.name) | |||
| self.padding = pad | |||
| self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) | |||
| if pad_mode != 'pad' and pad != (0, 0, 0, 0): | |||
| @@ -1158,15 +1158,15 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): | |||
| if self.pad_mode == 'pad': | |||
| for item in pad: | |||
| validator.check_non_negative_int(item, 'pad item', self.name) | |||
| self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name) | |||
| self.mode = validator.check_equal_int(mode, 3, "mode", self.name) | |||
| self.add_prim_attr('data_format', "NCHW") | |||
| self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name) | |||
| self.group = validator.check_positive_int(group, "group", self.name) | |||
| self.add_prim_attr('offset_a', 0) | |||
| def infer_shape(self, x_shape, w_shape, b_shape=None): | |||
| validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) | |||
| validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(w_shape), 4, "weight rank", self.name) | |||
| validator.check_equal_int(len(x_shape), 4, "x rank", self.name) | |||
| validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) | |||
| validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) | |||
| @@ -1250,7 +1250,7 @@ class _Pool(PrimitiveWithInfer): | |||
| self.add_prim_attr("strides", self.strides) | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(x_shape), 4, "x rank", self.name) | |||
| batch, channel, input_h, input_w = x_shape | |||
| if self.is_maxpoolwithargmax: | |||
| _, kernel_h, kernel_w, _ = self.ksize | |||
| @@ -1536,7 +1536,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): | |||
| if isinstance(pad, int): | |||
| pad = (pad,) * 4 | |||
| else: | |||
| validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(pad), 4, 'pad size', self.name) | |||
| self.padding = pad | |||
| self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) | |||
| if pad_mode != 'pad' and pad != (0, 0, 0, 0): | |||
| @@ -1547,7 +1547,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): | |||
| pad_mode = pad_mode.upper() | |||
| self.add_prim_attr('pad_mode', pad_mode) | |||
| self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) | |||
| self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) | |||
| self.group = validator.check_positive_int(group, 'group', self.name) | |||
| self.add_prim_attr('data_format', "NCHW") | |||
| if pad_list: | |||
| @@ -1624,8 +1624,8 @@ class BiasAdd(PrimitiveWithInfer): | |||
| self.add_prim_attr('data_format', 'NCHW') | |||
| def infer_shape(self, x_shape, b_shape): | |||
| validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) | |||
| validator.check_integer("bias rank", len(b_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) | |||
| validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) | |||
| validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_shape[1], Rel.EQ, self.name) | |||
| return x_shape | |||
| @@ -2007,10 +2007,10 @@ class RNNTLoss(PrimitiveWithInfer): | |||
| outputs=['costs', 'grads']) | |||
| def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape): | |||
| validator.check_integer('acts_rank', len(acts_shape), 4, Rel.EQ, self.name) | |||
| validator.check_integer('labels_rank', len(labels_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name) | |||
| validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(acts_shape), 4, 'acts_rank', self.name) | |||
| validator.check_equal_int(len(labels_shape), 2, 'labels_rank', self.name) | |||
| validator.check_equal_int(len(input_length_shape), 1, 'input_length_rank', self.name) | |||
| validator.check_equal_int(len(label_length_shape), 1, 'label_length_rank', self.name) | |||
| validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) | |||
| validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name) | |||
| validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) | |||
| @@ -2080,11 +2080,11 @@ class SGD(PrimitiveWithInfer): | |||
| def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape, | |||
| accum_shape, momentum_shape, stat_shape): | |||
| validator.check_positive_int(len(parameters_shape), "parameters rank", self.name) | |||
| validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name) | |||
| validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name) | |||
| validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name) | |||
| validator.check_int(len(learning_rate_shape), 0, Rel.GE, f'learning rate rank', self.name) | |||
| validator.check_positive_int(len(accum_shape), "accumulation rank", self.name) | |||
| validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name) | |||
| validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name) | |||
| validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name) | |||
| validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name) | |||
| validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name) | |||
| return parameters_shape | |||
| @@ -2780,17 +2780,17 @@ class LSTM(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape, h_shape, c_shape, w_shape): | |||
| # (seq, batch_size, feature) | |||
| validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name) | |||
| validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(x_shape), 3, "x rank", self.name) | |||
| validator.check_equal_int(x_shape[2], self.input_size, "x[2]", self.name) | |||
| # h and c should be same shape | |||
| validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(h_shape), 3, "h rank", self.name) | |||
| validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name) | |||
| # (num_layers * num_directions, batch, hidden_size) | |||
| validator.check_integer("h[0]", h_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) | |||
| validator.check_integer("h[1]", h_shape[1], x_shape[1], Rel.EQ, self.name) | |||
| validator.check_integer("h[2]", h_shape[2], self.hidden_size, Rel.EQ, self.name) | |||
| validator.check_int(h_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h[0]", self.name) | |||
| validator.check_equal_int(h_shape[1], x_shape[1], "h[1]", self.name) | |||
| validator.check_int(h_shape[2], self.hidden_size, Rel.EQ, "h[2]", self.name) | |||
| y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions) | |||
| @@ -2918,7 +2918,7 @@ class Pad(PrimitiveWithInfer): | |||
| def infer_shape(self, x): | |||
| paddings = np.array(self.paddings) | |||
| validator.check_integer('paddings.shape', paddings.size, len(x) * 2, Rel.EQ, self.name) | |||
| validator.check_int(paddings.size, len(x) * 2, Rel.EQ, 'paddings.shape', self.name) | |||
| if not np.all(paddings >= 0): | |||
| raise ValueError('All elements of paddings must be >= 0.') | |||
| y_shape = () | |||
| @@ -2992,7 +2992,7 @@ class MirrorPad(PrimitiveWithInfer): | |||
| x_shape = list(input_x['shape']) | |||
| paddings_value = paddings['value'].asnumpy() | |||
| paddings_size = paddings_value.size | |||
| validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ, self.name) | |||
| validator.check_int(paddings_size, len(x_shape) * 2, Rel.EQ, 'paddings.shape', self.name) | |||
| if not np.all(paddings_value >= 0): | |||
| raise ValueError('All elements of paddings must be >= 0.') | |||
| adjust = 0 | |||
| @@ -3276,7 +3276,7 @@ class FusedSparseAdam(PrimitiveWithInfer): | |||
| beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape): | |||
| validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) | |||
| validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) | |||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | |||
| if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]: | |||
| raise ValueError(f"For '{self.name}', the shape of updates should be [] or " | |||
| @@ -3409,7 +3409,7 @@ class FusedSparseLazyAdam(PrimitiveWithInfer): | |||
| beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape): | |||
| validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) | |||
| validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) | |||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | |||
| if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]: | |||
| raise ValueError(f"For '{self.name}', the shape of updates should be [] or " | |||
| @@ -3513,7 +3513,7 @@ class FusedSparseFtrl(PrimitiveWithInfer): | |||
| validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | |||
| if len(var_shape) > 1: | |||
| validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) | |||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | |||
| return [1], [1], [1] | |||
| @@ -3602,7 +3602,7 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): | |||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) | |||
| return [1], [1] | |||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): | |||
| @@ -3869,25 +3869,25 @@ class ApplyAdaMax(PrimitiveWithInfer): | |||
| validator.check("v_shape", v_shape, "var_shape", var_shape, Rel.EQ, self.name) | |||
| validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name) | |||
| beta1_power_shp_len = len(beta1_power_shape) | |||
| validator.check_integer("beta1 power's rank", beta1_power_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(beta1_power_shp_len, 1, Rel.LE, "beta1 power's rank", self.name) | |||
| if beta1_power_shp_len == 1: | |||
| validator.check_integer("beta1_power_shape[0]", beta1_power_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(beta1_power_shape[0], 1, Rel.EQ, "beta1_power_shape[0]", self.name) | |||
| lr_shp_len = len(lr_shape) | |||
| validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) | |||
| if lr_shp_len == 1: | |||
| validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) | |||
| beta1_shp_len = len(beta1_shape) | |||
| validator.check_integer("beta1's rank", beta1_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(beta1_shp_len, 1, Rel.LE, "beta1's rank", self.name) | |||
| if beta1_shp_len == 1: | |||
| validator.check_integer("beta1_shape[0]", beta1_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(beta1_shape[0], 1, Rel.EQ, "beta1_shape[0]", self.name) | |||
| beta2_shp_len = len(beta2_shape) | |||
| validator.check_integer("beta2's rank", beta2_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(beta2_shp_len, 1, Rel.LE, "beta2's rank", self.name) | |||
| if beta2_shp_len == 1: | |||
| validator.check_integer("beta2_shape[0]", beta2_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(beta2_shape[0], 1, Rel.EQ, "beta2_shape[0]", self.name) | |||
| epsilon_shp_len = len(epsilon_shape) | |||
| validator.check_integer("epsilon's rank", epsilon_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(epsilon_shp_len, 1, Rel.LE, "epsilon's rank", self.name) | |||
| if epsilon_shp_len == 1: | |||
| validator.check_integer("epsilon_shape[0]", epsilon_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name) | |||
| return var_shape, m_shape, v_shape | |||
| def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype, | |||
| @@ -3985,17 +3985,17 @@ class ApplyAdadelta(PrimitiveWithInfer): | |||
| validator.check("accum_update_shape", accum_update_shape, "var_shape", var_shape, Rel.EQ, self.name) | |||
| validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name) | |||
| lr_shp_len = len(lr_shape) | |||
| validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) | |||
| if lr_shp_len == 1: | |||
| validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) | |||
| rho_shp_len = len(rho_shape) | |||
| validator.check_integer("rho's rank", rho_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(rho_shp_len, 1, Rel.LE, "rho's rank", self.name) | |||
| if rho_shp_len == 1: | |||
| validator.check_integer("rho_shape[0]", rho_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(rho_shape[0], 1, Rel.EQ, "rho_shape[0]", self.name) | |||
| epsilon_shp_len = len(epsilon_shape) | |||
| validator.check_integer("lepsilon's rank", epsilon_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(epsilon_shp_len, 1, Rel.LE, "lepsilon's rank", self.name) | |||
| if epsilon_shp_len == 1: | |||
| validator.check_integer("epsilon_shape[0]", epsilon_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name) | |||
| return var_shape, accum_shape, accum_update_shape | |||
| def infer_dtype(self, var_dtype, accum_dtype, accum_update_dtype, lr_dtype, rho_dtype, | |||
| @@ -4077,9 +4077,9 @@ class ApplyAdagrad(PrimitiveWithInfer): | |||
| validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name) | |||
| validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name) | |||
| lr_shp_len = len(lr_shape) | |||
| validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) | |||
| if lr_shp_len == 1: | |||
| validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) | |||
| return var_shape, accum_shape | |||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): | |||
| @@ -4161,9 +4161,9 @@ class ApplyAdagradV2(PrimitiveWithInfer): | |||
| validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | |||
| validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name) | |||
| lr_shp_len = len(lr_shape) | |||
| validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) | |||
| if lr_shp_len == 1: | |||
| validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) | |||
| return var_shape, accum_shape | |||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): | |||
| @@ -4249,7 +4249,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer): | |||
| validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) | |||
| if len(var_shape) > 1: | |||
| validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) | |||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | |||
| return var_shape, accum_shape | |||
| @@ -4338,7 +4338,7 @@ class SparseApplyAdagradV2(PrimitiveWithInfer): | |||
| validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) | |||
| if len(var_shape) > 1: | |||
| validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) | |||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | |||
| return var_shape, accum_shape | |||
| @@ -4428,17 +4428,17 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): | |||
| validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name) | |||
| validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name) | |||
| lr_shp_len = len(lr_shape) | |||
| validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) | |||
| if lr_shp_len == 1: | |||
| validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) | |||
| l1_shp_len = len(l1_shape) | |||
| validator.check_integer("l1's rank", l1_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(l1_shp_len, 1, Rel.LE, "l1's rank", self.name) | |||
| if l1_shp_len == 1: | |||
| validator.check_integer("l1_shape[0]", l1_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(l1_shape[0], 1, Rel.EQ, "l1_shape[0]", self.name) | |||
| l2_shp_len = len(l2_shape) | |||
| validator.check_integer("l2's rank", l2_shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(l2_shp_len, 1, Rel.LE, "l2's rank", self.name) | |||
| if l2_shp_len == 1: | |||
| validator.check_integer("l2_shape[0]", l2_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(l2_shape[0], 1, Rel.EQ, "l2_shape[0]", self.name) | |||
| return var_shape, accum_shape | |||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype): | |||
| @@ -4532,7 +4532,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck): | |||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) | |||
| def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): | |||
| args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} | |||
| @@ -4623,21 +4623,21 @@ class ApplyAddSign(PrimitiveWithInfer): | |||
| validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name) | |||
| validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name) | |||
| lr_shape_len = len(lr_shape) | |||
| validator.check_integer("lr's rank", lr_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name) | |||
| if lr_shape_len == 1: | |||
| validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) | |||
| alpha_shape_len = len(alpha_shape) | |||
| validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name) | |||
| if alpha_shape_len == 1: | |||
| validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name) | |||
| sign_decay_shape_len = len(sign_decay_shape) | |||
| validator.check_integer("sign_decay's rank", sign_decay_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name) | |||
| if sign_decay_shape_len == 1: | |||
| validator.check_integer("sign_decay_shape[0]", sign_decay_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name) | |||
| beta_shape_len = len(beta_shape) | |||
| validator.check_integer("beta's rank", beta_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name) | |||
| if beta_shape_len == 1: | |||
| validator.check_integer("beta_shape[0]", beta_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name) | |||
| return var_shape, m_shape | |||
| def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, beta_dtype, grad_dtype): | |||
| @@ -4732,21 +4732,21 @@ class ApplyPowerSign(PrimitiveWithInfer): | |||
| validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name) | |||
| validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name) | |||
| lr_shape_len = len(lr_shape) | |||
| validator.check_integer("lr's rank", lr_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name) | |||
| if lr_shape_len == 1: | |||
| validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) | |||
| logbase_shape_len = len(logbase_shape) | |||
| validator.check_integer("logbase's rank", logbase_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(logbase_shape_len, 1, Rel.LE, "logbase's rank", self.name) | |||
| if logbase_shape_len == 1: | |||
| validator.check_integer("logbase_shape[0]", logbase_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(logbase_shape[0], 1, Rel.EQ, "logbase_shape[0]", self.name) | |||
| sign_decay_shape_len = len(sign_decay_shape) | |||
| validator.check_integer("sign_decay's rank", sign_decay_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name) | |||
| if sign_decay_shape_len == 1: | |||
| validator.check_integer("sign_decay_shape[0]", sign_decay_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name) | |||
| beta_shape_len = len(beta_shape) | |||
| validator.check_integer("beta's rank", beta_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name) | |||
| if beta_shape_len == 1: | |||
| validator.check_integer("beta_shape[0]", beta_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name) | |||
| return var_shape, m_shape | |||
| def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, beta_dtype, grad_dtype): | |||
| @@ -4812,9 +4812,9 @@ class ApplyGradientDescent(PrimitiveWithInfer): | |||
| def infer_shape(self, var_shape, alpha_shape, delta_shape): | |||
| validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name) | |||
| alpha_shape_len = len(alpha_shape) | |||
| validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name) | |||
| if alpha_shape_len == 1: | |||
| validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name) | |||
| return var_shape | |||
| def infer_dtype(self, var_dtype, alpha_dtype, delta_dtype): | |||
| @@ -4887,17 +4887,17 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer): | |||
| def infer_shape(self, var_shape, alpha_shape, l1_shape, l2_shape, delta_shape): | |||
| validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name) | |||
| alpha_shape_len = len(alpha_shape) | |||
| validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name) | |||
| if alpha_shape_len == 1: | |||
| validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name) | |||
| l1_shape_len = len(l1_shape) | |||
| validator.check_integer("l1's rank", l1_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(l1_shape_len, 1, Rel.LE, "l1's rank", self.name) | |||
| if l1_shape_len == 1: | |||
| validator.check_integer("l1_shape[0]", l1_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(l1_shape[0], 1, Rel.EQ, "l1_shape[0]", self.name) | |||
| l2_shape_len = len(l2_shape) | |||
| validator.check_integer("l2's rank", l2_shape_len, 1, Rel.LE, self.name) | |||
| validator.check_int(l2_shape_len, 1, Rel.LE, "l2's rank", self.name) | |||
| if l2_shape_len == 1: | |||
| validator.check_integer("l2_shape[0]", l2_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(l2_shape[0], 1, Rel.EQ, "l2_shape[0]", self.name) | |||
| return var_shape | |||
| def infer_dtype(self, var_dtype, alpha_dtype, l1_dtype, l2_dtype, delta_dtype): | |||
| @@ -4965,13 +4965,13 @@ class LARSUpdate(PrimitiveWithInfer): | |||
| validator.check("norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape, Rel.EQ, | |||
| self.name) | |||
| shp_len = len(weight_decay_shape) | |||
| validator.check_integer("weight decay's rank", shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(shp_len, 1, Rel.LE, "weight decay's rank", self.name) | |||
| if shp_len == 1: | |||
| validator.check_integer("weight_decay_shape[0]", weight_decay_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(weight_decay_shape[0], 1, Rel.EQ, "weight_decay_shape[0]", self.name) | |||
| shp_len = len(learning_rate_shape) | |||
| validator.check_integer("learning rate's rank", shp_len, 1, Rel.LE, self.name) | |||
| validator.check_int(shp_len, 1, Rel.LE, "learning rate's rank", self.name) | |||
| if shp_len == 1: | |||
| validator.check_integer("learning_rate_shape[0]", learning_rate_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(learning_rate_shape[0], 1, Rel.EQ, "learning_rate_shape[0]", self.name) | |||
| return weight_shape | |||
| def infer_dtype(self, weight_dtype, gradient_dtype, norm_weight_dtype, norm_gradient_dtype, | |||
| @@ -5155,7 +5155,7 @@ class SparseApplyFtrl(PrimitiveWithCheck): | |||
| validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | |||
| if len(var_shape) > 1: | |||
| validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) | |||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | |||
| def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): | |||
| @@ -5251,7 +5251,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): | |||
| validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | |||
| if len(var_shape) > 1: | |||
| validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) | |||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | |||
| return var_shape, accum_shape, linear_shape | |||
| @@ -5288,7 +5288,7 @@ class Dropout(PrimitiveWithInfer): | |||
| self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) | |||
| validator.check_int(len(x_shape), 1, Rel.GE, "x_shape", self.name) | |||
| mask_shape = x_shape | |||
| return x_shape, mask_shape | |||
| @@ -5352,11 +5352,11 @@ class CTCLoss(PrimitiveWithInfer): | |||
| self.ignore_longer_outputs_than_inputs_ = ignore_longer_outputs_than_inputs | |||
| def infer_shape(self, inputs, labels_indices, labels_values, sequence_length): | |||
| validator.check_integer("inputs rank", len(inputs), 3, Rel.EQ, self.name) | |||
| validator.check_integer("labels_indices rank", len(labels_indices), 2, Rel.EQ, self.name) | |||
| validator.check_integer("labels_indices dim one", labels_indices[1], 2, Rel.EQ, self.name) | |||
| validator.check_integer("labels_values rank", len(labels_values), 1, Rel.EQ, self.name) | |||
| validator.check_integer("sequence_length rank", len(sequence_length), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(inputs), 3, Rel.EQ, "inputs rank", self.name) | |||
| validator.check_int(len(labels_indices), 2, Rel.EQ, "labels_indices rank", self.name) | |||
| validator.check_int(labels_indices[1], 2, Rel.EQ, "labels_indices dim one", self.name) | |||
| validator.check_int(len(labels_values), 1, Rel.EQ, "labels_values rank", self.name) | |||
| validator.check_int(len(sequence_length), 1, Rel.EQ, "sequence_length rank", self.name) | |||
| validator.check('labels_indices size', labels_indices[0], 'labels_values size', | |||
| labels_values[0], Rel.EQ, self.name) | |||
| validator.check('inputs batch_size', inputs[1], 'sequence_length batch_size', | |||
| @@ -5422,8 +5422,8 @@ class CTCGreedyDecoder(PrimitiveWithInfer): | |||
| self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name) | |||
| def infer_shape(self, inputs_shape, sequence_length_shape): | |||
| validator.check_integer("inputs rank", len(inputs_shape), 3, Rel.EQ, self.name) | |||
| validator.check_integer("sequence_length rank", len(sequence_length_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name) | |||
| validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name) | |||
| validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size', | |||
| sequence_length_shape[0], Rel.EQ, self.name) | |||
| total_decoded_outputs = -1 | |||
| @@ -5517,11 +5517,11 @@ class BasicLSTMCell(PrimitiveWithInfer): | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): | |||
| validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("h rank", len(h_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name) | |||
| validator.check_int(len(x_shape), 2, Rel.EQ, "x rank", self.name) | |||
| validator.check_int(len(h_shape), 2, Rel.EQ, "h rank", self.name) | |||
| validator.check_int(len(c_shape), 2, Rel.EQ, "c rank", self.name) | |||
| validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name) | |||
| validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name) | |||
| validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) | |||
| validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) | |||
| validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name) | |||
| @@ -5637,11 +5637,11 @@ class DynamicRNN(PrimitiveWithInfer): | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape): | |||
| validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name) | |||
| validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name) | |||
| validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ, self.name) | |||
| validator.check_integer("c_shape", len(c_shape), 3, Rel.EQ, self.name) | |||
| validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name) | |||
| validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name) | |||
| validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name) | |||
| validator.check_int(len(h_shape), 3, Rel.EQ, "h_shape", self.name) | |||
| validator.check_int(len(c_shape), 3, Rel.EQ, "c_shape", self.name) | |||
| if seq_shape is not None: | |||
| raise ValueError(f"For {self.name}, seq_shape should be None.") | |||
| @@ -5654,7 +5654,7 @@ class DynamicRNN(PrimitiveWithInfer): | |||
| validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size", | |||
| input_size + hidden_size, Rel.EQ, self.name) | |||
| validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name) | |||
| validator.check_integer("h_shape[0]", h_shape[0], 1, Rel.EQ, self.name) | |||
| validator.check_int(h_shape[0], 1, Rel.EQ, "h_shape[0]", self.name) | |||
| validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name) | |||
| validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name) | |||
| validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name) | |||
| @@ -5754,5 +5754,5 @@ class LRN(PrimitiveWithInfer): | |||
| return x_dtype | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name) | |||
| validator.check_int(len(x_shape), 4, Rel.EQ, "x_shape", self.name) | |||
| return x_shape | |||
| @@ -98,16 +98,16 @@ class BoundingBoxEncode(PrimitiveWithInfer): | |||
| validator.check_value_type("means[%d]" % i, value, [float], self.name) | |||
| for i, value in enumerate(stds): | |||
| validator.check_value_type("stds[%d]" % i, value, [float], self.name) | |||
| validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) | |||
| validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(means), 4, "means len", self.name) | |||
| validator.check_equal_int(len(stds), 4, "stds len", self.name) | |||
| def infer_shape(self, anchor_box, groundtruth_box): | |||
| validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ, | |||
| self.name) | |||
| validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) | |||
| validator.check("groundtruth_box rank", len(groundtruth_box), "", 2, Rel.EQ, self.name) | |||
| validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) | |||
| validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name) | |||
| validator.check_equal_int(groundtruth_box[1], 4, 'groundtruth_box shape[1]', self.name) | |||
| return anchor_box | |||
| def infer_dtype(self, anchor_box, groundtruth_box): | |||
| @@ -153,18 +153,18 @@ class BoundingBoxDecode(PrimitiveWithInfer): | |||
| for i, value in enumerate(stds): | |||
| validator.check_value_type("stds[%d]" % i, value, [float], self.name) | |||
| validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name) | |||
| validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) | |||
| validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(means), 4, "means len", self.name) | |||
| validator.check_equal_int(len(stds), 4, "stds len", self.name) | |||
| if max_shape is not None: | |||
| validator.check_value_type('max_shape', max_shape, [tuple], self.name) | |||
| validator.check_integer("max_shape len", len(max_shape), 2, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(max_shape), 2, "max_shape len", self.name) | |||
| def infer_shape(self, anchor_box, deltas): | |||
| validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name) | |||
| validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) | |||
| validator.check("deltas rank", len(deltas), "", 2, Rel.EQ, self.name) | |||
| validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) | |||
| validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name) | |||
| validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name) | |||
| validator.check_equal_int(deltas[1], 4, 'deltas shape[1]', self.name) | |||
| return anchor_box | |||
| def infer_dtype(self, anchor_box, deltas): | |||
| @@ -272,10 +272,10 @@ class IOU(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap']) | |||
| def infer_shape(self, anchor_boxes, gt_boxes): | |||
| validator.check_integer('gt_boxes shape[1]', gt_boxes[1], 4, Rel.EQ, self.name) | |||
| validator.check_integer('anchor_boxes shape[1]', anchor_boxes[1], 4, Rel.EQ, self.name) | |||
| validator.check_integer('anchor_boxes rank', len(anchor_boxes), 2, Rel.EQ, self.name) | |||
| validator.check_integer('gt_boxes rank', len(gt_boxes), 2, Rel.EQ, self.name) | |||
| validator.check_equal_int(gt_boxes[1], 4, 'gt_boxes shape[1]', self.name) | |||
| validator.check_equal_int(anchor_boxes[1], 4, 'anchor_boxes shape[1]', self.name) | |||
| validator.check_equal_int(len(anchor_boxes), 2, 'anchor_boxes rank', self.name) | |||
| validator.check_equal_int(len(gt_boxes), 2, 'gt_boxes rank', self.name) | |||
| iou = [gt_boxes[0], anchor_boxes[0]] | |||
| return iou | |||
| @@ -356,8 +356,8 @@ class RandomChoiceWithMask(PrimitiveWithInfer): | |||
| Validator.check_value_type('seed2', seed2, [int], self.name) | |||
| def infer_shape(self, x_shape): | |||
| Validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name) | |||
| Validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name) | |||
| Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name) | |||
| Validator.check_int(len(x_shape), 5, Rel.LE, "input_x rank", self.name) | |||
| return ([self.count, len(x_shape)], [self.count]) | |||
| def infer_dtype(self, x_dtype): | |||
| @@ -227,7 +227,7 @@ class PrimitiveWithCheck(Primitive): | |||
| >>> def __init__(self): | |||
| >>> pass | |||
| >>> def check_shape(self, input_x): | |||
| >>> validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name) | |||
| >>> validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name) | |||
| >>> | |||
| >>> def check_dtype(self, input_x): | |||
| >>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name) | |||
| @@ -89,12 +89,12 @@ class ConvertToQuantNetwork: | |||
| def __init__(self, **kwargs): | |||
| self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) | |||
| self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE) | |||
| self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE) | |||
| self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay") | |||
| self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay") | |||
| self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold") | |||
| self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE) | |||
| self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE) | |||
| self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE) | |||
| self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn") | |||
| self.weight_bits = Validator.check_non_negative_int(kwargs["num_bits"][0], "weights bit") | |||
| self.act_bits = Validator.check_int(kwargs["num_bits"][-1], 0, Rel.GE, "activations bit") | |||
| self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel") | |||
| self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel") | |||
| self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric") | |||
| @@ -21,7 +21,7 @@ from PIL import Image | |||
| from mindspore import log as logger | |||
| from ..._checkparam import _check_str_by_regular | |||
| from ..._checkparam import Validator | |||
| from ..anf_ir_pb2 import DataType, ModelProto | |||
| from ..summary_pb2 import Event | |||
| @@ -47,8 +47,8 @@ def get_event_file_name(prefix, suffix): | |||
| Returns: | |||
| String, the name of event log file. | |||
| """ | |||
| _check_str_by_regular(prefix) | |||
| _check_str_by_regular(suffix) | |||
| Validator.check_str_by_regular(prefix) | |||
| Validator.check_str_by_regular(suffix) | |||
| file_name = "" | |||
| time_second = str(int(time.time())) | |||
| hostname = platform.node() | |||
| @@ -21,7 +21,7 @@ import threading | |||
| from mindspore import log as logger | |||
| from ..._c_expression import Tensor | |||
| from ..._checkparam import _check_str_by_regular | |||
| from ..._checkparam import Validator | |||
| from .._utils import _check_lineage_value, _check_to_numpy, _make_directory | |||
| from ._summary_adapter import get_event_file_name, package_graph_event | |||
| from ._writer_pool import WriterPool | |||
| @@ -103,8 +103,8 @@ class SummaryRecord: | |||
| self._closed, self._event_writer = False, None | |||
| self._mode, self._data_pool = 'train', _dictlist() | |||
| _check_str_by_regular(file_prefix) | |||
| _check_str_by_regular(file_suffix) | |||
| Validator.check_str_by_regular(file_prefix) | |||
| Validator.check_str_by_regular(file_suffix) | |||
| self.log_path = _make_directory(log_dir) | |||
| @@ -12,10 +12,10 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test checkparameter """ | |||
| """ test check parameter """ | |||
| import pytest | |||
| import numpy as np | |||
| from mindspore._checkparam import twice, Validator | |||
| from mindspore._checkparam import Validator, twice | |||
| kernel_size = 5 | |||
| kernel_size1 = twice(kernel_size) | |||
| @@ -18,7 +18,7 @@ import numpy as np | |||
| import pytest | |||
| from mindspore import context, Tensor, Parameter, ParameterTuple, nn | |||
| from mindspore._checkparam import _check_str_by_regular | |||
| from mindspore._checkparam import Validator | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.initializer import initializer | |||
| @@ -124,15 +124,15 @@ def test_check_str_by_regular(): | |||
| str4 = ".12_sf.asdf" | |||
| str5 = "12_sf.a$sdf." | |||
| str6 = "12+sf.asdf" | |||
| _check_str_by_regular(str1) | |||
| _check_str_by_regular(str2) | |||
| _check_str_by_regular(str3) | |||
| Validator.check_str_by_regular(str1) | |||
| Validator.check_str_by_regular(str2) | |||
| Validator.check_str_by_regular(str3) | |||
| with pytest.raises(ValueError): | |||
| _check_str_by_regular(str4) | |||
| Validator.check_str_by_regular(str4) | |||
| with pytest.raises(ValueError): | |||
| _check_str_by_regular(str5) | |||
| Validator.check_str_by_regular(str5) | |||
| with pytest.raises(ValueError): | |||
| _check_str_by_regular(str6) | |||
| Validator.check_str_by_regular(str6) | |||
| def test_parameter_compute(): | |||
| para_1 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test1') | |||