| @@ -27,13 +27,13 @@ from ..ops import operations as P | |||
| def _send_data(dataset, epoch_num): | |||
| """Engine dataset to write data to tdt queue.""" | |||
| if not hasattr(dataset, '__has_sent__'): | |||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||
| exec_dataset = dataset.__transfer_dataset__ | |||
| exec_dataset.send(epoch_num) | |||
| dataset.__has_sent__ = True | |||
| def _send_data_no_flag(dataset, epoch_num): | |||
| """Engine dataset to write data to tdt queue directly.""" | |||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||
| exec_dataset = dataset.__transfer_dataset__ | |||
| exec_dataset.send(epoch_num) | |||
| @@ -88,11 +88,13 @@ def connect_network_with_dataset(network, dataset_helper): | |||
| if isinstance(dataset_iter, _DatasetIterNormal): | |||
| raise RuntimeError("Dataset should be connected with network only in sink mode.") | |||
| if not hasattr(dataset, '__ME_INITED__') and (context.get_context("device_target") == "Ascend" \ | |||
| or context.get_context("device_target") == "GPU") and not context.get_context("enable_ge"): | |||
| dataset.__ME_INITED__ = True | |||
| if not hasattr(dataset, '__me_inited__') and (context.get_context("device_target") == "Ascend" | |||
| or context.get_context("device_target") == "GPU") and not \ | |||
| context.get_context("enable_ge"): | |||
| dataset.__me_inited__ = True | |||
| dataset_types, dataset_shapes = dataset_helper.types_shapes() | |||
| queue_name = dataset.__TRANSFER_DATASET__.queue_name | |||
| queue_name = dataset.__transfer_dataset__.queue_name | |||
| network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) | |||
| return network | |||
| @@ -175,18 +177,18 @@ class _DatasetIter: | |||
| self.sink_size = sink_size | |||
| self.sink_count = 1 | |||
| if not hasattr(dataset, '__TRANSFER_DATASET__'): | |||
| if not hasattr(dataset, '__transfer_dataset__'): | |||
| if hasattr(dataset, '__loop_size__'): | |||
| self.sink_size = dataset.__loop_size__ | |||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) | |||
| dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size) | |||
| if not hasattr(dataset, '__no_send__'): | |||
| _send_data(dataset, epoch_num) | |||
| else: | |||
| _send_data_no_flag(dataset, epoch_num) | |||
| self.stop_send = dataset.__TRANSFER_DATASET__.stop_send | |||
| self.continue_send = dataset.__TRANSFER_DATASET__.continue_send | |||
| self.stop_send = dataset.__transfer_dataset__.stop_send | |||
| self.continue_send = dataset.__transfer_dataset__.continue_send | |||
| self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | |||
| def __iter__(self): | |||
| @@ -273,7 +275,7 @@ class _DatasetIterMS(_DatasetIter): | |||
| else: | |||
| self.sink_count = dataset.get_dataset_size() | |||
| queue_name = dataset.__TRANSFER_DATASET__.queue_name | |||
| queue_name = dataset.__transfer_dataset__.queue_name | |||
| self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) | |||
| @@ -25,14 +25,14 @@ from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_s | |||
| def _send_data(dataset, epoch_num): | |||
| """Engine dataset to write data to tdt queue.""" | |||
| if not hasattr(dataset, '__has_sent__'): | |||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||
| exec_dataset = dataset.__transfer_dataset__ | |||
| exec_dataset.send(epoch_num) | |||
| dataset.__has_sent__ = True | |||
| def _send_data_no_flag(dataset, epoch_num): | |||
| """Engine dataset to write data to tdt queue directly.""" | |||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||
| exec_dataset = dataset.__transfer_dataset__ | |||
| exec_dataset.send(epoch_num) | |||
| @@ -100,17 +100,17 @@ class _DatasetIter: | |||
| self.sink_size = sink_size | |||
| self.sink_count = 1 | |||
| if not hasattr(dataset, '__TRANSFER_DATASET__'): | |||
| if not hasattr(dataset, '__transfer_dataset__'): | |||
| if hasattr(dataset, '__loop_size__'): | |||
| self.sink_size = dataset.__loop_size__ | |||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) | |||
| dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size) | |||
| if not hasattr(dataset, '__no_send__'): | |||
| _send_data(dataset, epoch_num) | |||
| else: | |||
| _send_data_no_flag(dataset, epoch_num) | |||
| self.stop_send = dataset.__TRANSFER_DATASET__.stop_send | |||
| self.stop_send = dataset.__transfer_dataset__.stop_send | |||
| self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | |||
| def __iter__(self): | |||
| @@ -187,5 +187,5 @@ class _DatasetIterMS(_DatasetIter): | |||
| else: | |||
| self.sink_count = dataset.get_dataset_size() | |||
| queue_name = dataset.__TRANSFER_DATASET__.queue_name | |||
| queue_name = dataset.__transfer_dataset__.queue_name | |||
| self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) | |||
| @@ -57,7 +57,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): | |||
| # transform data format | |||
| dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) | |||
| init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name, | |||
| init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name, | |||
| dataset_size, | |||
| batch_size, | |||
| dataset_types, | |||
| @@ -24,14 +24,14 @@ from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes | |||
| def _send_data(dataset, epoch_num): | |||
| """Engine dataset to write data to tdt queue.""" | |||
| if not hasattr(dataset, '__has_sent__'): | |||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||
| exec_dataset = dataset.__transfer_dataset__ | |||
| exec_dataset.send(epoch_num) | |||
| dataset.__has_sent__ = True | |||
| def _send_data_no_flag(dataset, epoch_num): | |||
| """Engine dataset to write data to tdt queue directly.""" | |||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||
| exec_dataset = dataset.__transfer_dataset__ | |||
| exec_dataset.send(epoch_num) | |||
| @@ -107,17 +107,17 @@ class _DatasetIter: | |||
| self.sink_size = sink_size | |||
| self.sink_count = 1 | |||
| if not hasattr(dataset, '__TRANSFER_DATASET__'): | |||
| if not hasattr(dataset, '__transfer_dataset__'): | |||
| if hasattr(dataset, '__loop_size__'): | |||
| self.sink_size = dataset.__loop_size__ | |||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) | |||
| dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size) | |||
| if not hasattr(dataset, '__no_send__'): | |||
| _send_data(dataset, epoch_num) | |||
| else: | |||
| _send_data_no_flag(dataset, epoch_num) | |||
| self.stop_send = dataset.__TRANSFER_DATASET__.stop_send | |||
| self.stop_send = dataset.__transfer_dataset__.stop_send | |||
| self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | |||
| def __iter__(self): | |||
| @@ -71,7 +71,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): | |||
| # transform data format | |||
| dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) | |||
| init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name, | |||
| init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name, | |||
| dataset_size, | |||
| batch_size, | |||
| dataset_types, | |||
| @@ -22,7 +22,7 @@ from mindspore.context import ParallelMode | |||
| def _send_data(dataset): | |||
| """Engine dataset to write data to tdt queue.""" | |||
| if not hasattr(dataset, '__has_sent__'): | |||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||
| exec_dataset = dataset.__transfer_dataset__ | |||
| exec_dataset.send() | |||
| dataset.__has_sent__ = True | |||
| @@ -71,12 +71,12 @@ class _DatasetIter: | |||
| def __init__(self, dataset): | |||
| self.loop_size = 1 | |||
| if not hasattr(dataset, '__TRANSFER_DATASET__'): | |||
| if not hasattr(dataset, '__transfer_dataset__'): | |||
| if not hasattr(dataset, '__loop_size__'): | |||
| self.loop_size = dataset.get_dataset_size() | |||
| else: | |||
| self.loop_size = dataset.__loop_size__ | |||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) | |||
| dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.loop_size) | |||
| if not hasattr(dataset, '__no_send__'): | |||
| _send_data(dataset) | |||
| @@ -67,7 +67,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): | |||
| # transform data format | |||
| dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) | |||
| init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name, | |||
| init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name, | |||
| dataset_size, | |||
| batch_size, | |||
| dataset_types, | |||