Merge pull request !1276 from 李鸿章/context_managertags/v0.5.0-beta
| @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell | |||||
| from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | ||||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | ||||
| from mindspore.train import amp | from mindspore.train import amp | ||||
| from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks | |||||
| from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager | |||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| from model.dataset_helper import DatasetHelper | from model.dataset_helper import DatasetHelper | ||||
| @@ -374,7 +374,6 @@ class Model: | |||||
| self._train_network.set_broadcast_flag() | self._train_network.set_broadcast_flag() | ||||
| # build callback list | # build callback list | ||||
| list_callback = _build_callbacks(callbacks) | |||||
| cb_params = _InternalCallbackParam() | cb_params = _InternalCallbackParam() | ||||
| cb_params.train_network = self._train_network | cb_params.train_network = self._train_network | ||||
| cb_params.epoch_num = epoch | cb_params.epoch_num = epoch | ||||
| @@ -385,17 +384,17 @@ class Model: | |||||
| cb_params.parallel_mode = self._parallel_mode | cb_params.parallel_mode = self._parallel_mode | ||||
| cb_params.device_number = self._device_number | cb_params.device_number = self._device_number | ||||
| cb_params.train_dataset = train_dataset | cb_params.train_dataset = train_dataset | ||||
| cb_params.list_callback = list_callback | |||||
| cb_params.list_callback = callbacks | |||||
| if dataset_sink_mode: | |||||
| if context.get_context("mode") == context.PYNATIVE_MODE: | |||||
| with _CallbackManager(callbacks) as list_callback: | |||||
| if not dataset_sink_mode: | |||||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||||
| elif context.get_context("mode") == context.PYNATIVE_MODE: | |||||
| logger.warning("The pynative mode cannot support dataset sink mode currently." | logger.warning("The pynative mode cannot support dataset sink mode currently." | ||||
| "So the training process will be performed with dataset not sink.") | "So the training process will be performed with dataset not sink.") | ||||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | self._train_process(epoch, train_dataset, list_callback, cb_params) | ||||
| else: | else: | ||||
| self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) | self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) | ||||
| else: | |||||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||||
| def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): | def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): | ||||
| """ | """ | ||||
| @@ -408,7 +407,7 @@ class Model: | |||||
| returned and passed to the network. Otherwise, a tuple (data, label) should | returned and passed to the network. Otherwise, a tuple (data, label) should | ||||
| be returned, and the data and label are passed to the network and loss | be returned, and the data and label are passed to the network and loss | ||||
| function respectively. | function respectively. | ||||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||||
| list_callback (Callback): Executor of callback list. Default: None. | |||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| """ | """ | ||||
| iter_first_order = self._frequency - 1 | iter_first_order = self._frequency - 1 | ||||
| @@ -473,7 +472,7 @@ class Model: | |||||
| returned and passed to the network. Otherwise, a tuple (data, label) should | returned and passed to the network. Otherwise, a tuple (data, label) should | ||||
| be returned, and the data and label are passed to the network and loss | be returned, and the data and label are passed to the network and loss | ||||
| function respectively. | function respectively. | ||||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||||
| list_callback (Callback): Executor of callback list. Default: None. | |||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| """ | """ | ||||
| dataset_helper, _ = self._exec_preprocess(self._train_network, | dataset_helper, _ = self._exec_preprocess(self._train_network, | ||||
| @@ -580,7 +579,7 @@ class Model: | |||||
| Args: | Args: | ||||
| valid_dataset (Dataset): Dataset to evaluate the model. | valid_dataset (Dataset): Dataset to evaluate the model. | ||||
| list_callback (ListCallback): Executor of callback list. Default: None. | |||||
| list_callback (Callback): Executor of callback list. Default: None. | |||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| Returns: | Returns: | ||||
| @@ -619,7 +618,7 @@ class Model: | |||||
| Args: | Args: | ||||
| valid_dataset (Dataset): Dataset to evaluate the model. | valid_dataset (Dataset): Dataset to evaluate the model. | ||||
| list_callback (ListCallback): Executor of callback list. Default: None. | |||||
| list_callback (Callback): Executor of callback list. Default: None. | |||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| Returns: | Returns: | ||||
| @@ -678,7 +677,6 @@ class Model: | |||||
| if not self._metric_fns: | if not self._metric_fns: | ||||
| raise ValueError("metric fn can not be None or empty.") | raise ValueError("metric fn can not be None or empty.") | ||||
| list_callback = _build_callbacks(callbacks) | |||||
| cb_params = _InternalCallbackParam() | cb_params = _InternalCallbackParam() | ||||
| cb_params.eval_network = self._eval_network | cb_params.eval_network = self._eval_network | ||||
| cb_params.valid_dataset = valid_dataset | cb_params.valid_dataset = valid_dataset | ||||
| @@ -691,9 +689,10 @@ class Model: | |||||
| self._clear_metrics() | self._clear_metrics() | ||||
| if dataset_sink_mode: | |||||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||||
| with _CallbackManager(callbacks) as list_callback: | |||||
| if dataset_sink_mode: | |||||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||||
| def predict(self, *predict_data): | def predict(self, *predict_data): | ||||
| """ | """ | ||||
| @@ -18,6 +18,7 @@ import os | |||||
| import stat | import stat | ||||
| import shutil | import shutil | ||||
| import time | import time | ||||
| from contextlib import ExitStack | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| @@ -282,80 +283,11 @@ def _summary_cb_for_save_op(summary_list): | |||||
| return ret | return ret | ||||
| def _build_callbacks(callbacks): | |||||
| """ | |||||
| Contain a list of callback. | |||||
| Args: | |||||
| callbacks (list): Callback functions list, Support None, a single Callback object, or a list. | |||||
| Returns: | |||||
| List, a list of callback functions. | |||||
| """ | |||||
| if callbacks: | |||||
| if isinstance(callbacks, tuple): | |||||
| raise TypeError("Callbacks cannot be a tuple. Please check it.") | |||||
| if not isinstance(callbacks, list): | |||||
| callbacks = [callbacks] | |||||
| else: | |||||
| callbacks = [] | |||||
| excute_callbacks = [] | |||||
| for cb in callbacks: | |||||
| if cb is None or not isinstance(cb, Callback): | |||||
| raise TypeError("Callback must inheriting base class Callback. Some callback is Wrong. Please check it.") | |||||
| excute_callbacks.append(cb) | |||||
| return _ListCallback(excute_callbacks) | |||||
| class _ListCallback: | |||||
| """ | |||||
| Sequential execution of callback functions. | |||||
| Execute Callback functions at certain points. | |||||
| Args: | |||||
| callbacks (list): Callback functions list. | |||||
| """ | |||||
| def __init__(self, callbacks): | |||||
| super(_ListCallback, self).__init__() | |||||
| self._callbacks = callbacks | |||||
| def begin(self, run_context): | |||||
| """Called once before network training.""" | |||||
| for cb in self._callbacks: | |||||
| cb.begin(run_context) | |||||
| def epoch_begin(self, run_context): | |||||
| """Called before each epoch begin.""" | |||||
| for cb in self._callbacks: | |||||
| cb.epoch_begin(run_context) | |||||
| def epoch_end(self, run_context): | |||||
| """Called after each epoch finished.""" | |||||
| for cb in self._callbacks: | |||||
| cb.epoch_end(run_context) | |||||
| def step_begin(self, run_context): | |||||
| """Called before each epoch begin.""" | |||||
| for cb in self._callbacks: | |||||
| cb.step_begin(run_context) | |||||
| def step_end(self, run_context): | |||||
| """Called after each step finished.""" | |||||
| for cb in self._callbacks: | |||||
| cb.step_end(run_context) | |||||
| def end(self, run_context): | |||||
| """Called once after network training.""" | |||||
| for cb in self._callbacks: | |||||
| cb.end(run_context) | |||||
| class Callback: | class Callback: | ||||
| """ | """ | ||||
| Abstract base class used to build a callback function. | |||||
| Abstract base class used to build a callback class. Callbacks are context managers | |||||
| which will be entered and exited when passing into the Model. | |||||
| You can leverage this mechanism to init and release resources automatically. | |||||
| Callback function will execution some operating to the current step or epoch. | Callback function will execution some operating to the current step or epoch. | ||||
| @@ -369,8 +301,13 @@ class Callback: | |||||
| >>> print_cb = Print_info() | >>> print_cb = Print_info() | ||||
| >>> model.train(epoch, dataset, callbacks=print_cb) | >>> model.train(epoch, dataset, callbacks=print_cb) | ||||
| """ | """ | ||||
| def __init__(self): | |||||
| pass | |||||
| def __enter__(self): | |||||
| """Return the enter target.""" | |||||
| return self | |||||
| def __exit__(self, *err): | |||||
| """Release resources here if have any.""" | |||||
| def begin(self, run_context): | def begin(self, run_context): | ||||
| """ | """ | ||||
| @@ -421,6 +358,67 @@ class Callback: | |||||
| """ | """ | ||||
| class _CallbackManager(Callback): | |||||
| """ | |||||
| Sequential execution of callback functions. | |||||
| Execute Callback functions at certain points. | |||||
| Args: | |||||
| callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list. | |||||
| """ | |||||
| def __init__(self, callbacks): | |||||
| self._callbacks, self._stack = [], None | |||||
| if isinstance(callbacks, Callback): | |||||
| self._callbacks.append(callbacks) | |||||
| elif callbacks is not None: | |||||
| for cb in callbacks: | |||||
| if not isinstance(cb, Callback): | |||||
| raise TypeError("%r is not an instance of %r" % (cb, Callback)) | |||||
| self._callbacks.append(cb) | |||||
| def __enter__(self): | |||||
| if self._stack is None: | |||||
| self._stack = ExitStack().__enter__() | |||||
| self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks] | |||||
| return self | |||||
| def __exit__(self, *err): | |||||
| return self._stack.__exit__(*err) | |||||
| def begin(self, run_context): | |||||
| """Called once before network training.""" | |||||
| for cb in self._callbacks: | |||||
| cb.begin(run_context) | |||||
| def epoch_begin(self, run_context): | |||||
| """Called before each epoch begin.""" | |||||
| for cb in self._callbacks: | |||||
| cb.epoch_begin(run_context) | |||||
| def epoch_end(self, run_context): | |||||
| """Called after each epoch finished.""" | |||||
| for cb in self._callbacks: | |||||
| cb.epoch_end(run_context) | |||||
| def step_begin(self, run_context): | |||||
| """Called before each epoch begin.""" | |||||
| for cb in self._callbacks: | |||||
| cb.step_begin(run_context) | |||||
| def step_end(self, run_context): | |||||
| """Called after each step finished.""" | |||||
| for cb in self._callbacks: | |||||
| cb.step_end(run_context) | |||||
| def end(self, run_context): | |||||
| """Called once after network training.""" | |||||
| for cb in self._callbacks: | |||||
| cb.end(run_context) | |||||
| class SummaryStep(Callback): | class SummaryStep(Callback): | ||||
| """ | """ | ||||
| The summary callback class. | The summary callback class. | ||||
| @@ -435,6 +433,13 @@ class SummaryStep(Callback): | |||||
| raise ValueError("`flush_step` should be int and greater than 0") | raise ValueError("`flush_step` should be int and greater than 0") | ||||
| self._summary = summary | self._summary = summary | ||||
| self._flush_step = flush_step | self._flush_step = flush_step | ||||
| def __enter__(self): | |||||
| self._summary.__enter__() | |||||
| return self | |||||
| def __exit__(self, *err): | |||||
| return self._summary.__exit__(*err) | |||||
| def step_end(self, run_context): | def step_end(self, run_context): | ||||
| """ | """ | ||||
| @@ -19,7 +19,7 @@ from mindspore import log as logger | |||||
| from ..common.tensor import Tensor | from ..common.tensor import Tensor | ||||
| from ..nn.metrics import get_metrics | from ..nn.metrics import get_metrics | ||||
| from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool | from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool | ||||
| from .callback.callback import _InternalCallbackParam, RunContext, _build_callbacks | |||||
| from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager | |||||
| from .. import context | from .. import context | ||||
| from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | ||||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | ||||
| @@ -334,8 +334,6 @@ class Model: | |||||
| if self._parameter_broadcast: | if self._parameter_broadcast: | ||||
| self._train_network.set_broadcast_flag() | self._train_network.set_broadcast_flag() | ||||
| # build callback list | |||||
| list_callback = _build_callbacks(callbacks) | |||||
| cb_params = _InternalCallbackParam() | cb_params = _InternalCallbackParam() | ||||
| cb_params.train_network = self._train_network | cb_params.train_network = self._train_network | ||||
| cb_params.epoch_num = epoch | cb_params.epoch_num = epoch | ||||
| @@ -346,17 +344,18 @@ class Model: | |||||
| cb_params.parallel_mode = self._parallel_mode | cb_params.parallel_mode = self._parallel_mode | ||||
| cb_params.device_number = self._device_number | cb_params.device_number = self._device_number | ||||
| cb_params.train_dataset = train_dataset | cb_params.train_dataset = train_dataset | ||||
| cb_params.list_callback = list_callback | |||||
| cb_params.list_callback = callbacks | |||||
| if dataset_sink_mode: | |||||
| if context.get_context("mode") == context.PYNATIVE_MODE: | |||||
| # build callback list | |||||
| with _CallbackManager(callbacks) as list_callback: | |||||
| if not dataset_sink_mode: | |||||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||||
| elif context.get_context("mode") == context.PYNATIVE_MODE: | |||||
| logger.warning("The pynative mode cannot support dataset sink mode currently." | logger.warning("The pynative mode cannot support dataset sink mode currently." | ||||
| "So the training process will be performed with dataset not sink.") | "So the training process will be performed with dataset not sink.") | ||||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | self._train_process(epoch, train_dataset, list_callback, cb_params) | ||||
| else: | else: | ||||
| self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) | self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) | ||||
| else: | |||||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||||
| def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): | def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): | ||||
| """ | """ | ||||
| @@ -369,7 +368,7 @@ class Model: | |||||
| returned and passed to the network. Otherwise, a tuple (data, label) should | returned and passed to the network. Otherwise, a tuple (data, label) should | ||||
| be returned, and the data and label are passed to the network and loss | be returned, and the data and label are passed to the network and loss | ||||
| function respectively. | function respectively. | ||||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||||
| list_callback (Callback): Executor of callback list. Default: None. | |||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| """ | """ | ||||
| dataset_helper, train_network = self._exec_preprocess(self._train_network, | dataset_helper, train_network = self._exec_preprocess(self._train_network, | ||||
| @@ -417,7 +416,7 @@ class Model: | |||||
| returned and passed to the network. Otherwise, a tuple (data, label) should | returned and passed to the network. Otherwise, a tuple (data, label) should | ||||
| be returned, and the data and label are passed to the network and loss | be returned, and the data and label are passed to the network and loss | ||||
| function respectively. | function respectively. | ||||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||||
| list_callback (Callback): Executor of callback list. Default: None. | |||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| """ | """ | ||||
| dataset_helper, _ = self._exec_preprocess(self._train_network, | dataset_helper, _ = self._exec_preprocess(self._train_network, | ||||
| @@ -524,7 +523,7 @@ class Model: | |||||
| Args: | Args: | ||||
| valid_dataset (Dataset): Dataset to evaluate the model. | valid_dataset (Dataset): Dataset to evaluate the model. | ||||
| list_callback (ListCallback): Executor of callback list. Default: None. | |||||
| list_callback (Callback): Executor of callback list. Default: None. | |||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| Returns: | Returns: | ||||
| @@ -563,7 +562,7 @@ class Model: | |||||
| Args: | Args: | ||||
| valid_dataset (Dataset): Dataset to evaluate the model. | valid_dataset (Dataset): Dataset to evaluate the model. | ||||
| list_callback (ListCallback): Executor of callback list. Default: None. | |||||
| list_callback (Callback): Executor of callback list. Default: None. | |||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| Returns: | Returns: | ||||
| @@ -622,7 +621,6 @@ class Model: | |||||
| if not self._metric_fns: | if not self._metric_fns: | ||||
| raise ValueError("metric fn can not be None or empty.") | raise ValueError("metric fn can not be None or empty.") | ||||
| list_callback = _build_callbacks(callbacks) | |||||
| cb_params = _InternalCallbackParam() | cb_params = _InternalCallbackParam() | ||||
| cb_params.eval_network = self._eval_network | cb_params.eval_network = self._eval_network | ||||
| cb_params.valid_dataset = valid_dataset | cb_params.valid_dataset = valid_dataset | ||||
| @@ -635,9 +633,10 @@ class Model: | |||||
| self._clear_metrics() | self._clear_metrics() | ||||
| if dataset_sink_mode: | |||||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||||
| with _CallbackManager(callbacks) as list_callback: | |||||
| if dataset_sink_mode: | |||||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||||
| def predict(self, *predict_data): | def predict(self, *predict_data): | ||||
| """ | """ | ||||
| @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell | |||||
| from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | ||||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | ||||
| from mindspore.train import amp | from mindspore.train import amp | ||||
| from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks | |||||
| from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager | |||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| from .dataset_helper import DatasetHelper | from .dataset_helper import DatasetHelper | ||||
| @@ -392,7 +392,6 @@ class Model: | |||||
| self._train_network.set_broadcast_flag() | self._train_network.set_broadcast_flag() | ||||
| # build callback list | # build callback list | ||||
| list_callback = _build_callbacks(callbacks) | |||||
| cb_params = _InternalCallbackParam() | cb_params = _InternalCallbackParam() | ||||
| cb_params.train_network = self._train_network | cb_params.train_network = self._train_network | ||||
| cb_params.epoch_num = epoch | cb_params.epoch_num = epoch | ||||
| @@ -403,17 +402,17 @@ class Model: | |||||
| cb_params.parallel_mode = self._parallel_mode | cb_params.parallel_mode = self._parallel_mode | ||||
| cb_params.device_number = self._device_number | cb_params.device_number = self._device_number | ||||
| cb_params.train_dataset = train_dataset | cb_params.train_dataset = train_dataset | ||||
| cb_params.list_callback = list_callback | |||||
| cb_params.list_callback = callbacks | |||||
| if dataset_sink_mode: | |||||
| if context.get_context("mode") == context.PYNATIVE_MODE: | |||||
| with _CallbackManager(callbacks) as list_callback: | |||||
| if not dataset_sink_mode: | |||||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||||
| elif context.get_context("mode") == context.PYNATIVE_MODE: | |||||
| logger.warning("The pynative mode cannot support dataset sink mode currently." | logger.warning("The pynative mode cannot support dataset sink mode currently." | ||||
| "So the training process will be performed with dataset not sink.") | "So the training process will be performed with dataset not sink.") | ||||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | self._train_process(epoch, train_dataset, list_callback, cb_params) | ||||
| else: | else: | ||||
| self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) | self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) | ||||
| else: | |||||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||||
| def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): | def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): | ||||
| """ | """ | ||||
| @@ -426,7 +425,7 @@ class Model: | |||||
| returned and passed to the network. Otherwise, a tuple (data, label) should | returned and passed to the network. Otherwise, a tuple (data, label) should | ||||
| be returned, and the data and label are passed to the network and loss | be returned, and the data and label are passed to the network and loss | ||||
| function respectively. | function respectively. | ||||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||||
| list_callback (Callback): Executor of callback list. Default: None. | |||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| """ | """ | ||||
| iter_first_order = self._frequency - 1 | iter_first_order = self._frequency - 1 | ||||
| @@ -490,7 +489,7 @@ class Model: | |||||
| returned and passed to the network. Otherwise, a tuple (data, label) should | returned and passed to the network. Otherwise, a tuple (data, label) should | ||||
| be returned, and the data and label are passed to the network and loss | be returned, and the data and label are passed to the network and loss | ||||
| function respectively. | function respectively. | ||||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||||
| list_callback (Callback): Executor of callback list. Default: None. | |||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| """ | """ | ||||
| dataset_helper, _ = self._exec_preprocess(self._train_network, | dataset_helper, _ = self._exec_preprocess(self._train_network, | ||||
| @@ -695,7 +694,6 @@ class Model: | |||||
| if not self._metric_fns: | if not self._metric_fns: | ||||
| raise ValueError("metric fn can not be None or empty.") | raise ValueError("metric fn can not be None or empty.") | ||||
| list_callback = _build_callbacks(callbacks) | |||||
| cb_params = _InternalCallbackParam() | cb_params = _InternalCallbackParam() | ||||
| cb_params.eval_network = self._eval_network | cb_params.eval_network = self._eval_network | ||||
| cb_params.valid_dataset = valid_dataset | cb_params.valid_dataset = valid_dataset | ||||
| @@ -708,9 +706,10 @@ class Model: | |||||
| self._clear_metrics() | self._clear_metrics() | ||||
| if dataset_sink_mode: | |||||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||||
| with _CallbackManager(callbacks) as list_callback: | |||||
| if dataset_sink_mode: | |||||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||||
| def predict(self, *predict_data): | def predict(self, *predict_data): | ||||
| """ | """ | ||||
| @@ -156,12 +156,19 @@ def get_dataset(): | |||||
| class ImageSummaryCallback: | class ImageSummaryCallback: | ||||
| def __init__(self, summaryRecord): | |||||
| self._summaryRecord = summaryRecord | |||||
| def __init__(self, summary_record): | |||||
| self._summary_record = summary_record | |||||
| def __enter__(self): | |||||
| return self | |||||
| def __exit__(self, *err): | |||||
| pass | |||||
| def record(self, step, train_network=None): | def record(self, step, train_network=None): | ||||
| self._summaryRecord.record(step, train_network) | |||||
| self._summaryRecord.flush() | |||||
| self._summary_record.record(step, train_network) | |||||
| self._summary_record.flush() | |||||
| def test_image_summary_train(): | def test_image_summary_train(): | ||||
| @@ -180,6 +180,12 @@ class CallbackTest: | |||||
| def __init__(self): | def __init__(self): | ||||
| pass | pass | ||||
| def __enter__(self): | |||||
| return self | |||||
| def __exit__(self, *err): | |||||
| pass | |||||
| def record(self, step, *args): | def record(self, step, *args): | ||||
| print(step, args) | print(step, args) | ||||
| @@ -15,6 +15,7 @@ | |||||
| """test callback function.""" | """test callback function.""" | ||||
| import os | import os | ||||
| import stat | import stat | ||||
| from unittest import mock | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| @@ -27,7 +28,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| from mindspore.nn.optim import Momentum | from mindspore.nn.optim import Momentum | ||||
| from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \ | from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \ | ||||
| _checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \ | _checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \ | ||||
| _build_callbacks, CheckpointConfig, _set_cur_net | |||||
| _CallbackManager, Callback, CheckpointConfig, _set_cur_net | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| @@ -122,13 +123,13 @@ def test_loss_monitor_sink_mode(): | |||||
| run_context = RunContext(cb_params) | run_context = RunContext(cb_params) | ||||
| loss_cb = LossMonitor(1) | loss_cb = LossMonitor(1) | ||||
| callbacks = [loss_cb] | callbacks = [loss_cb] | ||||
| callbacklist = _build_callbacks(callbacks) | |||||
| callbacklist.begin(run_context) | |||||
| callbacklist.epoch_begin(run_context) | |||||
| callbacklist.step_begin(run_context) | |||||
| callbacklist.step_end(run_context) | |||||
| callbacklist.epoch_end(run_context) | |||||
| callbacklist.end(run_context) | |||||
| with _CallbackManager(callbacks) as callbacklist: | |||||
| callbacklist.begin(run_context) | |||||
| callbacklist.epoch_begin(run_context) | |||||
| callbacklist.step_begin(run_context) | |||||
| callbacklist.step_end(run_context) | |||||
| callbacklist.epoch_end(run_context) | |||||
| callbacklist.end(run_context) | |||||
| def test_loss_monitor_normal_mode(): | def test_loss_monitor_normal_mode(): | ||||
| @@ -269,29 +270,61 @@ def test_checkpoint_save_ckpt_seconds(): | |||||
| ckpt_cb2.step_end(run_context) | ckpt_cb2.step_end(run_context) | ||||
| def test_build_callbacks(): | |||||
| """Test_build_callbacks.""" | |||||
| def test_CallbackManager(): | |||||
| """TestCallbackManager.""" | |||||
| ck_obj = ModelCheckpoint() | ck_obj = ModelCheckpoint() | ||||
| loss_cb_1 = LossMonitor(1) | loss_cb_1 = LossMonitor(1) | ||||
| callbacks = [None] | callbacks = [None] | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| callbacks = _build_callbacks(callbacks) | |||||
| _CallbackManager(callbacks) | |||||
| callbacks = ['Error'] | callbacks = ['Error'] | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| callbacks = _build_callbacks(callbacks) | |||||
| _CallbackManager(callbacks) | |||||
| callbacks = [ck_obj, loss_cb_1, 'Error', None] | callbacks = [ck_obj, loss_cb_1, 'Error', None] | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| _ = _build_callbacks(callbacks) | |||||
| _CallbackManager(callbacks) | |||||
| def test_CallbackManager_exit_called(): | |||||
| with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: | |||||
| cb1, cb2 = Callback(), Callback() | |||||
| with _CallbackManager([cb1, cb2]): | |||||
| pass | |||||
| for call_args in mock_exit.call_args_list: | |||||
| assert call_args == mock.call(mock.ANY, None, None, None) | |||||
| assert mock_exit.call_count == 2 | |||||
| def test_CallbackManager_exit_called_when_raises(): | |||||
| with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: | |||||
| cb1, cb2 = Callback(), Callback() | |||||
| with pytest.raises(ValueError): | |||||
| with _CallbackManager([cb1, cb2]): | |||||
| raise ValueError() | |||||
| for call_args in mock_exit.call_args_list: | |||||
| assert call_args == mock.call(*[mock.ANY] * 4) | |||||
| assert mock_exit.call_count == 2 | |||||
| def test_CallbackManager_begin_called(): | |||||
| context = dict() | |||||
| with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin: | |||||
| cb1, cb2 = Callback(), Callback() | |||||
| with _CallbackManager([cb1, cb2]) as cm: | |||||
| cm.begin(context) | |||||
| for call_args in mock_begin.call_args_list: | |||||
| assert call_args == mock.call(context) | |||||
| assert mock_begin.call_count == 2 | |||||
| def test_RunContext(): | def test_RunContext(): | ||||
| """Test RunContext.""" | """Test RunContext.""" | ||||
| context_err = 666 | context_err = 666 | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| _ = RunContext(context_err) | |||||
| RunContext(context_err) | |||||
| cb_params = _InternalCallbackParam() | cb_params = _InternalCallbackParam() | ||||
| cb_params.member1 = 1 | cb_params.member1 = 1 | ||||