Browse Source

modify repeatedly export quant net mindir

pull/14325/head
changzherui 4 years ago
parent
commit
c1ada6a3e5
1 changed files with 10 additions and 5 deletions
  1. +10
    -5
      mindspore/train/serialization.py

+ 10
- 5
mindspore/train/serialization.py View File

@@ -18,6 +18,8 @@ import sys
import stat import stat
import math import math
import shutil import shutil
import time
import copy
from threading import Thread, Lock from threading import Thread, Lock
import numpy as np import numpy as np


@@ -756,6 +758,9 @@ def _quant_export(network, *inputs, file_format, **kwargs):
supported_formats = ['AIR', 'MINDIR'] supported_formats = ['AIR', 'MINDIR']
quant_mode_formats = ['AUTO', 'MANUAL'] quant_mode_formats = ['AUTO', 'MANUAL']


quant_net = copy.deepcopy(network)
quant_net._create_time = int(time.time() * 1e9)

mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean'] mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean']
std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev'] std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev']


@@ -772,17 +777,17 @@ def _quant_export(network, *inputs, file_format, **kwargs):
if file_format not in supported_formats: if file_format not in supported_formats:
raise ValueError('Illegal file format {}.'.format(file_format)) raise ValueError('Illegal file format {}.'.format(file_format))


network.set_train(False)
quant_net.set_train(False)
if file_format == "MINDIR": if file_format == "MINDIR":
if quant_mode == 'MANUAL': if quant_mode == 'MANUAL':
exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
else: else:
exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
else: else:
if quant_mode == 'MANUAL': if quant_mode == 'MANUAL':
exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs)
exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs)
else: else:
exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs)
deploy_net = exporter.run() deploy_net = exporter.run()
return deploy_net return deploy_net




Loading…
Cancel
Save