Browse Source

add bool typecheck for conv param

tags/v1.1.0
yuchaojie 5 years ago
parent
commit
2bbe733fd6
4 changed files with 11 additions and 17 deletions
  1. +6
    -6
      mindspore/_checkparam.py
  2. +2
    -2
      mindspore/compression/quant/qat.py
  3. +2
    -8
      mindspore/nn/layer/quant.py
  4. +1
    -1
      mindspore/ops/operations/nn_ops.py

+ 6
- 6
mindspore/_checkparam.py View File

@@ -565,20 +565,20 @@ def check_input_format(input_param):




def _expand_tuple(n_dimensions): def _expand_tuple(n_dimensions):
"""To expand a number to tuple."""
"""To expand a int number to tuple."""


def convert(m): def convert(m):
if not isinstance(m, tuple): if not isinstance(m, tuple):
if isinstance(m, int):
if isinstance(m, int) and not isinstance(m, bool):
return tuple(repeat(m, n_dimensions)) return tuple(repeat(m, n_dimensions))
raise TypeError("Input type must be int or tuple.")
raise TypeError("Input type must be int or tuple[int].")


if not len(m) is n_dimensions: if not len(m) is n_dimensions:
raise TypeError("Input dimension is incorrect.")
raise TypeError("Input tuple dimension is incorrect.")


for i in m: for i in m:
if not isinstance(i, int):
raise TypeError("Incorrect type inside of a tuple!")
if not isinstance(i, int) or isinstance(i, bool):
raise TypeError("Incorrect type inside of a tuple, must be int!")
return m return m


return convert return convert


+ 2
- 2
mindspore/compression/quant/qat.py View File

@@ -410,12 +410,12 @@ class QuantizationAwareTraining(Quantizer):
""" """
act_class = activation.__class__ act_class = activation.__class__
act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid] act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid]
act_list_withfakebefore = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish]
act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish]
if act_class in act_list: if act_class in act_list:
return quant.ActQuant(activation=activation, return quant.ActQuant(activation=activation,
quant_config=self.quant_config, quant_config=self.quant_config,
quant_dtype=self.act_dtype) quant_dtype=self.act_dtype)
if act_class in act_list_withfakebefore:
if act_class in act_list_with_fake_before:
return quant.ActQuant(activation=activation, return quant.ActQuant(activation=activation,
ema=True, ema=True,
fake_before=True, fake_before=True,


+ 2
- 8
mindspore/nn/layer/quant.py View File

@@ -762,13 +762,10 @@ class Conv2dBnWithoutFoldQuant(Cell):
quant_config=quant_config_default, quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8): quant_dtype=QuantDtype.INT8):
super(Conv2dBnWithoutFoldQuant, self).__init__() super(Conv2dBnWithoutFoldQuant, self).__init__()
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
else:
self.kernel_size = kernel_size
self.in_channels = Validator.check_positive_int(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = has_bias self.has_bias = has_bias
self.kernel_size = twice(kernel_size)
self.stride = twice(stride) self.stride = twice(stride)
self.dilation = twice(dilation) self.dilation = twice(dilation)
self.pad_mode = pad_mode self.pad_mode = pad_mode
@@ -884,13 +881,10 @@ class Conv2dQuant(Cell):
quant_config=quant_config_default, quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8): quant_dtype=QuantDtype.INT8):
super(Conv2dQuant, self).__init__() super(Conv2dQuant, self).__init__()
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
else:
self.kernel_size = kernel_size
self.in_channels = Validator.check_positive_int(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = has_bias self.has_bias = has_bias
self.kernel_size = twice(kernel_size)
self.stride = twice(stride) self.stride = twice(stride)
self.dilation = twice(dilation) self.dilation = twice(dilation)
self.pad_mode = pad_mode self.pad_mode = pad_mode


+ 1
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -52,7 +52,7 @@ def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=Fals
validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
ret_value = _get_return_value() ret_value = _get_return_value()
for item in ret_value: for item in ret_value:
if isinstance(item, int) and item > 0:
if isinstance(item, int) and not isinstance(item, bool) and item > 0:
continue continue
_raise_message() _raise_message()
return ret_value return ret_value


Loading…
Cancel
Save