| @@ -16,6 +16,7 @@ __all__ = [ | |||
| "ResultsMonitor", | |||
| 'HasMonitorCallback', | |||
| "FitlogCallback", | |||
| "TimerCallback", | |||
| # collators | |||
| 'Collator', | |||
| @@ -21,7 +21,9 @@ __all__ = [ | |||
| "ResultsMonitor", | |||
| 'HasMonitorCallback', | |||
| "FitlogCallback" | |||
| "FitlogCallback", | |||
| "TimerCallback" | |||
| ] | |||
| @@ -37,4 +39,4 @@ from .torch_callbacks import * | |||
| from .more_evaluate_callback import MoreEvaluateCallback | |||
| from .has_monitor_callback import ResultsMonitor, HasMonitorCallback | |||
| from .fitlog_callback import FitlogCallback | |||
| from .timer_callback import TimerCallback | |||
| @@ -171,7 +171,7 @@ class ResultsMonitor: | |||
| @property | |||
| def log_name(self) -> str: | |||
| """ | |||
| 内部用于打印信息使用 | |||
| 内部用于打印当前类别信息使用 | |||
| :return: | |||
| """ | |||
| @@ -106,11 +106,11 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
| def on_train_end(self, trainer): | |||
| if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | |||
| if self.real_save_folder: | |||
| logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...") | |||
| logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...") | |||
| trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
| model_load_fn=self.model_load_fn) | |||
| else: | |||
| logger.info(f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") | |||
| logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...") | |||
| self.buffer.seek(0) | |||
| trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
| if self.delete_after_after: | |||
| @@ -1,5 +1,4 @@ | |||
| import json | |||
| import sys | |||
| from typing import Union | |||
| __all__ = [ | |||
| @@ -16,8 +15,21 @@ from fastNLP.core.log import logger | |||
| class ProgressCallback(HasMonitorCallback): | |||
| def __init__(self, monitor, larger_better, must_have_monitor=False): | |||
| super(ProgressCallback, self).__init__(monitor=monitor, larger_better=larger_better, | |||
| must_have_monitor=must_have_monitor) | |||
| self.best_monitor_epoch = -1 | |||
| self.best_monitor_step = -1 | |||
| def record_better_monitor(self, trainer): | |||
| self.best_monitor_step = trainer.global_forward_batches | |||
| self.best_monitor_epoch = trainer.cur_epoch_idx | |||
| def on_train_end(self, trainer): | |||
| f_rich_progress.stop() | |||
| if self.best_monitor_epoch != -1: | |||
| msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ | |||
| f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." | |||
| logger.info(msg) | |||
| @property | |||
| def name(self): # progress bar的名称 | |||
| @@ -97,6 +109,7 @@ class RichCallback(ProgressCallback): | |||
| advance=None, completed=trainer.cur_epoch_idx, refresh=True) | |||
| def on_train_end(self, trainer): | |||
| super(RichCallback, self).on_train_end(trainer) | |||
| self.clear_tasks() | |||
| def on_before_backward(self, trainer, outputs): | |||
| @@ -121,8 +134,8 @@ class RichCallback(ProgressCallback): | |||
| text_style = '' | |||
| characters = '-' | |||
| if self.monitor is not None: | |||
| monitor_value = self.get_monitor_value(results) | |||
| if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
| if self.is_better_results(results, keep_if_better=True): | |||
| self.record_better_monitor(trainer) | |||
| if abs(self.monitor_value) != float('inf'): | |||
| rule_style = 'spring_green3' | |||
| text_style = '[bold]' | |||
| @@ -201,8 +214,8 @@ class RawTextCallback(ProgressCallback): | |||
| base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | |||
| text = '' | |||
| if self.monitor is not None: | |||
| monitor_value = self.get_monitor_value(results) | |||
| if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
| if self.is_better_results(results, keep_if_better=True): | |||
| self.record_better_monitor(trainer) | |||
| if abs(self.monitor_value) != float('inf'): | |||
| text = '+'*self.num_signs + base_text + '+'*self.num_signs | |||
| if len(text) == 0: | |||
| @@ -266,6 +279,7 @@ class TqdmCallback(ProgressCallback): | |||
| self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) | |||
| def on_train_end(self, trainer): | |||
| super(TqdmCallback, self).on_train_end(trainer) | |||
| self.clear_tasks() | |||
| def on_before_backward(self, trainer, outputs): | |||
| @@ -287,8 +301,8 @@ class TqdmCallback(ProgressCallback): | |||
| base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | |||
| text = '' | |||
| if self.monitor is not None: | |||
| monitor_value = self.get_monitor_value(results) | |||
| if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
| if self.is_better_results(results, keep_if_better=True): | |||
| self.record_better_monitor(trainer) | |||
| if abs(self.monitor_value) != float('inf'): | |||
| text = '+'*self.num_signs + base_text + '+'*self.num_signs | |||
| if len(text) == 0: | |||
| @@ -0,0 +1,152 @@ | |||
| import time | |||
| from .callback import Callback | |||
| from ..log import logger | |||
| __all__ = ['TimerCallback'] | |||
| class _Timer: | |||
| """Timer.""" | |||
| def __init__(self, name): | |||
| self.name_ = name | |||
| self.elapsed_ = 0.0 | |||
| self.started_ = False | |||
| self.start_time = time.time() | |||
| def start(self): | |||
| """Start the timer.""" | |||
| assert not self.started_, f'{self.name_} timer has already been started' | |||
| self.start_time = time.time() | |||
| self.started_ = True | |||
| def stop(self): | |||
| """Stop the timer.""" | |||
| assert self.started_, f'{self.name_} timer is not started' | |||
| self.elapsed_ += (time.time() - self.start_time) | |||
| self.started_ = False | |||
| def reset(self): | |||
| """Reset timer.""" | |||
| self.elapsed_ = 0.0 | |||
| self.started_ = False | |||
| def elapsed(self, reset=True): | |||
| """Calculate the elapsed time.""" | |||
| started_ = self.started_ | |||
| # If the timing in progress, end it first. | |||
| if self.started_: | |||
| self.stop() | |||
| # Get the elapsed time. | |||
| elapsed_ = self.elapsed_ | |||
| # Reset the elapsed time | |||
| if reset: | |||
| self.reset() | |||
| # If timing was in progress, set it back. | |||
| if started_: | |||
| self.start() | |||
| return elapsed_ | |||
| class Timers: | |||
| """Group of timers.""" | |||
| def __init__(self): | |||
| self.timers = {} | |||
| def __call__(self, name): | |||
| if name not in self.timers: | |||
| self.timers[name] = _Timer(name) | |||
| return self.timers[name] | |||
| def __contains__(self, item): | |||
| return item in self.timers | |||
| def reset(self): | |||
| for timer in self.timers.values(): | |||
| timer.reset() | |||
| class TimerCallback(Callback): | |||
| """ | |||
| 这个 callback 的作用是打印训练过程中的相关时间信息,例如训练时长,评测时长,总的时长等 | |||
| """ | |||
| def __init__(self, print_every=-1, time_ndigit=3): | |||
| """ | |||
| :param print_every: 在哪个时候打印时间信息。 | |||
| * *负数*: 表示每隔多少 epoch 结束打印一次; | |||
| * *0*: 表示整个训练结束才打印; | |||
| * *正数*: 每隔多少个 step 打印一次; | |||
| :param time_ndigit: 保留多少位的小数 | |||
| """ | |||
| assert isinstance(print_every, int), "print_every must be an int number." | |||
| self.timers = Timers() | |||
| self.print_every = print_every | |||
| self.time_ndigit = time_ndigit | |||
| def on_train_begin(self, trainer): | |||
| self.timers('total').start() | |||
| self.timers('train').start() | |||
| def on_fetch_data_begin(self, trainer): | |||
| self.timers('fetch-data').start() | |||
| def on_fetch_data_end(self, trainer): | |||
| self.timers('fetch-data').stop() | |||
| def on_train_batch_begin(self, trainer, batch, indices): | |||
| self.timers('forward').start() | |||
| def on_before_backward(self, trainer, outputs): | |||
| self.timers('forward').stop() | |||
| self.timers('backward').start() | |||
| def on_after_backward(self, trainer): | |||
| self.timers('backward').stop() | |||
| def on_before_optimizers_step(self, trainer, optimizers): | |||
| self.timers('optimize').start() | |||
| def on_after_optimizers_step(self, trainer, optimizers): | |||
| self.timers('optimize').stop() | |||
| def on_evaluate_begin(self, trainer): | |||
| self.timers('train').stop() | |||
| self.timers('evaluate').start() | |||
| def on_evaluate_end(self, trainer, results): | |||
| self.timers('evaluate').stop() | |||
| self.timers('train').start() | |||
| def format_timer(self, reset=True): | |||
| line = '' | |||
| timers = ['fetch-data', 'forward', 'backward', 'optimize', 'evaluate', 'train', 'total'] | |||
| for timer_name in timers: | |||
| if not timer_name in self.timers: | |||
| continue | |||
| timer = self.timers(timer_name) | |||
| elapsed = round(timer.elapsed(reset=reset), self.time_ndigit) | |||
| if elapsed != 0: | |||
| line = line + f', {timer_name}: {elapsed}s' | |||
| return line | |||
| def on_train_batch_end(self, trainer): | |||
| if self.print_every>0 and trainer.global_forward_batches % self.print_every == 0: | |||
| line = self.format_timer() | |||
| logger.info(f"Running {self.print_every} batches{line}") | |||
| def on_train_epoch_end(self, trainer): | |||
| if self.print_every < 0 and trainer.cur_epoch_idx % abs(self.print_every) == 0: | |||
| line = self.format_timer() | |||
| logger.info(f"Running {abs(self.print_every)} epochs{line}") | |||
| def on_train_end(self, trainer): | |||
| if self.print_every == 0: | |||
| line = self.format_timer() | |||
| logger.info(f"Training finished{line}") | |||
| @@ -41,10 +41,12 @@ class TrainBatchLoop(Loop): | |||
| batch = next(dataloader) | |||
| indices = get_batch_indices() | |||
| except StopIteration: | |||
| trainer.on_fetch_data_end() | |||
| break | |||
| trainer.on_fetch_data_end() | |||
| try: | |||
| trainer.on_fetch_data_end() | |||
| batch = match_and_substitute_params(trainer.input_mapping, batch) | |||
| batch = trainer.move_data_to_device(batch) | |||
| @@ -108,6 +108,9 @@ class TorchDataLoader(DataLoader): | |||
| if not isinstance(dataset, _FDataSet): | |||
| dataset = _FDataSet(dataset) | |||
| if num_workers>0 and multiprocessing_context is None: | |||
| multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程 | |||
| if batch_sampler is not None: | |||
| batch_size = 1 | |||
| shuffle = False | |||