| @@ -1,8 +1,10 @@ | |||||
| __all__ = [ | __all__ = [ | ||||
| 'FitlogCallback' | 'FitlogCallback' | ||||
| ] | ] | ||||
| from fastNLP import HasMonitorCallback | |||||
| import fitlog | |||||
| from .has_monitor_callback import HasMonitorCallback | |||||
| from ...envs import _module_available | |||||
| if _module_available('fitlog'): | |||||
| import fitlog | |||||
| class FitlogCallback(HasMonitorCallback): | class FitlogCallback(HasMonitorCallback): | ||||
| @@ -25,6 +27,8 @@ class FitlogCallback(HasMonitorCallback): | |||||
| :param log_exception: 是否记录 ``exception`` 。 | :param log_exception: 是否记录 ``exception`` 。 | ||||
| :param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。 | :param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。 | ||||
| """ | """ | ||||
| assert _module_available('fitlog'), "fitlog is not installed." | |||||
| super().__init__(monitor=monitor, larger_better=larger_better) | super().__init__(monitor=monitor, larger_better=larger_better) | ||||
| self.log_exception = log_exception | self.log_exception = log_exception | ||||
| self.log_loss_every = log_loss_every | self.log_loss_every = log_loss_every | ||||