| @@ -16,6 +16,7 @@ __all__ = [ | |||||
| "ResultsMonitor", | "ResultsMonitor", | ||||
| 'HasMonitorCallback', | 'HasMonitorCallback', | ||||
| "FitlogCallback", | "FitlogCallback", | ||||
| "TimerCallback", | |||||
| # collators | # collators | ||||
| 'Collator', | 'Collator', | ||||
| @@ -21,7 +21,9 @@ __all__ = [ | |||||
| "ResultsMonitor", | "ResultsMonitor", | ||||
| 'HasMonitorCallback', | 'HasMonitorCallback', | ||||
| "FitlogCallback" | |||||
| "FitlogCallback", | |||||
| "TimerCallback" | |||||
| ] | ] | ||||
| @@ -37,4 +39,4 @@ from .torch_callbacks import * | |||||
| from .more_evaluate_callback import MoreEvaluateCallback | from .more_evaluate_callback import MoreEvaluateCallback | ||||
| from .has_monitor_callback import ResultsMonitor, HasMonitorCallback | from .has_monitor_callback import ResultsMonitor, HasMonitorCallback | ||||
| from .fitlog_callback import FitlogCallback | from .fitlog_callback import FitlogCallback | ||||
| from .timer_callback import TimerCallback | |||||
| @@ -171,7 +171,7 @@ class ResultsMonitor: | |||||
| @property | @property | ||||
| def log_name(self) -> str: | def log_name(self) -> str: | ||||
| """ | """ | ||||
| 内部用于打印信息使用 | |||||
| 内部用于打印当前类别信息使用 | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| @@ -106,11 +106,11 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
| if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | ||||
| if self.real_save_folder: | if self.real_save_folder: | ||||
| logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...") | |||||
| logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...") | |||||
| trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
| model_load_fn=self.model_load_fn) | model_load_fn=self.model_load_fn) | ||||
| else: | else: | ||||
| logger.info(f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") | |||||
| logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...") | |||||
| self.buffer.seek(0) | self.buffer.seek(0) | ||||
| trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | ||||
| if self.delete_after_after: | if self.delete_after_after: | ||||
| @@ -1,5 +1,4 @@ | |||||
| import json | import json | ||||
| import sys | |||||
| from typing import Union | from typing import Union | ||||
| __all__ = [ | __all__ = [ | ||||
| @@ -16,8 +15,21 @@ from fastNLP.core.log import logger | |||||
| class ProgressCallback(HasMonitorCallback): | class ProgressCallback(HasMonitorCallback): | ||||
| def __init__(self, monitor, larger_better, must_have_monitor=False): | |||||
| super(ProgressCallback, self).__init__(monitor=monitor, larger_better=larger_better, | |||||
| must_have_monitor=must_have_monitor) | |||||
| self.best_monitor_epoch = -1 | |||||
| self.best_monitor_step = -1 | |||||
| def record_better_monitor(self, trainer): | |||||
| self.best_monitor_step = trainer.global_forward_batches | |||||
| self.best_monitor_epoch = trainer.cur_epoch_idx | |||||
| def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
| f_rich_progress.stop() | |||||
| if self.best_monitor_epoch != -1: | |||||
| msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ | |||||
| f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." | |||||
| logger.info(msg) | |||||
| @property | @property | ||||
| def name(self): # progress bar的名称 | def name(self): # progress bar的名称 | ||||
| @@ -97,6 +109,7 @@ class RichCallback(ProgressCallback): | |||||
| advance=None, completed=trainer.cur_epoch_idx, refresh=True) | advance=None, completed=trainer.cur_epoch_idx, refresh=True) | ||||
| def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
| super(RichCallback, self).on_train_end(trainer) | |||||
| self.clear_tasks() | self.clear_tasks() | ||||
| def on_before_backward(self, trainer, outputs): | def on_before_backward(self, trainer, outputs): | ||||
| @@ -121,8 +134,8 @@ class RichCallback(ProgressCallback): | |||||
| text_style = '' | text_style = '' | ||||
| characters = '-' | characters = '-' | ||||
| if self.monitor is not None: | if self.monitor is not None: | ||||
| monitor_value = self.get_monitor_value(results) | |||||
| if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
| if self.is_better_results(results, keep_if_better=True): | |||||
| self.record_better_monitor(trainer) | |||||
| if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
| rule_style = 'spring_green3' | rule_style = 'spring_green3' | ||||
| text_style = '[bold]' | text_style = '[bold]' | ||||
| @@ -201,8 +214,8 @@ class RawTextCallback(ProgressCallback): | |||||
| base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | ||||
| text = '' | text = '' | ||||
| if self.monitor is not None: | if self.monitor is not None: | ||||
| monitor_value = self.get_monitor_value(results) | |||||
| if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
| if self.is_better_results(results, keep_if_better=True): | |||||
| self.record_better_monitor(trainer) | |||||
| if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
| text = '+'*self.num_signs + base_text + '+'*self.num_signs | text = '+'*self.num_signs + base_text + '+'*self.num_signs | ||||
| if len(text) == 0: | if len(text) == 0: | ||||
| @@ -266,6 +279,7 @@ class TqdmCallback(ProgressCallback): | |||||
| self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) | self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) | ||||
| def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
| super(TqdmCallback, self).on_train_end(trainer) | |||||
| self.clear_tasks() | self.clear_tasks() | ||||
| def on_before_backward(self, trainer, outputs): | def on_before_backward(self, trainer, outputs): | ||||
| @@ -287,8 +301,8 @@ class TqdmCallback(ProgressCallback): | |||||
| base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | ||||
| text = '' | text = '' | ||||
| if self.monitor is not None: | if self.monitor is not None: | ||||
| monitor_value = self.get_monitor_value(results) | |||||
| if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
| if self.is_better_results(results, keep_if_better=True): | |||||
| self.record_better_monitor(trainer) | |||||
| if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
| text = '+'*self.num_signs + base_text + '+'*self.num_signs | text = '+'*self.num_signs + base_text + '+'*self.num_signs | ||||
| if len(text) == 0: | if len(text) == 0: | ||||
| @@ -0,0 +1,152 @@ | |||||
| import time | |||||
| from .callback import Callback | |||||
| from ..log import logger | |||||
| __all__ = ['TimerCallback'] | |||||
| class _Timer: | |||||
| """Timer.""" | |||||
| def __init__(self, name): | |||||
| self.name_ = name | |||||
| self.elapsed_ = 0.0 | |||||
| self.started_ = False | |||||
| self.start_time = time.time() | |||||
| def start(self): | |||||
| """Start the timer.""" | |||||
| assert not self.started_, f'{self.name_} timer has already been started' | |||||
| self.start_time = time.time() | |||||
| self.started_ = True | |||||
| def stop(self): | |||||
| """Stop the timer.""" | |||||
| assert self.started_, f'{self.name_} timer is not started' | |||||
| self.elapsed_ += (time.time() - self.start_time) | |||||
| self.started_ = False | |||||
| def reset(self): | |||||
| """Reset timer.""" | |||||
| self.elapsed_ = 0.0 | |||||
| self.started_ = False | |||||
| def elapsed(self, reset=True): | |||||
| """Calculate the elapsed time.""" | |||||
| started_ = self.started_ | |||||
| # If the timing in progress, end it first. | |||||
| if self.started_: | |||||
| self.stop() | |||||
| # Get the elapsed time. | |||||
| elapsed_ = self.elapsed_ | |||||
| # Reset the elapsed time | |||||
| if reset: | |||||
| self.reset() | |||||
| # If timing was in progress, set it back. | |||||
| if started_: | |||||
| self.start() | |||||
| return elapsed_ | |||||
| class Timers: | |||||
| """Group of timers.""" | |||||
| def __init__(self): | |||||
| self.timers = {} | |||||
| def __call__(self, name): | |||||
| if name not in self.timers: | |||||
| self.timers[name] = _Timer(name) | |||||
| return self.timers[name] | |||||
| def __contains__(self, item): | |||||
| return item in self.timers | |||||
| def reset(self): | |||||
| for timer in self.timers.values(): | |||||
| timer.reset() | |||||
| class TimerCallback(Callback): | |||||
| """ | |||||
| 这个 callback 的作用是打印训练过程中的相关时间信息,例如训练时长,评测时长,总的时长等 | |||||
| """ | |||||
| def __init__(self, print_every=-1, time_ndigit=3): | |||||
| """ | |||||
| :param print_every: 在哪个时候打印时间信息。 | |||||
| * *负数*: 表示每隔多少 epoch 结束打印一次; | |||||
| * *0*: 表示整个训练结束才打印; | |||||
| * *正数*: 每隔多少个 step 打印一次; | |||||
| :param time_ndigit: 保留多少位的小数 | |||||
| """ | |||||
| assert isinstance(print_every, int), "print_every must be an int number." | |||||
| self.timers = Timers() | |||||
| self.print_every = print_every | |||||
| self.time_ndigit = time_ndigit | |||||
| def on_train_begin(self, trainer): | |||||
| self.timers('total').start() | |||||
| self.timers('train').start() | |||||
| def on_fetch_data_begin(self, trainer): | |||||
| self.timers('fetch-data').start() | |||||
| def on_fetch_data_end(self, trainer): | |||||
| self.timers('fetch-data').stop() | |||||
| def on_train_batch_begin(self, trainer, batch, indices): | |||||
| self.timers('forward').start() | |||||
| def on_before_backward(self, trainer, outputs): | |||||
| self.timers('forward').stop() | |||||
| self.timers('backward').start() | |||||
| def on_after_backward(self, trainer): | |||||
| self.timers('backward').stop() | |||||
| def on_before_optimizers_step(self, trainer, optimizers): | |||||
| self.timers('optimize').start() | |||||
| def on_after_optimizers_step(self, trainer, optimizers): | |||||
| self.timers('optimize').stop() | |||||
| def on_evaluate_begin(self, trainer): | |||||
| self.timers('train').stop() | |||||
| self.timers('evaluate').start() | |||||
| def on_evaluate_end(self, trainer, results): | |||||
| self.timers('evaluate').stop() | |||||
| self.timers('train').start() | |||||
| def format_timer(self, reset=True): | |||||
| line = '' | |||||
| timers = ['fetch-data', 'forward', 'backward', 'optimize', 'evaluate', 'train', 'total'] | |||||
| for timer_name in timers: | |||||
| if not timer_name in self.timers: | |||||
| continue | |||||
| timer = self.timers(timer_name) | |||||
| elapsed = round(timer.elapsed(reset=reset), self.time_ndigit) | |||||
| if elapsed != 0: | |||||
| line = line + f', {timer_name}: {elapsed}s' | |||||
| return line | |||||
| def on_train_batch_end(self, trainer): | |||||
| if self.print_every>0 and trainer.global_forward_batches % self.print_every == 0: | |||||
| line = self.format_timer() | |||||
| logger.info(f"Running {self.print_every} batches{line}") | |||||
| def on_train_epoch_end(self, trainer): | |||||
| if self.print_every < 0 and trainer.cur_epoch_idx % abs(self.print_every) == 0: | |||||
| line = self.format_timer() | |||||
| logger.info(f"Running {abs(self.print_every)} epochs{line}") | |||||
| def on_train_end(self, trainer): | |||||
| if self.print_every == 0: | |||||
| line = self.format_timer() | |||||
| logger.info(f"Training finished{line}") | |||||
| @@ -41,10 +41,12 @@ class TrainBatchLoop(Loop): | |||||
| batch = next(dataloader) | batch = next(dataloader) | ||||
| indices = get_batch_indices() | indices = get_batch_indices() | ||||
| except StopIteration: | except StopIteration: | ||||
| trainer.on_fetch_data_end() | |||||
| break | break | ||||
| trainer.on_fetch_data_end() | |||||
| try: | try: | ||||
| trainer.on_fetch_data_end() | |||||
| batch = match_and_substitute_params(trainer.input_mapping, batch) | batch = match_and_substitute_params(trainer.input_mapping, batch) | ||||
| batch = trainer.move_data_to_device(batch) | batch = trainer.move_data_to_device(batch) | ||||
| @@ -108,6 +108,9 @@ class TorchDataLoader(DataLoader): | |||||
| if not isinstance(dataset, _FDataSet): | if not isinstance(dataset, _FDataSet): | ||||
| dataset = _FDataSet(dataset) | dataset = _FDataSet(dataset) | ||||
| if num_workers>0 and multiprocessing_context is None: | |||||
| multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程 | |||||
| if batch_sampler is not None: | if batch_sampler is not None: | ||||
| batch_size = 1 | batch_size = 1 | ||||
| shuffle = False | shuffle = False | ||||