|
|
|
@@ -39,6 +39,22 @@ def _send_data_no_flag(dataset, epoch_num): |
|
|
|
exec_dataset.send(epoch_num) |
|
|
|
|
|
|
|
|
|
|
|
def _dynamic_sink_scenario(dataset, dataset_iter): |
|
|
|
"""Special scenario with dynamic shape and sink_size=1.""" |
|
|
|
flag = False |
|
|
|
ms_role = os.getenv("MS_ROLE") |
|
|
|
if hasattr(dataset_iter, "sink_size") and \ |
|
|
|
dataset_iter.sink_size == 1 and \ |
|
|
|
dataset.get_dataset_size() != 1 and \ |
|
|
|
hasattr(dataset_iter, "sink_count") and \ |
|
|
|
dataset_iter.sink_count == 1 and \ |
|
|
|
context.get_context("device_target") == "Ascend" and \ |
|
|
|
context.get_context("mode") == context.GRAPH_MODE and \ |
|
|
|
ms_role != "MS_WORKER": |
|
|
|
flag = True |
|
|
|
return flag |
|
|
|
|
|
|
|
|
|
|
|
class _DataWrapper(nn.Cell): |
|
|
|
""" |
|
|
|
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the |
|
|
|
@@ -107,14 +123,7 @@ def connect_network_with_dataset(network, dataset_helper): |
|
|
|
return network |
|
|
|
|
|
|
|
queue_name = dataset.__transfer_dataset__.queue_name |
|
|
|
if hasattr(dataset_iter, "sink_size") and \ |
|
|
|
dataset_iter.sink_size == 1 and \ |
|
|
|
hasattr(dataset_iter, "sink_count") and \ |
|
|
|
dataset_iter.sink_count == 1 and \ |
|
|
|
context.get_context("device_target") == "Ascend" and \ |
|
|
|
context.get_context("mode") == context.GRAPH_MODE and \ |
|
|
|
ms_role != "MS_WORKER": |
|
|
|
|
|
|
|
if _dynamic_sink_scenario(dataset, dataset_iter): |
|
|
|
if not hasattr(dataset_iter, '__network__'): |
|
|
|
dataset_iter.__network__ = network |
|
|
|
network = dataset_iter.__network__ |
|
|
|
|