Merge pull request !5471 from chenfei_mindspore/rm-bool-arg-of-scripttags/v1.0.0
| @@ -38,8 +38,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data", | |||||
| help='path where the dataset is saved') | help='path where the dataset is saved') | ||||
| parser.add_argument('--ckpt_path', type=str, default="", | parser.add_argument('--ckpt_path', type=str, default="", | ||||
| help='if mode is test, must provide path where the trained ckpt file') | help='if mode is test, must provide path where the trained ckpt file') | ||||
| parser.add_argument('--dataset_sink_mode', type=bool, default=True, | |||||
| help='dataset_sink_mode is False or True') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -67,5 +65,5 @@ if __name__ == "__main__": | |||||
| raise ValueError("Load param into net fail!") | raise ValueError("Load param into net fail!") | ||||
| print("============== Starting Testing ==============") | print("============== Starting Testing ==============") | ||||
| acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | |||||
| acc = model.eval(ds_eval, dataset_sink_mode=True) | |||||
| print("============== {} ==============".format(acc)) | print("============== {} ==============".format(acc)) | ||||
| @@ -36,8 +36,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data", | |||||
| help='path where the dataset is saved') | help='path where the dataset is saved') | ||||
| parser.add_argument('--ckpt_path', type=str, default="", | parser.add_argument('--ckpt_path', type=str, default="", | ||||
| help='if mode is test, must provide path where the trained ckpt file') | help='if mode is test, must provide path where the trained ckpt file') | ||||
| parser.add_argument('--dataset_sink_mode', type=bool, default=True, | |||||
| help='dataset_sink_mode is False or True') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -41,8 +41,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data", | |||||
| help='path where the dataset is saved') | help='path where the dataset is saved') | ||||
| parser.add_argument('--ckpt_path', type=str, default="", | parser.add_argument('--ckpt_path', type=str, default="", | ||||
| help='if mode is test, must provide path where the trained ckpt file') | help='if mode is test, must provide path where the trained ckpt file') | ||||
| parser.add_argument('--dataset_sink_mode', type=bool, default=True, | |||||
| help='dataset_sink_mode is False or True') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -76,5 +74,5 @@ if __name__ == "__main__": | |||||
| print("============== Starting Training ==============") | print("============== Starting Training ==============") | ||||
| model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], | model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], | ||||
| dataset_sink_mode=args.dataset_sink_mode) | |||||
| dataset_sink_mode=True) | |||||
| print("============== End Training ==============") | print("============== End Training ==============") | ||||
| @@ -32,7 +32,6 @@ parser = argparse.ArgumentParser(description='Image classification') | |||||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | ||||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | ||||
| parser.add_argument('--device_target', type=str, default=None, help='Run device target') | parser.add_argument('--device_target', type=str, default=None, help='Run device target') | ||||
| parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training') | |||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -51,9 +50,8 @@ if __name__ == '__main__': | |||||
| # define fusion network | # define fusion network | ||||
| network = mobilenetV2(num_classes=config_device_target.num_classes) | network = mobilenetV2(num_classes=config_device_target.num_classes) | ||||
| if args_opt.quantization_aware: | |||||
| # convert fusion network to quantization aware network | |||||
| network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) | |||||
| # convert fusion network to quantization aware network | |||||
| network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) | |||||
| # define network loss | # define network loss | ||||
| loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') | loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') | ||||
| @@ -50,5 +50,4 @@ python ${BASEPATH}/../eval.py \ | |||||
| --device_target=$1 \ | --device_target=$1 \ | ||||
| --dataset_path=$2 \ | --dataset_path=$2 \ | ||||
| --checkpoint_path=$3 \ | --checkpoint_path=$3 \ | ||||
| --quantization_aware=True \ | |||||
| &> infer.log & # dataset val folder path | &> infer.log & # dataset val folder path | ||||