@@ -16,12 +16,15 @@
create train or eval dataset.
"""
import os
from multiprocessing import cpu_count
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import init, get_rank, get_group_size
THREAD_NUM = 12 if cpu_count() >= 12 else 8
def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
@@ -38,15 +41,17 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else :
elif target == "GPU" :
init()
rank_id = get_rank()
device_num = get_group_size()
else:
device_num = 1
if device_num == 1:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12 , shuffle=True)
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=THREAD_NUM , shuffle=True)
else:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12 , shuffle=True,
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=THREAD_NUM , shuffle=True,
num_shards=device_num, shard_id=rank_id)
# define map operations
@@ -66,8 +71,8 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12 )
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12 )
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=THREAD_NUM )
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=THREAD_NUM )
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
@@ -99,9 +104,9 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
device_num = get_group_size()
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12 , shuffle=True)
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=THREAD_NUM , shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12 , shuffle=True,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=THREAD_NUM , shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
@@ -127,8 +132,8 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12 )
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12 )
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=THREAD_NUM )
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=THREAD_NUM )
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)