Merge pull request !4714 from chenzhongming/new_mastertags/v0.7.0-beta
| @@ -51,6 +51,8 @@ class _Conv(Cell): | |||
| self.kernel_size = kernel_size | |||
| self.stride = stride | |||
| self.pad_mode = pad_mode | |||
| self.weight_init = weight_init | |||
| self.bias_init = bias_init | |||
| if isinstance(padding, int): | |||
| Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name) | |||
| self.padding = padding | |||
| @@ -85,12 +87,12 @@ class _Conv(Cell): | |||
| shape = [in_channels, out_channels // group, *kernel_size] | |||
| else: | |||
| shape = [out_channels, in_channels // group, *kernel_size] | |||
| self.weight = Parameter(initializer(weight_init, shape), name='weight') | |||
| self.weight = Parameter(initializer(self.weight_init, shape), name='weight') | |||
| if check_bool(has_bias): | |||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') | |||
| self.bias = Parameter(initializer(self.bias_init, [out_channels]), name='bias') | |||
| else: | |||
| if bias_init != 'zeros': | |||
| if self.bias_init != 'zeros': | |||
| logger.warning("Value of 'has_bias' is False, value of 'bias_init' will be ignored.") | |||
| self.bias = None | |||
| @@ -249,11 +251,8 @@ class Conv2d(_Conv): | |||
| self.dilation, | |||
| self.group, | |||
| self.has_bias, | |||
| self.weight, | |||
| self.bias) | |||
| if self.has_bias: | |||
| s += ', bias={}'.format(self.bias) | |||
| self.weight_init, | |||
| self.bias_init) | |||
| return s | |||
| @@ -431,11 +430,8 @@ class Conv1d(_Conv): | |||
| self.dilation, | |||
| self.group, | |||
| self.has_bias, | |||
| self.weight, | |||
| self.bias) | |||
| if self.has_bias: | |||
| s += ', bias={}'.format(self.bias) | |||
| self.weight_init, | |||
| self.bias_init) | |||
| return s | |||
| @@ -605,8 +601,8 @@ class Conv2dTranspose(_Conv): | |||
| self.dilation, | |||
| self.group, | |||
| self.has_bias, | |||
| self.weight, | |||
| self.bias) | |||
| self.weight_init, | |||
| self.bias_init) | |||
| return s | |||
| @@ -788,8 +784,8 @@ class Conv1dTranspose(_Conv): | |||
| self.dilation, | |||
| self.group, | |||
| self.has_bias, | |||
| self.weight, | |||
| self.bias) | |||
| self.weight_init, | |||
| self.bias_init) | |||
| return s | |||
| @@ -30,7 +30,7 @@ from src.mobilenetV2 import mobilenet_v2 | |||
| parser = argparse.ArgumentParser(description='Image classification') | |||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | |||
| parser.add_argument('--device_targe', type=str, default=None, help='run device_targe') | |||
| parser.add_argument('--device_target', type=str, default=None, help='run device_target') | |||
| args_opt = parser.parse_args() | |||
| @@ -73,6 +73,7 @@ run_gpu() | |||
| mpirun -n $2 --allow-run-as-root \ | |||
| python ${BASEPATH}/../train.py \ | |||
| --dataset_path=$4 \ | |||
| --pre_trained=$5 \ | |||
| --device_target=$1 \ | |||
| &> ../train.log & # dataset train folder | |||
| } | |||
| @@ -81,7 +82,7 @@ if [ $# -gt 6 ] || [ $# -lt 4 ] | |||
| then | |||
| echo "Usage:\n \ | |||
| Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH]\n \ | |||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ | |||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ | |||
| " | |||
| exit 1 | |||
| fi | |||
| @@ -49,10 +49,10 @@ de.config.set_seed(1) | |||
| parser = argparse.ArgumentParser(description='Image classification') | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | |||
| parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') | |||
| parser.add_argument('--device_targe', type=str, default=None, help='run device_targe') | |||
| parser.add_argument('--device_target', type=str, default=None, help='run device_target') | |||
| args_opt = parser.parse_args() | |||
| if args_opt.device_targe == "Ascend": | |||
| if args_opt.device_target == "Ascend": | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| rank_id = int(os.getenv('RANK_ID')) | |||
| rank_size = int(os.getenv('RANK_SIZE')) | |||
| @@ -61,7 +61,7 @@ if args_opt.device_targe == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="Ascend", | |||
| device_id=device_id, save_graphs=False) | |||
| elif args_opt.device_targe == "GPU": | |||
| elif args_opt.device_target == "GPU": | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="GPU", | |||
| save_graphs=False) | |||
| @@ -161,13 +161,13 @@ class Monitor(Callback): | |||
| if __name__ == '__main__': | |||
| if args_opt.device_targe == "GPU": | |||
| if args_opt.device_target == "GPU": | |||
| # train on gpu | |||
| print("train args: ", args_opt) | |||
| print("cfg: ", config_gpu) | |||
| # define network | |||
| net = mobilenet_v2(num_classes=config_gpu.num_classes, device_targe="GPU") | |||
| net = mobilenet_v2(num_classes=config_gpu.num_classes, device_target="GPU") | |||
| # define loss | |||
| if config_gpu.label_smooth > 0: | |||
| loss = CrossEntropyWithLabelSmooth(smooth_factor=config_gpu.label_smooth, | |||
| @@ -179,7 +179,7 @@ if __name__ == '__main__': | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, | |||
| do_train=True, | |||
| config=config_gpu, | |||
| device_targe=args_opt.device_targe, | |||
| device_target=args_opt.device_target, | |||
| repeat_num=1, | |||
| batch_size=config_gpu.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| @@ -216,7 +216,7 @@ if __name__ == '__main__': | |||
| # begin train | |||
| model.train(epoch_size, dataset, callbacks=cb) | |||
| print("============== End Training ==============") | |||
| elif args_opt.device_targe == "Ascend": | |||
| elif args_opt.device_target == "Ascend": | |||
| # train on ascend | |||
| print("train args: ", args_opt, "\ncfg: ", config_ascend, | |||
| "\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) | |||
| @@ -228,7 +228,7 @@ if __name__ == '__main__': | |||
| init() | |||
| epoch_size = config_ascend.epoch_size | |||
| net = mobilenet_v2(num_classes=config_ascend.num_classes, device_targe="Ascend") | |||
| net = mobilenet_v2(num_classes=config_ascend.num_classes, device_target="Ascend") | |||
| net.to_float(mstype.float16) | |||
| for _, cell in net.cells_and_names(): | |||
| if isinstance(cell, nn.Dense): | |||
| @@ -242,7 +242,7 @@ if __name__ == '__main__': | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, | |||
| do_train=True, | |||
| config=config_ascend, | |||
| device_targe=args_opt.device_targe, | |||
| device_target=args_opt.device_target, | |||
| repeat_num=1, | |||
| batch_size=config_ascend.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| @@ -276,4 +276,4 @@ if __name__ == '__main__': | |||
| cb += [ckpt_cb] | |||
| model.train(epoch_size, dataset, callbacks=cb) | |||
| else: | |||
| raise ValueError("Unsupported device_targe.") | |||
| raise ValueError("Unsupported device_target.") | |||
| @@ -27,8 +27,8 @@ Dataset used: [imagenet](http://www.image-net.org/) | |||
| # Environment Requirements | |||
| - Hardware(Ascend/GPU) | |||
| - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Hardware(GPU) | |||
| - Prepare hardware environment with GPU processor. | |||
| - Framework | |||
| - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) | |||
| - For more information, please check the resources below: | |||
| @@ -60,14 +60,12 @@ Dataset used: [imagenet](http://www.image-net.org/) | |||
| ### Usage | |||
| - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] | |||
| - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] | |||
| ### Launch | |||
| ``` | |||
| # training example | |||
| Ascend: sh run_train.sh Ascend 8 192.168.0.1 0,1,2,3,4,5,6,7 ~/imagenet/train/ | |||
| GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ | |||
| ``` | |||
| @@ -86,14 +84,12 @@ epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 | |||
| ### Usage | |||
| - Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] | |||
| - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] | |||
| ### Launch | |||
| ``` | |||
| # infer example | |||
| Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt | |||
| GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt | |||
| ``` | |||