| @@ -18,6 +18,8 @@ from ._callback import Callback | |||||
| from ._callback import CallbackManager as _CallbackManager | from ._callback import CallbackManager as _CallbackManager | ||||
| from ._callback import InternalCallbackParam as _InternalCallbackParam | from ._callback import InternalCallbackParam as _InternalCallbackParam | ||||
| from ._callback import RunContext | from ._callback import RunContext | ||||
| from ._callback import checkpoint_cb_for_save_op as _checkpoint_cb_for_save_op | |||||
| from ._callback import set_cur_net as _set_cur_net | |||||
| from ._checkpoint import CheckpointConfig | from ._checkpoint import CheckpointConfig | ||||
| from ._checkpoint import CheckpointManager as _CheckpointManager | from ._checkpoint import CheckpointManager as _CheckpointManager | ||||
| from ._checkpoint import ModelCheckpoint | from ._checkpoint import ModelCheckpoint | ||||
| @@ -160,16 +160,25 @@ class CallbackManager(Callback): | |||||
| self._callbacks, self._stack = [], None | self._callbacks, self._stack = [], None | ||||
| if isinstance(callbacks, Callback): | if isinstance(callbacks, Callback): | ||||
| self._callbacks.append(callbacks) | self._callbacks.append(callbacks) | ||||
| elif callbacks is not None: | |||||
| elif isinstance(callbacks, list): | |||||
| for cb in callbacks: | for cb in callbacks: | ||||
| if not isinstance(cb, Callback): | if not isinstance(cb, Callback): | ||||
| raise TypeError("%r is not an instance of %r" % (cb, Callback)) | |||||
| raise TypeError("The 'callbacks' contains not-a-Callback item.") | |||||
| self._callbacks.append(cb) | self._callbacks.append(cb) | ||||
| elif callbacks is not None: | |||||
| raise TypeError("The 'callbacks' is not a Callback or a list of Callback.") | |||||
| def __enter__(self): | def __enter__(self): | ||||
| if self._stack is None: | if self._stack is None: | ||||
| self._stack = ExitStack().__enter__() | |||||
| self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks] | |||||
| callbacks, self._stack = [], ExitStack().__enter__() | |||||
| for callback in self._callbacks: | |||||
| target = self._stack.enter_context(callback) | |||||
| if not isinstance(target, Callback): | |||||
| logger.warning("Please return 'self' or a Callback as the enter target.") | |||||
| callbacks.append(callback) | |||||
| else: | |||||
| callbacks.append(target) | |||||
| self._callbacks = callbacks | |||||
| return self | return self | ||||
| def __exit__(self, *err): | def __exit__(self, *err): | ||||
| @@ -27,8 +27,7 @@ from mindspore.common.tensor import Tensor | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | from mindspore.nn import TrainOneStepCell, WithLossCell | ||||
| from mindspore.nn.optim import Momentum | from mindspore.nn.optim import Momentum | ||||
| from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ | from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ | ||||
| _CallbackManager, Callback, CheckpointConfig | |||||
| from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op | |||||
| _CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op | |||||
| from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist | from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist | ||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| @@ -189,7 +188,7 @@ def test_checkpoint_cb_for_save_op(): | |||||
| one_param['name'] = "conv1.weight" | one_param['name'] = "conv1.weight" | ||||
| one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32) | one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32) | ||||
| parameter_list.append(one_param) | parameter_list.append(one_param) | ||||
| checkpoint_cb_for_save_op(parameter_list) | |||||
| _checkpoint_cb_for_save_op(parameter_list) | |||||
| def test_checkpoint_cb_for_save_op_update_net(): | def test_checkpoint_cb_for_save_op_update_net(): | ||||
| @@ -200,8 +199,8 @@ def test_checkpoint_cb_for_save_op_update_net(): | |||||
| one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32) | one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32) | ||||
| parameter_list.append(one_param) | parameter_list.append(one_param) | ||||
| net = Net() | net = Net() | ||||
| set_cur_net(net) | |||||
| checkpoint_cb_for_save_op(parameter_list) | |||||
| _set_cur_net(net) | |||||
| _checkpoint_cb_for_save_op(parameter_list) | |||||
| assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1 | assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1 | ||||