|
|
@@ -99,7 +99,8 @@ def connect_network_with_dataset(network, dataset_helper): |
|
|
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ |
|
|
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ |
|
|
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ |
|
|
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ |
|
|
and context.get_context("device_target") == "Ascend" \ |
|
|
and context.get_context("device_target") == "Ascend" \ |
|
|
and context.get_context("mode") == context.GRAPH_MODE: |
|
|
|
|
|
|
|
|
and context.get_context("mode") == context.GRAPH_MODE \ |
|
|
|
|
|
and ms_role != "MS_WORKER": |
|
|
|
|
|
|
|
|
if not hasattr(dataset, '__network__'): |
|
|
if not hasattr(dataset, '__network__'): |
|
|
dataset.__network__ = network |
|
|
dataset.__network__ = network |
|
|
|