|
|
|
@@ -93,8 +93,9 @@ def connect_network_with_dataset(network, dataset_helper): |
|
|
|
raise RuntimeError("Dataset should be connected with network only in sink mode.") |
|
|
|
|
|
|
|
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ |
|
|
|
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ |
|
|
|
and context.get_context("device_target") == "Ascend": |
|
|
|
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ |
|
|
|
and context.get_context("device_target") == "Ascend" \ |
|
|
|
and context.get_context("mode") == context.GRAPH_MODE: |
|
|
|
|
|
|
|
if not hasattr(dataset, '__network__'): |
|
|
|
dataset.__network__ = network |
|
|
|
@@ -206,6 +207,7 @@ class DatasetHelper: |
|
|
|
def get_data_info(self): |
|
|
|
return self.iter.get_data_info() |
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIter: |
|
|
|
"""Base iter for dataset helper""" |
|
|
|
|
|
|
|
@@ -286,6 +288,7 @@ class _DatasetIterGE(_DatasetIter): |
|
|
|
|
|
|
|
self.op = op |
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIterPyNative(_DatasetIter): |
|
|
|
"""Iter for MS(enable_loop_sink=False).""" |
|
|
|
|
|
|
|
@@ -301,6 +304,7 @@ class _DatasetIterPyNative(_DatasetIter): |
|
|
|
|
|
|
|
self.op = op |
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIterMSLoopSink(_DatasetIter): |
|
|
|
"""Iter for context (device_target=Ascend)""" |
|
|
|
|
|
|
|
@@ -354,6 +358,7 @@ class _DatasetIterPSLite(_DatasetIter): |
|
|
|
|
|
|
|
class _DatasetIterNormal: |
|
|
|
"""Iter for normal(non sink) mode, feed the data from host.""" |
|
|
|
|
|
|
|
def __init__(self, dataset, epoch_num=-1): |
|
|
|
self.dataset = dataset |
|
|
|
self.device_num = _get_device_num() |
|
|
|
|