Merge pull request !2836 from chenzhongming/mastertags/v0.6.0-beta
| @@ -1193,9 +1193,9 @@ class QuantBlock(Cell): | |||||
| self.dequant = dequant_op | self.dequant = dequant_op | ||||
| self.dequant_scale = dequant_scale | self.dequant_scale = dequant_scale | ||||
| self.bias = bias | self.bias = bias | ||||
| self.has_bias = bias is None | |||||
| self.has_bias = bias is not None | |||||
| self.activation = activation | self.activation = activation | ||||
| self.has_act = activation is None | |||||
| self.has_act = activation is not None | |||||
| self.bias_add = P.BiasAdd() | self.bias_add = P.BiasAdd() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| @@ -86,7 +86,7 @@ class LossMonitor(Callback): | |||||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | ||||
| print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " | print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " | ||||
| "loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}]".format( | |||||
| "loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}ms]".format( | |||||
| cb_params.cur_epoch_num, cb_params.epoch_num, | cb_params.cur_epoch_num, cb_params.epoch_num, | ||||
| cur_step_in_epoch, int(cb_params.batch_num), | cur_step_in_epoch, int(cb_params.batch_num), | ||||
| step_loss, np.mean(self.losses), | step_loss, np.mean(self.losses), | ||||
| @@ -33,7 +33,6 @@ from ...ops.operations import _inner_ops as inner | |||||
| from ...train import serialization | from ...train import serialization | ||||
| from . import quant_utils | from . import quant_utils | ||||
| _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, | _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, | ||||
| nn.ReLU6: quant.ReLU6Quant, | nn.ReLU6: quant.ReLU6Quant, | ||||
| nn.HSigmoid: quant.HSigmoidQuant, | nn.HSigmoid: quant.HSigmoidQuant, | ||||
| @@ -178,7 +177,6 @@ class ConvertToQuantNetwork: | |||||
| dilation=conv_inner.dilation, | dilation=conv_inner.dilation, | ||||
| group=conv_inner.group, | group=conv_inner.group, | ||||
| eps=bn_inner.eps, | eps=bn_inner.eps, | ||||
| momentum=1 - bn_inner.momentum, | |||||
| quant_delay=self.weight_qdelay, | quant_delay=self.weight_qdelay, | ||||
| freeze_bn=self.freeze_bn, | freeze_bn=self.freeze_bn, | ||||
| per_channel=self.weight_channel, | per_channel=self.weight_channel, | ||||
| @@ -268,16 +266,16 @@ class ConvertToQuantNetwork: | |||||
| narrow_range=self.act_range) | narrow_range=self.act_range) | ||||
| class ExportQuantNetworkDeploy: | |||||
| class ExportToQuantInferNetwork: | |||||
| """ | """ | ||||
| Convert quantization aware network to deploy network. | |||||
| Convert quantization aware network to infer network. | |||||
| Args: | Args: | ||||
| network (Cell): MindSpore network produced by `convert_quant_network`. | |||||
| inputs (Tensor): Inputs of the `network`. | |||||
| network (Cell): MindSpore network API `convert_quant_network`. | |||||
| inputs (Tensor): Input tensors of the `quantization aware training network`. | |||||
| Returns: | Returns: | ||||
| Cell, converted network. | |||||
| Cell, GEIR backend Infer network. | |||||
| """ | """ | ||||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | ||||
| @@ -287,7 +285,7 @@ class ExportQuantNetworkDeploy: | |||||
| network = validator.check_isinstance('network', network, (nn.Cell,)) | network = validator.check_isinstance('network', network, (nn.Cell,)) | ||||
| self.data_type = mstype.int8 | self.data_type = mstype.int8 | ||||
| self.network = copy.deepcopy(network) | self.network = copy.deepcopy(network) | ||||
| self.all_paramters = {p.name: p for p in self.network.get_parameters()} | |||||
| self.all_parameters = {p.name: p for p in self.network.get_parameters()} | |||||
| self.get_inputs_table(inputs) | self.get_inputs_table(inputs) | ||||
| def get_inputs_table(self, inputs): | def get_inputs_table(self, inputs): | ||||
| @@ -315,8 +313,8 @@ class ExportQuantNetworkDeploy: | |||||
| info = self.quant_info_table.get(w_minq_name, None) | info = self.quant_info_table.get(w_minq_name, None) | ||||
| if info: | if info: | ||||
| fack_quant_a_in_op, minq_name = info | fack_quant_a_in_op, minq_name = info | ||||
| maxq = self.all_paramters[minq_name[:-4] + "maxq"] | |||||
| minq = self.all_paramters[minq_name] | |||||
| maxq = self.all_parameters[minq_name[:-4] + "maxq"] | |||||
| minq = self.all_parameters[minq_name] | |||||
| scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) | scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) | ||||
| else: | else: | ||||
| logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}") | logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}") | ||||
| @@ -357,7 +355,7 @@ class ExportQuantNetworkDeploy: | |||||
| return block | return block | ||||
| def _convert_quant2deploy(self, network): | def _convert_quant2deploy(self, network): | ||||
| """Convet network's all quant subcell to deploy subcell.""" | |||||
| """Convert network's all quant subcell to deploy subcell.""" | |||||
| cells = network.name_cells() | cells = network.name_cells() | ||||
| change = False | change = False | ||||
| for name in cells: | for name in cells: | ||||
| @@ -395,18 +393,26 @@ class ExportQuantNetworkDeploy: | |||||
| return network | return network | ||||
| def export_geir(network, *inputs, file_name): | |||||
| def export(network, *inputs, file_name, file_format='GEIR'): | |||||
| """ | """ | ||||
| Exports MindSpore quant predict model to deploy with GEIR. | |||||
| Exports MindSpore quantization predict model to deploy with GEIR. | |||||
| Args: | Args: | ||||
| network (Cell): MindSpore network produced by `convert_quant_network`. | network (Cell): MindSpore network produced by `convert_quant_network`. | ||||
| inputs (Tensor): Inputs of the `network`. | |||||
| inputs (Tensor): Inputs of the `quantization aware training network`. | |||||
| file_name (str): File name of model to export. | file_name (str): File name of model to export. | ||||
| file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model. | |||||
| - GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model. | |||||
| """ | """ | ||||
| exporter = ExportQuantNetworkDeploy(network, *inputs) | |||||
| deploy_net = exporter.run() | |||||
| serialization.export(deploy_net, *inputs, file_name=file_name, file_format="GEIR") | |||||
| supported_formats = ['GEIR'] | |||||
| if file_format not in supported_formats: | |||||
| raise ValueError('Illegal file format {}.'.format(file_format)) | |||||
| if file_format == 'GEIR': | |||||
| exporter = ExportToQuantInferNetwork(network, *inputs) | |||||
| deploy_net = exporter.run() | |||||
| serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) | |||||
| def convert_quant_network(network, | def convert_quant_network(network, | ||||
| @@ -443,6 +449,7 @@ def convert_quant_network(network, | |||||
| Cell, Network which has change to quantization aware training network cell. | Cell, Network which has change to quantization aware training network cell. | ||||
| """ | """ | ||||
| support_device = ["Ascend", "GPU"] | support_device = ["Ascend", "GPU"] | ||||
| def convert2list(name, value): | def convert2list(name, value): | ||||
| if not isinstance(value, list) and not isinstance(value, tuple): | if not isinstance(value, list) and not isinstance(value, tuple): | ||||
| value = [value] | value = [value] | ||||
| @@ -457,7 +464,7 @@ def convert_quant_network(network, | |||||
| narrow_range = convert2list("narrow range", narrow_range) | narrow_range = convert2list("narrow range", narrow_range) | ||||
| if context.get_context('device_target') not in support_device: | if context.get_context('device_target') not in support_device: | ||||
| raise KeyError("Not support {} backend.".format(context.get_context('device_target'))) | |||||
| raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | |||||
| net = ConvertToQuantNetwork(network=network, | net = ConvertToQuantNetwork(network=network, | ||||
| quant_delay=quant_delay, | quant_delay=quant_delay, | ||||
| @@ -160,7 +160,10 @@ def load_checkpoint(ckpt_file_name, net=None): | |||||
| if not isinstance(ckpt_file_name, str): | if not isinstance(ckpt_file_name, str): | ||||
| raise ValueError("The ckpt_file_name must be string.") | raise ValueError("The ckpt_file_name must be string.") | ||||
| if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt": | |||||
| if not os.path.exists(ckpt_file_name): | |||||
| raise ValueError("The checkpoint file is not exist.") | |||||
| if ckpt_file_name[-5:] != ".ckpt": | |||||
| raise ValueError("Please input the correct checkpoint file name.") | raise ValueError("Please input the correct checkpoint file name.") | ||||
| if os.path.getsize(ckpt_file_name) == 0: | if os.path.getsize(ckpt_file_name) == 0: | ||||
| @@ -57,7 +57,7 @@ if __name__ == "__main__": | |||||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | ||||
| # load check point into network | # load check point into network | ||||
| param_dict = load_checkpoint(args.ckpt_path, network.type) | |||||
| param_dict = load_checkpoint(args.ckpt_path) | |||||
| load_param_into_net(network, param_dict) | load_param_into_net(network, param_dict) | ||||
| print("============== Starting Testing ==============") | print("============== Starting Testing ==============") | ||||
| @@ -49,7 +49,7 @@ if __name__ == "__main__": | |||||
| # define fusion network | # define fusion network | ||||
| network = LeNet5Fusion(cfg.num_classes) | network = LeNet5Fusion(cfg.num_classes) | ||||
| # convert fusion netwrok to quantization aware network | |||||
| # convert fusion network to quantization aware network | |||||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | ||||
| # define loss | # define loss | ||||
| @@ -0,0 +1,56 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| export quantization aware training network to infer `GEIR` backend. | |||||
| """ | |||||
| import argparse | |||||
| import numpy as np | |||||
| import mindspore | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.train.quant import quant | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.config import mnist_cfg as cfg | |||||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | |||||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||||
| choices=['Ascend', 'GPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--data_path', type=str, default="./MNIST_Data", | |||||
| help='path where the dataset is saved') | |||||
| parser.add_argument('--ckpt_path', type=str, default="", | |||||
| help='if mode is test, must provide path where the trained ckpt file') | |||||
| parser.add_argument('--dataset_sink_mode', type=bool, default=True, | |||||
| help='dataset_sink_mode is False or True') | |||||
| args = parser.parse_args() | |||||
| if __name__ == "__main__": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||||
| # define fusion network | |||||
| network = LeNet5Fusion(cfg.num_classes) | |||||
| # convert fusion network to quantization aware network | |||||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||||
| # load quantization aware network checkpoint | |||||
| param_dict = load_checkpoint(args.ckpt_path) | |||||
| load_param_into_net(network, param_dict) | |||||
| # export network | |||||
| inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32) | |||||
| quant.export(network, inputs, file_name="lenet_quant", file_format='GEIR') | |||||
| @@ -22,7 +22,7 @@ import os | |||||
| import argparse | import argparse | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | |||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.nn.metrics import Accuracy | from mindspore.nn.metrics import Accuracy | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| @@ -54,7 +54,6 @@ if __name__ == "__main__": | |||||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | ||||
| # call back and monitor | # call back and monitor | ||||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||||
| config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | ||||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | keep_checkpoint_max=cfg.keep_checkpoint_max) | ||||
| ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) | ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) | ||||
| @@ -63,6 +62,6 @@ if __name__ == "__main__": | |||||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | ||||
| print("============== Starting Training ==============") | print("============== Starting Training ==============") | ||||
| model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()], | |||||
| model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], | |||||
| dataset_sink_mode=args.dataset_sink_mode) | dataset_sink_mode=args.dataset_sink_mode) | ||||
| print("============== End Training ==============") | print("============== End Training ==============") | ||||
| @@ -23,7 +23,7 @@ import argparse | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | |||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.nn.metrics import Accuracy | from mindspore.nn.metrics import Accuracy | ||||
| from mindspore.train.quant import quant | from mindspore.train.quant import quant | ||||
| @@ -51,20 +51,19 @@ if __name__ == "__main__": | |||||
| # define fusion network | # define fusion network | ||||
| network = LeNet5Fusion(cfg.num_classes) | network = LeNet5Fusion(cfg.num_classes) | ||||
| # convert fusion network to quantization aware network | |||||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||||
| # load quantization aware network checkpoint | # load quantization aware network checkpoint | ||||
| param_dict = load_checkpoint(args.ckpt_path, network.type) | |||||
| param_dict = load_checkpoint(args.ckpt_path) | |||||
| load_param_into_net(network, param_dict) | load_param_into_net(network, param_dict) | ||||
| # convert fusion network to quantization aware network | |||||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||||
| # define network loss | # define network loss | ||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | ||||
| # define network optimization | # define network optimization | ||||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | ||||
| # call back and monitor | # call back and monitor | ||||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||||
| config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | ||||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | keep_checkpoint_max=cfg.keep_checkpoint_max) | ||||
| ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) | ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) | ||||
| @@ -73,6 +72,6 @@ if __name__ == "__main__": | |||||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | ||||
| print("============== Starting Training ==============") | print("============== Starting Training ==============") | ||||
| model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()], | |||||
| model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], | |||||
| dataset_sink_mode=args.dataset_sink_mode) | dataset_sink_mode=args.dataset_sink_mode) | ||||
| print("============== End Training ==============") | print("============== End Training ==============") | ||||