Merge pull request !2236 from 李鸿章/callbacktags/v0.5.0-beta
| @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | |||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | |||
| from mindspore.train import amp | |||
| from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| from model.dataset_helper import DatasetHelper | |||
| @@ -26,9 +26,9 @@ | |||
| namespace mindspore { | |||
| namespace callbacks { | |||
| const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback"; | |||
| const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op"; | |||
| const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op"; | |||
| const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback"; | |||
| const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op"; | |||
| const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op"; | |||
| const char kSummary[] = "Summary"; | |||
| const char kCheckPoint[] = "Save"; | |||
| const int ONE_SHAPE = 1; | |||
| @@ -25,9 +25,9 @@ | |||
| namespace mindspore { | |||
| namespace callbacks { | |||
| const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback"; | |||
| const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op"; | |||
| const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op"; | |||
| const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback"; | |||
| const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op"; | |||
| const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op"; | |||
| const char kSummary[] = "Summary"; | |||
| const char kCheckPoint[] = "Save"; | |||
| const int ONE_SHAPE = 1; | |||
| @@ -14,7 +14,15 @@ | |||
| # ============================================================================ | |||
| """Callback related classes and functions.""" | |||
| from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryStep, CheckpointConfig, RunContext | |||
| from ._callback import Callback | |||
| from ._callback import CallbackManager as _CallbackManager | |||
| from ._callback import InternalCallbackParam as _InternalCallbackParam | |||
| from ._callback import RunContext | |||
| from ._checkpoint import CheckpointConfig | |||
| from ._checkpoint import CheckpointManager as _CheckpointManager | |||
| from ._checkpoint import ModelCheckpoint | |||
| from ._loss_monitor import LossMonitor | |||
| from ._summary_step import SummaryStep | |||
| from ._time_monitor import TimeMonitor | |||
| __all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", | |||
| "SummaryStep", "CheckpointConfig", "RunContext"] | |||
| __all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"] | |||
| @@ -0,0 +1,260 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Callback related classes and functions.""" | |||
| from contextlib import ExitStack | |||
| from mindspore import log as logger | |||
| from mindspore.train.serialization import _fill_param_into_net | |||
| from mindspore.train.summary.summary_record import _cache_summary_tensor_data | |||
| _cur_net = None | |||
| def set_cur_net(net): | |||
| """ | |||
| Set current net for which we are using to save checkpoint. | |||
| Args: | |||
| net (Cell): train network | |||
| """ | |||
| global _cur_net | |||
| _cur_net = net | |||
| def checkpoint_cb_for_save_op(parameter_list): | |||
| """ | |||
| The checkpoint callback function for MindSpore. | |||
| Will be executed by checkpoint save op. | |||
| Args: | |||
| parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor. | |||
| Returns: | |||
| bool, true: means save checkpoint success. | |||
| """ | |||
| if _cur_net is None: | |||
| logger.warning("_cur_net is None. parameters are not updated.") | |||
| return False | |||
| logger.info("update parameters in the net.") | |||
| _fill_param_into_net(_cur_net, parameter_list) | |||
| set_cur_net(None) | |||
| return True | |||
| def summary_cb_for_save_op(summary_list): | |||
| """ | |||
| The summary callback function for MindSpore. | |||
| Will be executed by summary op. | |||
| Args: | |||
| summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor. | |||
| Returns: | |||
| bool, true: means save summary success. | |||
| """ | |||
| ret = _cache_summary_tensor_data(summary_list) | |||
| return ret | |||
| class Callback: | |||
| """ | |||
| Abstract base class used to build a callback class. Callbacks are context managers | |||
| which will be entered and exited when passing into the Model. | |||
| You can leverage this mechanism to init and release resources automatically. | |||
| Callback function will execution some operating to the current step or epoch. | |||
| Examples: | |||
| >>> class Print_info(Callback): | |||
| >>> def step_end(self, run_context): | |||
| >>> cb_params = run_context.original_args() | |||
| >>> print(cb_params.cur_epoch_num) | |||
| >>> print(cb_params.cur_step_num) | |||
| >>> | |||
| >>> print_cb = Print_info() | |||
| >>> model.train(epoch, dataset, callbacks=print_cb) | |||
| """ | |||
| def __enter__(self): | |||
| """Return the enter target.""" | |||
| return self | |||
| def __exit__(self, *err): | |||
| """Release resources here if have any.""" | |||
| def begin(self, run_context): | |||
| """ | |||
| Called once before the network executing. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| def epoch_begin(self, run_context): | |||
| """ | |||
| Called before each epoch beginning. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| def epoch_end(self, run_context): | |||
| """ | |||
| Called after each epoch finished. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| def step_begin(self, run_context): | |||
| """ | |||
| Called before each epoch beginning. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| def step_end(self, run_context): | |||
| """ | |||
| Called after each step finished. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| def end(self, run_context): | |||
| """ | |||
| Called once after network training. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| class CallbackManager(Callback): | |||
| """ | |||
| Sequential execution of callback functions. | |||
| Execute Callback functions at certain points. | |||
| Args: | |||
| callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list. | |||
| """ | |||
| def __init__(self, callbacks): | |||
| self._callbacks, self._stack = [], None | |||
| if isinstance(callbacks, Callback): | |||
| self._callbacks.append(callbacks) | |||
| elif callbacks is not None: | |||
| for cb in callbacks: | |||
| if not isinstance(cb, Callback): | |||
| raise TypeError("%r is not an instance of %r" % (cb, Callback)) | |||
| self._callbacks.append(cb) | |||
| def __enter__(self): | |||
| if self._stack is None: | |||
| self._stack = ExitStack().__enter__() | |||
| self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks] | |||
| return self | |||
| def __exit__(self, *err): | |||
| return self._stack.__exit__(*err) | |||
| def begin(self, run_context): | |||
| """Called once before network training.""" | |||
| for cb in self._callbacks: | |||
| cb.begin(run_context) | |||
| def epoch_begin(self, run_context): | |||
| """Called before each epoch begin.""" | |||
| for cb in self._callbacks: | |||
| cb.epoch_begin(run_context) | |||
| def epoch_end(self, run_context): | |||
| """Called after each epoch finished.""" | |||
| for cb in self._callbacks: | |||
| cb.epoch_end(run_context) | |||
| def step_begin(self, run_context): | |||
| """Called before each epoch begin.""" | |||
| for cb in self._callbacks: | |||
| cb.step_begin(run_context) | |||
| def step_end(self, run_context): | |||
| """Called after each step finished.""" | |||
| for cb in self._callbacks: | |||
| cb.step_end(run_context) | |||
| def end(self, run_context): | |||
| """Called once after network training.""" | |||
| for cb in self._callbacks: | |||
| cb.end(run_context) | |||
| class InternalCallbackParam(dict): | |||
| """Internal callback object's parameters.""" | |||
| def __getattr__(self, key): | |||
| return self[key] | |||
| def __setattr__(self, key, value): | |||
| self[key] = value | |||
| class RunContext: | |||
| """ | |||
| Provides information about the model. | |||
| Run call being made. Provides information about original request to model function. | |||
| callback objects can stop the loop by calling request_stop() of run_context. | |||
| Args: | |||
| original_args (dict): Holding the related information of model etc. | |||
| """ | |||
| def __init__(self, original_args): | |||
| if not isinstance(original_args, dict): | |||
| raise TypeError("The arg of RunContext should be dict type.") | |||
| self._original_args = original_args | |||
| self._stop_requested = False | |||
| def original_args(self): | |||
| """ | |||
| Get the _original_args object. | |||
| Returns: | |||
| Dict, a object holding the original arguments of model. | |||
| """ | |||
| return self._original_args | |||
| def request_stop(self): | |||
| """ | |||
| Sets stop requested during training. | |||
| Callbacks can use this function to request stop of iterations. | |||
| model.train() checks whether this is called or not. | |||
| """ | |||
| self._stop_requested = True | |||
| def get_stop_requested(self): | |||
| """ | |||
| Returns whether a stop is requested or not. | |||
| Returns: | |||
| bool, if true, model.train() stops iterations. | |||
| """ | |||
| return self._stop_requested | |||
| @@ -12,93 +12,25 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Callback related classes and functions.""" | |||
| """Checkpoint related classes and functions.""" | |||
| import os | |||
| import stat | |||
| import shutil | |||
| import stat | |||
| import time | |||
| from contextlib import ExitStack | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph | |||
| from mindspore.train._utils import _make_directory | |||
| from mindspore import log as logger | |||
| from mindspore._checkparam import check_int_non_negative, check_bool | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.train.summary.summary_record import _cache_summary_tensor_data | |||
| from mindspore._checkparam import check_bool, check_int_non_negative | |||
| from mindspore.train._utils import _make_directory | |||
| from mindspore.train.serialization import _exec_save_checkpoint, _save_graph | |||
| from ._callback import Callback, set_cur_net | |||
| _cur_dir = os.getcwd() | |||
| _cur_net = None | |||
| _save_dir = _cur_dir | |||
| class _CheckpointManager: | |||
| """Manage checkpoint files according to train_config of checkpoint.""" | |||
| def __init__(self): | |||
| self._ckpoint_filelist = [] | |||
| @property | |||
| def ckpoint_filelist(self): | |||
| """Get all the related checkpoint files managed here.""" | |||
| return self._ckpoint_filelist | |||
| @property | |||
| def ckpoint_num(self): | |||
| """Get the number of the related checkpoint files managed here.""" | |||
| return len(self._ckpoint_filelist) | |||
| def update_ckpoint_filelist(self, directory, prefix): | |||
| """Update the checkpoint file list.""" | |||
| self._ckpoint_filelist = [] | |||
| files = os.listdir(directory) | |||
| for filename in files: | |||
| if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix): | |||
| mid_name = filename[len(prefix):-5] | |||
| flag = True | |||
| for char in mid_name: | |||
| if char.isalpha(): | |||
| flag = False | |||
| if flag: | |||
| self._ckpoint_filelist.append(directory + '/' + filename) | |||
| def remove_ckpoint_file(self, file_name): | |||
| """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | |||
| try: | |||
| os.chmod(file_name, stat.S_IWRITE) | |||
| os.remove(file_name) | |||
| self._ckpoint_filelist.remove(file_name) | |||
| except OSError: | |||
| logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | |||
| except ValueError: | |||
| logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | |||
| def remove_oldest_ckpoint_file(self): | |||
| """Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" | |||
| ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) | |||
| self.remove_ckpoint_file(ckpoint_files[0]) | |||
| def keep_one_ckpoint_per_minutes(self, minutes, cur_time): | |||
| """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" | |||
| movs = [] | |||
| oldest_file = '' | |||
| oldest_time = cur_time | |||
| for ck_file in self._ckpoint_filelist: | |||
| modify_time = os.path.getmtime(ck_file) | |||
| if cur_time - modify_time < 60 * minutes: | |||
| movs.append(ck_file) | |||
| if modify_time < oldest_time: | |||
| oldest_time = modify_time | |||
| oldest_file = ck_file | |||
| for mv_file in movs: | |||
| if mv_file == oldest_file: | |||
| continue | |||
| self.remove_ckpoint_file(mv_file) | |||
| def _check_file_name_prefix(file_name_prefix): | |||
| """ | |||
| @@ -234,282 +166,6 @@ class CheckpointConfig: | |||
| return checkpoint_policy | |||
| def _set_cur_net(net): | |||
| """ | |||
| Set current net for which we are using to save checkpoint. | |||
| Args: | |||
| net (Cell): train network | |||
| """ | |||
| global _cur_net | |||
| _cur_net = net | |||
| def _checkpoint_cb_for_save_op(parameter_list): | |||
| """ | |||
| The checkpoint callback function for MindSpore. | |||
| Will be executed by checkpoint save op. | |||
| Args: | |||
| parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor. | |||
| Returns: | |||
| bool, true: means save checkpoint success. | |||
| """ | |||
| if _cur_net is None: | |||
| logger.warning("_cur_net is None. parameters are not updated.") | |||
| return False | |||
| logger.info("update parameters in the net.") | |||
| _fill_param_into_net(_cur_net, parameter_list) | |||
| _set_cur_net(None) | |||
| return True | |||
| def _summary_cb_for_save_op(summary_list): | |||
| """ | |||
| The summary callback function for MindSpore. | |||
| Will be executed by summary op. | |||
| Args: | |||
| summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor. | |||
| Returns: | |||
| bool, true: means save summary success. | |||
| """ | |||
| ret = _cache_summary_tensor_data(summary_list) | |||
| return ret | |||
| class Callback: | |||
| """ | |||
| Abstract base class used to build a callback class. Callbacks are context managers | |||
| which will be entered and exited when passing into the Model. | |||
| You can leverage this mechanism to init and release resources automatically. | |||
| Callback function will execution some operating to the current step or epoch. | |||
| Examples: | |||
| >>> class Print_info(Callback): | |||
| >>> def step_end(self, run_context): | |||
| >>> cb_params = run_context.original_args() | |||
| >>> print(cb_params.cur_epoch_num) | |||
| >>> print(cb_params.cur_step_num) | |||
| >>> | |||
| >>> print_cb = Print_info() | |||
| >>> model.train(epoch, dataset, callbacks=print_cb) | |||
| """ | |||
| def __enter__(self): | |||
| """Return the enter target.""" | |||
| return self | |||
| def __exit__(self, *err): | |||
| """Release resources here if have any.""" | |||
| def begin(self, run_context): | |||
| """ | |||
| Called once before the network executing. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| def epoch_begin(self, run_context): | |||
| """ | |||
| Called before each epoch beginning. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| def epoch_end(self, run_context): | |||
| """ | |||
| Called after each epoch finished. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| def step_begin(self, run_context): | |||
| """ | |||
| Called before each epoch beginning. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| def step_end(self, run_context): | |||
| """ | |||
| Called after each step finished. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| def end(self, run_context): | |||
| """ | |||
| Called once after network training. | |||
| Args: | |||
| run_context (RunContext): Include some information of the model. | |||
| """ | |||
| class _CallbackManager(Callback): | |||
| """ | |||
| Sequential execution of callback functions. | |||
| Execute Callback functions at certain points. | |||
| Args: | |||
| callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list. | |||
| """ | |||
| def __init__(self, callbacks): | |||
| self._callbacks, self._stack = [], None | |||
| if isinstance(callbacks, Callback): | |||
| self._callbacks.append(callbacks) | |||
| elif callbacks is not None: | |||
| for cb in callbacks: | |||
| if not isinstance(cb, Callback): | |||
| raise TypeError("%r is not an instance of %r" % (cb, Callback)) | |||
| self._callbacks.append(cb) | |||
| def __enter__(self): | |||
| if self._stack is None: | |||
| self._stack = ExitStack().__enter__() | |||
| self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks] | |||
| return self | |||
| def __exit__(self, *err): | |||
| return self._stack.__exit__(*err) | |||
| def begin(self, run_context): | |||
| """Called once before network training.""" | |||
| for cb in self._callbacks: | |||
| cb.begin(run_context) | |||
| def epoch_begin(self, run_context): | |||
| """Called before each epoch begin.""" | |||
| for cb in self._callbacks: | |||
| cb.epoch_begin(run_context) | |||
| def epoch_end(self, run_context): | |||
| """Called after each epoch finished.""" | |||
| for cb in self._callbacks: | |||
| cb.epoch_end(run_context) | |||
| def step_begin(self, run_context): | |||
| """Called before each epoch begin.""" | |||
| for cb in self._callbacks: | |||
| cb.step_begin(run_context) | |||
| def step_end(self, run_context): | |||
| """Called after each step finished.""" | |||
| for cb in self._callbacks: | |||
| cb.step_end(run_context) | |||
| def end(self, run_context): | |||
| """Called once after network training.""" | |||
| for cb in self._callbacks: | |||
| cb.end(run_context) | |||
| class SummaryStep(Callback): | |||
| """ | |||
| The summary callback class. | |||
| Args: | |||
| summary (Object): Summary recode object. | |||
| flush_step (int): Number of interval steps to execute. Default: 10. | |||
| """ | |||
| def __init__(self, summary, flush_step=10): | |||
| super(SummaryStep, self).__init__() | |||
| if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0: | |||
| raise ValueError("`flush_step` should be int and greater than 0") | |||
| self._summary = summary | |||
| self._flush_step = flush_step | |||
| def __enter__(self): | |||
| self._summary.__enter__() | |||
| return self | |||
| def __exit__(self, *err): | |||
| return self._summary.__exit__(*err) | |||
| def step_end(self, run_context): | |||
| """ | |||
| Save summary. | |||
| Args: | |||
| run_context (RunContext): Context of the train running. | |||
| """ | |||
| cb_params = run_context.original_args() | |||
| if cb_params.cur_step_num % self._flush_step == 0: | |||
| self._summary.record(cb_params.cur_step_num, cb_params.train_network) | |||
| @property | |||
| def summary_file_name(self): | |||
| return self._summary.full_file_name | |||
| class _InternalCallbackParam(dict): | |||
| """Internal callback object's parameters.""" | |||
| def __getattr__(self, key): | |||
| return self[key] | |||
| def __setattr__(self, key, value): | |||
| self[key] = value | |||
| class RunContext: | |||
| """ | |||
| Provides information about the model. | |||
| Run call being made. Provides information about original request to model function. | |||
| callback objects can stop the loop by calling request_stop() of run_context. | |||
| Args: | |||
| original_args (dict): Holding the related information of model etc. | |||
| """ | |||
| def __init__(self, original_args): | |||
| if not isinstance(original_args, dict): | |||
| raise TypeError("The arg of RunContext should be dict type.") | |||
| self._original_args = original_args | |||
| self._stop_requested = False | |||
| def original_args(self): | |||
| """ | |||
| Get the _original_args object. | |||
| Returns: | |||
| Dict, a object holding the original arguments of model. | |||
| """ | |||
| return self._original_args | |||
| def request_stop(self): | |||
| """ | |||
| Sets stop requested during training. | |||
| Callbacks can use this function to request stop of iterations. | |||
| model.train() checks whether this is called or not. | |||
| """ | |||
| self._stop_requested = True | |||
| def get_stop_requested(self): | |||
| """ | |||
| Returns whether a stop is requested or not. | |||
| Returns: | |||
| bool, if true, model.train() stops iterations. | |||
| """ | |||
| return self._stop_requested | |||
| class ModelCheckpoint(Callback): | |||
| """ | |||
| @@ -553,7 +209,7 @@ class ModelCheckpoint(Callback): | |||
| self._config = config | |||
| # get existing checkpoint files | |||
| self._manager = _CheckpointManager() | |||
| self._manager = CheckpointManager() | |||
| self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) | |||
| self._graph_saved = False | |||
| @@ -633,7 +289,7 @@ class ModelCheckpoint(Callback): | |||
| self._last_triggered_step = cb_params.cur_step_num | |||
| if context.get_context("enable_ge"): | |||
| _set_cur_net(cb_params.train_network) | |||
| set_cur_net(cb_params.train_network) | |||
| cb_params.train_network.exec_checkpoint_graph() | |||
| _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) | |||
| @@ -648,57 +304,66 @@ class ModelCheckpoint(Callback): | |||
| return self._latest_ckpt_file_name | |||
| class LossMonitor(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| If the loss is NAN or INF, it will terminate training. | |||
| Note: | |||
| If per_print_times is 0 do not print loss. | |||
| Args: | |||
| per_print_times (int): Print loss every times. Default: 1. | |||
| Raises: | |||
| ValueError: If print_step is not int or less than zero. | |||
| """ | |||
| def __init__(self, per_print_times=1): | |||
| super(LossMonitor, self).__init__() | |||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||
| raise ValueError("print_step must be int and >= 0.") | |||
| self._per_print_times = per_print_times | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| loss = cb_params.net_outputs | |||
| class CheckpointManager: | |||
| """Manage checkpoint files according to train_config of checkpoint.""" | |||
| def __init__(self): | |||
| self._ckpoint_filelist = [] | |||
| if isinstance(loss, (tuple, list)): | |||
| if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): | |||
| loss = loss[0] | |||
| @property | |||
| def ckpoint_filelist(self): | |||
| """Get all the related checkpoint files managed here.""" | |||
| return self._ckpoint_filelist | |||
| if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): | |||
| loss = np.mean(loss.asnumpy()) | |||
| @property | |||
| def ckpoint_num(self): | |||
| """Get the number of the related checkpoint files managed here.""" | |||
| return len(self._ckpoint_filelist) | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
| def update_ckpoint_filelist(self, directory, prefix): | |||
| """Update the checkpoint file list.""" | |||
| self._ckpoint_filelist = [] | |||
| files = os.listdir(directory) | |||
| for filename in files: | |||
| if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix): | |||
| mid_name = filename[len(prefix):-5] | |||
| flag = True | |||
| for char in mid_name: | |||
| if char.isalpha(): | |||
| flag = False | |||
| if flag: | |||
| self._ckpoint_filelist.append(directory + '/' + filename) | |||
| if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): | |||
| raise ValueError("epoch: {} step: {}. Invalid loss, terminating training." | |||
| .format(cb_params.cur_epoch_num, cur_step_in_epoch)) | |||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | |||
| print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) | |||
| def remove_ckpoint_file(self, file_name): | |||
| """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | |||
| try: | |||
| os.chmod(file_name, stat.S_IWRITE) | |||
| os.remove(file_name) | |||
| self._ckpoint_filelist.remove(file_name) | |||
| except OSError: | |||
| logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | |||
| except ValueError: | |||
| logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | |||
| def remove_oldest_ckpoint_file(self): | |||
| """Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" | |||
| ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) | |||
| self.remove_ckpoint_file(ckpoint_files[0]) | |||
| class TimeMonitor(Callback): | |||
| """Time Monitor.""" | |||
| def __init__(self, data_size): | |||
| super(TimeMonitor, self).__init__() | |||
| self.data_size = data_size | |||
| def keep_one_ckpoint_per_minutes(self, minutes, cur_time): | |||
| """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" | |||
| movs = [] | |||
| oldest_file = '' | |||
| oldest_time = cur_time | |||
| for ck_file in self._ckpoint_filelist: | |||
| modify_time = os.path.getmtime(ck_file) | |||
| if cur_time - modify_time < 60 * minutes: | |||
| movs.append(ck_file) | |||
| def epoch_begin(self, run_context): | |||
| self.epoch_time = time.time() | |||
| if modify_time < oldest_time: | |||
| oldest_time = modify_time | |||
| oldest_file = ck_file | |||
| def epoch_end(self, run_context): | |||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||
| per_step_mseconds = epoch_mseconds / self.data_size | |||
| print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) | |||
| for mv_file in movs: | |||
| if mv_file == oldest_file: | |||
| continue | |||
| self.remove_ckpoint_file(mv_file) | |||
| @@ -0,0 +1,62 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LossMonitor Callback class.""" | |||
| import numpy as np | |||
| from mindspore.common.tensor import Tensor | |||
| from ._callback import Callback | |||
| class LossMonitor(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| If the loss is NAN or INF, it will terminate training. | |||
| Note: | |||
| If per_print_times is 0 do not print loss. | |||
| Args: | |||
| per_print_times (int): Print loss every times. Default: 1. | |||
| Raises: | |||
| ValueError: If print_step is not int or less than zero. | |||
| """ | |||
| def __init__(self, per_print_times=1): | |||
| super(LossMonitor, self).__init__() | |||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||
| raise ValueError("print_step must be int and >= 0.") | |||
| self._per_print_times = per_print_times | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| loss = cb_params.net_outputs | |||
| if isinstance(loss, (tuple, list)): | |||
| if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): | |||
| loss = loss[0] | |||
| if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): | |||
| loss = np.mean(loss.asnumpy()) | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
| if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): | |||
| raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( | |||
| cb_params.cur_epoch_num, cur_step_in_epoch)) | |||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | |||
| print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) | |||
| @@ -0,0 +1,56 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """SummaryStep Callback class.""" | |||
| from ._callback import Callback | |||
| class SummaryStep(Callback): | |||
| """ | |||
| The summary callback class. | |||
| Args: | |||
| summary (Object): Summary recode object. | |||
| flush_step (int): Number of interval steps to execute. Default: 10. | |||
| """ | |||
| def __init__(self, summary, flush_step=10): | |||
| super(SummaryStep, self).__init__() | |||
| if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0: | |||
| raise ValueError("`flush_step` should be int and greater than 0") | |||
| self._summary = summary | |||
| self._flush_step = flush_step | |||
| def __enter__(self): | |||
| self._summary.__enter__() | |||
| return self | |||
| def __exit__(self, *err): | |||
| return self._summary.__exit__(*err) | |||
| def step_end(self, run_context): | |||
| """ | |||
| Save summary. | |||
| Args: | |||
| run_context (RunContext): Context of the train running. | |||
| """ | |||
| cb_params = run_context.original_args() | |||
| if cb_params.cur_step_num % self._flush_step == 0: | |||
| self._summary.record(cb_params.cur_step_num, cb_params.train_network) | |||
| @property | |||
| def summary_file_name(self): | |||
| return self._summary.full_file_name | |||
| @@ -0,0 +1,35 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """TimeMonitor Callback class.""" | |||
| import time | |||
| from ._callback import Callback | |||
| class TimeMonitor(Callback): | |||
| """Time Monitor.""" | |||
| def __init__(self, data_size): | |||
| super(TimeMonitor, self).__init__() | |||
| self.data_size = data_size | |||
| def epoch_begin(self, run_context): | |||
| self.epoch_time = time.time() | |||
| def epoch_end(self, run_context): | |||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||
| per_step_mseconds = epoch_mseconds / self.data_size | |||
| print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) | |||
| @@ -19,7 +19,7 @@ from mindspore import log as logger | |||
| from ..common.tensor import Tensor | |||
| from ..nn.metrics import get_metrics | |||
| from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool | |||
| from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from .callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from .. import context | |||
| from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | |||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | |||
| @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | |||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | |||
| from mindspore.train import amp | |||
| from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| from .dataset_helper import DatasetHelper | |||
| @@ -26,10 +26,10 @@ from mindspore.common.api import ms_function | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Momentum | |||
| from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \ | |||
| _checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \ | |||
| _CallbackManager, Callback, CheckpointConfig, _set_cur_net | |||
| from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ | |||
| _CallbackManager, Callback, CheckpointConfig | |||
| from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op | |||
| from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist | |||
| class Net(nn.Cell): | |||
| """Net definition.""" | |||
| @@ -187,7 +187,7 @@ def test_checkpoint_cb_for_save_op(): | |||
| one_param['name'] = "conv1.weight" | |||
| one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32) | |||
| parameter_list.append(one_param) | |||
| _checkpoint_cb_for_save_op(parameter_list) | |||
| checkpoint_cb_for_save_op(parameter_list) | |||
| def test_checkpoint_cb_for_save_op_update_net(): | |||
| @@ -198,8 +198,8 @@ def test_checkpoint_cb_for_save_op_update_net(): | |||
| one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32) | |||
| parameter_list.append(one_param) | |||
| net = Net() | |||
| _set_cur_net(net) | |||
| _checkpoint_cb_for_save_op(parameter_list) | |||
| set_cur_net(net) | |||
| checkpoint_cb_for_save_op(parameter_list) | |||
| assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1 | |||
| @@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||
| from mindspore.nn import WithLossCell, TrainOneStepCell | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore.train.callback.callback import _CheckpointManager | |||
| from mindspore.train.callback import _CheckpointManager | |||
| from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \ | |||
| _exec_save_checkpoint, export, _save_graph | |||
| from ..ut_filter import non_graph_engine | |||