Browse Source

!5597 Modify model._train method

Merge pull request !5597 from liuyang/MT
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
1093c97f76
2 changed files with 2 additions and 4 deletions
  1. +2
    -2
      mindspore/train/model.py
  2. +0
    -2
      model_zoo/official/cv/lenet/train.py

+ 2
- 2
mindspore/train/model.py View File

@@ -378,8 +378,8 @@ class Model:
with _CallbackManager(callbacks) as list_callback: with _CallbackManager(callbacks) as list_callback:
if not dataset_sink_mode: if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params) self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("mode") == context.PYNATIVE_MODE:
logger.warning("The pynative mode cannot support dataset sink mode currently."
elif context.get_context("mode") == context.PYNATIVE_MODE or context.get_context("device_target") == "CPU":
logger.warning("The pynative mode and CPU cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.") "So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params) self._train_process(epoch, train_dataset, list_callback, cb_params)
else: else:


+ 0
- 2
model_zoo/official/cv/lenet/train.py View File

@@ -44,8 +44,6 @@ if __name__ == "__main__":


args = parser.parse_args() args = parser.parse_args()


if args.device_target == "CPU":
args.dataset_sink_mode = False


context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"), ds_train = create_dataset(os.path.join(args.data_path, "train"),


Loading…
Cancel
Save