|
|
|
@@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype |
|
|
|
import mindspore.dataset.engine as de |
|
|
|
import mindspore.dataset.transforms.vision.c_transforms as C |
|
|
|
import mindspore.dataset.transforms.c_transforms as C2 |
|
|
|
from mindspore.communication.management import get_rank, get_group_size |
|
|
|
from mindspore.communication.management import init, get_rank, get_group_size |
|
|
|
|
|
|
|
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): |
|
|
|
""" |
|
|
|
@@ -40,6 +40,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=" |
|
|
|
device_num = int(os.getenv("DEVICE_NUM")) |
|
|
|
rank_id = int(os.getenv("RANK_ID")) |
|
|
|
else: |
|
|
|
init("nccl") |
|
|
|
rank_id = get_rank() |
|
|
|
device_num = get_group_size() |
|
|
|
|
|
|
|
|