Browse Source

fix_bug_for autoparallel gpu

tags/v0.7.0-beta
lichenever 5 years ago
parent
commit
ac043ed0d3
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      mindspore/train/model.py

+ 1
- 1
mindspore/train/model.py View File

@@ -420,7 +420,7 @@ class Model:


# for data sink dataset_helper only iter once, other wise iter epoch_size times. # for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper: for inputs in dataset_helper:
if _need_to_full():
if _need_to_full() and context.get_context("device_target") == "GPU":
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context) list_callback.step_begin(run_context)
outputs = self._train_network(*inputs) outputs = self._train_network(*inputs)


Loading…
Cancel
Save