|
|
|
@@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore.nn.metrics import Accuracy |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor |
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export |
|
|
|
from mindspore.train import Model |
|
|
|
from mindspore.train.quant import quant |
|
|
|
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net |
|
|
|
@@ -136,7 +136,7 @@ def export_lenet(): |
|
|
|
|
|
|
|
# export network |
|
|
|
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32) |
|
|
|
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR') |
|
|
|
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO') |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
|