Merge pull request !2836 from chenzhongming/mastertags/v0.6.0-beta
| @@ -1193,9 +1193,9 @@ class QuantBlock(Cell): | |||
| self.dequant = dequant_op | |||
| self.dequant_scale = dequant_scale | |||
| self.bias = bias | |||
| self.has_bias = bias is None | |||
| self.has_bias = bias is not None | |||
| self.activation = activation | |||
| self.has_act = activation is None | |||
| self.has_act = activation is not None | |||
| self.bias_add = P.BiasAdd() | |||
| def construct(self, x): | |||
| @@ -86,7 +86,7 @@ class LossMonitor(Callback): | |||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | |||
| print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " | |||
| "loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}]".format( | |||
| "loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}ms]".format( | |||
| cb_params.cur_epoch_num, cb_params.epoch_num, | |||
| cur_step_in_epoch, int(cb_params.batch_num), | |||
| step_loss, np.mean(self.losses), | |||
| @@ -33,7 +33,6 @@ from ...ops.operations import _inner_ops as inner | |||
| from ...train import serialization | |||
| from . import quant_utils | |||
| _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, | |||
| nn.ReLU6: quant.ReLU6Quant, | |||
| nn.HSigmoid: quant.HSigmoidQuant, | |||
| @@ -178,7 +177,6 @@ class ConvertToQuantNetwork: | |||
| dilation=conv_inner.dilation, | |||
| group=conv_inner.group, | |||
| eps=bn_inner.eps, | |||
| momentum=1 - bn_inner.momentum, | |||
| quant_delay=self.weight_qdelay, | |||
| freeze_bn=self.freeze_bn, | |||
| per_channel=self.weight_channel, | |||
| @@ -268,16 +266,16 @@ class ConvertToQuantNetwork: | |||
| narrow_range=self.act_range) | |||
| class ExportQuantNetworkDeploy: | |||
| class ExportToQuantInferNetwork: | |||
| """ | |||
| Convert quantization aware network to deploy network. | |||
| Convert quantization aware network to infer network. | |||
| Args: | |||
| network (Cell): MindSpore network produced by `convert_quant_network`. | |||
| inputs (Tensor): Inputs of the `network`. | |||
| network (Cell): MindSpore network API `convert_quant_network`. | |||
| inputs (Tensor): Input tensors of the `quantization aware training network`. | |||
| Returns: | |||
| Cell, converted network. | |||
| Cell, GEIR backend Infer network. | |||
| """ | |||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | |||
| @@ -287,7 +285,7 @@ class ExportQuantNetworkDeploy: | |||
| network = validator.check_isinstance('network', network, (nn.Cell,)) | |||
| self.data_type = mstype.int8 | |||
| self.network = copy.deepcopy(network) | |||
| self.all_paramters = {p.name: p for p in self.network.get_parameters()} | |||
| self.all_parameters = {p.name: p for p in self.network.get_parameters()} | |||
| self.get_inputs_table(inputs) | |||
| def get_inputs_table(self, inputs): | |||
| @@ -315,8 +313,8 @@ class ExportQuantNetworkDeploy: | |||
| info = self.quant_info_table.get(w_minq_name, None) | |||
| if info: | |||
| fack_quant_a_in_op, minq_name = info | |||
| maxq = self.all_paramters[minq_name[:-4] + "maxq"] | |||
| minq = self.all_paramters[minq_name] | |||
| maxq = self.all_parameters[minq_name[:-4] + "maxq"] | |||
| minq = self.all_parameters[minq_name] | |||
| scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) | |||
| else: | |||
| logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}") | |||
| @@ -357,7 +355,7 @@ class ExportQuantNetworkDeploy: | |||
| return block | |||
| def _convert_quant2deploy(self, network): | |||
| """Convet network's all quant subcell to deploy subcell.""" | |||
| """Convert network's all quant subcell to deploy subcell.""" | |||
| cells = network.name_cells() | |||
| change = False | |||
| for name in cells: | |||
| @@ -395,18 +393,26 @@ class ExportQuantNetworkDeploy: | |||
| return network | |||
| def export_geir(network, *inputs, file_name): | |||
| def export(network, *inputs, file_name, file_format='GEIR'): | |||
| """ | |||
| Exports MindSpore quant predict model to deploy with GEIR. | |||
| Exports MindSpore quantization predict model to deploy with GEIR. | |||
| Args: | |||
| network (Cell): MindSpore network produced by `convert_quant_network`. | |||
| inputs (Tensor): Inputs of the `network`. | |||
| inputs (Tensor): Inputs of the `quantization aware training network`. | |||
| file_name (str): File name of model to export. | |||
| file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model. | |||
| - GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model. | |||
| """ | |||
| exporter = ExportQuantNetworkDeploy(network, *inputs) | |||
| deploy_net = exporter.run() | |||
| serialization.export(deploy_net, *inputs, file_name=file_name, file_format="GEIR") | |||
| supported_formats = ['GEIR'] | |||
| if file_format not in supported_formats: | |||
| raise ValueError('Illegal file format {}.'.format(file_format)) | |||
| if file_format == 'GEIR': | |||
| exporter = ExportToQuantInferNetwork(network, *inputs) | |||
| deploy_net = exporter.run() | |||
| serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) | |||
| def convert_quant_network(network, | |||
| @@ -443,6 +449,7 @@ def convert_quant_network(network, | |||
| Cell, Network which has change to quantization aware training network cell. | |||
| """ | |||
| support_device = ["Ascend", "GPU"] | |||
| def convert2list(name, value): | |||
| if not isinstance(value, list) and not isinstance(value, tuple): | |||
| value = [value] | |||
| @@ -457,7 +464,7 @@ def convert_quant_network(network, | |||
| narrow_range = convert2list("narrow range", narrow_range) | |||
| if context.get_context('device_target') not in support_device: | |||
| raise KeyError("Not support {} backend.".format(context.get_context('device_target'))) | |||
| raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | |||
| net = ConvertToQuantNetwork(network=network, | |||
| quant_delay=quant_delay, | |||
| @@ -160,7 +160,10 @@ def load_checkpoint(ckpt_file_name, net=None): | |||
| if not isinstance(ckpt_file_name, str): | |||
| raise ValueError("The ckpt_file_name must be string.") | |||
| if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt": | |||
| if not os.path.exists(ckpt_file_name): | |||
| raise ValueError("The checkpoint file is not exist.") | |||
| if ckpt_file_name[-5:] != ".ckpt": | |||
| raise ValueError("Please input the correct checkpoint file name.") | |||
| if os.path.getsize(ckpt_file_name) == 0: | |||
| @@ -57,7 +57,7 @@ if __name__ == "__main__": | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| # load check point into network | |||
| param_dict = load_checkpoint(args.ckpt_path, network.type) | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| print("============== Starting Testing ==============") | |||
| @@ -49,7 +49,7 @@ if __name__ == "__main__": | |||
| # define fusion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # convert fusion netwrok to quantization aware network | |||
| # convert fusion network to quantization aware network | |||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||
| # define loss | |||
| @@ -0,0 +1,56 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| export quantization aware training network to infer `GEIR` backend. | |||
| """ | |||
| import argparse | |||
| import numpy as np | |||
| import mindspore | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.train.quant import quant | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.config import mnist_cfg as cfg | |||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | |||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||
| choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default="./MNIST_Data", | |||
| help='path where the dataset is saved') | |||
| parser.add_argument('--ckpt_path', type=str, default="", | |||
| help='if mode is test, must provide path where the trained ckpt file') | |||
| parser.add_argument('--dataset_sink_mode', type=bool, default=True, | |||
| help='dataset_sink_mode is False or True') | |||
| args = parser.parse_args() | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| # define fusion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # convert fusion network to quantization aware network | |||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||
| # load quantization aware network checkpoint | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| # export network | |||
| inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32) | |||
| quant.export(network, inputs, file_name="lenet_quant", file_format='GEIR') | |||
| @@ -22,7 +22,7 @@ import os | |||
| import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from src.dataset import create_dataset | |||
| @@ -54,7 +54,6 @@ if __name__ == "__main__": | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| # call back and monitor | |||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||
| config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) | |||
| @@ -63,6 +62,6 @@ if __name__ == "__main__": | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| print("============== Starting Training ==============") | |||
| model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()], | |||
| model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], | |||
| dataset_sink_mode=args.dataset_sink_mode) | |||
| print("============== End Training ==============") | |||
| @@ -23,7 +23,7 @@ import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from mindspore.train.quant import quant | |||
| @@ -51,20 +51,19 @@ if __name__ == "__main__": | |||
| # define fusion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # convert fusion network to quantization aware network | |||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||
| # load quantization aware network checkpoint | |||
| param_dict = load_checkpoint(args.ckpt_path, network.type) | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| # convert fusion network to quantization aware network | |||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||
| # define network loss | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| # define network optimization | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| # call back and monitor | |||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||
| config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) | |||
| @@ -73,6 +72,6 @@ if __name__ == "__main__": | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| print("============== Starting Training ==============") | |||
| model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()], | |||
| model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], | |||
| dataset_sink_mode=args.dataset_sink_mode) | |||
| print("============== End Training ==============") | |||