| @@ -39,6 +39,32 @@ def _send_data_no_flag(dataset, epoch_num): | |||
| exec_dataset.send(epoch_num) | |||
| class _DataWrapper(nn.Cell): | |||
| """ | |||
| Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the | |||
| dataset channel 'queue_name' and performs the forward computation. | |||
| """ | |||
| def __init__(self, network, dataset_types, dataset_shapes, queue_name): | |||
| super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) | |||
| # Also copy the flag in `network` construct | |||
| flags = getattr(network.__class__.construct, "_mindspore_flags", {}) | |||
| self.info = (dataset_types, dataset_shapes) | |||
| self.add_flags(**flags) | |||
| self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) | |||
| self.network = network | |||
| def construct(self): | |||
| outputs = self.get_next() | |||
| return self.network(*outputs) | |||
| def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name): | |||
| if not isinstance(network, _DataWrapper): | |||
| network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) | |||
| return network | |||
| def connect_network_with_dataset(network, dataset_helper): | |||
| """ | |||
| Connect the `network` with dataset in `dataset_helper`. | |||
| @@ -70,24 +96,6 @@ def connect_network_with_dataset(network, dataset_helper): | |||
| >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) | |||
| """ | |||
| class _DataWrapper(nn.Cell): | |||
| """ | |||
| Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the | |||
| dataset channel 'queue_name' and performs the forward computation. | |||
| """ | |||
| def __init__(self, network, dataset_types, dataset_shapes, queue_name): | |||
| super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) | |||
| # Also copy the flag in `network` construct | |||
| flags = getattr(network.__class__.construct, "_mindspore_flags", {}) | |||
| self.add_flags(**flags) | |||
| self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) | |||
| self.network = network | |||
| def construct(self): | |||
| outputs = self.get_next() | |||
| return self.network(*outputs) | |||
| dataset_iter = dataset_helper.iter | |||
| dataset = dataset_iter.dataset | |||
| @@ -98,11 +106,14 @@ def connect_network_with_dataset(network, dataset_helper): | |||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||
| return network | |||
| if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ | |||
| and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ | |||
| and context.get_context("device_target") == "Ascend" \ | |||
| and context.get_context("mode") == context.GRAPH_MODE \ | |||
| and ms_role != "MS_WORKER": | |||
| queue_name = dataset.__transfer_dataset__.queue_name | |||
| if hasattr(dataset_iter, "sink_size") and \ | |||
| dataset_iter.sink_size == 1 and \ | |||
| hasattr(dataset_iter, "sink_count") and \ | |||
| dataset_iter.sink_count == 1 and \ | |||
| context.get_context("device_target") == "Ascend" and \ | |||
| context.get_context("mode") == context.GRAPH_MODE and \ | |||
| ms_role != "MS_WORKER": | |||
| if not hasattr(dataset_iter, '__network__'): | |||
| dataset_iter.__network__ = network | |||
| @@ -118,21 +129,19 @@ def connect_network_with_dataset(network, dataset_helper): | |||
| if _need_to_full(): | |||
| device_num = _get_device_num() | |||
| dataset_shapes = _to_full_shapes(dataset_shapes, device_num) | |||
| network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name) | |||
| network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name) | |||
| dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr( | |||
| dataset_iter, '__network_manage__') else dict() | |||
| dataset_iter.__network_manage__[key] = network | |||
| return network | |||
| if not hasattr(dataset, '__me_inited__') and context.get_context("device_target") in ("Ascend", "GPU")\ | |||
| and not context.get_context("enable_ge"): | |||
| if not hasattr(dataset, '__me_inited__') and \ | |||
| not context.get_context("enable_ge") and \ | |||
| context.get_context("device_target") in ("Ascend", "GPU"): | |||
| dataset.__me_inited__ = True | |||
| dataset_types, dataset_shapes = dataset_helper.types_shapes() | |||
| queue_name = dataset.__transfer_dataset__.queue_name | |||
| network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) | |||
| network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name) | |||
| return network | |||