| @@ -656,10 +656,13 @@ class EvaluateCallback(Callback): | |||
| for key, tester in self.testers.items(): | |||
| try: | |||
| eval_result = tester.test() | |||
| self.pbar.write("Evaluation on {}:".format(key)) | |||
| self.pbar.write(tester._format_eval_results(eval_result)) | |||
| # self.pbar.write("Evaluation on {}:".format(key)) | |||
| self.logger.info("Evaluation on {}:".format(key)) | |||
| # self.pbar.write(tester._format_eval_results(eval_result)) | |||
| self.logger.info(tester._format_eval_results(eval_result)) | |||
| except Exception: | |||
| self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||
| # self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||
| self.logger.info("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||
| class LRScheduler(Callback): | |||
| @@ -22,7 +22,7 @@ from .optimizer import Optimizer | |||
| from .utils import _build_args | |||
| from .utils import _move_dict_value_to_device | |||
| from .utils import _get_func_signature | |||
| from ..io.logger import initLogger | |||
| from ..io.logger import init_logger | |||
| from pkg_resources import parse_version | |||
| __all__ = [ | |||
| @@ -140,7 +140,7 @@ class DistTrainer(): | |||
| self.cp_save_path = None | |||
| # use INFO in the master, WARN for others | |||
| initLogger(log_path, level=logging.INFO if self.is_master else logging.WARNING) | |||
| init_logger(log_path, level=logging.INFO if self.is_master else logging.WARNING) | |||
| self.logger = logging.getLogger(__name__) | |||
| self.logger.info("Setup Distributed Trainer") | |||
| self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||
| @@ -56,6 +56,7 @@ from .utils import _move_model_to_device | |||
| from ._parallel_utils import _data_parallel_wrapper | |||
| from ._parallel_utils import _model_contains_inner_module | |||
| from functools import partial | |||
| from ..io.logger import init_logger, get_logger | |||
| __all__ = [ | |||
| "Tester" | |||
| @@ -103,6 +104,8 @@ class Tester(object): | |||
| self.batch_size = batch_size | |||
| self.verbose = verbose | |||
| self.use_tqdm = use_tqdm | |||
| init_logger(stdout='tqdm' if use_tqdm else 'plain') | |||
| self.logger = get_logger(__name__) | |||
| if isinstance(data, DataSet): | |||
| self.data_iterator = DataSetIter( | |||
| @@ -181,7 +184,8 @@ class Tester(object): | |||
| end_time = time.time() | |||
| test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' | |||
| pbar.write(test_str) | |||
| # pbar.write(test_str) | |||
| self.logger.info(test_str) | |||
| pbar.close() | |||
| except _CheckError as e: | |||
| prev_func_signature = _get_func_signature(self._predict_func) | |||
| @@ -353,8 +353,7 @@ from .utils import _get_func_signature | |||
| from .utils import _get_model_device | |||
| from .utils import _move_model_to_device | |||
| from ._parallel_utils import _model_contains_inner_module | |||
| from ..io.logger import initLogger | |||
| import logging | |||
| from ..io.logger import init_logger, get_logger | |||
| class Trainer(object): | |||
| @@ -552,8 +551,8 @@ class Trainer(object): | |||
| log_path = None | |||
| if save_path is not None: | |||
| log_path = os.path.join(os.path.dirname(save_path), 'log') | |||
| initLogger(log_path) | |||
| self.logger = logging.getLogger(__name__) | |||
| init_logger(path=log_path, stdout='tqdm' if use_tqdm else 'plain') | |||
| self.logger = get_logger(__name__) | |||
| self.use_tqdm = use_tqdm | |||
| self.pbar = None | |||
| @@ -701,8 +700,8 @@ class Trainer(object): | |||
| eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
| self.n_steps) + \ | |||
| self.tester._format_eval_results(eval_res) | |||
| pbar.write(eval_str + '\n') | |||
| # pbar.write(eval_str + '\n') | |||
| self.logger.info(eval_str) | |||
| # ================= mini-batch end ==================== # | |||
| # lr decay; early stopping | |||
| @@ -661,7 +661,7 @@ class _pseudo_tqdm: | |||
| 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | |||
| """ | |||
| def __init__(self, **kwargs): | |||
| self.logger = logging.getLogger() | |||
| self.logger = logging.getLogger(__name__) | |||
| def write(self, info): | |||
| self.logger.info(info) | |||
| @@ -0,0 +1,88 @@ | |||
| import logging | |||
| import logging.config | |||
| import torch | |||
| import _pickle as pickle | |||
| import os | |||
| import sys | |||
| import warnings | |||
| try: | |||
| import fitlog | |||
| except ImportError: | |||
| fitlog = None | |||
| try: | |||
| from tqdm.auto import tqdm | |||
| except ImportError: | |||
| tqdm = None | |||
| if tqdm is not None: | |||
| class TqdmLoggingHandler(logging.Handler): | |||
| def __init__(self, level=logging.INFO): | |||
| super().__init__(level) | |||
| def emit(self, record): | |||
| try: | |||
| msg = self.format(record) | |||
| tqdm.write(msg) | |||
| self.flush() | |||
| except (KeyboardInterrupt, SystemExit): | |||
| raise | |||
| except: | |||
| self.handleError(record) | |||
| else: | |||
| class TqdmLoggingHandler(logging.StreamHandler): | |||
| def __init__(self, level=logging.INFO): | |||
| super().__init__(sys.stdout) | |||
| self.setLevel(level) | |||
| def init_logger(path=None, stdout='tqdm', level='INFO'): | |||
| """initialize logger""" | |||
| if stdout not in ['none', 'plain', 'tqdm']: | |||
| raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) | |||
| if isinstance(level, int): | |||
| pass | |||
| else: | |||
| level = level.lower() | |||
| level = {'info': logging.INFO, 'debug': logging.DEBUG, | |||
| 'warn': logging.WARN, 'warning': logging.WARN, | |||
| 'error': logging.ERROR}[level] | |||
| logger = logging.getLogger('fastNLP') | |||
| logger.setLevel(level) | |||
| handlers_type = set([type(h) for h in logger.handlers]) | |||
| # make sure to initialize logger only once | |||
| # Stream Handler | |||
| if stdout == 'plain' and (logging.StreamHandler not in handlers_type): | |||
| stream_handler = logging.StreamHandler(sys.stdout) | |||
| elif stdout == 'tqdm' and (TqdmLoggingHandler not in handlers_type): | |||
| stream_handler = TqdmLoggingHandler(level) | |||
| else: | |||
| stream_handler = None | |||
| if stream_handler is not None: | |||
| stream_formatter = logging.Formatter('[%(levelname)s] %(message)s') | |||
| stream_handler.setLevel(level) | |||
| stream_handler.setFormatter(stream_formatter) | |||
| logger.addHandler(stream_handler) | |||
| # File Handler | |||
| if path is not None and (logging.FileHandler not in handlers_type): | |||
| if os.path.exists(path): | |||
| assert os.path.isfile(path) | |||
| warnings.warn('log already exists in {}'.format(path)) | |||
| dirname = os.path.abspath(os.path.dirname(path)) | |||
| os.makedirs(dirname, exist_ok=True) | |||
| file_handler = logging.FileHandler(path, mode='a') | |||
| file_handler.setLevel(level) | |||
| file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s', | |||
| datefmt='%Y/%m/%d %H:%M:%S') | |||
| file_handler.setFormatter(file_formatter) | |||
| logger.addHandler(file_handler) | |||
| return logger | |||
| get_logger = logging.getLogger | |||
| @@ -111,17 +111,17 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||
| print(device) | |||
| # 4.定义train方法 | |||
| # trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
| # sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||
| # metrics=[metric], | |||
| # dev_data=datainfo.datasets['test'], device=device, | |||
| # check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | |||
| # n_epochs=ops.train_epoch, num_workers=4) | |||
| trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
| metrics=[metric], | |||
| dev_data=datainfo.datasets['test'], device='cuda', | |||
| batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks, | |||
| n_epochs=ops.train_epoch, num_workers=4) | |||
| trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
| sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||
| metrics=[metric], use_tqdm=False, | |||
| dev_data=datainfo.datasets['test'], device=device, | |||
| check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | |||
| n_epochs=ops.train_epoch, num_workers=4) | |||
| # trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
| # metrics=[metric], | |||
| # dev_data=datainfo.datasets['test'], device='cuda', | |||
| # batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks, | |||
| # n_epochs=ops.train_epoch, num_workers=4) | |||
| if __name__ == "__main__": | |||