Browse Source

!5597 Modify model._train method

Merge pull request !5597 from liuyang/MT
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
1093c97f76
2 changed files with 2 additions and 4 deletions
  1. +2
    -2
      mindspore/train/model.py
  2. +0
    -2
      model_zoo/official/cv/lenet/train.py

+ 2
- 2
mindspore/train/model.py View File

@@ -378,8 +378,8 @@ class Model:
with _CallbackManager(callbacks) as list_callback:
if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("mode") == context.PYNATIVE_MODE:
logger.warning("The pynative mode cannot support dataset sink mode currently."
elif context.get_context("mode") == context.PYNATIVE_MODE or context.get_context("device_target") == "CPU":
logger.warning("The pynative mode and CPU cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params)
else:


+ 0
- 2
model_zoo/official/cv/lenet/train.py View File

@@ -44,8 +44,6 @@ if __name__ == "__main__":

args = parser.parse_args()

if args.device_target == "CPU":
args.dataset_sink_mode = False

context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"),


Loading…
Cancel
Save