| @@ -19,6 +19,8 @@ train and infer lenet quantization network | |||||
| import os | import os | ||||
| import pytest | import pytest | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore import Tensor | |||||
| from mindspore.common import dtype as mstype | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.nn.metrics import Accuracy | from mindspore.nn.metrics import Accuracy | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | ||||
| @@ -30,6 +32,7 @@ from dataset import create_dataset | |||||
| from config import nonquant_cfg, quant_cfg | from config import nonquant_cfg, quant_cfg | ||||
| from lenet import LeNet5 | from lenet import LeNet5 | ||||
| from lenet_fusion import LeNet5 as LeNet5Fusion | from lenet_fusion import LeNet5 as LeNet5Fusion | ||||
| import numpy as np | |||||
| device_target = 'GPU' | device_target = 'GPU' | ||||
| data_path = "/home/workspace/mindspore_dataset/mnist" | data_path = "/home/workspace/mindspore_dataset/mnist" | ||||
| @@ -122,6 +125,19 @@ def eval_quant(): | |||||
| print("============== {} ==============".format(acc)) | print("============== {} ==============".format(acc)) | ||||
| assert acc['Accuracy'] > 0.98 | assert acc['Accuracy'] > 0.98 | ||||
| def export_lenet(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=device_target) | |||||
| cfg = quant_cfg | |||||
| # define fusion network | |||||
| network = LeNet5Fusion(cfg.num_classes) | |||||
| # convert fusion network to quantization aware network | |||||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, | |||||
| per_channel=[True, False], symmetric=[True, False]) | |||||
| # export network | |||||
| inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32) | |||||
| quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR') | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @@ -130,6 +146,7 @@ def test_lenet_quant(): | |||||
| train_lenet() | train_lenet() | ||||
| train_lenet_quant() | train_lenet_quant() | ||||
| eval_quant() | eval_quant() | ||||
| export_lenet() | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||