|
|
|
@@ -37,8 +37,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target= |
|
|
|
dataset |
|
|
|
""" |
|
|
|
if target == "Ascend": |
|
|
|
device_num = int(os.getenv("DEVICE_NUM")) |
|
|
|
rank_id = int(os.getenv("RANK_ID")) |
|
|
|
device_num, rank_id = _get_rank_info() |
|
|
|
else: |
|
|
|
init("nccl") |
|
|
|
rank_id = get_rank() |
|
|
|
@@ -93,8 +92,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= |
|
|
|
dataset |
|
|
|
""" |
|
|
|
if target == "Ascend": |
|
|
|
device_num = int(os.getenv("DEVICE_NUM")) |
|
|
|
rank_id = int(os.getenv("RANK_ID")) |
|
|
|
device_num, rank_id = _get_rank_info() |
|
|
|
else: |
|
|
|
init("nccl") |
|
|
|
rank_id = get_rank() |
|
|
|
@@ -153,8 +151,7 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32): |
|
|
|
Returns: |
|
|
|
dataset |
|
|
|
""" |
|
|
|
device_num = int(os.getenv("RANK_SIZE")) |
|
|
|
rank_id = int(os.getenv("RANK_ID")) |
|
|
|
device_num, rank_id = _get_rank_info() |
|
|
|
|
|
|
|
if device_num == 1: |
|
|
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) |
|
|
|
@@ -203,3 +200,19 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32): |
|
|
|
ds = ds.repeat(repeat_num) |
|
|
|
|
|
|
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
def _get_rank_info(): |
|
|
|
""" |
|
|
|
get rank size and rank id |
|
|
|
""" |
|
|
|
rank_size = int(os.environ.get("RANK_SIZE", 1)) |
|
|
|
|
|
|
|
if rank_size > 1: |
|
|
|
rank_size = get_group_size() |
|
|
|
rank_id = get_rank() |
|
|
|
else: |
|
|
|
rank_size = 1 |
|
|
|
rank_id = 0 |
|
|
|
|
|
|
|
return rank_size, rank_id |