diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index d38eda2468..e42ca26715 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -39,6 +39,32 @@ def _send_data_no_flag(dataset, epoch_num): exec_dataset.send(epoch_num) +class _DataWrapper(nn.Cell): + """ + Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the + dataset channel 'queue_name' and performs the forward computation. + """ + + def __init__(self, network, dataset_types, dataset_shapes, queue_name): + super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) + # Also copy the flag in `network` construct + flags = getattr(network.__class__.construct, "_mindspore_flags", {}) + self.info = (dataset_types, dataset_shapes) + self.add_flags(**flags) + self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) + self.network = network + + def construct(self): + outputs = self.get_next() + return self.network(*outputs) + + +def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name): + if not isinstance(network, _DataWrapper): + network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) + return network + + def connect_network_with_dataset(network, dataset_helper): """ Connect the `network` with dataset in `dataset_helper`. @@ -70,24 +96,6 @@ def connect_network_with_dataset(network, dataset_helper): >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) """ - class _DataWrapper(nn.Cell): - """ - Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the - dataset channel 'queue_name' and performs the forward computation. - """ - - def __init__(self, network, dataset_types, dataset_shapes, queue_name): - super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) - # Also copy the flag in `network` construct - flags = getattr(network.__class__.construct, "_mindspore_flags", {}) - self.add_flags(**flags) - self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) - self.network = network - - def construct(self): - outputs = self.get_next() - return self.network(*outputs) - dataset_iter = dataset_helper.iter dataset = dataset_iter.dataset @@ -98,11 +106,14 @@ def connect_network_with_dataset(network, dataset_helper): if ms_role in ("MS_PSERVER", "MS_SCHED"): return network - if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ - and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ - and context.get_context("device_target") == "Ascend" \ - and context.get_context("mode") == context.GRAPH_MODE \ - and ms_role != "MS_WORKER": + queue_name = dataset.__transfer_dataset__.queue_name + if hasattr(dataset_iter, "sink_size") and \ + dataset_iter.sink_size == 1 and \ + hasattr(dataset_iter, "sink_count") and \ + dataset_iter.sink_count == 1 and \ + context.get_context("device_target") == "Ascend" and \ + context.get_context("mode") == context.GRAPH_MODE and \ + ms_role != "MS_WORKER": if not hasattr(dataset_iter, '__network__'): dataset_iter.__network__ = network @@ -118,21 +129,19 @@ def connect_network_with_dataset(network, dataset_helper): if _need_to_full(): device_num = _get_device_num() dataset_shapes = _to_full_shapes(dataset_shapes, device_num) - network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name) + + network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name) dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr( dataset_iter, '__network_manage__') else dict() dataset_iter.__network_manage__[key] = network - return network - if not hasattr(dataset, '__me_inited__') and context.get_context("device_target") in ("Ascend", "GPU")\ - and not context.get_context("enable_ge"): + if not hasattr(dataset, '__me_inited__') and \ + not context.get_context("enable_ge") and \ + context.get_context("device_target") in ("Ascend", "GPU"): dataset.__me_inited__ = True - dataset_types, dataset_shapes = dataset_helper.types_shapes() - queue_name = dataset.__transfer_dataset__.queue_name - - network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) + network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name) return network