Browse Source

!15723 quant mode change in quant export

From: @zhang__sss
Reviewed-by: @zh_qh,@guoqi1024
Signed-off-by: @zh_qh
pull/15723/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
c04e8fce4d
9 changed files with 65 additions and 135 deletions
  1. +15
    -73
      mindspore/compression/export/quant_export.py
  2. +28
    -21
      mindspore/train/serialization.py
  3. +1
    -1
      model_zoo/official/cv/lenet_quant/export.py
  4. +10
    -7
      model_zoo/official/cv/lenet_quant/src/lenet_quant.py
  5. +3
    -5
      model_zoo/official/cv/mobilenetv2_quant/export.py
  6. +1
    -8
      model_zoo/official/cv/resnet50_quant/eval.py
  7. +3
    -8
      model_zoo/official/cv/resnet50_quant/export.py
  8. +1
    -10
      model_zoo/official/cv/resnet50_quant/train.py
  9. +3
    -2
      model_zoo/official/cv/yolov3_darknet53_quant/export.py

+ 15
- 73
mindspore/compression/export/quant_export.py View File

@@ -30,18 +30,20 @@ from ..quant import quant_utils
from ..quant.qat import QuantizationAwareTraining, _AddFakeQuantInput, _AddFakeQuantAfterSubCell from ..quant.qat import QuantizationAwareTraining, _AddFakeQuantInput, _AddFakeQuantAfterSubCell




__all__ = ["ExportToQuantInferNetwork", "ExportManualQuantNetwork"]
__all__ = ["ExportToQuantInferNetwork"]


class ExportToQuantInferNetwork: class ExportToQuantInferNetwork:
""" """
Convert quantization aware network to infer network. Convert quantization aware network to infer network.


Args: Args:
network (Cell): MindSpore network API `convert_quant_network`.
network (Cell): MindSpore quantization aware training network.
inputs (Tensor): Input tensors of the `quantization aware training network`. inputs (Tensor): Input tensors of the `quantization aware training network`.
mean (int): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
is_mindir (bool): Whether is MINDIR format. Default: False.
mean (int, float): The mean of input data after preprocessing, used for quantizing the first layer of network.
Default: 127.5.
std_dev (int, float): The variance of input data after preprocessing, used for quantizing the first layer
of network. Default: 127.5.
is_mindir (bool): Whether export MINDIR format. Default: False.


Returns: Returns:
Cell, Infer network. Cell, Infer network.
@@ -59,9 +61,11 @@ class ExportToQuantInferNetwork:
self.mean = mean self.mean = mean
self.std_dev = std_dev self.std_dev = std_dev
self.is_mindir = is_mindir self.is_mindir = is_mindir
self.upcell = None
self.upname = None


def get_inputs_table(self, inputs): def get_inputs_table(self, inputs):
"""Get the support info for quant export."""
"""Get the input quantization parameters of quantization cell for quant export."""
phase_name = 'export_quant' phase_name = 'export_quant'
graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False) graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False)
self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id) self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id)
@@ -151,7 +155,6 @@ class ExportToQuantInferNetwork:
dequant_param = np.zeros(scale_length, dtype=np.uint64) dequant_param = np.zeros(scale_length, dtype=np.uint64)
for index in range(scale_length): for index in range(scale_length):
dequant_param[index] += uint32_deq_scale[index] dequant_param[index] += uint32_deq_scale[index]

scale_deq = Tensor(dequant_param, mstype.uint64) scale_deq = Tensor(dequant_param, mstype.uint64)
# get op # get op
if isinstance(cell_core, quant.DenseQuant): if isinstance(cell_core, quant.DenseQuant):
@@ -170,69 +173,8 @@ class ExportToQuantInferNetwork:
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
return block return block


def _convert_quant2deploy(self, network):
"""Convert network's all quant subcell to deploy subcell."""
cells = network.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue
cell_core = None
fake_quant_act = None
activation = None
if isinstance(subcell, nn.Conv2dBnAct):
cell_core = subcell.conv
activation = subcell.activation
fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
elif isinstance(subcell, nn.DenseBnAct):
cell_core = subcell.dense
activation = subcell.activation
fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
if cell_core is not None:
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
if new_subcell:
prefix = subcell.param_prefix
new_subcell.update_parameters_name(prefix + '.')
network.insert_child_to_cell(name, new_subcell)
change = True
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
op = subcell.subcell
if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive):
if self.is_mindir:
op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy()))
op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy()))
network.__delattr__(name)
network.__setattr__(name, op)
change = True
else:
self._convert_quant2deploy(subcell)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
return network

class ExportManualQuantNetwork(ExportToQuantInferNetwork):
"""
Convert manual quantization aware network to infer network.

Args:
network (Cell): MindSpore network API `convert_quant_network`.
inputs (Tensor): Input tensors of the `quantization aware training network`.
mean (int): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
is_mindir (bool): Whether is MINDIR format. Default: False.

Returns:
Cell, Infer network.
"""
__quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]

def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir=is_mindir)
self.upcell = None
self.upname = None

def _add_output_min_max_for_op(self, origin_op, fake_quant_cell): def _add_output_min_max_for_op(self, origin_op, fake_quant_cell):
"""add output quant info for quant op for export mindir."""
if self.is_mindir: if self.is_mindir:
np_type = mstype.dtype_to_nptype(self.data_type) np_type = mstype.dtype_to_nptype(self.data_type)
_, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_cell, np_type) _, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_cell, np_type)
@@ -251,8 +193,8 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
network, change = self._convert_subcell(network, change, name, subcell) network, change = self._convert_subcell(network, change, name, subcell)
elif isinstance(subcell, nn.DenseBnAct): elif isinstance(subcell, nn.DenseBnAct):
network, change = self._convert_subcell(network, change, name, subcell, conv=False) network, change = self._convert_subcell(network, change, name, subcell, conv=False)
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant,
quant.Conv2dQuant, quant.DenseQuant)):
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv,
quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)):
network, change = self._convert_subcell(network, change, name, subcell, core=False) network, change = self._convert_subcell(network, change, name, subcell, core=False)
elif isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"): elif isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"):
if self.upcell: if self.upcell:
@@ -292,16 +234,16 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
def _convert_subcell(self, network, change, name, subcell, core=True, conv=True): def _convert_subcell(self, network, change, name, subcell, core=True, conv=True):
"""Convert subcell to ant subcell.""" """Convert subcell to ant subcell."""
new_subcell = None new_subcell = None
fake_quant_act = None
if core: if core:
cell_core = subcell.conv if conv else subcell.dense cell_core = subcell.conv if conv else subcell.dense
activation = subcell.activation activation = subcell.activation
if hasattr(activation, 'fake_quant_act'): if hasattr(activation, 'fake_quant_act'):
fake_quant_act = activation.fake_quant_act fake_quant_act = activation.fake_quant_act
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
else: else:
cell_core = subcell cell_core = subcell
activation = None activation = None
fake_quant_act = None
if cell_core is not None and hasattr(cell_core, "fake_quant_weight"):
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
if new_subcell: if new_subcell:
prefix = subcell.param_prefix prefix = subcell.param_prefix


+ 28
- 21
mindspore/train/serialization.py View File

@@ -599,9 +599,12 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):


kwargs (dict): Configuration options dictionary. kwargs (dict): Configuration options dictionary.


- quant_mode: The mode of quant.
- mean: Input data mean. Default: 127.5.
- std_dev: Input data variance. Default: 127.5.
- quant_mode: If the network is quantization aware training network, the quant_mode should
be set to "QUANT", else the quant_mode should be set to "NONQUANT".
- mean: The mean of input data after preprocessing, used for quantizing the first layer of network.
Default: 127.5.
- std_dev: The variance of input data after preprocessing, used for quantizing the first layer of network.
Default: 127.5.
""" """
logger.info("exporting model file:%s format:%s.", file_name, file_format) logger.info("exporting model file:%s format:%s.", file_name, file_format)
check_input_data(*inputs, data_class=Tensor) check_input_data(*inputs, data_class=Tensor)
@@ -755,28 +758,38 @@ def _mindir_save_together(net_dict, model):
return False return False
return True return True



def quant_mode_manage(func):
"""
Inherit the quant_mode in old version.
"""
def warpper(network, *inputs, file_format, **kwargs):
if not kwargs.get('quant_mode', None):
return network
quant_mode = kwargs['quant_mode']
if quant_mode in ('AUTO', 'MANUAL'):
kwargs['quant_mode'] = 'QUANT'
return func(network, *inputs, file_format=file_format, **kwargs)
return warpper

@quant_mode_manage
def _quant_export(network, *inputs, file_format, **kwargs): def _quant_export(network, *inputs, file_format, **kwargs):
""" """
Exports MindSpore quantization predict model to deploy with AIR and MINDIR. Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
""" """
if not kwargs.get('quant_mode', None):
return network

supported_device = ["Ascend", "GPU"] supported_device = ["Ascend", "GPU"]
supported_formats = ['AIR', 'MINDIR'] supported_formats = ['AIR', 'MINDIR']
quant_mode_formats = ['AUTO', 'MANUAL']
quant_mode_formats = ['QUANT', 'NONQUANT']


quant_mode = kwargs['quant_mode']
if quant_mode not in quant_mode_formats:
raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.')
if quant_mode == 'NONQUANT':
return network
quant_net = copy.deepcopy(network) quant_net = copy.deepcopy(network)
quant_net._create_time = int(time.time() * 1e9) quant_net._create_time = int(time.time() * 1e9)


mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean'] mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean']
std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev'] std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev']

quant_mode = kwargs['quant_mode']
if quant_mode not in quant_mode_formats:
raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.')

mean = Validator.check_value_type("mean", mean, (int, float)) mean = Validator.check_value_type("mean", mean, (int, float))
std_dev = Validator.check_value_type("std_dev", std_dev, (int, float)) std_dev = Validator.check_value_type("std_dev", std_dev, (int, float))


@@ -788,15 +801,9 @@ def _quant_export(network, *inputs, file_format, **kwargs):


quant_net.set_train(False) quant_net.set_train(False)
if file_format == "MINDIR": if file_format == "MINDIR":
if quant_mode == 'MANUAL':
exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
else:
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
else: else:
if quant_mode == 'MANUAL':
exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs)
else:
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs)
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs)
deploy_net = exporter.run() deploy_net = exporter.run()
return deploy_net return deploy_net




+ 1
- 1
model_zoo/official/cv/lenet_quant/export.py View File

@@ -54,4 +54,4 @@ if __name__ == "__main__":


# export network # export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32) inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO')
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='QUANT')

+ 10
- 7
model_zoo/official/cv/lenet_quant/src/lenet_quant.py View File

@@ -15,7 +15,8 @@
"""Manual construct network for LeNet""" """Manual construct network for LeNet"""


import mindspore.nn as nn import mindspore.nn as nn

from mindspore.compression.quant import create_quant_config
from mindspore.compression.common import QuantDtype


class LeNet5(nn.Cell): class LeNet5(nn.Cell):
""" """
@@ -34,14 +35,16 @@ class LeNet5(nn.Cell):
def __init__(self, num_class=10, channel=1): def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.num_class = num_class self.num_class = num_class
self.qconfig = create_quant_config(per_channel=(True, False), symmetric=(True, False))


self.conv1 = nn.Conv2dBnFoldQuant(channel, 6, 5, pad_mode='valid', per_channel=True, quant_delay=900)
self.conv2 = nn.Conv2dBnFoldQuant(6, 16, 5, pad_mode='valid', per_channel=True, quant_delay=900)
self.fc1 = nn.DenseQuant(16 * 5 * 5, 120, per_channel=True, quant_delay=900)
self.fc2 = nn.DenseQuant(120, 84, per_channel=True, quant_delay=900)
self.fc3 = nn.DenseQuant(84, self.num_class, per_channel=True, quant_delay=900)
self.conv1 = nn.Conv2dQuant(channel, 6, 5, pad_mode='valid', quant_config=self.qconfig,
quant_dtype=QuantDtype.INT8)
self.conv2 = nn.Conv2dQuant(6, 16, 5, pad_mode='valid', quant_config=self.qconfig, quant_dtype=QuantDtype.INT8)
self.fc1 = nn.DenseQuant(16 * 5 * 5, 120, quant_config=self.qconfig, quant_dtype=QuantDtype.INT8)
self.fc2 = nn.DenseQuant(120, 84, quant_config=self.qconfig, quant_dtype=QuantDtype.INT8)
self.fc3 = nn.DenseQuant(84, self.num_class, quant_config=self.qconfig, quant_dtype=QuantDtype.INT8)


self.relu = nn.ActQuant(nn.ReLU(), per_channel=False, quant_delay=900)
self.relu = nn.ActQuant(nn.ReLU(), quant_config=self.qconfig, quant_dtype=QuantDtype.INT8)
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()




+ 3
- 5
model_zoo/official/cv/mobilenetv2_quant/export.py View File

@@ -47,9 +47,7 @@ if __name__ == '__main__':
# export network # export network
print("============== Starting export ==============") print("============== Starting export ==============")
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32) inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
if args_opt.file_format == 'MINDIR':
export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR', quant_mode='AUTO')
else:
export(network, inputs, file_name="mobilenet_quant", file_format='AIR',
quant_mode='AUTO', mean=0., std_dev=48.106)
export(network, inputs, file_name="mobilenetv2_quant", file_format=args_opt.file_format,
quant_mode='QUANT', mean=0., std_dev=48.106)

print("============== End export ==============") print("============== End export ==============")

+ 1
- 8
model_zoo/official/cv/resnet50_quant/eval.py View File

@@ -20,13 +20,11 @@ import argparse
from src.config import config_quant from src.config import config_quant
from src.dataset import create_dataset from src.dataset import create_dataset
from src.crossentropy import CrossEntropy from src.crossentropy import CrossEntropy
#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50 from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50


from mindspore import context from mindspore import context
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.compression.quant import QuantizationAwareTraining


parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
@@ -42,13 +40,8 @@ if args_opt.device_target == "Ascend":
context.set_context(device_id=device_id) context.set_context(device_id=device_id)


if __name__ == '__main__': if __name__ == '__main__':
# define fusion network
# define manual quantization network
network = resnet50_quant(class_num=config.class_num) network = resnet50_quant(class_num=config.class_num)
# convert fusion network to quantization aware network
quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)


# define network loss # define network loss
if not config.use_label_smooth: if not config.use_label_smooth:


+ 3
- 8
model_zoo/official/cv/resnet50_quant/export.py View File

@@ -19,7 +19,6 @@ import numpy as np


import mindspore import mindspore
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore.compression.quant import QuantizationAwareTraining


from models.resnet_quant_manual import resnet50_quant from models.resnet_quant_manual import resnet50_quant
from src.config import config_quant from src.config import config_quant
@@ -32,13 +31,9 @@ args_opt = parser.parse_args()


if __name__ == '__main__': if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
# define fusion network
# define manual quantization network
network = resnet50_quant(class_num=config_quant.class_num) network = resnet50_quant(class_num=config_quant.class_num)
# convert fusion network to quantization aware network
quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)

# load checkpoint # load checkpoint
if args_opt.checkpoint_path: if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path) param_dict = load_checkpoint(args_opt.checkpoint_path)
@@ -49,5 +44,5 @@ if __name__ == '__main__':
print("============== Starting export ==============") print("============== Starting export ==============")
inputs = Tensor(np.ones([1, 3, 224, 224]), mindspore.float32) inputs = Tensor(np.ones([1, 3, 224, 224]), mindspore.float32)
export(network, inputs, file_name="resnet50_quant", file_format=args_opt.file_format, export(network, inputs, file_name="resnet50_quant", file_format=args_opt.file_format,
quant_mode='MANUAL', mean=0., std_dev=48.106)
quant_mode='QUANT', mean=0., std_dev=48.106)
print("============== End export ==============") print("============== End export ==============")

+ 1
- 10
model_zoo/official/cv/resnet50_quant/train.py View File

@@ -25,14 +25,12 @@ from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint from mindspore.train.serialization import load_checkpoint
from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
from mindspore.communication.management import init from mindspore.communication.management import init
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.initializer as weight_init import mindspore.common.initializer as weight_init
from mindspore.common import set_seed from mindspore.common import set_seed


#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50 from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
@@ -80,7 +78,7 @@ if __name__ == '__main__':
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, all_reduce_fusion_config=[107, 160]) gradients_mean=True, all_reduce_fusion_config=[107, 160])


# define network
# define manual quantization network
net = resnet50_quant(class_num=config.class_num) net = resnet50_quant(class_num=config.class_num)
net.set_train(True) net.set_train(True)


@@ -112,13 +110,6 @@ if __name__ == '__main__':
target=args_opt.device_target) target=args_opt.device_target)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()


# convert fusion network to quantization aware network
quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False],
symmetric=[True, False],
one_conv_fold=False)
net = quantizer.quantize(net)

# get learning rate # get learning rate
lr = get_lr(lr_init=config.lr_init, lr = get_lr(lr_init=config.lr_init,
lr_end=0.0, lr_end=0.0,


+ 3
- 2
model_zoo/official/cv/yolov3_darknet53_quant/export.py View File

@@ -28,7 +28,7 @@ parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="yolov3_darknet53_quant", help="output file name.") parser.add_argument("--file_name", type=str, default="yolov3_darknet53_quant", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='MINDIR', help='file format')
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='MINDIR', help='file format')
args = parser.parse_args() args = parser.parse_args()


context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
@@ -50,4 +50,5 @@ if __name__ == "__main__":
input_data = Tensor(np.zeros(shape), ms.float32) input_data = Tensor(np.zeros(shape), ms.float32)
input_shape = Tensor(tuple(config.test_img_shape), ms.float32) input_shape = Tensor(tuple(config.test_img_shape), ms.float32)


export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format)
export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format,
quant_mode='QUANT', mean=0., std_dev=48.106)

Loading…
Cancel
Save