|
|
|
@@ -39,6 +39,32 @@ def _send_data_no_flag(dataset, epoch_num): |
|
|
|
exec_dataset.send(epoch_num) |
|
|
|
|
|
|
|
|
|
|
|
class _DataWrapper(nn.Cell): |
|
|
|
""" |
|
|
|
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the |
|
|
|
dataset channel 'queue_name' and performs the forward computation. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, network, dataset_types, dataset_shapes, queue_name): |
|
|
|
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) |
|
|
|
# Also copy the flag in `network` construct |
|
|
|
flags = getattr(network.__class__.construct, "_mindspore_flags", {}) |
|
|
|
self.info = (dataset_types, dataset_shapes) |
|
|
|
self.add_flags(**flags) |
|
|
|
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) |
|
|
|
self.network = network |
|
|
|
|
|
|
|
def construct(self): |
|
|
|
outputs = self.get_next() |
|
|
|
return self.network(*outputs) |
|
|
|
|
|
|
|
|
|
|
|
def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name): |
|
|
|
if not isinstance(network, _DataWrapper): |
|
|
|
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) |
|
|
|
return network |
|
|
|
|
|
|
|
|
|
|
|
def connect_network_with_dataset(network, dataset_helper): |
|
|
|
""" |
|
|
|
Connect the `network` with dataset in `dataset_helper`. |
|
|
|
@@ -70,24 +96,6 @@ def connect_network_with_dataset(network, dataset_helper): |
|
|
|
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) |
|
|
|
""" |
|
|
|
|
|
|
|
class _DataWrapper(nn.Cell): |
|
|
|
""" |
|
|
|
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the |
|
|
|
dataset channel 'queue_name' and performs the forward computation. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, network, dataset_types, dataset_shapes, queue_name): |
|
|
|
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) |
|
|
|
# Also copy the flag in `network` construct |
|
|
|
flags = getattr(network.__class__.construct, "_mindspore_flags", {}) |
|
|
|
self.add_flags(**flags) |
|
|
|
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) |
|
|
|
self.network = network |
|
|
|
|
|
|
|
def construct(self): |
|
|
|
outputs = self.get_next() |
|
|
|
return self.network(*outputs) |
|
|
|
|
|
|
|
dataset_iter = dataset_helper.iter |
|
|
|
dataset = dataset_iter.dataset |
|
|
|
|
|
|
|
@@ -98,11 +106,14 @@ def connect_network_with_dataset(network, dataset_helper): |
|
|
|
if ms_role in ("MS_PSERVER", "MS_SCHED"): |
|
|
|
return network |
|
|
|
|
|
|
|
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ |
|
|
|
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ |
|
|
|
and context.get_context("device_target") == "Ascend" \ |
|
|
|
and context.get_context("mode") == context.GRAPH_MODE \ |
|
|
|
and ms_role != "MS_WORKER": |
|
|
|
queue_name = dataset.__transfer_dataset__.queue_name |
|
|
|
if hasattr(dataset_iter, "sink_size") and \ |
|
|
|
dataset_iter.sink_size == 1 and \ |
|
|
|
hasattr(dataset_iter, "sink_count") and \ |
|
|
|
dataset_iter.sink_count == 1 and \ |
|
|
|
context.get_context("device_target") == "Ascend" and \ |
|
|
|
context.get_context("mode") == context.GRAPH_MODE and \ |
|
|
|
ms_role != "MS_WORKER": |
|
|
|
|
|
|
|
if not hasattr(dataset_iter, '__network__'): |
|
|
|
dataset_iter.__network__ = network |
|
|
|
@@ -118,21 +129,19 @@ def connect_network_with_dataset(network, dataset_helper): |
|
|
|
if _need_to_full(): |
|
|
|
device_num = _get_device_num() |
|
|
|
dataset_shapes = _to_full_shapes(dataset_shapes, device_num) |
|
|
|
network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name) |
|
|
|
|
|
|
|
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name) |
|
|
|
dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr( |
|
|
|
dataset_iter, '__network_manage__') else dict() |
|
|
|
dataset_iter.__network_manage__[key] = network |
|
|
|
|
|
|
|
return network |
|
|
|
|
|
|
|
if not hasattr(dataset, '__me_inited__') and context.get_context("device_target") in ("Ascend", "GPU")\ |
|
|
|
and not context.get_context("enable_ge"): |
|
|
|
if not hasattr(dataset, '__me_inited__') and \ |
|
|
|
not context.get_context("enable_ge") and \ |
|
|
|
context.get_context("device_target") in ("Ascend", "GPU"): |
|
|
|
dataset.__me_inited__ = True |
|
|
|
|
|
|
|
dataset_types, dataset_shapes = dataset_helper.types_shapes() |
|
|
|
queue_name = dataset.__transfer_dataset__.queue_name |
|
|
|
|
|
|
|
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) |
|
|
|
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name) |
|
|
|
return network |
|
|
|
|
|
|
|
|
|
|
|
|