Browse Source

!9480 fix eval symmetric bug

From: @xiaoyisd
Reviewed-by: @liangchenghui,@chujinjin
Signed-off-by: @liangchenghui
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
57da31bfdd
1 changed files with 3 additions and 1 deletions
  1. +3
    -1
      model_zoo/official/cv/mobilenetv2_quant/eval.py

+ 3
- 1
model_zoo/official/cv/mobilenetv2_quant/eval.py View File

@@ -41,10 +41,12 @@ if __name__ == '__main__':
config_device_target = config_ascend_quant config_device_target = config_ascend_quant
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False) device_id=device_id, save_graphs=False)
symmetric_list = [True, False]
elif args_opt.device_target == "GPU": elif args_opt.device_target == "GPU":
config_device_target = config_gpu_quant config_device_target = config_gpu_quant
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", context.set_context(mode=context.GRAPH_MODE, device_target="GPU",
device_id=device_id, save_graphs=False) device_id=device_id, save_graphs=False)
symmetric_list = [False, False]
else: else:
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target)) raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))


@@ -53,7 +55,7 @@ if __name__ == '__main__':
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
quantizer = QuantizationAwareTraining(bn_fold=True, quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False], per_channel=[True, False],
symmetric=[True, False])
symmetric=symmetric_list)
network = quantizer.quantize(network) network = quantizer.quantize(network)
# define network loss # define network loss
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')


Loading…
Cancel
Save