|
|
|
@@ -18,6 +18,8 @@ import sys |
|
|
|
import stat |
|
|
|
import math |
|
|
|
import shutil |
|
|
|
import time |
|
|
|
import copy |
|
|
|
from threading import Thread, Lock |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
@@ -756,6 +758,9 @@ def _quant_export(network, *inputs, file_format, **kwargs): |
|
|
|
supported_formats = ['AIR', 'MINDIR'] |
|
|
|
quant_mode_formats = ['AUTO', 'MANUAL'] |
|
|
|
|
|
|
|
quant_net = copy.deepcopy(network) |
|
|
|
quant_net._create_time = int(time.time() * 1e9) |
|
|
|
|
|
|
|
mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean'] |
|
|
|
std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev'] |
|
|
|
|
|
|
|
@@ -772,17 +777,17 @@ def _quant_export(network, *inputs, file_format, **kwargs): |
|
|
|
if file_format not in supported_formats: |
|
|
|
raise ValueError('Illegal file format {}.'.format(file_format)) |
|
|
|
|
|
|
|
network.set_train(False) |
|
|
|
quant_net.set_train(False) |
|
|
|
if file_format == "MINDIR": |
|
|
|
if quant_mode == 'MANUAL': |
|
|
|
exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) |
|
|
|
exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True) |
|
|
|
else: |
|
|
|
exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) |
|
|
|
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True) |
|
|
|
else: |
|
|
|
if quant_mode == 'MANUAL': |
|
|
|
exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs) |
|
|
|
exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs) |
|
|
|
else: |
|
|
|
exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs) |
|
|
|
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs) |
|
|
|
deploy_net = exporter.run() |
|
|
|
return deploy_net |
|
|
|
|
|
|
|
|