| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Check parameters.""" | """Check parameters.""" | ||||
| import re | import re | ||||
| import inspect | import inspect | ||||
| import math | import math | ||||
| @@ -20,10 +21,9 @@ from enum import Enum | |||||
| from functools import reduce, wraps | from functools import reduce, wraps | ||||
| from itertools import repeat | from itertools import repeat | ||||
| from collections.abc import Iterable | from collections.abc import Iterable | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from .common import dtype as mstype | |||||
| from mindspore.common import dtype as mstype | |||||
| # Named string regular expression | # Named string regular expression | ||||
| @@ -103,18 +103,17 @@ class Validator: | |||||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): | def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): | ||||
| """ | """ | ||||
| Method for judging relation between two int values or list/tuple made up of ints. | Method for judging relation between two int values or list/tuple made up of ints. | ||||
| This method is not suitable for judging relation between floats, since it does not consider float error. | This method is not suitable for judging relation between floats, since it does not consider float error. | ||||
| """ | """ | ||||
| rel_fn = Rel.get_fns(rel) | rel_fn = Rel.get_fns(rel) | ||||
| if not rel_fn(arg_value, value): | if not rel_fn(arg_value, value): | ||||
| rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | ||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | ||||
| raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') | raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') | ||||
| return arg_value | |||||
| @staticmethod | @staticmethod | ||||
| def check_integer(arg_name, arg_value, value, rel, prim_name): | |||||
| def check_integer(arg_name, arg_value, value, rel, prim_name=None): | |||||
| """Integer value judgment.""" | """Integer value judgment.""" | ||||
| rel_fn = Rel.get_fns(rel) | rel_fn = Rel.get_fns(rel) | ||||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | ||||
| @@ -135,6 +134,20 @@ class Validator: | |||||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') | raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') | ||||
| return arg_value | return arg_value | ||||
| @staticmethod | |||||
| def check_isinstance(arg_name, arg_value, classes): | |||||
| """Check arg isinstance of classes""" | |||||
| if not isinstance(arg_value, classes): | |||||
| raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_bool(arg_name, arg_value): | |||||
| """Check arg isinstance of bool""" | |||||
| if not isinstance(arg_value, bool): | |||||
| raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | @staticmethod | ||||
| def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | ||||
| """Method for checking whether an int value is in some range.""" | """Method for checking whether an int value is in some range.""" | ||||
| @@ -208,6 +221,27 @@ class Validator: | |||||
| """Checks valid value.""" | """Checks valid value.""" | ||||
| if arg_value is None: | if arg_value is None: | ||||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') | raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') | ||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_type(arg_name, arg_value, valid_types): | |||||
| """Type checking.""" | |||||
| def raise_error_msg(): | |||||
| """func for raising error message when check failed""" | |||||
| type_names = [t.__name__ for t in valid_types] | |||||
| num_types = len(valid_types) | |||||
| raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||||
| if isinstance(arg_value, type(mstype.tensor)): | |||||
| arg_value = arg_value.element_type() | |||||
| # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and | |||||
| # `check_type('x', True, [bool, int])` will check pass | |||||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||||
| raise_error_msg() | |||||
| if isinstance(arg_value, tuple(valid_types)): | |||||
| return arg_value | |||||
| raise_error_msg() | |||||
| @staticmethod | @staticmethod | ||||
| def check_type_same(args, valid_values, prim_name): | def check_type_same(args, valid_values, prim_name): | ||||
| @@ -239,7 +273,6 @@ class Validator: | |||||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): | def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): | ||||
| """ | """ | ||||
| Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. | Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. | ||||
| If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. | If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. | ||||
| """ | """ | ||||
| @@ -335,63 +368,6 @@ class Validator: | |||||
| f'{tuple(exp_shape)}, but got {shape}.') | f'{tuple(exp_shape)}, but got {shape}.') | ||||
| class ParamValidator: | |||||
| """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | |||||
| @staticmethod | |||||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ): | |||||
| """This method is only used for check int values, since when compare float values, | |||||
| we need consider float error.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| if not rel_fn(arg_value, value): | |||||
| rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | |||||
| raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||||
| @staticmethod | |||||
| def check_integer(arg_name, arg_value, value, rel): | |||||
| """Integer value judgment.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||||
| if type_mismatch or not rel_fn(arg_value, value): | |||||
| rel_str = Rel.get_strs(rel).format(value) | |||||
| raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_isinstance(arg_name, arg_value, classes): | |||||
| """Check arg isinstance of classes""" | |||||
| if not isinstance(arg_value, classes): | |||||
| raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_bool(arg_name, arg_value): | |||||
| """Check arg isinstance of bool""" | |||||
| if not isinstance(arg_value, bool): | |||||
| raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_type(arg_name, arg_value, valid_types): | |||||
| """Type checking.""" | |||||
| def raise_error_msg(): | |||||
| """func for raising error message when check failed""" | |||||
| type_names = [t.__name__ for t in valid_types] | |||||
| num_types = len(valid_types) | |||||
| raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||||
| if isinstance(arg_value, type(mstype.tensor)): | |||||
| arg_value = arg_value.element_type() | |||||
| # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and | |||||
| # `check_type('x', True, [bool, int])` will check pass | |||||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||||
| raise_error_msg() | |||||
| if isinstance(arg_value, tuple(valid_types)): | |||||
| return arg_value | |||||
| raise_error_msg() | |||||
| def check_int(input_param): | def check_int(input_param): | ||||
| """Int type judgment.""" | """Int type judgment.""" | ||||
| if isinstance(input_param, int) and not isinstance(input_param, bool): | if isinstance(input_param, int) and not isinstance(input_param, bool): | ||||
| @@ -638,7 +614,6 @@ def args_type_check(*type_args, **type_kwargs): | |||||
| if value is not None and not isinstance(value, bound_types[name]): | if value is not None and not isinstance(value, bound_types[name]): | ||||
| raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) | raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) | ||||
| return func(*args, **kwargs) | return func(*args, **kwargs) | ||||
| return wrapper | return wrapper | ||||
| return type_check | return type_check | ||||
| @@ -21,7 +21,7 @@ from ...ops import operations as P | |||||
| from ...ops.primitive import PrimitiveWithInfer, prim_attr_register | from ...ops.primitive import PrimitiveWithInfer, prim_attr_register | ||||
| from ...ops.composite import multitype_ops as C | from ...ops.composite import multitype_ops as C | ||||
| from ...ops.operations import _grad_ops as G | from ...ops.operations import _grad_ops as G | ||||
| from ..._checkparam import ParamValidator as validator | |||||
| from ..._checkparam import Validator | |||||
| from ..cell import Cell, GraphKernel | from ..cell import Cell, GraphKernel | ||||
| @@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel): | |||||
| use_locking=False, | use_locking=False, | ||||
| gradient_scale=1.0): | gradient_scale=1.0): | ||||
| super(ApplyMomentum, self).__init__() | super(ApplyMomentum, self).__init__() | ||||
| self.gradient_scale = validator.check_type('gradient_scale', gradient_scale, [float]) | |||||
| self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float]) | |||||
| self.fake_output_assign_1 = InplaceAssign() | self.fake_output_assign_1 = InplaceAssign() | ||||
| self.fake_output_assign_1.add_prim_attr("fake_output", True) | self.fake_output_assign_1.add_prim_attr("fake_output", True) | ||||
| self.fake_output_assign_2 = InplaceAssign() | self.fake_output_assign_2 = InplaceAssign() | ||||
| @@ -334,7 +334,7 @@ class ReduceMean(GraphKernel): | |||||
| def __init__(self, keep_dims=True): | def __init__(self, keep_dims=True): | ||||
| super(ReduceMean, self).__init__() | super(ReduceMean, self).__init__() | ||||
| self.keep_dims = validator.check_type('keep_dims', keep_dims, [bool]) | |||||
| self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool]) | |||||
| self.sum = P.ReduceSum(self.keep_dims) | self.sum = P.ReduceSum(self.keep_dims) | ||||
| def construct(self, x, axis): | def construct(self, x, axis): | ||||
| @@ -431,8 +431,8 @@ class LayerNormForward(GraphKernel): | |||||
| """ Forward function of the LayerNorm operator. """ | """ Forward function of the LayerNorm operator. """ | ||||
| def __init__(self, begin_norm_axis=1, begin_params_axis=1): | def __init__(self, begin_norm_axis=1, begin_params_axis=1): | ||||
| super(LayerNormForward, self).__init__() | super(LayerNormForward, self).__init__() | ||||
| self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int]) | |||||
| self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int]) | |||||
| self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int]) | |||||
| self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int]) | |||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| self.sum_keep_dims = P.ReduceSum(keep_dims=True) | self.sum_keep_dims = P.ReduceSum(keep_dims=True) | ||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| @@ -686,7 +686,7 @@ class LogSoftmax(GraphKernel): | |||||
| def __init__(self, axis=-1): | def __init__(self, axis=-1): | ||||
| super(LogSoftmax, self).__init__() | super(LogSoftmax, self).__init__() | ||||
| self.axis = validator.check_type('axis', axis, [int]) | |||||
| self.axis = Validator.check_type('axis', axis, [int]) | |||||
| self.max_keep_dims = P.ReduceMax(keep_dims=True) | self.max_keep_dims = P.ReduceMax(keep_dims=True) | ||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| @@ -952,13 +952,13 @@ class Softmax(GraphKernel): | |||||
| def __init__(self, axis): | def __init__(self, axis): | ||||
| super(Softmax, self).__init__() | super(Softmax, self).__init__() | ||||
| validator.check_type("axis", axis, [int, tuple]) | |||||
| Validator.check_type("axis", axis, [int, tuple]) | |||||
| if isinstance(axis, int): | if isinstance(axis, int): | ||||
| self.axis = (axis,) | self.axis = (axis,) | ||||
| else: | else: | ||||
| self.axis = axis | self.axis = axis | ||||
| for item in self.axis: | for item in self.axis: | ||||
| validator.check_type("item of axis", item, [int]) | |||||
| Validator.check_type("item of axis", item, [int]) | |||||
| self.max = P.ReduceMax(keep_dims=True) | self.max = P.ReduceMax(keep_dims=True) | ||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| @@ -21,8 +21,7 @@ from mindspore.ops.primitive import constexpr | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common.initializer import initializer, Initializer | from mindspore.common.initializer import initializer, Initializer | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore._checkparam import ParamValidator as validator, Rel | |||||
| from mindspore._checkparam import check_bool, twice, check_int_positive, Validator | |||||
| from mindspore._checkparam import Validator, Rel, check_bool, twice, check_int_positive | |||||
| from mindspore._extends import cell_attr_register | from mindspore._extends import cell_attr_register | ||||
| from ..cell import Cell | from ..cell import Cell | ||||
| @@ -240,8 +239,8 @@ class Conv2d(_Conv): | |||||
| """Initialize depthwise conv2d op""" | """Initialize depthwise conv2d op""" | ||||
| if context.get_context("device_target") == "Ascend" and self.group > 1: | if context.get_context("device_target") == "Ascend" and self.group > 1: | ||||
| self.dilation = self._dilation | self.dilation = self._dilation | ||||
| validator.check_integer('group', self.group, self.in_channels, Rel.EQ) | |||||
| validator.check_integer('group', self.group, self.out_channels, Rel.EQ) | |||||
| Validator.check_integer('group', self.group, self.in_channels, Rel.EQ) | |||||
| Validator.check_integer('group', self.group, self.out_channels, Rel.EQ) | |||||
| self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1, | self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1, | ||||
| kernel_size=self.kernel_size, | kernel_size=self.kernel_size, | ||||
| pad_mode=self.pad_mode, | pad_mode=self.pad_mode, | ||||
| @@ -23,7 +23,7 @@ from mindspore.ops import functional as F | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, ParamValidator as validator | |||||
| from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, Validator | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from .normalization import BatchNorm2d, BatchNorm1d | from .normalization import BatchNorm2d, BatchNorm1d | ||||
| from .activation import get_activation, ReLU, LeakyReLU | from .activation import get_activation, ReLU, LeakyReLU | ||||
| @@ -133,7 +133,7 @@ class Conv2dBnAct(Cell): | |||||
| has_bias=has_bias, | has_bias=has_bias, | ||||
| weight_init=weight_init, | weight_init=weight_init, | ||||
| bias_init=bias_init) | bias_init=bias_init) | ||||
| self.has_bn = validator.check_bool("has_bn", has_bn) | |||||
| self.has_bn = Validator.check_bool("has_bn", has_bn) | |||||
| self.has_act = activation is not None | self.has_act = activation is not None | ||||
| self.after_fake = after_fake | self.after_fake = after_fake | ||||
| if has_bn: | if has_bn: | ||||
| @@ -201,7 +201,7 @@ class DenseBnAct(Cell): | |||||
| weight_init, | weight_init, | ||||
| bias_init, | bias_init, | ||||
| has_bias) | has_bias) | ||||
| self.has_bn = validator.check_bool("has_bn", has_bn) | |||||
| self.has_bn = Validator.check_bool("has_bn", has_bn) | |||||
| self.has_act = activation is not None | self.has_act = activation is not None | ||||
| self.after_fake = after_fake | self.after_fake = after_fake | ||||
| if has_bn: | if has_bn: | ||||
| @@ -320,10 +320,10 @@ class FakeQuantWithMinMax(Cell): | |||||
| quant_delay=0): | quant_delay=0): | ||||
| """Initialize FakeQuantWithMinMax layer""" | """Initialize FakeQuantWithMinMax layer""" | ||||
| super(FakeQuantWithMinMax, self).__init__() | super(FakeQuantWithMinMax, self).__init__() | ||||
| validator.check_type("min_init", min_init, [int, float]) | |||||
| validator.check_type("max_init", max_init, [int, float]) | |||||
| validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) | |||||
| validator.check_integer('quant_delay', quant_delay, 0, Rel.GE) | |||||
| Validator.check_type("min_init", min_init, [int, float]) | |||||
| Validator.check_type("max_init", max_init, [int, float]) | |||||
| Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) | |||||
| Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE) | |||||
| self.min_init = min_init | self.min_init = min_init | ||||
| self.max_init = max_init | self.max_init = max_init | ||||
| self.num_bits = num_bits | self.num_bits = num_bits | ||||
| @@ -489,8 +489,8 @@ class Conv2dBnFoldQuant(Cell): | |||||
| # initialize convolution op and Parameter | # initialize convolution op and Parameter | ||||
| if context.get_context('device_target') == "Ascend" and group > 1: | if context.get_context('device_target') == "Ascend" and group > 1: | ||||
| validator.check_integer('group', group, in_channels, Rel.EQ) | |||||
| validator.check_integer('group', group, out_channels, Rel.EQ) | |||||
| Validator.check_integer('group', group, in_channels, Rel.EQ) | |||||
| Validator.check_integer('group', group, out_channels, Rel.EQ) | |||||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | ||||
| kernel_size=self.kernel_size, | kernel_size=self.kernel_size, | ||||
| pad_mode=pad_mode, | pad_mode=pad_mode, | ||||
| @@ -674,8 +674,8 @@ class Conv2dBnWithoutFoldQuant(Cell): | |||||
| self.bias = None | self.bias = None | ||||
| # initialize convolution op and Parameter | # initialize convolution op and Parameter | ||||
| if context.get_context('device_target') == "Ascend" and group > 1: | if context.get_context('device_target') == "Ascend" and group > 1: | ||||
| validator.check_integer('group', group, in_channels, Rel.EQ) | |||||
| validator.check_integer('group', group, out_channels, Rel.EQ) | |||||
| Validator.check_integer('group', group, in_channels, Rel.EQ) | |||||
| Validator.check_integer('group', group, out_channels, Rel.EQ) | |||||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | ||||
| kernel_size=self.kernel_size, | kernel_size=self.kernel_size, | ||||
| pad_mode=pad_mode, | pad_mode=pad_mode, | ||||
| @@ -22,7 +22,7 @@ import mindspore.context as context | |||||
| from ... import log as logger | from ... import log as logger | ||||
| from ... import nn, ops | from ... import nn, ops | ||||
| from ..._checkparam import ParamValidator as validator | |||||
| from ..._checkparam import Validator | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import Tensor | from ...common import Tensor | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| @@ -89,19 +89,19 @@ class ConvertToQuantNetwork: | |||||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | ||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| self.network = validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) | |||||
| self.weight_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE) | |||||
| self.act_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE) | |||||
| self.bn_fold = validator.check_bool("bn fold", kwargs["bn_fold"]) | |||||
| self.freeze_bn = validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE) | |||||
| self.weight_bits = validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE) | |||||
| self.act_bits = validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE) | |||||
| self.weight_channel = validator.check_bool("per channel", kwargs["per_channel"][0]) | |||||
| self.act_channel = validator.check_bool("per channel", kwargs["per_channel"][-1]) | |||||
| self.weight_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][0]) | |||||
| self.act_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][-1]) | |||||
| self.weight_range = validator.check_bool("narrow range", kwargs["narrow_range"][0]) | |||||
| self.act_range = validator.check_bool("narrow range", kwargs["narrow_range"][-1]) | |||||
| self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) | |||||
| self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE) | |||||
| self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE) | |||||
| self.bn_fold = Validator.check_bool("bn fold", kwargs["bn_fold"]) | |||||
| self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE) | |||||
| self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE) | |||||
| self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE) | |||||
| self.weight_channel = Validator.check_bool("per channel", kwargs["per_channel"][0]) | |||||
| self.act_channel = Validator.check_bool("per channel", kwargs["per_channel"][-1]) | |||||
| self.weight_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][0]) | |||||
| self.act_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][-1]) | |||||
| self.weight_range = Validator.check_bool("narrow range", kwargs["narrow_range"][0]) | |||||
| self.act_range = Validator.check_bool("narrow range", kwargs["narrow_range"][-1]) | |||||
| self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, | self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, | ||||
| quant.DenseBnAct: self._convert_dense} | quant.DenseBnAct: self._convert_dense} | ||||
| @@ -316,7 +316,7 @@ class ExportToQuantInferNetwork: | |||||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | ||||
| def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): | def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): | ||||
| network = validator.check_isinstance('network', network, (nn.Cell,)) | |||||
| network = Validator.check_isinstance('network', network, (nn.Cell,)) | |||||
| self.input_scale = 1 / std_dev | self.input_scale = 1 / std_dev | ||||
| self.input_zero_point = round(mean) | self.input_zero_point = round(mean) | ||||
| self.data_type = mstype.int8 | self.data_type = mstype.int8 | ||||
| @@ -510,8 +510,8 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format=' | |||||
| supported_device = ["Ascend", "GPU"] | supported_device = ["Ascend", "GPU"] | ||||
| supported_formats = ['AIR', 'MINDIR'] | supported_formats = ['AIR', 'MINDIR'] | ||||
| mean = validator.check_type("mean", mean, (int, float)) | |||||
| std_dev = validator.check_type("std_dev", std_dev, (int, float)) | |||||
| mean = Validator.check_type("mean", mean, (int, float)) | |||||
| std_dev = Validator.check_type("std_dev", std_dev, (int, float)) | |||||
| if context.get_context('device_target') not in supported_device: | if context.get_context('device_target') not in supported_device: | ||||
| raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | ||||