|
|
|
@@ -22,7 +22,6 @@ import mindspore.dataset.engine as de |
|
|
|
import mindspore.dataset.transforms.vision.c_transforms as C |
|
|
|
import mindspore.dataset.transforms.c_transforms as C2 |
|
|
|
import mindspore.dataset.transforms.vision.py_transforms as P |
|
|
|
from src.config import config_ascend |
|
|
|
|
|
|
|
|
|
|
|
def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32): |
|
|
|
@@ -42,7 +41,7 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, |
|
|
|
rank_size = int(os.getenv("RANK_SIZE")) |
|
|
|
rank_id = int(os.getenv("RANK_ID")) |
|
|
|
columns_list = ['image', 'label'] |
|
|
|
if config_ascend.data_load_mode == "mindrecord": |
|
|
|
if config.data_load_mode == "mindrecord": |
|
|
|
load_func = partial(de.MindDataset, dataset_path, columns_list) |
|
|
|
else: |
|
|
|
load_func = partial(de.ImageFolderDatasetV2, dataset_path) |
|
|
|
@@ -54,6 +53,13 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, |
|
|
|
num_shards=rank_size, shard_id=rank_id) |
|
|
|
else: |
|
|
|
ds = load_func(num_parallel_workers=8, shuffle=False) |
|
|
|
elif device_target == "GPU": |
|
|
|
if do_train: |
|
|
|
from mindspore.communication.management import get_rank, get_group_size |
|
|
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, |
|
|
|
num_shards=get_group_size(), shard_id=get_rank()) |
|
|
|
else: |
|
|
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) |
|
|
|
else: |
|
|
|
raise ValueError("Unsupport device_target.") |
|
|
|
|
|
|
|
|