Browse Source

dataset_sink_mode is supported in model.eval() and not in model.train() in pynative mode

tags/v0.2.0-alpha
guohongzilong 5 years ago
parent
commit
d8b9442ab8
1 changed files with 13 additions and 4 deletions
  1. +13
    -4
      mindspore/train/model.py

+ 13
- 4
mindspore/train/model.py View File

@@ -206,6 +206,8 @@ class Model:
function respectively.
callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with
dataset not sink.
"""
epoch = check_int_positive(epoch)
self._train_network.set_train()
@@ -227,8 +229,13 @@ class Model:
cb_params.train_dataset = train_dataset
cb_params.list_callback = list_callback

if dataset_sink_mode and context.get_context("mode") == context.GRAPH_MODE:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
if dataset_sink_mode:
if context.get_context("mode") == context.PYNATIVE_MODE:
logger.warning("The pynative mode cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_process(epoch, train_dataset, list_callback, cb_params)

@@ -349,7 +356,7 @@ class Model:
"""
Training API where the iteration is controlled by python front-end.

Configure to pynative mode, the training will be performed with dataset non-sink mode.
When setting pynative mode, the training process will be performed with dataset not sink.

Note:
CPU is not supported when dataset_sink_mode is true.
@@ -363,6 +370,8 @@ class Model:
function respectively.
callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with
dataset not sink.


Examples:
@@ -508,7 +517,7 @@ class Model:

self._clear_metrics()

if dataset_sink_mode and context.get_context("mode") == context.GRAPH_MODE:
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)



Loading…
Cancel
Save