| @@ -565,20 +565,20 @@ def check_input_format(input_param): | |||||
| def _expand_tuple(n_dimensions): | def _expand_tuple(n_dimensions): | ||||
| """To expand a number to tuple.""" | |||||
| """To expand a int number to tuple.""" | |||||
| def convert(m): | def convert(m): | ||||
| if not isinstance(m, tuple): | if not isinstance(m, tuple): | ||||
| if isinstance(m, int): | |||||
| if isinstance(m, int) and not isinstance(m, bool): | |||||
| return tuple(repeat(m, n_dimensions)) | return tuple(repeat(m, n_dimensions)) | ||||
| raise TypeError("Input type must be int or tuple.") | |||||
| raise TypeError("Input type must be int or tuple[int].") | |||||
| if not len(m) is n_dimensions: | if not len(m) is n_dimensions: | ||||
| raise TypeError("Input dimension is incorrect.") | |||||
| raise TypeError("Input tuple dimension is incorrect.") | |||||
| for i in m: | for i in m: | ||||
| if not isinstance(i, int): | |||||
| raise TypeError("Incorrect type inside of a tuple!") | |||||
| if not isinstance(i, int) or isinstance(i, bool): | |||||
| raise TypeError("Incorrect type inside of a tuple, must be int!") | |||||
| return m | return m | ||||
| return convert | return convert | ||||
| @@ -410,12 +410,12 @@ class QuantizationAwareTraining(Quantizer): | |||||
| """ | """ | ||||
| act_class = activation.__class__ | act_class = activation.__class__ | ||||
| act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid] | act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid] | ||||
| act_list_withfakebefore = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish] | |||||
| act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish] | |||||
| if act_class in act_list: | if act_class in act_list: | ||||
| return quant.ActQuant(activation=activation, | return quant.ActQuant(activation=activation, | ||||
| quant_config=self.quant_config, | quant_config=self.quant_config, | ||||
| quant_dtype=self.act_dtype) | quant_dtype=self.act_dtype) | ||||
| if act_class in act_list_withfakebefore: | |||||
| if act_class in act_list_with_fake_before: | |||||
| return quant.ActQuant(activation=activation, | return quant.ActQuant(activation=activation, | ||||
| ema=True, | ema=True, | ||||
| fake_before=True, | fake_before=True, | ||||
| @@ -762,13 +762,10 @@ class Conv2dBnWithoutFoldQuant(Cell): | |||||
| quant_config=quant_config_default, | quant_config=quant_config_default, | ||||
| quant_dtype=QuantDtype.INT8): | quant_dtype=QuantDtype.INT8): | ||||
| super(Conv2dBnWithoutFoldQuant, self).__init__() | super(Conv2dBnWithoutFoldQuant, self).__init__() | ||||
| if isinstance(kernel_size, int): | |||||
| self.kernel_size = (kernel_size, kernel_size) | |||||
| else: | |||||
| self.kernel_size = kernel_size | |||||
| self.in_channels = Validator.check_positive_int(in_channels) | self.in_channels = Validator.check_positive_int(in_channels) | ||||
| self.out_channels = Validator.check_positive_int(out_channels) | self.out_channels = Validator.check_positive_int(out_channels) | ||||
| self.has_bias = has_bias | self.has_bias = has_bias | ||||
| self.kernel_size = twice(kernel_size) | |||||
| self.stride = twice(stride) | self.stride = twice(stride) | ||||
| self.dilation = twice(dilation) | self.dilation = twice(dilation) | ||||
| self.pad_mode = pad_mode | self.pad_mode = pad_mode | ||||
| @@ -884,13 +881,10 @@ class Conv2dQuant(Cell): | |||||
| quant_config=quant_config_default, | quant_config=quant_config_default, | ||||
| quant_dtype=QuantDtype.INT8): | quant_dtype=QuantDtype.INT8): | ||||
| super(Conv2dQuant, self).__init__() | super(Conv2dQuant, self).__init__() | ||||
| if isinstance(kernel_size, int): | |||||
| self.kernel_size = (kernel_size, kernel_size) | |||||
| else: | |||||
| self.kernel_size = kernel_size | |||||
| self.in_channels = Validator.check_positive_int(in_channels) | self.in_channels = Validator.check_positive_int(in_channels) | ||||
| self.out_channels = Validator.check_positive_int(out_channels) | self.out_channels = Validator.check_positive_int(out_channels) | ||||
| self.has_bias = has_bias | self.has_bias = has_bias | ||||
| self.kernel_size = twice(kernel_size) | |||||
| self.stride = twice(stride) | self.stride = twice(stride) | ||||
| self.dilation = twice(dilation) | self.dilation = twice(dilation) | ||||
| self.pad_mode = pad_mode | self.pad_mode = pad_mode | ||||
| @@ -52,7 +52,7 @@ def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=Fals | |||||
| validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) | validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) | ||||
| ret_value = _get_return_value() | ret_value = _get_return_value() | ||||
| for item in ret_value: | for item in ret_value: | ||||
| if isinstance(item, int) and item > 0: | |||||
| if isinstance(item, int) and not isinstance(item, bool) and item > 0: | |||||
| continue | continue | ||||
| _raise_message() | _raise_message() | ||||
| return ret_value | return ret_value | ||||