diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 9335cc838d..56096d8dac 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -157,6 +157,12 @@ class Model: if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: raise ValueError(f"Unsupported arg '{arg}'") + def _check_reuse_dataset(self, dataset): + if not hasattr(dataset, '__model_hash__'): + dataset.__model_hash__ = hash(self) + if hasattr(dataset, '__model_hash__') and dataset.__model_hash__ != hash(self): + raise RuntimeError('The Dataset cannot be bound to different models, please create a new dataset.') + def _build_train_network(self): """Build train network""" network = self._network @@ -388,6 +394,7 @@ class Model: "So the training process will be performed with dataset not sink.") self._train_process(epoch, train_dataset, list_callback, cb_params) else: + self._check_reuse_dataset(train_dataset) self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size) @staticmethod