| @@ -52,21 +52,25 @@ cpus=`cat /proc/cpuinfo| grep "processor"| wc -l` | |||||
| avg=`expr $cpus \/ $RANK_SIZE` | avg=`expr $cpus \/ $RANK_SIZE` | ||||
| gap=`expr $avg \- 1` | gap=`expr $avg \- 1` | ||||
| script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" | |||||
| src_dir=$script_dir/.. | |||||
| start_idx=0 | |||||
| for((i=0;i<RANK_SIZE;i++)) | for((i=0;i<RANK_SIZE;i++)) | ||||
| do | do | ||||
| start=`expr $i \* $avg` | start=`expr $i \* $avg` | ||||
| end=`expr $start \+ $gap` | end=`expr $start \+ $gap` | ||||
| cmdopt=$start"-"$end | cmdopt=$start"-"$end | ||||
| export DEVICE_ID=$i | |||||
| export DEVICE_ID=`expr $i \+ $start_idx` | |||||
| export RANK_ID=$i | export RANK_ID=$i | ||||
| rm -rf ./train_parallel$i | |||||
| mkdir ./train_parallel$i | |||||
| cp *.py ./train_parallel$i | |||||
| cp -r src ./train_parallel$i | |||||
| cd ./train_parallel$i || exit | |||||
| rm -rf ./train_parallel$DEVICE_ID | |||||
| mkdir ./train_parallel$DEVICE_ID | |||||
| cp $src_dir/*.py ./train_parallel$DEVICE_ID | |||||
| cp -r $src_dir/src ./train_parallel$DEVICE_ID | |||||
| cd ./train_parallel$DEVICE_ID || exit | |||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type" | echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type" | ||||
| env > env.log | env > env.log | ||||
| taskset -c $cmdopt python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 --dataset=$dataset_type &> log & | |||||
| taskset -c $cmdopt python train.py --data_path=$2 --device_target="Ascend" --device_id=$DEVICE_ID --is_distributed=1 --dataset=$dataset_type &> log & | |||||
| cd .. | cd .. | ||||
| done | done | ||||
| @@ -47,14 +47,14 @@ cifar_cfg = edict({ | |||||
| # config for vgg16, imagenet2012 | # config for vgg16, imagenet2012 | ||||
| imagenet_cfg = edict({ | imagenet_cfg = edict({ | ||||
| "num_classes": 1000, | "num_classes": 1000, | ||||
| "lr": 0.01, | |||||
| "lr": 0.04, | |||||
| "lr_init": 0.01, | "lr_init": 0.01, | ||||
| "lr_max": 0.1, | "lr_max": 0.1, | ||||
| "lr_epochs": '30,60,90,120', | "lr_epochs": '30,60,90,120', | ||||
| "lr_scheduler": 'cosine_annealing', | "lr_scheduler": 'cosine_annealing', | ||||
| "warmup_epochs": 0, | "warmup_epochs": 0, | ||||
| "batch_size": 32, | |||||
| "max_epoch": 150, | |||||
| "batch_size": 64, | |||||
| "max_epoch": 90, | |||||
| "momentum": 0.9, | "momentum": 0.9, | ||||
| "weight_decay": 1e-4, | "weight_decay": 1e-4, | ||||
| "loss_scale": 1024, | "loss_scale": 1024, | ||||
| @@ -61,7 +61,7 @@ def parse_args(cloud_args=None): | |||||
| parser.add_argument('--lr_gamma', type=float, default=0.1, | parser.add_argument('--lr_gamma', type=float, default=0.1, | ||||
| help='decrease lr by a factor of exponential lr_scheduler') | help='decrease lr by a factor of exponential lr_scheduler') | ||||
| parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') | parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') | ||||
| parser.add_argument('--T_max', type=int, default=150, help='T-max in cosine_annealing scheduler') | |||||
| parser.add_argument('--T_max', type=int, default=90, help='T-max in cosine_annealing scheduler') | |||||
| # logging and checkpoint related | # logging and checkpoint related | ||||
| parser.add_argument('--log_interval', type=int, default=100, help='logging interval') | parser.add_argument('--log_interval', type=int, default=100, help='logging interval') | ||||
| @@ -140,7 +140,7 @@ if __name__ == '__main__': | |||||
| device_num = args.group_size | device_num = args.group_size | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| gradients_mean=True, all_reduce_fusion_config=[3, 10, 12, 15]) | |||||
| gradients_mean=True, all_reduce_fusion_config=[2, 18]) | |||||
| else: | else: | ||||
| if args.device_target == "Ascend": | if args.device_target == "Ascend": | ||||
| context.set_context(device_id=args.device_id) | context.set_context(device_id=args.device_id) | ||||