Merge pull request !2272 from wangnan39/remove_dataset_send_from_model_inittags/v0.5.0-beta
| @@ -15,6 +15,7 @@ | |||||
| """Dataset help for minddata dataset""" | """Dataset help for minddata dataset""" | ||||
| from mindspore._checkparam import check_bool | from mindspore._checkparam import check_bool | ||||
| from mindspore.parallel._utils import _get_device_num, _get_parallel_mode | from mindspore.parallel._utils import _get_device_num, _get_parallel_mode | ||||
| from mindspore.train.dataset_helper import _send_data | |||||
| from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ | from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ | ||||
| _to_full_shapes | _to_full_shapes | ||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| @@ -67,7 +68,13 @@ class _DatasetIter: | |||||
| self.loop_size = dataset.get_dataset_size() | self.loop_size = dataset.get_dataset_size() | ||||
| else: | else: | ||||
| self.loop_size = dataset.__loop_size__ | self.loop_size = dataset.__loop_size__ | ||||
| dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name | |||||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) | |||||
| dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name | |||||
| if not hasattr(dataset, '__no_send__'): | |||||
| _send_data(dataset) | |||||
| else: | |||||
| _send_data(dataset) | |||||
| self.ind = 0 | self.ind = 0 | ||||
| self.dataset = dataset | self.dataset = dataset | ||||
| @@ -16,11 +16,10 @@ | |||||
| import os | import os | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.dtype import dtype_to_nptype | |||||
| from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| from mindspore.common.dtype import pytype_to_dtype | |||||
| def _convert_type(types): | def _convert_type(types): | ||||
| @@ -64,8 +63,6 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): | |||||
| input_indexs, | input_indexs, | ||||
| phase=phase) | phase=phase) | ||||
| # engine dataset to write data to tdt queue | |||||
| exec_dataset.send() | |||||
| return exec_dataset | return exec_dataset | ||||
| @@ -23,6 +23,14 @@ from ..nn.wrap import GetNextSingleOp | |||||
| from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full | from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full | ||||
| def _send_data(dataset): | |||||
| """Engine dataset to write data to tdt queue.""" | |||||
| if not hasattr(dataset, '__has_sent__'): | |||||
| exec_dataset = dataset.__TRANSFER_DATASET__ | |||||
| exec_dataset.send() | |||||
| dataset.__has_sent__ = True | |||||
| class DatasetHelper: | class DatasetHelper: | ||||
| """ | """ | ||||
| Help function to use the Minddata dataset. | Help function to use the Minddata dataset. | ||||
| @@ -81,7 +89,13 @@ class _DatasetIter: | |||||
| self.loop_size = dataset.get_dataset_size() | self.loop_size = dataset.get_dataset_size() | ||||
| else: | else: | ||||
| self.loop_size = dataset.__loop_size__ | self.loop_size = dataset.__loop_size__ | ||||
| dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name | |||||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) | |||||
| dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name | |||||
| if not hasattr(dataset, '__no_send__'): | |||||
| _send_data(dataset) | |||||
| else: | |||||
| _send_data(dataset) | |||||
| self.ind = 0 | self.ind = 0 | ||||
| self.dataset = dataset | self.dataset = dataset | ||||
| @@ -285,7 +285,7 @@ class Model: | |||||
| if self._parameter_broadcast: | if self._parameter_broadcast: | ||||
| self._train_network.set_broadcast_flag() | self._train_network.set_broadcast_flag() | ||||
| train_dataset.__no_send__ = True | |||||
| train_dataset_helper, train_network = self._exec_preprocess(self._train_network, | train_dataset_helper, train_network = self._exec_preprocess(self._train_network, | ||||
| is_train=True, | is_train=True, | ||||
| phase='train', | phase='train', | ||||
| @@ -302,6 +302,7 @@ class Model: | |||||
| self._eval_network.set_train(False) | self._eval_network.set_train(False) | ||||
| self._eval_network.phase = 'eval' | self._eval_network.phase = 'eval' | ||||
| valid_dataset.__no_send__ = True | |||||
| valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, | valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, | ||||
| is_train=False, | is_train=False, | ||||
| phase='eval', | phase='eval', | ||||
| @@ -15,6 +15,7 @@ | |||||
| """Dataset help for minddata dataset""" | """Dataset help for minddata dataset""" | ||||
| from mindspore._checkparam import check_bool | from mindspore._checkparam import check_bool | ||||
| from mindspore.parallel._utils import _get_device_num, _get_parallel_mode | from mindspore.parallel._utils import _get_device_num, _get_parallel_mode | ||||
| from mindspore.train.dataset_helper import _send_data | |||||
| from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ | from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ | ||||
| _to_full_shapes | _to_full_shapes | ||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| @@ -69,7 +70,13 @@ class _DatasetIter: | |||||
| self.loop_size = dataset.get_dataset_size() | self.loop_size = dataset.get_dataset_size() | ||||
| else: | else: | ||||
| self.loop_size = dataset.__loop_size__ | self.loop_size = dataset.__loop_size__ | ||||
| dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name | |||||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) | |||||
| dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name | |||||
| if not hasattr(dataset, '__no_send__'): | |||||
| _send_data(dataset) | |||||
| else: | |||||
| _send_data(dataset) | |||||
| self.ind = 0 | self.ind = 0 | ||||
| self.dataset = dataset | self.dataset = dataset | ||||