Browse Source

fix wrap twice problem

tags/v1.2.0-rc1
xiefangqi 5 years ago
parent
commit
edce230586
1 changed files with 17 additions and 8 deletions
  1. +17
    -8
      mindspore/train/dataset_helper.py

+ 17
- 8
mindspore/train/dataset_helper.py View File

@@ -39,6 +39,22 @@ def _send_data_no_flag(dataset, epoch_num):
exec_dataset.send(epoch_num)


def _dynamic_sink_scenario(dataset, dataset_iter):
"""Special scenario with dynamic shape and sink_size=1."""
flag = False
ms_role = os.getenv("MS_ROLE")
if hasattr(dataset_iter, "sink_size") and \
dataset_iter.sink_size == 1 and \
dataset.get_dataset_size() != 1 and \
hasattr(dataset_iter, "sink_count") and \
dataset_iter.sink_count == 1 and \
context.get_context("device_target") == "Ascend" and \
context.get_context("mode") == context.GRAPH_MODE and \
ms_role != "MS_WORKER":
flag = True
return flag


class _DataWrapper(nn.Cell):
"""
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
@@ -107,14 +123,7 @@ def connect_network_with_dataset(network, dataset_helper):
return network

queue_name = dataset.__transfer_dataset__.queue_name
if hasattr(dataset_iter, "sink_size") and \
dataset_iter.sink_size == 1 and \
hasattr(dataset_iter, "sink_count") and \
dataset_iter.sink_count == 1 and \
context.get_context("device_target") == "Ascend" and \
context.get_context("mode") == context.GRAPH_MODE and \
ms_role != "MS_WORKER":

if _dynamic_sink_scenario(dataset, dataset_iter):
if not hasattr(dataset_iter, '__network__'):
dataset_iter.__network__ = network
network = dataset_iter.__network__


Loading…
Cancel
Save