|
|
|
@@ -41,10 +41,12 @@ if __name__ == '__main__': |
|
|
|
config_device_target = config_ascend_quant |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", |
|
|
|
device_id=device_id, save_graphs=False) |
|
|
|
symmetric_list = [True, False] |
|
|
|
elif args_opt.device_target == "GPU": |
|
|
|
config_device_target = config_gpu_quant |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", |
|
|
|
device_id=device_id, save_graphs=False) |
|
|
|
symmetric_list = [False, False] |
|
|
|
else: |
|
|
|
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target)) |
|
|
|
|
|
|
|
@@ -53,7 +55,7 @@ if __name__ == '__main__': |
|
|
|
# convert fusion network to quantization aware network |
|
|
|
quantizer = QuantizationAwareTraining(bn_fold=True, |
|
|
|
per_channel=[True, False], |
|
|
|
symmetric=[True, False]) |
|
|
|
symmetric=symmetric_list) |
|
|
|
network = quantizer.quantize(network) |
|
|
|
# define network loss |
|
|
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') |
|
|
|
|