| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Check parameters.""" | |||
| import re | |||
| import inspect | |||
| import math | |||
| @@ -20,10 +21,9 @@ from enum import Enum | |||
| from functools import reduce, wraps | |||
| from itertools import repeat | |||
| from collections.abc import Iterable | |||
| import numpy as np | |||
| from mindspore import log as logger | |||
| from .common import dtype as mstype | |||
| from mindspore.common import dtype as mstype | |||
| # Named string regular expression | |||
| @@ -103,18 +103,17 @@ class Validator: | |||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): | |||
| """ | |||
| Method for judging relation between two int values or list/tuple made up of ints. | |||
| This method is not suitable for judging relation between floats, since it does not consider float error. | |||
| """ | |||
| rel_fn = Rel.get_fns(rel) | |||
| if not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||
| raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_integer(arg_name, arg_value, value, rel, prim_name): | |||
| def check_integer(arg_name, arg_value, value, rel, prim_name=None): | |||
| """Integer value judgment.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||
| @@ -135,6 +134,20 @@ class Validator: | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_isinstance(arg_name, arg_value, classes): | |||
| """Check arg isinstance of classes""" | |||
| if not isinstance(arg_value, classes): | |||
| raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_bool(arg_name, arg_value): | |||
| """Check arg isinstance of bool""" | |||
| if not isinstance(arg_value, bool): | |||
| raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | |||
| """Method for checking whether an int value is in some range.""" | |||
| @@ -208,6 +221,27 @@ class Validator: | |||
| """Checks valid value.""" | |||
| if arg_value is None: | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_type(arg_name, arg_value, valid_types): | |||
| """Type checking.""" | |||
| def raise_error_msg(): | |||
| """func for raising error message when check failed""" | |||
| type_names = [t.__name__ for t in valid_types] | |||
| num_types = len(valid_types) | |||
| raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||
| if isinstance(arg_value, type(mstype.tensor)): | |||
| arg_value = arg_value.element_type() | |||
| # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and | |||
| # `check_type('x', True, [bool, int])` will check pass | |||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||
| raise_error_msg() | |||
| if isinstance(arg_value, tuple(valid_types)): | |||
| return arg_value | |||
| raise_error_msg() | |||
| @staticmethod | |||
| def check_type_same(args, valid_values, prim_name): | |||
| @@ -239,7 +273,6 @@ class Validator: | |||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): | |||
| """ | |||
| Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. | |||
| If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. | |||
| """ | |||
| @@ -335,63 +368,6 @@ class Validator: | |||
| f'{tuple(exp_shape)}, but got {shape}.') | |||
| class ParamValidator: | |||
| """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | |||
| @staticmethod | |||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ): | |||
| """This method is only used for check int values, since when compare float values, | |||
| we need consider float error.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| if not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | |||
| raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||
| @staticmethod | |||
| def check_integer(arg_name, arg_value, value, rel): | |||
| """Integer value judgment.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||
| if type_mismatch or not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(value) | |||
| raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_isinstance(arg_name, arg_value, classes): | |||
| """Check arg isinstance of classes""" | |||
| if not isinstance(arg_value, classes): | |||
| raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_bool(arg_name, arg_value): | |||
| """Check arg isinstance of bool""" | |||
| if not isinstance(arg_value, bool): | |||
| raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_type(arg_name, arg_value, valid_types): | |||
| """Type checking.""" | |||
| def raise_error_msg(): | |||
| """func for raising error message when check failed""" | |||
| type_names = [t.__name__ for t in valid_types] | |||
| num_types = len(valid_types) | |||
| raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||
| if isinstance(arg_value, type(mstype.tensor)): | |||
| arg_value = arg_value.element_type() | |||
| # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and | |||
| # `check_type('x', True, [bool, int])` will check pass | |||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||
| raise_error_msg() | |||
| if isinstance(arg_value, tuple(valid_types)): | |||
| return arg_value | |||
| raise_error_msg() | |||
| def check_int(input_param): | |||
| """Int type judgment.""" | |||
| if isinstance(input_param, int) and not isinstance(input_param, bool): | |||
| @@ -638,7 +614,6 @@ def args_type_check(*type_args, **type_kwargs): | |||
| if value is not None and not isinstance(value, bound_types[name]): | |||
| raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) | |||
| return func(*args, **kwargs) | |||
| return wrapper | |||
| return type_check | |||
| @@ -21,7 +21,7 @@ from ...ops import operations as P | |||
| from ...ops.primitive import PrimitiveWithInfer, prim_attr_register | |||
| from ...ops.composite import multitype_ops as C | |||
| from ...ops.operations import _grad_ops as G | |||
| from ..._checkparam import ParamValidator as validator | |||
| from ..._checkparam import Validator | |||
| from ..cell import Cell, GraphKernel | |||
| @@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel): | |||
| use_locking=False, | |||
| gradient_scale=1.0): | |||
| super(ApplyMomentum, self).__init__() | |||
| self.gradient_scale = validator.check_type('gradient_scale', gradient_scale, [float]) | |||
| self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float]) | |||
| self.fake_output_assign_1 = InplaceAssign() | |||
| self.fake_output_assign_1.add_prim_attr("fake_output", True) | |||
| self.fake_output_assign_2 = InplaceAssign() | |||
| @@ -334,7 +334,7 @@ class ReduceMean(GraphKernel): | |||
| def __init__(self, keep_dims=True): | |||
| super(ReduceMean, self).__init__() | |||
| self.keep_dims = validator.check_type('keep_dims', keep_dims, [bool]) | |||
| self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool]) | |||
| self.sum = P.ReduceSum(self.keep_dims) | |||
| def construct(self, x, axis): | |||
| @@ -431,8 +431,8 @@ class LayerNormForward(GraphKernel): | |||
| """ Forward function of the LayerNorm operator. """ | |||
| def __init__(self, begin_norm_axis=1, begin_params_axis=1): | |||
| super(LayerNormForward, self).__init__() | |||
| self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int]) | |||
| self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int]) | |||
| self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int]) | |||
| self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int]) | |||
| self.mul = P.Mul() | |||
| self.sum_keep_dims = P.ReduceSum(keep_dims=True) | |||
| self.sub = P.Sub() | |||
| @@ -686,7 +686,7 @@ class LogSoftmax(GraphKernel): | |||
| def __init__(self, axis=-1): | |||
| super(LogSoftmax, self).__init__() | |||
| self.axis = validator.check_type('axis', axis, [int]) | |||
| self.axis = Validator.check_type('axis', axis, [int]) | |||
| self.max_keep_dims = P.ReduceMax(keep_dims=True) | |||
| self.sub = P.Sub() | |||
| self.exp = P.Exp() | |||
| @@ -952,13 +952,13 @@ class Softmax(GraphKernel): | |||
| def __init__(self, axis): | |||
| super(Softmax, self).__init__() | |||
| validator.check_type("axis", axis, [int, tuple]) | |||
| Validator.check_type("axis", axis, [int, tuple]) | |||
| if isinstance(axis, int): | |||
| self.axis = (axis,) | |||
| else: | |||
| self.axis = axis | |||
| for item in self.axis: | |||
| validator.check_type("item of axis", item, [int]) | |||
| Validator.check_type("item of axis", item, [int]) | |||
| self.max = P.ReduceMax(keep_dims=True) | |||
| self.sub = P.Sub() | |||
| self.exp = P.Exp() | |||
| @@ -21,8 +21,7 @@ from mindspore.ops.primitive import constexpr | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer, Initializer | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import ParamValidator as validator, Rel | |||
| from mindspore._checkparam import check_bool, twice, check_int_positive, Validator | |||
| from mindspore._checkparam import Validator, Rel, check_bool, twice, check_int_positive | |||
| from mindspore._extends import cell_attr_register | |||
| from ..cell import Cell | |||
| @@ -240,8 +239,8 @@ class Conv2d(_Conv): | |||
| """Initialize depthwise conv2d op""" | |||
| if context.get_context("device_target") == "Ascend" and self.group > 1: | |||
| self.dilation = self._dilation | |||
| validator.check_integer('group', self.group, self.in_channels, Rel.EQ) | |||
| validator.check_integer('group', self.group, self.out_channels, Rel.EQ) | |||
| Validator.check_integer('group', self.group, self.in_channels, Rel.EQ) | |||
| Validator.check_integer('group', self.group, self.out_channels, Rel.EQ) | |||
| self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=self.pad_mode, | |||
| @@ -23,7 +23,7 @@ from mindspore.ops import functional as F | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, ParamValidator as validator | |||
| from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, Validator | |||
| import mindspore.context as context | |||
| from .normalization import BatchNorm2d, BatchNorm1d | |||
| from .activation import get_activation, ReLU, LeakyReLU | |||
| @@ -133,7 +133,7 @@ class Conv2dBnAct(Cell): | |||
| has_bias=has_bias, | |||
| weight_init=weight_init, | |||
| bias_init=bias_init) | |||
| self.has_bn = validator.check_bool("has_bn", has_bn) | |||
| self.has_bn = Validator.check_bool("has_bn", has_bn) | |||
| self.has_act = activation is not None | |||
| self.after_fake = after_fake | |||
| if has_bn: | |||
| @@ -201,7 +201,7 @@ class DenseBnAct(Cell): | |||
| weight_init, | |||
| bias_init, | |||
| has_bias) | |||
| self.has_bn = validator.check_bool("has_bn", has_bn) | |||
| self.has_bn = Validator.check_bool("has_bn", has_bn) | |||
| self.has_act = activation is not None | |||
| self.after_fake = after_fake | |||
| if has_bn: | |||
| @@ -320,10 +320,10 @@ class FakeQuantWithMinMax(Cell): | |||
| quant_delay=0): | |||
| """Initialize FakeQuantWithMinMax layer""" | |||
| super(FakeQuantWithMinMax, self).__init__() | |||
| validator.check_type("min_init", min_init, [int, float]) | |||
| validator.check_type("max_init", max_init, [int, float]) | |||
| validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) | |||
| validator.check_integer('quant_delay', quant_delay, 0, Rel.GE) | |||
| Validator.check_type("min_init", min_init, [int, float]) | |||
| Validator.check_type("max_init", max_init, [int, float]) | |||
| Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) | |||
| Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE) | |||
| self.min_init = min_init | |||
| self.max_init = max_init | |||
| self.num_bits = num_bits | |||
| @@ -489,8 +489,8 @@ class Conv2dBnFoldQuant(Cell): | |||
| # initialize convolution op and Parameter | |||
| if context.get_context('device_target') == "Ascend" and group > 1: | |||
| validator.check_integer('group', group, in_channels, Rel.EQ) | |||
| validator.check_integer('group', group, out_channels, Rel.EQ) | |||
| Validator.check_integer('group', group, in_channels, Rel.EQ) | |||
| Validator.check_integer('group', group, out_channels, Rel.EQ) | |||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| @@ -674,8 +674,8 @@ class Conv2dBnWithoutFoldQuant(Cell): | |||
| self.bias = None | |||
| # initialize convolution op and Parameter | |||
| if context.get_context('device_target') == "Ascend" and group > 1: | |||
| validator.check_integer('group', group, in_channels, Rel.EQ) | |||
| validator.check_integer('group', group, out_channels, Rel.EQ) | |||
| Validator.check_integer('group', group, in_channels, Rel.EQ) | |||
| Validator.check_integer('group', group, out_channels, Rel.EQ) | |||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| @@ -22,7 +22,7 @@ import mindspore.context as context | |||
| from ... import log as logger | |||
| from ... import nn, ops | |||
| from ..._checkparam import ParamValidator as validator | |||
| from ..._checkparam import Validator | |||
| from ..._checkparam import Rel | |||
| from ...common import Tensor | |||
| from ...common import dtype as mstype | |||
| @@ -89,19 +89,19 @@ class ConvertToQuantNetwork: | |||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | |||
| def __init__(self, **kwargs): | |||
| self.network = validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) | |||
| self.weight_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE) | |||
| self.act_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE) | |||
| self.bn_fold = validator.check_bool("bn fold", kwargs["bn_fold"]) | |||
| self.freeze_bn = validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE) | |||
| self.weight_bits = validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE) | |||
| self.act_bits = validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE) | |||
| self.weight_channel = validator.check_bool("per channel", kwargs["per_channel"][0]) | |||
| self.act_channel = validator.check_bool("per channel", kwargs["per_channel"][-1]) | |||
| self.weight_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][0]) | |||
| self.act_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][-1]) | |||
| self.weight_range = validator.check_bool("narrow range", kwargs["narrow_range"][0]) | |||
| self.act_range = validator.check_bool("narrow range", kwargs["narrow_range"][-1]) | |||
| self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) | |||
| self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE) | |||
| self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE) | |||
| self.bn_fold = Validator.check_bool("bn fold", kwargs["bn_fold"]) | |||
| self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE) | |||
| self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE) | |||
| self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE) | |||
| self.weight_channel = Validator.check_bool("per channel", kwargs["per_channel"][0]) | |||
| self.act_channel = Validator.check_bool("per channel", kwargs["per_channel"][-1]) | |||
| self.weight_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][0]) | |||
| self.act_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][-1]) | |||
| self.weight_range = Validator.check_bool("narrow range", kwargs["narrow_range"][0]) | |||
| self.act_range = Validator.check_bool("narrow range", kwargs["narrow_range"][-1]) | |||
| self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, | |||
| quant.DenseBnAct: self._convert_dense} | |||
| @@ -316,7 +316,7 @@ class ExportToQuantInferNetwork: | |||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | |||
| def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): | |||
| network = validator.check_isinstance('network', network, (nn.Cell,)) | |||
| network = Validator.check_isinstance('network', network, (nn.Cell,)) | |||
| self.input_scale = 1 / std_dev | |||
| self.input_zero_point = round(mean) | |||
| self.data_type = mstype.int8 | |||
| @@ -510,8 +510,8 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format=' | |||
| supported_device = ["Ascend", "GPU"] | |||
| supported_formats = ['AIR', 'MINDIR'] | |||
| mean = validator.check_type("mean", mean, (int, float)) | |||
| std_dev = validator.check_type("std_dev", std_dev, (int, float)) | |||
| mean = Validator.check_type("mean", mean, (int, float)) | |||
| std_dev = Validator.check_type("std_dev", std_dev, (int, float)) | |||
| if context.get_context('device_target') not in supported_device: | |||
| raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | |||