Browse Source

[ME]delete ParamValidator and change all to Validator

tags/v1.1.0
chenzomi 5 years ago
parent
commit
6c9b6d491d
5 changed files with 78 additions and 104 deletions
  1. +39
    -64
      mindspore/_checkparam.py
  2. +8
    -8
      mindspore/nn/graph_kernels/graph_kernels.py
  3. +3
    -4
      mindspore/nn/layer/conv.py
  4. +11
    -11
      mindspore/nn/layer/quant.py
  5. +17
    -17
      mindspore/train/quant/quant.py

+ 39
- 64
mindspore/_checkparam.py View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Check parameters.""" """Check parameters."""

import re import re
import inspect import inspect
import math import math
@@ -20,10 +21,9 @@ from enum import Enum
from functools import reduce, wraps from functools import reduce, wraps
from itertools import repeat from itertools import repeat
from collections.abc import Iterable from collections.abc import Iterable

import numpy as np import numpy as np
from mindspore import log as logger from mindspore import log as logger
from .common import dtype as mstype
from mindspore.common import dtype as mstype




# Named string regular expression # Named string regular expression
@@ -103,18 +103,17 @@ class Validator:
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
""" """
Method for judging relation between two int values or list/tuple made up of ints. Method for judging relation between two int values or list/tuple made up of ints.

This method is not suitable for judging relation between floats, since it does not consider float error. This method is not suitable for judging relation between floats, since it does not consider float error.
""" """

rel_fn = Rel.get_fns(rel) rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value): if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
return arg_value


@staticmethod @staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name):
def check_integer(arg_name, arg_value, value, rel, prim_name=None):
"""Integer value judgment.""" """Integer value judgment."""
rel_fn = Rel.get_fns(rel) rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
@@ -135,6 +134,20 @@ class Validator:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
return arg_value return arg_value


@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isinstance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value

@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isinstance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
return arg_value

@staticmethod @staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
"""Method for checking whether an int value is in some range.""" """Method for checking whether an int value is in some range."""
@@ -208,6 +221,27 @@ class Validator:
"""Checks valid value.""" """Checks valid value."""
if arg_value is None: if arg_value is None:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
return arg_value

@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')

if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()


@staticmethod @staticmethod
def check_type_same(args, valid_values, prim_name): def check_type_same(args, valid_values, prim_name):
@@ -239,7 +273,6 @@ class Validator:
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
""" """
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.

If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
""" """


@@ -335,63 +368,6 @@ class Validator:
f'{tuple(exp_shape)}, but got {shape}.') f'{tuple(exp_shape)}, but got {shape}.')




class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""

@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
"""This method is only used for check int values, since when compare float values,
we need consider float error."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')

@staticmethod
def check_integer(arg_name, arg_value, value, rel):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value

@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isinstance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value

@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isinstance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
return arg_value

@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')

if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()


def check_int(input_param): def check_int(input_param):
"""Int type judgment.""" """Int type judgment."""
if isinstance(input_param, int) and not isinstance(input_param, bool): if isinstance(input_param, int) and not isinstance(input_param, bool):
@@ -638,7 +614,6 @@ def args_type_check(*type_args, **type_kwargs):
if value is not None and not isinstance(value, bound_types[name]): if value is not None and not isinstance(value, bound_types[name]):
raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
return func(*args, **kwargs) return func(*args, **kwargs)

return wrapper return wrapper


return type_check return type_check

+ 8
- 8
mindspore/nn/graph_kernels/graph_kernels.py View File

@@ -21,7 +21,7 @@ from ...ops import operations as P
from ...ops.primitive import PrimitiveWithInfer, prim_attr_register from ...ops.primitive import PrimitiveWithInfer, prim_attr_register
from ...ops.composite import multitype_ops as C from ...ops.composite import multitype_ops as C
from ...ops.operations import _grad_ops as G from ...ops.operations import _grad_ops as G
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator
from ..cell import Cell, GraphKernel from ..cell import Cell, GraphKernel




@@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel):
use_locking=False, use_locking=False,
gradient_scale=1.0): gradient_scale=1.0):
super(ApplyMomentum, self).__init__() super(ApplyMomentum, self).__init__()
self.gradient_scale = validator.check_type('gradient_scale', gradient_scale, [float])
self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float])
self.fake_output_assign_1 = InplaceAssign() self.fake_output_assign_1 = InplaceAssign()
self.fake_output_assign_1.add_prim_attr("fake_output", True) self.fake_output_assign_1.add_prim_attr("fake_output", True)
self.fake_output_assign_2 = InplaceAssign() self.fake_output_assign_2 = InplaceAssign()
@@ -334,7 +334,7 @@ class ReduceMean(GraphKernel):


def __init__(self, keep_dims=True): def __init__(self, keep_dims=True):
super(ReduceMean, self).__init__() super(ReduceMean, self).__init__()
self.keep_dims = validator.check_type('keep_dims', keep_dims, [bool])
self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool])
self.sum = P.ReduceSum(self.keep_dims) self.sum = P.ReduceSum(self.keep_dims)


def construct(self, x, axis): def construct(self, x, axis):
@@ -431,8 +431,8 @@ class LayerNormForward(GraphKernel):
""" Forward function of the LayerNorm operator. """ """ Forward function of the LayerNorm operator. """
def __init__(self, begin_norm_axis=1, begin_params_axis=1): def __init__(self, begin_norm_axis=1, begin_params_axis=1):
super(LayerNormForward, self).__init__() super(LayerNormForward, self).__init__()
self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int])
self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int])
self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int])
self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int])
self.mul = P.Mul() self.mul = P.Mul()
self.sum_keep_dims = P.ReduceSum(keep_dims=True) self.sum_keep_dims = P.ReduceSum(keep_dims=True)
self.sub = P.Sub() self.sub = P.Sub()
@@ -686,7 +686,7 @@ class LogSoftmax(GraphKernel):


def __init__(self, axis=-1): def __init__(self, axis=-1):
super(LogSoftmax, self).__init__() super(LogSoftmax, self).__init__()
self.axis = validator.check_type('axis', axis, [int])
self.axis = Validator.check_type('axis', axis, [int])
self.max_keep_dims = P.ReduceMax(keep_dims=True) self.max_keep_dims = P.ReduceMax(keep_dims=True)
self.sub = P.Sub() self.sub = P.Sub()
self.exp = P.Exp() self.exp = P.Exp()
@@ -952,13 +952,13 @@ class Softmax(GraphKernel):


def __init__(self, axis): def __init__(self, axis):
super(Softmax, self).__init__() super(Softmax, self).__init__()
validator.check_type("axis", axis, [int, tuple])
Validator.check_type("axis", axis, [int, tuple])
if isinstance(axis, int): if isinstance(axis, int):
self.axis = (axis,) self.axis = (axis,)
else: else:
self.axis = axis self.axis = axis
for item in self.axis: for item in self.axis:
validator.check_type("item of axis", item, [int])
Validator.check_type("item of axis", item, [int])
self.max = P.ReduceMax(keep_dims=True) self.max = P.ReduceMax(keep_dims=True)
self.sub = P.Sub() self.sub = P.Sub()
self.exp = P.Exp() self.exp = P.Exp()


+ 3
- 4
mindspore/nn/layer/conv.py View File

@@ -21,8 +21,7 @@ from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer, Initializer from mindspore.common.initializer import initializer, Initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import ParamValidator as validator, Rel
from mindspore._checkparam import check_bool, twice, check_int_positive, Validator
from mindspore._checkparam import Validator, Rel, check_bool, twice, check_int_positive
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from ..cell import Cell from ..cell import Cell


@@ -240,8 +239,8 @@ class Conv2d(_Conv):
"""Initialize depthwise conv2d op""" """Initialize depthwise conv2d op"""
if context.get_context("device_target") == "Ascend" and self.group > 1: if context.get_context("device_target") == "Ascend" and self.group > 1:
self.dilation = self._dilation self.dilation = self._dilation
validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
Validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
Validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1, self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
pad_mode=self.pad_mode, pad_mode=self.pad_mode,


+ 11
- 11
mindspore/nn/layer/quant.py View File

@@ -23,7 +23,7 @@ from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, ParamValidator as validator
from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, Validator
import mindspore.context as context import mindspore.context as context
from .normalization import BatchNorm2d, BatchNorm1d from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU, LeakyReLU from .activation import get_activation, ReLU, LeakyReLU
@@ -133,7 +133,7 @@ class Conv2dBnAct(Cell):
has_bias=has_bias, has_bias=has_bias,
weight_init=weight_init, weight_init=weight_init,
bias_init=bias_init) bias_init=bias_init)
self.has_bn = validator.check_bool("has_bn", has_bn)
self.has_bn = Validator.check_bool("has_bn", has_bn)
self.has_act = activation is not None self.has_act = activation is not None
self.after_fake = after_fake self.after_fake = after_fake
if has_bn: if has_bn:
@@ -201,7 +201,7 @@ class DenseBnAct(Cell):
weight_init, weight_init,
bias_init, bias_init,
has_bias) has_bias)
self.has_bn = validator.check_bool("has_bn", has_bn)
self.has_bn = Validator.check_bool("has_bn", has_bn)
self.has_act = activation is not None self.has_act = activation is not None
self.after_fake = after_fake self.after_fake = after_fake
if has_bn: if has_bn:
@@ -320,10 +320,10 @@ class FakeQuantWithMinMax(Cell):
quant_delay=0): quant_delay=0):
"""Initialize FakeQuantWithMinMax layer""" """Initialize FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__() super(FakeQuantWithMinMax, self).__init__()
validator.check_type("min_init", min_init, [int, float])
validator.check_type("max_init", max_init, [int, float])
validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
self.min_init = min_init self.min_init = min_init
self.max_init = max_init self.max_init = max_init
self.num_bits = num_bits self.num_bits = num_bits
@@ -489,8 +489,8 @@ class Conv2dBnFoldQuant(Cell):


# initialize convolution op and Parameter # initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1: if context.get_context('device_target') == "Ascend" and group > 1:
validator.check_integer('group', group, in_channels, Rel.EQ)
validator.check_integer('group', group, out_channels, Rel.EQ)
Validator.check_integer('group', group, in_channels, Rel.EQ)
Validator.check_integer('group', group, out_channels, Rel.EQ)
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
pad_mode=pad_mode, pad_mode=pad_mode,
@@ -674,8 +674,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
self.bias = None self.bias = None
# initialize convolution op and Parameter # initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1: if context.get_context('device_target') == "Ascend" and group > 1:
validator.check_integer('group', group, in_channels, Rel.EQ)
validator.check_integer('group', group, out_channels, Rel.EQ)
Validator.check_integer('group', group, in_channels, Rel.EQ)
Validator.check_integer('group', group, out_channels, Rel.EQ)
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
pad_mode=pad_mode, pad_mode=pad_mode,


+ 17
- 17
mindspore/train/quant/quant.py View File

@@ -22,7 +22,7 @@ import mindspore.context as context


from ... import log as logger from ... import log as logger
from ... import nn, ops from ... import nn, ops
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import Tensor from ...common import Tensor
from ...common import dtype as mstype from ...common import dtype as mstype
@@ -89,19 +89,19 @@ class ConvertToQuantNetwork:
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]


def __init__(self, **kwargs): def __init__(self, **kwargs):
self.network = validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
self.act_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
self.bn_fold = validator.check_bool("bn fold", kwargs["bn_fold"])
self.freeze_bn = validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
self.weight_bits = validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
self.act_bits = validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
self.weight_channel = validator.check_bool("per channel", kwargs["per_channel"][0])
self.act_channel = validator.check_bool("per channel", kwargs["per_channel"][-1])
self.weight_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][0])
self.act_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][-1])
self.weight_range = validator.check_bool("narrow range", kwargs["narrow_range"][0])
self.act_range = validator.check_bool("narrow range", kwargs["narrow_range"][-1])
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
self.bn_fold = Validator.check_bool("bn fold", kwargs["bn_fold"])
self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
self.weight_channel = Validator.check_bool("per channel", kwargs["per_channel"][0])
self.act_channel = Validator.check_bool("per channel", kwargs["per_channel"][-1])
self.weight_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][0])
self.act_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][-1])
self.weight_range = Validator.check_bool("narrow range", kwargs["narrow_range"][0])
self.act_range = Validator.check_bool("narrow range", kwargs["narrow_range"][-1])
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense} quant.DenseBnAct: self._convert_dense}


@@ -316,7 +316,7 @@ class ExportToQuantInferNetwork:
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]


def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
network = validator.check_isinstance('network', network, (nn.Cell,))
network = Validator.check_isinstance('network', network, (nn.Cell,))
self.input_scale = 1 / std_dev self.input_scale = 1 / std_dev
self.input_zero_point = round(mean) self.input_zero_point = round(mean)
self.data_type = mstype.int8 self.data_type = mstype.int8
@@ -510,8 +510,8 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
supported_device = ["Ascend", "GPU"] supported_device = ["Ascend", "GPU"]
supported_formats = ['AIR', 'MINDIR'] supported_formats = ['AIR', 'MINDIR']


mean = validator.check_type("mean", mean, (int, float))
std_dev = validator.check_type("std_dev", std_dev, (int, float))
mean = Validator.check_type("mean", mean, (int, float))
std_dev = Validator.check_type("std_dev", std_dev, (int, float))


if context.get_context('device_target') not in supported_device: if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))


Loading…
Cancel
Save