Browse Source

mode_export_modelzoo

tags/v1.1.0
bai-yangfan 5 years ago
parent
commit
f0d18d2ce0
4 changed files with 7 additions and 8 deletions
  1. +1
    -2
      mindspore/train/serialization.py
  2. +2
    -2
      model_zoo/official/cv/lenet_quant/export.py
  3. +2
    -2
      model_zoo/official/cv/mobilenetv2_quant/export.py
  4. +2
    -2
      tests/st/quantization/lenet_quant/test_lenet_quant.py

+ 1
- 2
mindspore/train/serialization.py View File

@@ -29,10 +29,9 @@ from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data
from mindspore._checkparam import check_input_data, Validator
from mindspore.train.quant import quant
import mindspore.context as context
from .._checkparam import Validator

__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
"build_searched_strategy", "merge_sliced_parameter"]


+ 2
- 2
model_zoo/official/cv/lenet_quant/export.py View File

@@ -23,7 +23,7 @@ import mindspore
from mindspore import Tensor
from mindspore import context
from mindspore.train.quant import quant
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export

from src.config import mnist_cfg as cfg
from src.lenet_fusion import LeNet5 as LeNet5Fusion
@@ -52,4 +52,4 @@ if __name__ == "__main__":

# export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
quant.export(network, inputs, file_name="lenet_quant", file_format='AIR')
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO')

+ 2
- 2
model_zoo/official/cv/mobilenetv2_quant/export.py View File

@@ -20,7 +20,7 @@ import numpy as np
import mindspore
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore.train.quant import quant

from src.mobilenetV2 import mobilenetV2
@@ -50,5 +50,5 @@ if __name__ == '__main__':
# export network
print("============== Starting export ==============")
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
quant.export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR')
export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR', quant_mode='AUTO')
print("============== End export ==============")

+ 2
- 2
tests/st/quantization/lenet_quant/test_lenet_quant.py View File

@@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore.train import Model
from mindspore.train.quant import quant
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
@@ -136,7 +136,7 @@ def export_lenet():

# export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR')
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO')


@pytest.mark.level0


Loading…
Cancel
Save