diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 7eeea138f4..b0bf1a2f47 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -99,7 +99,8 @@ def connect_network_with_dataset(network, dataset_helper): if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ and context.get_context("device_target") == "Ascend" \ - and context.get_context("mode") == context.GRAPH_MODE: + and context.get_context("mode") == context.GRAPH_MODE \ + and ms_role != "MS_WORKER": if not hasattr(dataset, '__network__'): dataset.__network__ = network