| @@ -328,6 +328,9 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI | |||
| x = cnode->input(1); | |||
| count += 1; | |||
| } | |||
| if (x->isa<Parameter>()) { | |||
| fake_quant_table[weight_name] = std::make_pair(nullptr, "input"); | |||
| } | |||
| // get the fakequant parameter minq's name | |||
| if (!is_quant_cnode(x)) { | |||
| continue; | |||
| @@ -1169,9 +1169,9 @@ class QuantBlock(Cell): | |||
| return x | |||
| def extend_repr(self): | |||
| str_info = f'quant={self.quant}, core_op={type(self.core_op)}' | |||
| str_info = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]' | |||
| if self.has_bias: | |||
| str_info = str_info + f', bias={self.bias}' | |||
| str_info = str_info + f', bias=shape[{self.bias.shape}]' | |||
| if self.has_act: | |||
| str_info = str_info + f', activation={self.activation}' | |||
| str_info = str_info + f', dequant={self.dequant}' | |||
| @@ -237,12 +237,14 @@ class PrimitiveWithInfer(Primitive): | |||
| """ | |||
| Infer output shape based on input shape. | |||
| Args: | |||
| inputs (tuple(int)): dimensions of input tensors. | |||
| outputs (tuple(int)): dimensions of output tensors. | |||
| Note: | |||
| The shape of scalar is an empty tuple. | |||
| Args: | |||
| args (tuple(int)): shapes of input tensors. | |||
| Return: | |||
| `tuple(int)`, shapes of output tensors. | |||
| """ | |||
| return None | |||
| @@ -251,8 +253,10 @@ class PrimitiveWithInfer(Primitive): | |||
| Infer output dtype based on input dtype. | |||
| Args: | |||
| inputs (mstype): data type of inputs. | |||
| outputs (mstype): data type of outputs. | |||
| args (:class:`mindspore.dtype`): data type of inputs. | |||
| Return: | |||
| :class:`mindspore.dtype`, data type of outputs. | |||
| """ | |||
| return None | |||
| @@ -261,8 +265,10 @@ class PrimitiveWithInfer(Primitive): | |||
| Infer output value based on input value at compile time. | |||
| Args: | |||
| inputs (any): value of inputs. | |||
| outputs (any): value of outputs. | |||
| args (Any): value of inputs. | |||
| Return: | |||
| Value of outputs. Return `None` for, cat not infer the value at compile time. | |||
| """ | |||
| return None | |||
| @@ -318,9 +318,12 @@ class ExportToQuantInferNetwork: | |||
| info = self.quant_info_table.get(w_minq_name, None) | |||
| if info: | |||
| fack_quant_a_in_op, minq_name = info | |||
| maxq = self.all_parameters[minq_name[:-4] + "maxq"] | |||
| minq = self.all_parameters[minq_name] | |||
| scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) | |||
| if minq_name == 'input': | |||
| scale_a_in, zp_a_in = self.input_scale, self.input_zero_point | |||
| else: | |||
| maxq = self.all_parameters[minq_name[:-4] + "maxq"] | |||
| minq = self.all_parameters[minq_name] | |||
| scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) | |||
| else: | |||
| logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") | |||
| return None | |||
| @@ -104,19 +104,20 @@ def weight2int(data, scale, zero_point): | |||
| raise ValueError("`scale` and `zero_point` should have the same shape.") | |||
| if scale.shape[0] < 0: | |||
| raise ValueError("`scale` and `zero_point` shape should greater than zero.") | |||
| if scale.shape[0] == data.shape[0]: | |||
| # `Conv2d` or `Dense` op weight | |||
| shape_list = [-1] + [1] * len(data.shape[1:]) | |||
| scale = scale.reshape(shape_list) | |||
| zero_point = zero_point.reshape(shape_list) | |||
| elif scale.shape[0] == data.shape[1]: | |||
| # `DepthwiseConv2d` op weight | |||
| shape_list = [1, -1] + [1] * len(data.shape[2:]) | |||
| scale = scale.reshape(shape_list) | |||
| zero_point = zero_point.reshape(shape_list) | |||
| else: | |||
| raise ValueError("Unsupported weight shape({})".format(data.shape)) | |||
| if len(scale.shape) > 1: | |||
| # for perchannel | |||
| if scale.shape[0] == data.shape[0]: | |||
| # `Conv2d` or `Dense` op weight | |||
| shape_list = [-1] + [1] * len(data.shape[1:]) | |||
| scale = scale.reshape(shape_list) | |||
| zero_point = zero_point.reshape(shape_list) | |||
| elif scale.shape[0] == data.shape[1]: | |||
| # `DepthwiseConv2d` op weight | |||
| shape_list = [1, -1] + [1] * len(data.shape[2:]) | |||
| scale = scale.reshape(shape_list) | |||
| zero_point = zero_point.reshape(shape_list) | |||
| else: | |||
| raise ValueError("Unsupported weight shape({})".format(data.shape)) | |||
| return np.round((data / scale) + zero_point) | |||
| @@ -1,115 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MobileNetV2""" | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as P | |||
| def make_divisible(input_x, div_by=8): | |||
| return int((input_x + div_by) // div_by) | |||
| def _conv_bn(in_channel, | |||
| out_channel, | |||
| ksize, | |||
| stride=1): | |||
| """Get a conv2d batchnorm and relu layer.""" | |||
| return nn.SequentialCell( | |||
| [nn.Conv2d(in_channel, | |||
| out_channel, | |||
| kernel_size=ksize, | |||
| stride=stride), | |||
| nn.BatchNorm2d(out_channel)]) | |||
| class InvertedResidual(nn.Cell): | |||
| def __init__(self, inp, oup, stride, expend_ratio): | |||
| super(InvertedResidual, self).__init__() | |||
| self.stride = stride | |||
| assert stride in [1, 2] | |||
| hidden_dim = int(inp * expend_ratio) | |||
| self.use_res_connect = self.stride == 1 and inp == oup | |||
| if expend_ratio == 1: | |||
| self.conv = nn.SequentialCell([ | |||
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim), | |||
| nn.BatchNorm2d(hidden_dim), | |||
| nn.ReLU6(), | |||
| nn.Conv2d(hidden_dim, oup, 1, 1), | |||
| nn.BatchNorm2d(oup) | |||
| ]) | |||
| else: | |||
| self.conv = nn.SequentialCell([ | |||
| nn.Conv2d(inp, hidden_dim, 1, 1), | |||
| nn.BatchNorm2d(hidden_dim), | |||
| nn.ReLU6(), | |||
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim), | |||
| nn.BatchNorm2d(hidden_dim), | |||
| nn.ReLU6(), | |||
| nn.Conv2d(hidden_dim, oup, 1, 1), | |||
| nn.BatchNorm2d(oup) | |||
| ]) | |||
| def construct(self, input_x): | |||
| out = self.conv(input_x) | |||
| if self.use_res_connect: | |||
| out = input_x + out | |||
| return out | |||
| class MobileNetV2(nn.Cell): | |||
| def __init__(self, num_class=1000, input_size=224, width_mul=1.): | |||
| super(MobileNetV2, self).__init__() | |||
| _ = input_size | |||
| block = InvertedResidual | |||
| input_channel = 32 | |||
| last_channel = 1280 | |||
| inverted_residual_setting = [ | |||
| [1, 16, 1, 1], | |||
| [6, 24, 2, 2], | |||
| [6, 32, 3, 2], | |||
| [6, 64, 4, 2], | |||
| [6, 96, 3, 1], | |||
| [6, 160, 3, 2], | |||
| [6, 230, 1, 1], | |||
| ] | |||
| if width_mul > 1.0: | |||
| last_channel = make_divisible(last_channel * width_mul) | |||
| self.last_channel = last_channel | |||
| features = [_conv_bn(3, input_channel, 3, 2)] | |||
| for t, c, n, s in inverted_residual_setting: | |||
| out_channel = make_divisible(c * width_mul) if t > 1 else c | |||
| for i in range(n): | |||
| if i == 0: | |||
| features.append(block(input_channel, out_channel, s, t)) | |||
| else: | |||
| features.append(block(input_channel, out_channel, 1, t)) | |||
| input_channel = out_channel | |||
| features.append(_conv_bn(input_channel, self.last_channel, 1)) | |||
| self.features = nn.SequentialCell(features) | |||
| self.mean = P.ReduceMean(keep_dims=False) | |||
| self.classifier = nn.Dense(self.last_channel, num_class) | |||
| def construct(self, input_x): | |||
| out = input_x | |||
| out = self.features(out) | |||
| out = self.mean(out, (2, 3)) | |||
| out = self.classifier(out) | |||
| return out | |||
| @@ -1,122 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """mobile net v2""" | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as P | |||
| def make_divisible(input_x, div_by=8): | |||
| return int((input_x + div_by) // div_by) | |||
| def _conv_bn(in_channel, | |||
| out_channel, | |||
| ksize, | |||
| stride=1): | |||
| """Get a conv2d batchnorm and relu layer.""" | |||
| return nn.SequentialCell( | |||
| [nn.Conv2dBnAct(in_channel, | |||
| out_channel, | |||
| kernel_size=ksize, | |||
| stride=stride, | |||
| has_bn=True)]) | |||
| class InvertedResidual(nn.Cell): | |||
| def __init__(self, inp, oup, stride, expend_ratio): | |||
| super(InvertedResidual, self).__init__() | |||
| self.stride = stride | |||
| assert stride in [1, 2] | |||
| hidden_dim = int(inp * expend_ratio) | |||
| self.use_res_connect = self.stride == 1 and inp == oup | |||
| if expend_ratio == 1: | |||
| self.conv = nn.SequentialCell([ | |||
| nn.Conv2dBnAct(hidden_dim, | |||
| hidden_dim, | |||
| 3, | |||
| stride, | |||
| group=hidden_dim, | |||
| has_bn=True, | |||
| activation='relu6'), | |||
| nn.Conv2dBnAct(hidden_dim, oup, 1, 1, | |||
| has_bn=True) | |||
| ]) | |||
| else: | |||
| self.conv = nn.SequentialCell([ | |||
| nn.Conv2dBnAct(inp, hidden_dim, 1, 1, | |||
| has_bn=True, | |||
| activation='relu6'), | |||
| nn.Conv2dBnAct(hidden_dim, | |||
| hidden_dim, | |||
| 3, | |||
| stride, | |||
| group=hidden_dim, | |||
| has_bn=True, | |||
| activation='relu6'), | |||
| nn.Conv2dBnAct(hidden_dim, oup, 1, 1, | |||
| has_bn=True) | |||
| ]) | |||
| self.add = P.TensorAdd() | |||
| def construct(self, input_x): | |||
| out = self.conv(input_x) | |||
| if self.use_res_connect: | |||
| out = self.add(input_x, out) | |||
| return out | |||
| class MobileNetV2(nn.Cell): | |||
| def __init__(self, num_class=1000, input_size=224, width_mul=1.): | |||
| super(MobileNetV2, self).__init__() | |||
| _ = input_size | |||
| block = InvertedResidual | |||
| input_channel = 32 | |||
| last_channel = 1280 | |||
| inverted_residual_setting = [ | |||
| [1, 16, 1, 1], | |||
| [6, 24, 2, 2], | |||
| [6, 32, 3, 2], | |||
| [6, 64, 4, 2], | |||
| [6, 96, 3, 1], | |||
| [6, 160, 3, 2], | |||
| [6, 230, 1, 1], | |||
| ] | |||
| if width_mul > 1.0: | |||
| last_channel = make_divisible(last_channel * width_mul) | |||
| self.last_channel = last_channel | |||
| features = [_conv_bn(3, input_channel, 3, 2)] | |||
| for t, c, n, s in inverted_residual_setting: | |||
| out_channel = make_divisible(c * width_mul) if t > 1 else c | |||
| for i in range(n): | |||
| if i == 0: | |||
| features.append(block(input_channel, out_channel, s, t)) | |||
| else: | |||
| features.append(block(input_channel, out_channel, 1, t)) | |||
| input_channel = out_channel | |||
| features.append(_conv_bn(input_channel, self.last_channel, 1)) | |||
| self.features = nn.SequentialCell(features) | |||
| self.mean = P.ReduceMean(keep_dims=False) | |||
| self.classifier = nn.DenseBnAct(self.last_channel, num_class) | |||
| def construct(self, input_x): | |||
| out = input_x | |||
| out = self.features(out) | |||
| out = self.mean(out, (2, 3)) | |||
| out = self.classifier(out) | |||
| return out | |||
| @@ -20,7 +20,7 @@ import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore import nn | |||
| from mindspore.train.quant import quant as qat | |||
| from mobilenetv2_combined import MobileNetV2 | |||
| from model_zoo.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| @@ -42,7 +42,7 @@ class LeNet5(nn.Cell): | |||
| def __init__(self, num_class=10): | |||
| super(LeNet5, self).__init__() | |||
| self.num_class = num_class | |||
| self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid") | |||
| self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu', pad_mode="valid") | |||
| self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid") | |||
| self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') | |||
| self.fc2 = nn.DenseBnAct(120, 84, activation='relu') | |||
| @@ -67,20 +67,19 @@ def test_qat_lenet(): | |||
| img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) | |||
| net = LeNet5() | |||
| net = qat.convert_quant_network( | |||
| net, freeze_bn=10000, num_bits=8) | |||
| net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) | |||
| # should load the checkpoint. mock here | |||
| for param in net.get_parameters(): | |||
| param.init_data() | |||
| qat.export_geir(net, img, file_name="quant.pb") | |||
| qat.export(net, img, file_name="quant.pb") | |||
| @pytest.mark.skip(reason="no `te.lang.cce` in ut env") | |||
| def test_qat_mobile(): | |||
| net = MobileNetV2() | |||
| network = mobilenetV2(num_classes=1000) | |||
| img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | |||
| net = qat.convert_quant_network( | |||
| net, quant_delay=0, bn_fold=True, freeze_bn=10000, num_bits=8) | |||
| network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) | |||
| # should load the checkpoint. mock here | |||
| for param in net.get_parameters(): | |||
| for param in network.get_parameters(): | |||
| param.init_data() | |||
| qat.export_geir(net, img, file_name="quant.pb") | |||
| qat.export(network, img, file_name="quant.pb") | |||