diff --git a/mindspore/train/model.py b/mindspore/train/model.py index a439951b6d..cc5e857a66 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -420,7 +420,7 @@ class Model: # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: - if _need_to_full(): + if _need_to_full() and context.get_context("device_target") == "GPU": inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) list_callback.step_begin(run_context) outputs = self._train_network(*inputs)