| @@ -39,6 +39,32 @@ def _send_data_no_flag(dataset, epoch_num): | |||||
| exec_dataset.send(epoch_num) | exec_dataset.send(epoch_num) | ||||
| class _DataWrapper(nn.Cell): | |||||
| """ | |||||
| Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the | |||||
| dataset channel 'queue_name' and performs the forward computation. | |||||
| """ | |||||
| def __init__(self, network, dataset_types, dataset_shapes, queue_name): | |||||
| super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) | |||||
| # Also copy the flag in `network` construct | |||||
| flags = getattr(network.__class__.construct, "_mindspore_flags", {}) | |||||
| self.info = (dataset_types, dataset_shapes) | |||||
| self.add_flags(**flags) | |||||
| self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) | |||||
| self.network = network | |||||
| def construct(self): | |||||
| outputs = self.get_next() | |||||
| return self.network(*outputs) | |||||
| def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name): | |||||
| if not isinstance(network, _DataWrapper): | |||||
| network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) | |||||
| return network | |||||
| def connect_network_with_dataset(network, dataset_helper): | def connect_network_with_dataset(network, dataset_helper): | ||||
| """ | """ | ||||
| Connect the `network` with dataset in `dataset_helper`. | Connect the `network` with dataset in `dataset_helper`. | ||||
| @@ -70,24 +96,6 @@ def connect_network_with_dataset(network, dataset_helper): | |||||
| >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) | >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) | ||||
| """ | """ | ||||
| class _DataWrapper(nn.Cell): | |||||
| """ | |||||
| Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the | |||||
| dataset channel 'queue_name' and performs the forward computation. | |||||
| """ | |||||
| def __init__(self, network, dataset_types, dataset_shapes, queue_name): | |||||
| super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) | |||||
| # Also copy the flag in `network` construct | |||||
| flags = getattr(network.__class__.construct, "_mindspore_flags", {}) | |||||
| self.add_flags(**flags) | |||||
| self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) | |||||
| self.network = network | |||||
| def construct(self): | |||||
| outputs = self.get_next() | |||||
| return self.network(*outputs) | |||||
| dataset_iter = dataset_helper.iter | dataset_iter = dataset_helper.iter | ||||
| dataset = dataset_iter.dataset | dataset = dataset_iter.dataset | ||||
| @@ -98,11 +106,14 @@ def connect_network_with_dataset(network, dataset_helper): | |||||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | if ms_role in ("MS_PSERVER", "MS_SCHED"): | ||||
| return network | return network | ||||
| if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ | |||||
| and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ | |||||
| and context.get_context("device_target") == "Ascend" \ | |||||
| and context.get_context("mode") == context.GRAPH_MODE \ | |||||
| and ms_role != "MS_WORKER": | |||||
| queue_name = dataset.__transfer_dataset__.queue_name | |||||
| if hasattr(dataset_iter, "sink_size") and \ | |||||
| dataset_iter.sink_size == 1 and \ | |||||
| hasattr(dataset_iter, "sink_count") and \ | |||||
| dataset_iter.sink_count == 1 and \ | |||||
| context.get_context("device_target") == "Ascend" and \ | |||||
| context.get_context("mode") == context.GRAPH_MODE and \ | |||||
| ms_role != "MS_WORKER": | |||||
| if not hasattr(dataset_iter, '__network__'): | if not hasattr(dataset_iter, '__network__'): | ||||
| dataset_iter.__network__ = network | dataset_iter.__network__ = network | ||||
| @@ -118,21 +129,19 @@ def connect_network_with_dataset(network, dataset_helper): | |||||
| if _need_to_full(): | if _need_to_full(): | ||||
| device_num = _get_device_num() | device_num = _get_device_num() | ||||
| dataset_shapes = _to_full_shapes(dataset_shapes, device_num) | dataset_shapes = _to_full_shapes(dataset_shapes, device_num) | ||||
| network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name) | |||||
| network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name) | |||||
| dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr( | dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr( | ||||
| dataset_iter, '__network_manage__') else dict() | dataset_iter, '__network_manage__') else dict() | ||||
| dataset_iter.__network_manage__[key] = network | dataset_iter.__network_manage__[key] = network | ||||
| return network | return network | ||||
| if not hasattr(dataset, '__me_inited__') and context.get_context("device_target") in ("Ascend", "GPU")\ | |||||
| and not context.get_context("enable_ge"): | |||||
| if not hasattr(dataset, '__me_inited__') and \ | |||||
| not context.get_context("enable_ge") and \ | |||||
| context.get_context("device_target") in ("Ascend", "GPU"): | |||||
| dataset.__me_inited__ = True | dataset.__me_inited__ = True | ||||
| dataset_types, dataset_shapes = dataset_helper.types_shapes() | dataset_types, dataset_shapes = dataset_helper.types_shapes() | ||||
| queue_name = dataset.__transfer_dataset__.queue_name | |||||
| network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) | |||||
| network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name) | |||||
| return network | return network | ||||