Browse Source

add resnet50 support multi node training

tags/v0.7.0-beta
gengdongjie 5 years ago
parent
commit
00f7a936bf
2 changed files with 23 additions and 7 deletions
  1. +4
    -1
      model_zoo/official/cv/resnet/scripts/run_distribute_train.sh
  2. +19
    -6
      model_zoo/official/cv/resnet/src/dataset.py

+ 4
- 1
model_zoo/official/cv/resnet/scripts/run_distribute_train.sh View File

@@ -79,10 +79,13 @@ export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
export RANK_TABLE_FILE=$PATH1

export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))

for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i


+ 19
- 6
model_zoo/official/cv/resnet/src/dataset.py View File

@@ -37,8 +37,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
dataset
"""
if target == "Ascend":
device_num = int(os.getenv("DEVICE_NUM"))
rank_id = int(os.getenv("RANK_ID"))
device_num, rank_id = _get_rank_info()
else:
init("nccl")
rank_id = get_rank()
@@ -93,8 +92,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
dataset
"""
if target == "Ascend":
device_num = int(os.getenv("DEVICE_NUM"))
rank_id = int(os.getenv("RANK_ID"))
device_num, rank_id = _get_rank_info()
else:
init("nccl")
rank_id = get_rank()
@@ -153,8 +151,7 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32):
Returns:
dataset
"""
device_num = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
device_num, rank_id = _get_rank_info()

if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
@@ -203,3 +200,19 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32):
ds = ds.repeat(repeat_num)

return ds


def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))

if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0

return rank_size, rank_id

Loading…
Cancel
Save