Browse Source

!3791 modelzoo: repair vgg distribute training problem

Merge pull request !3791 from ms_yan/vgg_8p_D
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
b55e5e2ce2
2 changed files with 5 additions and 4 deletions
  1. +1
    -1
      model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh
  2. +4
    -3
      model_zoo/official/cv/vgg16/train.py

+ 1
- 1
model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh View File

@@ -47,6 +47,6 @@ do
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i &> log &
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 &> log &
cd ..
done

+ 4
- 3
model_zoo/official/cv/vgg16/train.py View File

@@ -191,12 +191,13 @@ if __name__ == '__main__':
if args.is_distributed:
if args.device_target == "Ascend":
init()
context.set_context(device_id=args.device_id)
elif args.device_target == "GPU":
init("nccl")
args.rank = get_rank()
args.group_size = get_group_size()
device_num = args.group_size

args.rank = get_rank()
args.group_size = get_group_size()
device_num = args.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)


Loading…
Cancel
Save