Merge pull request !1276 from 李鸿章/context_managertags/v0.5.0-beta
| @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | |||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | |||
| from mindspore.train import amp | |||
| from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks | |||
| from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| from model.dataset_helper import DatasetHelper | |||
| @@ -374,7 +374,6 @@ class Model: | |||
| self._train_network.set_broadcast_flag() | |||
| # build callback list | |||
| list_callback = _build_callbacks(callbacks) | |||
| cb_params = _InternalCallbackParam() | |||
| cb_params.train_network = self._train_network | |||
| cb_params.epoch_num = epoch | |||
| @@ -385,17 +384,17 @@ class Model: | |||
| cb_params.parallel_mode = self._parallel_mode | |||
| cb_params.device_number = self._device_number | |||
| cb_params.train_dataset = train_dataset | |||
| cb_params.list_callback = list_callback | |||
| cb_params.list_callback = callbacks | |||
| if dataset_sink_mode: | |||
| if context.get_context("mode") == context.PYNATIVE_MODE: | |||
| with _CallbackManager(callbacks) as list_callback: | |||
| if not dataset_sink_mode: | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| elif context.get_context("mode") == context.PYNATIVE_MODE: | |||
| logger.warning("The pynative mode cannot support dataset sink mode currently." | |||
| "So the training process will be performed with dataset not sink.") | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| else: | |||
| self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) | |||
| else: | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): | |||
| """ | |||
| @@ -408,7 +407,7 @@ class Model: | |||
| returned and passed to the network. Otherwise, a tuple (data, label) should | |||
| be returned, and the data and label are passed to the network and loss | |||
| function respectively. | |||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||
| list_callback (Callback): Executor of callback list. Default: None. | |||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | |||
| """ | |||
| iter_first_order = self._frequency - 1 | |||
| @@ -473,7 +472,7 @@ class Model: | |||
| returned and passed to the network. Otherwise, a tuple (data, label) should | |||
| be returned, and the data and label are passed to the network and loss | |||
| function respectively. | |||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||
| list_callback (Callback): Executor of callback list. Default: None. | |||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | |||
| """ | |||
| dataset_helper, _ = self._exec_preprocess(self._train_network, | |||
| @@ -580,7 +579,7 @@ class Model: | |||
| Args: | |||
| valid_dataset (Dataset): Dataset to evaluate the model. | |||
| list_callback (ListCallback): Executor of callback list. Default: None. | |||
| list_callback (Callback): Executor of callback list. Default: None. | |||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | |||
| Returns: | |||
| @@ -619,7 +618,7 @@ class Model: | |||
| Args: | |||
| valid_dataset (Dataset): Dataset to evaluate the model. | |||
| list_callback (ListCallback): Executor of callback list. Default: None. | |||
| list_callback (Callback): Executor of callback list. Default: None. | |||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | |||
| Returns: | |||
| @@ -678,7 +677,6 @@ class Model: | |||
| if not self._metric_fns: | |||
| raise ValueError("metric fn can not be None or empty.") | |||
| list_callback = _build_callbacks(callbacks) | |||
| cb_params = _InternalCallbackParam() | |||
| cb_params.eval_network = self._eval_network | |||
| cb_params.valid_dataset = valid_dataset | |||
| @@ -691,9 +689,10 @@ class Model: | |||
| self._clear_metrics() | |||
| if dataset_sink_mode: | |||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||
| with _CallbackManager(callbacks) as list_callback: | |||
| if dataset_sink_mode: | |||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||
| def predict(self, *predict_data): | |||
| """ | |||
| @@ -18,6 +18,7 @@ import os | |||
| import stat | |||
| import shutil | |||
| import time | |||
| from contextlib import ExitStack | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| @@ -282,80 +283,11 @@ def _summary_cb_for_save_op(summary_list): | |||
| return ret | |||
| def _build_callbacks(callbacks): | |||
| """ | |||
| Contain a list of callback. | |||
| Args: | |||
| callbacks (list): Callback functions list, Support None, a single Callback object, or a list. | |||
| Returns: | |||
| List, a list of callback functions. | |||
| """ | |||
| if callbacks: | |||
| if isinstance(callbacks, tuple): | |||
| raise TypeError("Callbacks cannot be a tuple. Please check it.") | |||
| if not isinstance(callbacks, list): | |||
| callbacks = [callbacks] | |||
| else: | |||
| callbacks = [] | |||
| excute_callbacks = [] | |||
| for cb in callbacks: | |||
| if cb is None or not isinstance(cb, Callback): | |||
| raise TypeError("Callback must inheriting base class Callback. Some callback is Wrong. Please check it.") | |||
| excute_callbacks.append(cb) | |||
| return _ListCallback(excute_callbacks) | |||
| class _ListCallback: | |||
| """ | |||
| Sequential execution of callback functions. | |||
| Execute Callback functions at certain points. | |||
| Args: | |||
| callbacks (list): Callback functions list. | |||
| """ | |||
| def __init__(self, callbacks): | |||
| super(_ListCallback, self).__init__() | |||
| self._callbacks = callbacks | |||
| def begin(self, run_context): | |||
| """Called once before network training.""" | |||
| for cb in self._callbacks: | |||
| cb.begin(run_context) | |||
| def epoch_begin(self, run_context): | |||
| """Called before each epoch begin.""" | |||
| for cb in self._callbacks: | |||
| cb.epoch_begin(run_context) | |||
| def epoch_end(self, run_context): | |||
| """Called after each epoch finished.""" | |||
| for cb in self._callbacks: | |||
| cb.epoch_end(run_context) | |||
| def step_begin(self, run_context): | |||
| """Called before each epoch begin.""" | |||
| for cb in self._callbacks: | |||
| cb.step_begin(run_context) | |||
| def step_end(self, run_context): | |||
| """Called after each step finished.""" | |||
| for cb in self._callbacks: | |||
| cb.step_end(run_context) | |||
| def end(self, run_context): | |||
| """Called once after network training.""" | |||
| for cb in self._callbacks: | |||
| cb.end(run_context) | |||
| class Callback: | |||
| """ | |||
| Abstract base class used to build a callback function. | |||
| Abstract base class used to build a callback class. Callbacks are context managers | |||
| which will be entered and exited when passing into the Model. | |||
| You can leverage this mechanism to init and release resources automatically. | |||
| Callback function will execution some operating to the current step or epoch. | |||
| @@ -369,8 +301,13 @@ class Callback: | |||
| >>> print_cb = Print_info() | |||
| >>> model.train(epoch, dataset, callbacks=print_cb) | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| def __enter__(self): | |||
| """Return the enter target.""" | |||
| return self | |||
| def __exit__(self, *err): | |||
| """Release resources here if have any.""" | |||
| def begin(self, run_context): | |||
| """ | |||
| @@ -421,6 +358,67 @@ class Callback: | |||
| """ | |||
| class _CallbackManager(Callback): | |||
| """ | |||
| Sequential execution of callback functions. | |||
| Execute Callback functions at certain points. | |||
| Args: | |||
| callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list. | |||
| """ | |||
| def __init__(self, callbacks): | |||
| self._callbacks, self._stack = [], None | |||
| if isinstance(callbacks, Callback): | |||
| self._callbacks.append(callbacks) | |||
| elif callbacks is not None: | |||
| for cb in callbacks: | |||
| if not isinstance(cb, Callback): | |||
| raise TypeError("%r is not an instance of %r" % (cb, Callback)) | |||
| self._callbacks.append(cb) | |||
| def __enter__(self): | |||
| if self._stack is None: | |||
| self._stack = ExitStack().__enter__() | |||
| self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks] | |||
| return self | |||
| def __exit__(self, *err): | |||
| return self._stack.__exit__(*err) | |||
| def begin(self, run_context): | |||
| """Called once before network training.""" | |||
| for cb in self._callbacks: | |||
| cb.begin(run_context) | |||
| def epoch_begin(self, run_context): | |||
| """Called before each epoch begin.""" | |||
| for cb in self._callbacks: | |||
| cb.epoch_begin(run_context) | |||
| def epoch_end(self, run_context): | |||
| """Called after each epoch finished.""" | |||
| for cb in self._callbacks: | |||
| cb.epoch_end(run_context) | |||
| def step_begin(self, run_context): | |||
| """Called before each epoch begin.""" | |||
| for cb in self._callbacks: | |||
| cb.step_begin(run_context) | |||
| def step_end(self, run_context): | |||
| """Called after each step finished.""" | |||
| for cb in self._callbacks: | |||
| cb.step_end(run_context) | |||
| def end(self, run_context): | |||
| """Called once after network training.""" | |||
| for cb in self._callbacks: | |||
| cb.end(run_context) | |||
| class SummaryStep(Callback): | |||
| """ | |||
| The summary callback class. | |||
| @@ -435,6 +433,13 @@ class SummaryStep(Callback): | |||
| raise ValueError("`flush_step` should be int and greater than 0") | |||
| self._summary = summary | |||
| self._flush_step = flush_step | |||
| def __enter__(self): | |||
| self._summary.__enter__() | |||
| return self | |||
| def __exit__(self, *err): | |||
| return self._summary.__exit__(*err) | |||
| def step_end(self, run_context): | |||
| """ | |||
| @@ -19,7 +19,7 @@ from mindspore import log as logger | |||
| from ..common.tensor import Tensor | |||
| from ..nn.metrics import get_metrics | |||
| from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool | |||
| from .callback.callback import _InternalCallbackParam, RunContext, _build_callbacks | |||
| from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from .. import context | |||
| from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | |||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | |||
| @@ -334,8 +334,6 @@ class Model: | |||
| if self._parameter_broadcast: | |||
| self._train_network.set_broadcast_flag() | |||
| # build callback list | |||
| list_callback = _build_callbacks(callbacks) | |||
| cb_params = _InternalCallbackParam() | |||
| cb_params.train_network = self._train_network | |||
| cb_params.epoch_num = epoch | |||
| @@ -346,17 +344,18 @@ class Model: | |||
| cb_params.parallel_mode = self._parallel_mode | |||
| cb_params.device_number = self._device_number | |||
| cb_params.train_dataset = train_dataset | |||
| cb_params.list_callback = list_callback | |||
| cb_params.list_callback = callbacks | |||
| if dataset_sink_mode: | |||
| if context.get_context("mode") == context.PYNATIVE_MODE: | |||
| # build callback list | |||
| with _CallbackManager(callbacks) as list_callback: | |||
| if not dataset_sink_mode: | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| elif context.get_context("mode") == context.PYNATIVE_MODE: | |||
| logger.warning("The pynative mode cannot support dataset sink mode currently." | |||
| "So the training process will be performed with dataset not sink.") | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| else: | |||
| self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) | |||
| else: | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): | |||
| """ | |||
| @@ -369,7 +368,7 @@ class Model: | |||
| returned and passed to the network. Otherwise, a tuple (data, label) should | |||
| be returned, and the data and label are passed to the network and loss | |||
| function respectively. | |||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||
| list_callback (Callback): Executor of callback list. Default: None. | |||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | |||
| """ | |||
| dataset_helper, train_network = self._exec_preprocess(self._train_network, | |||
| @@ -417,7 +416,7 @@ class Model: | |||
| returned and passed to the network. Otherwise, a tuple (data, label) should | |||
| be returned, and the data and label are passed to the network and loss | |||
| function respectively. | |||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||
| list_callback (Callback): Executor of callback list. Default: None. | |||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | |||
| """ | |||
| dataset_helper, _ = self._exec_preprocess(self._train_network, | |||
| @@ -524,7 +523,7 @@ class Model: | |||
| Args: | |||
| valid_dataset (Dataset): Dataset to evaluate the model. | |||
| list_callback (ListCallback): Executor of callback list. Default: None. | |||
| list_callback (Callback): Executor of callback list. Default: None. | |||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | |||
| Returns: | |||
| @@ -563,7 +562,7 @@ class Model: | |||
| Args: | |||
| valid_dataset (Dataset): Dataset to evaluate the model. | |||
| list_callback (ListCallback): Executor of callback list. Default: None. | |||
| list_callback (Callback): Executor of callback list. Default: None. | |||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | |||
| Returns: | |||
| @@ -622,7 +621,6 @@ class Model: | |||
| if not self._metric_fns: | |||
| raise ValueError("metric fn can not be None or empty.") | |||
| list_callback = _build_callbacks(callbacks) | |||
| cb_params = _InternalCallbackParam() | |||
| cb_params.eval_network = self._eval_network | |||
| cb_params.valid_dataset = valid_dataset | |||
| @@ -635,9 +633,10 @@ class Model: | |||
| self._clear_metrics() | |||
| if dataset_sink_mode: | |||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||
| with _CallbackManager(callbacks) as list_callback: | |||
| if dataset_sink_mode: | |||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||
| def predict(self, *predict_data): | |||
| """ | |||
| @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | |||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | |||
| from mindspore.train import amp | |||
| from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks | |||
| from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| from .dataset_helper import DatasetHelper | |||
| @@ -392,7 +392,6 @@ class Model: | |||
| self._train_network.set_broadcast_flag() | |||
| # build callback list | |||
| list_callback = _build_callbacks(callbacks) | |||
| cb_params = _InternalCallbackParam() | |||
| cb_params.train_network = self._train_network | |||
| cb_params.epoch_num = epoch | |||
| @@ -403,17 +402,17 @@ class Model: | |||
| cb_params.parallel_mode = self._parallel_mode | |||
| cb_params.device_number = self._device_number | |||
| cb_params.train_dataset = train_dataset | |||
| cb_params.list_callback = list_callback | |||
| cb_params.list_callback = callbacks | |||
| if dataset_sink_mode: | |||
| if context.get_context("mode") == context.PYNATIVE_MODE: | |||
| with _CallbackManager(callbacks) as list_callback: | |||
| if not dataset_sink_mode: | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| elif context.get_context("mode") == context.PYNATIVE_MODE: | |||
| logger.warning("The pynative mode cannot support dataset sink mode currently." | |||
| "So the training process will be performed with dataset not sink.") | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| else: | |||
| self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) | |||
| else: | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): | |||
| """ | |||
| @@ -426,7 +425,7 @@ class Model: | |||
| returned and passed to the network. Otherwise, a tuple (data, label) should | |||
| be returned, and the data and label are passed to the network and loss | |||
| function respectively. | |||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||
| list_callback (Callback): Executor of callback list. Default: None. | |||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | |||
| """ | |||
| iter_first_order = self._frequency - 1 | |||
| @@ -490,7 +489,7 @@ class Model: | |||
| returned and passed to the network. Otherwise, a tuple (data, label) should | |||
| be returned, and the data and label are passed to the network and loss | |||
| function respectively. | |||
| list_callback (_ListCallback): Executor of callback list. Default: None. | |||
| list_callback (Callback): Executor of callback list. Default: None. | |||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | |||
| """ | |||
| dataset_helper, _ = self._exec_preprocess(self._train_network, | |||
| @@ -695,7 +694,6 @@ class Model: | |||
| if not self._metric_fns: | |||
| raise ValueError("metric fn can not be None or empty.") | |||
| list_callback = _build_callbacks(callbacks) | |||
| cb_params = _InternalCallbackParam() | |||
| cb_params.eval_network = self._eval_network | |||
| cb_params.valid_dataset = valid_dataset | |||
| @@ -708,9 +706,10 @@ class Model: | |||
| self._clear_metrics() | |||
| if dataset_sink_mode: | |||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||
| with _CallbackManager(callbacks) as list_callback: | |||
| if dataset_sink_mode: | |||
| return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) | |||
| return self._eval_process(valid_dataset, list_callback, cb_params) | |||
| def predict(self, *predict_data): | |||
| """ | |||
| @@ -156,12 +156,19 @@ def get_dataset(): | |||
| class ImageSummaryCallback: | |||
| def __init__(self, summaryRecord): | |||
| self._summaryRecord = summaryRecord | |||
| def __init__(self, summary_record): | |||
| self._summary_record = summary_record | |||
| def __enter__(self): | |||
| return self | |||
| def __exit__(self, *err): | |||
| pass | |||
| def record(self, step, train_network=None): | |||
| self._summaryRecord.record(step, train_network) | |||
| self._summaryRecord.flush() | |||
| self._summary_record.record(step, train_network) | |||
| self._summary_record.flush() | |||
| def test_image_summary_train(): | |||
| @@ -180,6 +180,12 @@ class CallbackTest: | |||
| def __init__(self): | |||
| pass | |||
| def __enter__(self): | |||
| return self | |||
| def __exit__(self, *err): | |||
| pass | |||
| def record(self, step, *args): | |||
| print(step, args) | |||
| @@ -15,6 +15,7 @@ | |||
| """test callback function.""" | |||
| import os | |||
| import stat | |||
| from unittest import mock | |||
| import numpy as np | |||
| import pytest | |||
| @@ -27,7 +28,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Momentum | |||
| from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \ | |||
| _checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \ | |||
| _build_callbacks, CheckpointConfig, _set_cur_net | |||
| _CallbackManager, Callback, CheckpointConfig, _set_cur_net | |||
| class Net(nn.Cell): | |||
| @@ -122,13 +123,13 @@ def test_loss_monitor_sink_mode(): | |||
| run_context = RunContext(cb_params) | |||
| loss_cb = LossMonitor(1) | |||
| callbacks = [loss_cb] | |||
| callbacklist = _build_callbacks(callbacks) | |||
| callbacklist.begin(run_context) | |||
| callbacklist.epoch_begin(run_context) | |||
| callbacklist.step_begin(run_context) | |||
| callbacklist.step_end(run_context) | |||
| callbacklist.epoch_end(run_context) | |||
| callbacklist.end(run_context) | |||
| with _CallbackManager(callbacks) as callbacklist: | |||
| callbacklist.begin(run_context) | |||
| callbacklist.epoch_begin(run_context) | |||
| callbacklist.step_begin(run_context) | |||
| callbacklist.step_end(run_context) | |||
| callbacklist.epoch_end(run_context) | |||
| callbacklist.end(run_context) | |||
| def test_loss_monitor_normal_mode(): | |||
| @@ -269,29 +270,61 @@ def test_checkpoint_save_ckpt_seconds(): | |||
| ckpt_cb2.step_end(run_context) | |||
| def test_build_callbacks(): | |||
| """Test_build_callbacks.""" | |||
| def test_CallbackManager(): | |||
| """TestCallbackManager.""" | |||
| ck_obj = ModelCheckpoint() | |||
| loss_cb_1 = LossMonitor(1) | |||
| callbacks = [None] | |||
| with pytest.raises(TypeError): | |||
| callbacks = _build_callbacks(callbacks) | |||
| _CallbackManager(callbacks) | |||
| callbacks = ['Error'] | |||
| with pytest.raises(TypeError): | |||
| callbacks = _build_callbacks(callbacks) | |||
| _CallbackManager(callbacks) | |||
| callbacks = [ck_obj, loss_cb_1, 'Error', None] | |||
| with pytest.raises(TypeError): | |||
| _ = _build_callbacks(callbacks) | |||
| _CallbackManager(callbacks) | |||
| def test_CallbackManager_exit_called(): | |||
| with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: | |||
| cb1, cb2 = Callback(), Callback() | |||
| with _CallbackManager([cb1, cb2]): | |||
| pass | |||
| for call_args in mock_exit.call_args_list: | |||
| assert call_args == mock.call(mock.ANY, None, None, None) | |||
| assert mock_exit.call_count == 2 | |||
| def test_CallbackManager_exit_called_when_raises(): | |||
| with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: | |||
| cb1, cb2 = Callback(), Callback() | |||
| with pytest.raises(ValueError): | |||
| with _CallbackManager([cb1, cb2]): | |||
| raise ValueError() | |||
| for call_args in mock_exit.call_args_list: | |||
| assert call_args == mock.call(*[mock.ANY] * 4) | |||
| assert mock_exit.call_count == 2 | |||
| def test_CallbackManager_begin_called(): | |||
| context = dict() | |||
| with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin: | |||
| cb1, cb2 = Callback(), Callback() | |||
| with _CallbackManager([cb1, cb2]) as cm: | |||
| cm.begin(context) | |||
| for call_args in mock_begin.call_args_list: | |||
| assert call_args == mock.call(context) | |||
| assert mock_begin.call_count == 2 | |||
| def test_RunContext(): | |||
| """Test RunContext.""" | |||
| context_err = 666 | |||
| with pytest.raises(TypeError): | |||
| _ = RunContext(context_err) | |||
| RunContext(context_err) | |||
| cb_params = _InternalCallbackParam() | |||
| cb_params.member1 = 1 | |||