| @@ -27,13 +27,13 @@ from ..ops import operations as P | |||||
| def _send_data(dataset, epoch_num): | def _send_data(dataset, epoch_num): | ||||
| """Engine dataset to write data to tdt queue.""" | """Engine dataset to write data to tdt queue.""" | ||||
| if not hasattr(dataset, '__has_sent__'): | if not hasattr(dataset, '__has_sent__'): | ||||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||||
| exec_dataset = dataset.__transfer_dataset__ | |||||
| exec_dataset.send(epoch_num) | exec_dataset.send(epoch_num) | ||||
| dataset.__has_sent__ = True | dataset.__has_sent__ = True | ||||
| def _send_data_no_flag(dataset, epoch_num): | def _send_data_no_flag(dataset, epoch_num): | ||||
| """Engine dataset to write data to tdt queue directly.""" | """Engine dataset to write data to tdt queue directly.""" | ||||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||||
| exec_dataset = dataset.__transfer_dataset__ | |||||
| exec_dataset.send(epoch_num) | exec_dataset.send(epoch_num) | ||||
| @@ -88,11 +88,13 @@ def connect_network_with_dataset(network, dataset_helper): | |||||
| if isinstance(dataset_iter, _DatasetIterNormal): | if isinstance(dataset_iter, _DatasetIterNormal): | ||||
| raise RuntimeError("Dataset should be connected with network only in sink mode.") | raise RuntimeError("Dataset should be connected with network only in sink mode.") | ||||
| if not hasattr(dataset, '__ME_INITED__') and (context.get_context("device_target") == "Ascend" \ | |||||
| or context.get_context("device_target") == "GPU") and not context.get_context("enable_ge"): | |||||
| dataset.__ME_INITED__ = True | |||||
| if not hasattr(dataset, '__me_inited__') and (context.get_context("device_target") == "Ascend" | |||||
| or context.get_context("device_target") == "GPU") and not \ | |||||
| context.get_context("enable_ge"): | |||||
| dataset.__me_inited__ = True | |||||
| dataset_types, dataset_shapes = dataset_helper.types_shapes() | dataset_types, dataset_shapes = dataset_helper.types_shapes() | ||||
| queue_name = dataset.__TRANSFER_DATASET__.queue_name | |||||
| queue_name = dataset.__transfer_dataset__.queue_name | |||||
| network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) | network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) | ||||
| return network | return network | ||||
| @@ -175,18 +177,18 @@ class _DatasetIter: | |||||
| self.sink_size = sink_size | self.sink_size = sink_size | ||||
| self.sink_count = 1 | self.sink_count = 1 | ||||
| if not hasattr(dataset, '__TRANSFER_DATASET__'): | |||||
| if not hasattr(dataset, '__transfer_dataset__'): | |||||
| if hasattr(dataset, '__loop_size__'): | if hasattr(dataset, '__loop_size__'): | ||||
| self.sink_size = dataset.__loop_size__ | self.sink_size = dataset.__loop_size__ | ||||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) | |||||
| dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size) | |||||
| if not hasattr(dataset, '__no_send__'): | if not hasattr(dataset, '__no_send__'): | ||||
| _send_data(dataset, epoch_num) | _send_data(dataset, epoch_num) | ||||
| else: | else: | ||||
| _send_data_no_flag(dataset, epoch_num) | _send_data_no_flag(dataset, epoch_num) | ||||
| self.stop_send = dataset.__TRANSFER_DATASET__.stop_send | |||||
| self.continue_send = dataset.__TRANSFER_DATASET__.continue_send | |||||
| self.stop_send = dataset.__transfer_dataset__.stop_send | |||||
| self.continue_send = dataset.__transfer_dataset__.continue_send | |||||
| self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | ||||
| def __iter__(self): | def __iter__(self): | ||||
| @@ -273,7 +275,7 @@ class _DatasetIterMS(_DatasetIter): | |||||
| else: | else: | ||||
| self.sink_count = dataset.get_dataset_size() | self.sink_count = dataset.get_dataset_size() | ||||
| queue_name = dataset.__TRANSFER_DATASET__.queue_name | |||||
| queue_name = dataset.__transfer_dataset__.queue_name | |||||
| self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) | self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) | ||||
| @@ -25,14 +25,14 @@ from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_s | |||||
| def _send_data(dataset, epoch_num): | def _send_data(dataset, epoch_num): | ||||
| """Engine dataset to write data to tdt queue.""" | """Engine dataset to write data to tdt queue.""" | ||||
| if not hasattr(dataset, '__has_sent__'): | if not hasattr(dataset, '__has_sent__'): | ||||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||||
| exec_dataset = dataset.__transfer_dataset__ | |||||
| exec_dataset.send(epoch_num) | exec_dataset.send(epoch_num) | ||||
| dataset.__has_sent__ = True | dataset.__has_sent__ = True | ||||
| def _send_data_no_flag(dataset, epoch_num): | def _send_data_no_flag(dataset, epoch_num): | ||||
| """Engine dataset to write data to tdt queue directly.""" | """Engine dataset to write data to tdt queue directly.""" | ||||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||||
| exec_dataset = dataset.__transfer_dataset__ | |||||
| exec_dataset.send(epoch_num) | exec_dataset.send(epoch_num) | ||||
| @@ -100,17 +100,17 @@ class _DatasetIter: | |||||
| self.sink_size = sink_size | self.sink_size = sink_size | ||||
| self.sink_count = 1 | self.sink_count = 1 | ||||
| if not hasattr(dataset, '__TRANSFER_DATASET__'): | |||||
| if not hasattr(dataset, '__transfer_dataset__'): | |||||
| if hasattr(dataset, '__loop_size__'): | if hasattr(dataset, '__loop_size__'): | ||||
| self.sink_size = dataset.__loop_size__ | self.sink_size = dataset.__loop_size__ | ||||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) | |||||
| dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size) | |||||
| if not hasattr(dataset, '__no_send__'): | if not hasattr(dataset, '__no_send__'): | ||||
| _send_data(dataset, epoch_num) | _send_data(dataset, epoch_num) | ||||
| else: | else: | ||||
| _send_data_no_flag(dataset, epoch_num) | _send_data_no_flag(dataset, epoch_num) | ||||
| self.stop_send = dataset.__TRANSFER_DATASET__.stop_send | |||||
| self.stop_send = dataset.__transfer_dataset__.stop_send | |||||
| self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | ||||
| def __iter__(self): | def __iter__(self): | ||||
| @@ -187,5 +187,5 @@ class _DatasetIterMS(_DatasetIter): | |||||
| else: | else: | ||||
| self.sink_count = dataset.get_dataset_size() | self.sink_count = dataset.get_dataset_size() | ||||
| queue_name = dataset.__TRANSFER_DATASET__.queue_name | |||||
| queue_name = dataset.__transfer_dataset__.queue_name | |||||
| self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) | self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) | ||||
| @@ -57,7 +57,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): | |||||
| # transform data format | # transform data format | ||||
| dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) | dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) | ||||
| init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name, | |||||
| init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name, | |||||
| dataset_size, | dataset_size, | ||||
| batch_size, | batch_size, | ||||
| dataset_types, | dataset_types, | ||||
| @@ -24,14 +24,14 @@ from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes | |||||
| def _send_data(dataset, epoch_num): | def _send_data(dataset, epoch_num): | ||||
| """Engine dataset to write data to tdt queue.""" | """Engine dataset to write data to tdt queue.""" | ||||
| if not hasattr(dataset, '__has_sent__'): | if not hasattr(dataset, '__has_sent__'): | ||||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||||
| exec_dataset = dataset.__transfer_dataset__ | |||||
| exec_dataset.send(epoch_num) | exec_dataset.send(epoch_num) | ||||
| dataset.__has_sent__ = True | dataset.__has_sent__ = True | ||||
| def _send_data_no_flag(dataset, epoch_num): | def _send_data_no_flag(dataset, epoch_num): | ||||
| """Engine dataset to write data to tdt queue directly.""" | """Engine dataset to write data to tdt queue directly.""" | ||||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||||
| exec_dataset = dataset.__transfer_dataset__ | |||||
| exec_dataset.send(epoch_num) | exec_dataset.send(epoch_num) | ||||
| @@ -107,17 +107,17 @@ class _DatasetIter: | |||||
| self.sink_size = sink_size | self.sink_size = sink_size | ||||
| self.sink_count = 1 | self.sink_count = 1 | ||||
| if not hasattr(dataset, '__TRANSFER_DATASET__'): | |||||
| if not hasattr(dataset, '__transfer_dataset__'): | |||||
| if hasattr(dataset, '__loop_size__'): | if hasattr(dataset, '__loop_size__'): | ||||
| self.sink_size = dataset.__loop_size__ | self.sink_size = dataset.__loop_size__ | ||||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) | |||||
| dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size) | |||||
| if not hasattr(dataset, '__no_send__'): | if not hasattr(dataset, '__no_send__'): | ||||
| _send_data(dataset, epoch_num) | _send_data(dataset, epoch_num) | ||||
| else: | else: | ||||
| _send_data_no_flag(dataset, epoch_num) | _send_data_no_flag(dataset, epoch_num) | ||||
| self.stop_send = dataset.__TRANSFER_DATASET__.stop_send | |||||
| self.stop_send = dataset.__transfer_dataset__.stop_send | |||||
| self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | ||||
| def __iter__(self): | def __iter__(self): | ||||
| @@ -71,7 +71,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): | |||||
| # transform data format | # transform data format | ||||
| dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) | dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) | ||||
| init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name, | |||||
| init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name, | |||||
| dataset_size, | dataset_size, | ||||
| batch_size, | batch_size, | ||||
| dataset_types, | dataset_types, | ||||
| @@ -22,7 +22,7 @@ from mindspore.context import ParallelMode | |||||
| def _send_data(dataset): | def _send_data(dataset): | ||||
| """Engine dataset to write data to tdt queue.""" | """Engine dataset to write data to tdt queue.""" | ||||
| if not hasattr(dataset, '__has_sent__'): | if not hasattr(dataset, '__has_sent__'): | ||||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||||
| exec_dataset = dataset.__transfer_dataset__ | |||||
| exec_dataset.send() | exec_dataset.send() | ||||
| dataset.__has_sent__ = True | dataset.__has_sent__ = True | ||||
| @@ -71,12 +71,12 @@ class _DatasetIter: | |||||
| def __init__(self, dataset): | def __init__(self, dataset): | ||||
| self.loop_size = 1 | self.loop_size = 1 | ||||
| if not hasattr(dataset, '__TRANSFER_DATASET__'): | |||||
| if not hasattr(dataset, '__transfer_dataset__'): | |||||
| if not hasattr(dataset, '__loop_size__'): | if not hasattr(dataset, '__loop_size__'): | ||||
| self.loop_size = dataset.get_dataset_size() | self.loop_size = dataset.get_dataset_size() | ||||
| else: | else: | ||||
| self.loop_size = dataset.__loop_size__ | self.loop_size = dataset.__loop_size__ | ||||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) | |||||
| dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.loop_size) | |||||
| if not hasattr(dataset, '__no_send__'): | if not hasattr(dataset, '__no_send__'): | ||||
| _send_data(dataset) | _send_data(dataset) | ||||
| @@ -67,7 +67,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): | |||||
| # transform data format | # transform data format | ||||
| dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) | dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) | ||||
| init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name, | |||||
| init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name, | |||||
| dataset_size, | dataset_size, | ||||
| batch_size, | batch_size, | ||||
| dataset_types, | dataset_types, | ||||