From dbacead74e20b7712a137924ceb4699cd85e5c45 Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Thu, 11 Aug 2022 13:24:16 +0800 Subject: [PATCH] [ to #43850241] fix json dump numpy Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9703595 * fix json dump numpy * Merge remote-tracking branch 'origin' into fix/josn_dump --- modelscope/trainers/hooks/logger/base.py | 2 +- .../trainers/hooks/logger/text_logger_hook.py | 5 +++-- modelscope/utils/json_utils.py | 17 +++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) create mode 100644 modelscope/utils/json_utils.py diff --git a/modelscope/trainers/hooks/logger/base.py b/modelscope/trainers/hooks/logger/base.py index 0e9ebb1f..e1da251f 100644 --- a/modelscope/trainers/hooks/logger/base.py +++ b/modelscope/trainers/hooks/logger/base.py @@ -15,7 +15,7 @@ class LoggerHook(Hook): """Base class for logger hooks. Args: - interval (int): Logging interval (every k iterations). + interval (int): Logging interval (every k iterations). It is interval of iterations even by_epoch is true. ignore_last (bool): Ignore the log of last iterations in each epoch if less than `interval`. reset_flag (bool): Whether to clear the output buffer after logging. diff --git a/modelscope/trainers/hooks/logger/text_logger_hook.py b/modelscope/trainers/hooks/logger/text_logger_hook.py index 168792d9..a204284c 100644 --- a/modelscope/trainers/hooks/logger/text_logger_hook.py +++ b/modelscope/trainers/hooks/logger/text_logger_hook.py @@ -12,6 +12,7 @@ from modelscope.metainfo import Hooks from modelscope.trainers.hooks.builder import HOOKS from modelscope.trainers.hooks.logger.base import LoggerHook from modelscope.utils.constant import LogKeys, ModeKeys +from modelscope.utils.json_utils import EnhancedEncoder from modelscope.utils.torch_utils import get_dist_info, is_master @@ -23,7 +24,7 @@ class TextLoggerHook(LoggerHook): by_epoch (bool, optional): Whether EpochBasedtrainer is used. Default: True. interval (int, optional): Logging interval (every k iterations). - Default: 10. + It is interval of iterations even by_epoch is true. Default: 10. ignore_last (bool, optional): Ignore the log of last iterations in each epoch if less than :attr:`interval`. Default: True. reset_flag (bool, optional): Whether to clear the output buffer after @@ -142,7 +143,7 @@ class TextLoggerHook(LoggerHook): if is_master(): with open(self.json_log_path, 'a+') as f: - json.dump(json_log, f) + json.dump(json_log, f, cls=EnhancedEncoder) f.write('\n') def _round_float(self, items, ndigits=5): diff --git a/modelscope/utils/json_utils.py b/modelscope/utils/json_utils.py new file mode 100644 index 00000000..c5bece23 --- /dev/null +++ b/modelscope/utils/json_utils.py @@ -0,0 +1,17 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import json +import numpy as np + + +class EnhancedEncoder(json.JSONEncoder): + """ Enhanced json encoder for not supported types """ + + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj)