Browse Source

!7247 export access control

Merge pull request !7247 from baiyangfan/access_control
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
186275a517
1 changed files with 17 additions and 0 deletions
  1. +17
    -0
      tests/st/quantization/lenet_quant/test_lenet_quant.py

+ 17
- 0
tests/st/quantization/lenet_quant/test_lenet_quant.py View File

@@ -19,6 +19,8 @@ train and infer lenet quantization network
import os
import pytest
from mindspore import context
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
@@ -30,6 +32,7 @@ from dataset import create_dataset
from config import nonquant_cfg, quant_cfg
from lenet import LeNet5
from lenet_fusion import LeNet5 as LeNet5Fusion
import numpy as np

device_target = 'GPU'
data_path = "/home/workspace/mindspore_dataset/mnist"
@@ -122,6 +125,19 @@ def eval_quant():
print("============== {} ==============".format(acc))
assert acc['Accuracy'] > 0.98

def export_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = quant_cfg
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
per_channel=[True, False], symmetric=[True, False])

# export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR')


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@@ -130,6 +146,7 @@ def test_lenet_quant():
train_lenet()
train_lenet_quant()
eval_quant()
export_lenet()


if __name__ == "__main__":


Loading…
Cancel
Save