From c1ada6a3e5e53b68db94ef8223e3be12f3a90b2c Mon Sep 17 00:00:00 2001 From: changzherui Date: Mon, 29 Mar 2021 16:34:48 +0800 Subject: [PATCH] modify repeatedly export quant net mindir --- mindspore/train/serialization.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index e4eec594aa..52b59ed2fc 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -18,6 +18,8 @@ import sys import stat import math import shutil +import time +import copy from threading import Thread, Lock import numpy as np @@ -756,6 +758,9 @@ def _quant_export(network, *inputs, file_format, **kwargs): supported_formats = ['AIR', 'MINDIR'] quant_mode_formats = ['AUTO', 'MANUAL'] + quant_net = copy.deepcopy(network) + quant_net._create_time = int(time.time() * 1e9) + mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean'] std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev'] @@ -772,17 +777,17 @@ def _quant_export(network, *inputs, file_format, **kwargs): if file_format not in supported_formats: raise ValueError('Illegal file format {}.'.format(file_format)) - network.set_train(False) + quant_net.set_train(False) if file_format == "MINDIR": if quant_mode == 'MANUAL': - exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) + exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True) else: - exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) + exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True) else: if quant_mode == 'MANUAL': - exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs) + exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs) else: - exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs) + exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs) deploy_net = exporter.run() return deploy_net