| @@ -93,8 +93,9 @@ def connect_network_with_dataset(network, dataset_helper): | |||||
| raise RuntimeError("Dataset should be connected with network only in sink mode.") | raise RuntimeError("Dataset should be connected with network only in sink mode.") | ||||
| if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ | if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ | ||||
| and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ | |||||
| and context.get_context("device_target") == "Ascend": | |||||
| and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ | |||||
| and context.get_context("device_target") == "Ascend" \ | |||||
| and context.get_context("mode") == context.GRAPH_MODE: | |||||
| if not hasattr(dataset, '__network__'): | if not hasattr(dataset, '__network__'): | ||||
| dataset.__network__ = network | dataset.__network__ = network | ||||
| @@ -206,6 +207,7 @@ class DatasetHelper: | |||||
| def get_data_info(self): | def get_data_info(self): | ||||
| return self.iter.get_data_info() | return self.iter.get_data_info() | ||||
| class _DatasetIter: | class _DatasetIter: | ||||
| """Base iter for dataset helper""" | """Base iter for dataset helper""" | ||||
| @@ -286,6 +288,7 @@ class _DatasetIterGE(_DatasetIter): | |||||
| self.op = op | self.op = op | ||||
| class _DatasetIterPyNative(_DatasetIter): | class _DatasetIterPyNative(_DatasetIter): | ||||
| """Iter for MS(enable_loop_sink=False).""" | """Iter for MS(enable_loop_sink=False).""" | ||||
| @@ -301,6 +304,7 @@ class _DatasetIterPyNative(_DatasetIter): | |||||
| self.op = op | self.op = op | ||||
| class _DatasetIterMSLoopSink(_DatasetIter): | class _DatasetIterMSLoopSink(_DatasetIter): | ||||
| """Iter for context (device_target=Ascend)""" | """Iter for context (device_target=Ascend)""" | ||||
| @@ -354,6 +358,7 @@ class _DatasetIterPSLite(_DatasetIter): | |||||
| class _DatasetIterNormal: | class _DatasetIterNormal: | ||||
| """Iter for normal(non sink) mode, feed the data from host.""" | """Iter for normal(non sink) mode, feed the data from host.""" | ||||
| def __init__(self, dataset, epoch_num=-1): | def __init__(self, dataset, epoch_num=-1): | ||||
| self.dataset = dataset | self.dataset = dataset | ||||
| self.device_num = _get_device_num() | self.device_num = _get_device_num() | ||||