From 2bbe733fd6ee7f0c0aaeb9bcde41b7d377ff7bb6 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Sat, 31 Oct 2020 10:30:00 +0800 Subject: [PATCH] add bool typecheck for conv param --- mindspore/_checkparam.py | 12 ++++++------ mindspore/compression/quant/qat.py | 4 ++-- mindspore/nn/layer/quant.py | 10 ++-------- mindspore/ops/operations/nn_ops.py | 2 +- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index bd9a615ea8..15ef79cf77 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -565,20 +565,20 @@ def check_input_format(input_param): def _expand_tuple(n_dimensions): - """To expand a number to tuple.""" + """To expand a int number to tuple.""" def convert(m): if not isinstance(m, tuple): - if isinstance(m, int): + if isinstance(m, int) and not isinstance(m, bool): return tuple(repeat(m, n_dimensions)) - raise TypeError("Input type must be int or tuple.") + raise TypeError("Input type must be int or tuple[int].") if not len(m) is n_dimensions: - raise TypeError("Input dimension is incorrect.") + raise TypeError("Input tuple dimension is incorrect.") for i in m: - if not isinstance(i, int): - raise TypeError("Incorrect type inside of a tuple!") + if not isinstance(i, int) or isinstance(i, bool): + raise TypeError("Incorrect type inside of a tuple, must be int!") return m return convert diff --git a/mindspore/compression/quant/qat.py b/mindspore/compression/quant/qat.py index bdb59414e7..f7fe7cfa69 100644 --- a/mindspore/compression/quant/qat.py +++ b/mindspore/compression/quant/qat.py @@ -410,12 +410,12 @@ class QuantizationAwareTraining(Quantizer): """ act_class = activation.__class__ act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid] - act_list_withfakebefore = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish] + act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish] if act_class in act_list: return quant.ActQuant(activation=activation, quant_config=self.quant_config, quant_dtype=self.act_dtype) - if act_class in act_list_withfakebefore: + if act_class in act_list_with_fake_before: return quant.ActQuant(activation=activation, ema=True, fake_before=True, diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index edf6f6da43..eaf993eeed 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -762,13 +762,10 @@ class Conv2dBnWithoutFoldQuant(Cell): quant_config=quant_config_default, quant_dtype=QuantDtype.INT8): super(Conv2dBnWithoutFoldQuant, self).__init__() - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - else: - self.kernel_size = kernel_size self.in_channels = Validator.check_positive_int(in_channels) self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = has_bias + self.kernel_size = twice(kernel_size) self.stride = twice(stride) self.dilation = twice(dilation) self.pad_mode = pad_mode @@ -884,13 +881,10 @@ class Conv2dQuant(Cell): quant_config=quant_config_default, quant_dtype=QuantDtype.INT8): super(Conv2dQuant, self).__init__() - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - else: - self.kernel_size = kernel_size self.in_channels = Validator.check_positive_int(in_channels) self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = has_bias + self.kernel_size = twice(kernel_size) self.stride = twice(stride) self.dilation = twice(dilation) self.pad_mode = pad_mode diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 58fdef21f0..ec47100f08 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -52,7 +52,7 @@ def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=Fals validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) ret_value = _get_return_value() for item in ret_value: - if isinstance(item, int) and item > 0: + if isinstance(item, int) and not isinstance(item, bool) and item > 0: continue _raise_message() return ret_value