|
|
|
@@ -19,6 +19,8 @@ train and infer lenet quantization network |
|
|
|
import os |
|
|
|
import pytest |
|
|
|
from mindspore import context |
|
|
|
from mindspore import Tensor |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore.nn.metrics import Accuracy |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor |
|
|
|
@@ -30,6 +32,7 @@ from dataset import create_dataset |
|
|
|
from config import nonquant_cfg, quant_cfg |
|
|
|
from lenet import LeNet5 |
|
|
|
from lenet_fusion import LeNet5 as LeNet5Fusion |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
device_target = 'GPU' |
|
|
|
data_path = "/home/workspace/mindspore_dataset/mnist" |
|
|
|
@@ -122,6 +125,19 @@ def eval_quant(): |
|
|
|
print("============== {} ==============".format(acc)) |
|
|
|
assert acc['Accuracy'] > 0.98 |
|
|
|
|
|
|
|
def export_lenet(): |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=device_target) |
|
|
|
cfg = quant_cfg |
|
|
|
# define fusion network |
|
|
|
network = LeNet5Fusion(cfg.num_classes) |
|
|
|
# convert fusion network to quantization aware network |
|
|
|
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, |
|
|
|
per_channel=[True, False], symmetric=[True, False]) |
|
|
|
|
|
|
|
# export network |
|
|
|
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32) |
|
|
|
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR') |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_x86_gpu_training |
|
|
|
@@ -130,6 +146,7 @@ def test_lenet_quant(): |
|
|
|
train_lenet() |
|
|
|
train_lenet_quant() |
|
|
|
eval_quant() |
|
|
|
export_lenet() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|