|
|
|
@@ -19,9 +19,9 @@ import numpy as np |
|
|
|
import mindspore.context as context |
|
|
|
from mindspore.communication.management import init |
|
|
|
from mindspore.train.dataset_helper import DatasetHelper |
|
|
|
from mindspore.communication._comm_helper import GlobalComm |
|
|
|
from ....dataset_mock import MindData |
|
|
|
|
|
|
|
|
|
|
|
def get_dataset(batch_size=1): |
|
|
|
dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32) |
|
|
|
dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1), |
|
|
|
@@ -75,7 +75,9 @@ def test_dataset_iter_normal(): |
|
|
|
|
|
|
|
@pytest.mark.skipif('not context.get_context("enable_ge")') |
|
|
|
def test_dataset_iter_ge(): |
|
|
|
GlobalComm.CHECK_ENVS = False |
|
|
|
init("hccl") |
|
|
|
GlobalComm.CHECK_ENVS = True |
|
|
|
dataset = get_dataset(32) |
|
|
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) |
|
|
|
count = 0 |
|
|
|
@@ -87,7 +89,9 @@ def test_dataset_iter_ge(): |
|
|
|
|
|
|
|
@pytest.mark.skipif('context.get_context("enable_ge")') |
|
|
|
def test_dataset_iter_ms_loop_sink(): |
|
|
|
GlobalComm.CHECK_ENVS = False |
|
|
|
init("hccl") |
|
|
|
GlobalComm.CHECK_ENVS = True |
|
|
|
context.set_context(enable_loop_sink=True) |
|
|
|
dataset = get_dataset(32) |
|
|
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) |
|
|
|
@@ -101,7 +105,9 @@ def test_dataset_iter_ms_loop_sink(): |
|
|
|
|
|
|
|
@pytest.mark.skipif('context.get_context("enable_ge")') |
|
|
|
def test_dataset_iter_ms(): |
|
|
|
GlobalComm.CHECK_ENVS = False |
|
|
|
init("hccl") |
|
|
|
GlobalComm.CHECK_ENVS = True |
|
|
|
context.set_context(enable_loop_sink=False) |
|
|
|
dataset = get_dataset(32) |
|
|
|
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) |