|
|
|
@@ -28,8 +28,8 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch |
|
|
|
Args: |
|
|
|
dataset_path(string): the path of dataset. |
|
|
|
do_train(bool): whether dataset is used for train or eval. |
|
|
|
repeat_num(int): the repeat times of dataset. Default: 1 |
|
|
|
batch_size(int): the batch size of dataset. Default: 32 |
|
|
|
repeat_num(int): the repeat times of dataset. Default: 1. |
|
|
|
batch_size(int): the batch size of dataset. Default: 32. |
|
|
|
|
|
|
|
Returns: |
|
|
|
dataset |
|
|
|
@@ -43,9 +43,12 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch |
|
|
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, |
|
|
|
num_shards=rank_size, shard_id=rank_id) |
|
|
|
elif platform == "GPU": |
|
|
|
from mindspore.communication.management import get_rank, get_group_size |
|
|
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, |
|
|
|
num_shards=get_group_size(), shard_id=get_rank()) |
|
|
|
if do_train: |
|
|
|
from mindspore.communication.management import get_rank, get_group_size |
|
|
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, |
|
|
|
num_shards=get_group_size(), shard_id=get_rank()) |
|
|
|
else: |
|
|
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) |
|
|
|
else: |
|
|
|
raise ValueError("Unsupport platform.") |
|
|
|
|
|
|
|
|