| @@ -1,3 +1,9 @@ | |||
| """ | |||
| This module contains the base class for the Bridge part. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| from abc import ABCMeta, abstractmethod | |||
| from typing import Any, List, Optional, Tuple, Union | |||
| @@ -31,11 +37,11 @@ class BaseBridge(metaclass=ABCMeta): | |||
| def __init__(self, model: ABLModel, reasoner: Reasoner) -> None: | |||
| if not isinstance(model, ABLModel): | |||
| raise TypeError( | |||
| "Expected an instance of ABLModel, but received type: {}".format(type(model)) | |||
| f"Expected an instance of ABLModel, but received type: {type(model)}" | |||
| ) | |||
| if not isinstance(reasoner, Reasoner): | |||
| raise TypeError( | |||
| "Expected an instance of Reasoner, but received type: {}".format(type(reasoner)) | |||
| f"Expected an instance of Reasoner, but received type: {type(reasoner)}" | |||
| ) | |||
| self.model = model | |||
| @@ -1,3 +1,9 @@ | |||
| """ | |||
| This module contains a simple implementation of the Bridge part. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| import os.path as osp | |||
| from typing import Any, List, Optional, Tuple, Union | |||
| @@ -221,7 +227,7 @@ class SimpleBridge(BaseBridge): | |||
| Labeled data should be in the same format as ``train_data``. The only difference is | |||
| that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be | |||
| utilized to train the model. Defaults to None. | |||
| val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 | |||
| val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 pylint: disable=line-too-long | |||
| Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label`` | |||
| and ``Y`` can be either None or not, which depends on the evaluation metircs in | |||
| ``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate | |||
| @@ -327,10 +333,11 @@ class SimpleBridge(BaseBridge): | |||
| Parameters | |||
| ---------- | |||
| val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 | |||
| Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object | |||
| with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be | |||
| either None or not, which depends on the evaluation metircs in ``self.metric_list``. | |||
| val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long | |||
| Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` | |||
| object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` | |||
| and ``Y`` can be either None or not, which depends on the evaluation metircs in | |||
| ``self.metric_list``. | |||
| """ | |||
| val_data_examples = self.data_preprocess("val", val_data) | |||
| self._valid(val_data_examples) | |||
| @@ -346,10 +353,11 @@ class SimpleBridge(BaseBridge): | |||
| Parameters | |||
| ---------- | |||
| test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 | |||
| test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long | |||
| Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object | |||
| with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` | |||
| can be either None or not, which depends on the evaluation metircs in ``self.metric_list``. | |||
| can be either None or not, which depends on the evaluation metircs in | |||
| ``self.metric_list``. | |||
| """ | |||
| print_log("Test start:", logger="current") | |||
| test_data_examples = self.data_preprocess("test", test_data) | |||
| @@ -1,3 +1,9 @@ | |||
| """ | |||
| This module contains the base class used for evaluation. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| import logging | |||
| from abc import ABCMeta, abstractmethod | |||
| from typing import Any, List, Optional | |||
| @@ -1,3 +1,10 @@ | |||
| """ | |||
| This module contains the ReasoningMetric, which is used for evaluating the model performance | |||
| on tasks need reasoning. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| from typing import Optional | |||
| from ...reasoning import KBBase | |||
| @@ -34,6 +41,7 @@ class ReasoningMetric(BaseMetric): | |||
| super().__init__(prefix) | |||
| self.kb = kb | |||
| # pylint: disable=protected-access | |||
| def process(self, data_examples: ListData) -> None: | |||
| """ | |||
| Process a batch of data examples. | |||
| @@ -1,4 +1,8 @@ | |||
| from typing import Optional | |||
| """ | |||
| This module contains the class SymbolAccuracy, which is used for evaluating symbol-level accuracy. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| import numpy as np | |||
| @@ -20,9 +24,6 @@ class SymbolAccuracy(BaseMetric): | |||
| metrics of different tasks. Inherits from BaseMetric. Default to None. | |||
| """ | |||
| def __init__(self, prefix: Optional[str] = None) -> None: | |||
| super().__init__(prefix) | |||
| def process(self, data_examples: ListData) -> None: | |||
| """ | |||
| Processes a batch of data examples. | |||
| @@ -1,6 +1,8 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| # Modified from | |||
| # https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py | |||
| """ | |||
| Copyright (c) OpenMMLab. All rights reserved. | |||
| Modified from | |||
| https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py # noqa: E501 pylint: disable=line-too-long | |||
| """ | |||
| import copy | |||
| from typing import Any, Iterator, Optional, Tuple, Type, Union | |||
| @@ -1,6 +1,8 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| # Modified from | |||
| # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa | |||
| """ | |||
| Copyright (c) OpenMMLab. All rights reserved. | |||
| Modified from | |||
| https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa: E501 pylint: disable=line-too-long | |||
| """ | |||
| from typing import List, Union | |||
| @@ -54,7 +56,7 @@ class ListData(BaseDataElement): | |||
| ``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``. | |||
| This design is inspired by and extends the functionalities of the ``BaseDataElement`` | |||
| class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501 | |||
| class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501 pylint: disable=line-too-long | |||
| Examples: | |||
| >>> from ablkit.data.structures import ListData | |||
| @@ -72,7 +74,7 @@ class ListData(BaseDataElement): | |||
| DATA FIELDS | |||
| Y: [1, 2, 3] | |||
| gt_pseudo_label: [[1, 2], [3, 4], [5, 6]] | |||
| X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 | |||
| X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 pylint: disable=line-too-long | |||
| ) at 0x7f3bbf1991c0> | |||
| >>> print(data_examples[:1]) | |||
| <ListData( | |||
| @@ -1,3 +1,10 @@ | |||
| """ | |||
| This module contains the class ABLModel, which provides a unified interface for different | |||
| machine learning models. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| import pickle | |||
| from typing import Any, Dict | |||
| @@ -99,21 +106,20 @@ class ABLModel: | |||
| method = getattr(model, operation) | |||
| method(*args, **kwargs) | |||
| else: | |||
| if f"{operation}_path" not in kwargs.keys(): | |||
| if f"{operation}_path" not in kwargs: | |||
| raise ValueError(f"'{operation}_path' should not be None") | |||
| else: | |||
| try: | |||
| if operation == "save": | |||
| with open(kwargs["save_path"], "wb") as file: | |||
| pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||
| elif operation == "load": | |||
| with open(kwargs["load_path"], "rb") as file: | |||
| self.base_model = pickle.load(file) | |||
| except (OSError, pickle.PickleError): | |||
| raise NotImplementedError( | |||
| f"{type(model).__name__} object doesn't have the {operation} method \ | |||
| and the default pickle-based {operation} method failed." | |||
| ) | |||
| try: | |||
| if operation == "save": | |||
| with open(kwargs["save_path"], "wb") as file: | |||
| pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||
| elif operation == "load": | |||
| with open(kwargs["load_path"], "rb") as file: | |||
| self.base_model = pickle.load(file) | |||
| except (OSError, pickle.PickleError) as exc: | |||
| raise NotImplementedError( | |||
| f"{type(model).__name__} object doesn't have the {operation} method \ | |||
| and the default pickle-based {operation} method failed." | |||
| ) from exc | |||
| def save(self, *args, **kwargs) -> None: | |||
| """ | |||
| @@ -1,3 +1,9 @@ | |||
| """ | |||
| This module contains the class BasicNN, which servers as a wrapper for PyTorch NN models. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| from __future__ import annotations | |||
| import logging | |||
| @@ -474,7 +480,7 @@ class BasicNN: | |||
| raise ValueError("X should not be None.") | |||
| if y is None: | |||
| y = [0] * len(X) | |||
| if not (len(y) == len(X)): | |||
| if not len(y) == len(X): | |||
| raise ValueError("X and y should have equal length.") | |||
| dataset = ClassificationDataset(X, y, transform=self.train_transform) | |||
| @@ -1,3 +1,9 @@ | |||
| """ | |||
| Implementation of PyTorch dataset class used for classification. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| from typing import Any, Callable, List, Tuple, Optional | |||
| import torch | |||
| @@ -1,3 +1,9 @@ | |||
| """ | |||
| Implementation of PyTorch dataset class used for Prediction. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| from typing import Any, Callable, List, Tuple, Optional | |||
| import torch | |||
| @@ -1,3 +1,9 @@ | |||
| """ | |||
| Implementation of PyTorch dataset class used for regression. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| from typing import Any, List, Tuple | |||
| from torch.utils.data import Dataset | |||
| @@ -1,3 +1,10 @@ | |||
| """ | |||
| This module contains the classes KBBase, GroundKB, and PrologKB, which provide wrappers | |||
| for different kinds of knowledge bases. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| import bisect | |||
| import inspect | |||
| import logging | |||
| @@ -394,7 +401,7 @@ class GroundKB(KBBase): | |||
| base. The second element is a list of reasoning results corresponding to each | |||
| candidate, i.e., the outcome of the ``logic_forward`` function. | |||
| """ | |||
| if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list: | |||
| if not self.GKB or len(pseudo_label) not in self.GKB_len_list: | |||
| return [], [] | |||
| all_candidates, all_reasoning_results = self._find_candidate_GKB(pseudo_label, y) | |||
| @@ -478,7 +485,7 @@ class PrologKB(KBBase): | |||
| super().__init__(pseudo_label_list) | |||
| try: | |||
| import pyswip | |||
| import pyswip # pylint: disable=import-outside-toplevel | |||
| except (IndexError, ImportError): | |||
| print( | |||
| "A Prolog-based knowledge base is in use. Please install SWI-Prolog using the" | |||
| @@ -493,7 +500,7 @@ class PrologKB(KBBase): | |||
| raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.") | |||
| self.prolog.consult(self.pl_file) | |||
| def logic_forward(self, pseudo_label: List[Any]) -> Any: | |||
| def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any: | |||
| """ | |||
| Consult prolog with the query ``logic_forward(pseudo_labels, Res).``, and set the | |||
| returned ``Res`` as the reasoning results. To use this default function, there must be | |||
| @@ -504,11 +511,15 @@ class PrologKB(KBBase): | |||
| ---------- | |||
| pseudo_label : List[Any] | |||
| Pseudo-labels of an example. | |||
| x : List[Any] | |||
| The corresponding input example. If the information from the input | |||
| is not required in the reasoning process, then this parameter will not have | |||
| any effect. | |||
| """ | |||
| result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_label))[0]["Res"] | |||
| result = list(self.prolog.query(f"logic_forward({pseudo_label}, Res)."))[0]["Res"] | |||
| if result == "true": | |||
| return True | |||
| elif result == "false": | |||
| if result == "false": | |||
| return False | |||
| return result | |||
| @@ -517,7 +528,7 @@ class PrologKB(KBBase): | |||
| pseudo_label: List[Any], | |||
| revision_idx: List[int], | |||
| ) -> List[Any]: | |||
| import re | |||
| import re # pylint: disable=import-outside-toplevel | |||
| revision_pseudo_label = pseudo_label.copy() | |||
| revision_pseudo_label = flatten(revision_pseudo_label) | |||
| @@ -533,7 +544,7 @@ class PrologKB(KBBase): | |||
| self, | |||
| pseudo_label: List[Any], | |||
| y: Any, | |||
| x: List[Any], | |||
| x: List[Any], # pylint: disable=unused-argument | |||
| revision_idx: List[int], | |||
| ) -> str: | |||
| """ | |||
| @@ -563,7 +574,7 @@ class PrologKB(KBBase): | |||
| query_string = "logic_forward(" | |||
| query_string += self._revision_pseudo_label(pseudo_label, revision_idx) | |||
| key_is_none_flag = y is None or (isinstance(y, list) and y[0] is None) | |||
| query_string += ",%s)." % y if not key_is_none_flag else ")." | |||
| query_string += f",{y})." if not key_is_none_flag else ")." | |||
| return query_string | |||
| def revise_at_idx( | |||
| @@ -1,3 +1,10 @@ | |||
| """ | |||
| This module contains the class Reasoner, which is used for minimizing the inconsistency | |||
| between the knowledge base and learning models. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| import inspect | |||
| from typing import Any, Callable, List, Optional, Union | |||
| @@ -252,24 +259,19 @@ class Reasoner: | |||
| def zoopt_budget(self, symbol_num: int) -> int: | |||
| """ | |||
| Set the budget for ZOOpt optimization. The function, in its default implementation, | |||
| returns a fixed budget value of 100. However, it can be adjusted to return other fixed | |||
| values, or a dynamic budget based on the number of symbols, if desired. For example, | |||
| one might choose to set the budget as 100 times ``symbol_num``. | |||
| returns a budget value of 10 * ``symbol_num``. | |||
| Parameters | |||
| ---------- | |||
| symbol_num : int | |||
| The number of symbols to be considered in the ZOOpt optimization process. Although this | |||
| parameter can be used to compute a dynamic optimization budget, by default it is not | |||
| utilized in the calculation. | |||
| The number of symbols to be considered in the ZOOpt optimization process. | |||
| Returns | |||
| ------- | |||
| int | |||
| The budget for ZOOpt optimization. By default, this is a fixed value of 100, | |||
| irrespective of the symbol_num value. | |||
| The budget for ZOOpt optimization. By default, this is 10 * ``symbol_num``. | |||
| """ | |||
| return 100 | |||
| return 10 * symbol_num | |||
| def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int: | |||
| """ | |||
| @@ -288,19 +290,18 @@ class Reasoner: | |||
| if max_revision == -1: | |||
| return symbol_num | |||
| elif isinstance(max_revision, float): | |||
| if not (0 <= max_revision <= 1): | |||
| if isinstance(max_revision, float): | |||
| if not 0 <= max_revision <= 1: | |||
| raise ValueError( | |||
| "If max_revision is a float, it must be between 0 and 1, " | |||
| + f"but got {max_revision}" | |||
| ) | |||
| return round(symbol_num * max_revision) | |||
| else: | |||
| if max_revision < 0: | |||
| raise ValueError( | |||
| f"If max_revision is an int, it must be non-negative, but got {max_revision}" | |||
| ) | |||
| return max_revision | |||
| if max_revision < 0: | |||
| raise ValueError( | |||
| f"If max_revision is an int, it must be non-negative, but got {max_revision}" | |||
| ) | |||
| return max_revision | |||
| def abduce(self, data_example: ListData) -> List[Any]: | |||
| """ | |||
| @@ -1,14 +1,15 @@ | |||
| # Python module wrapper for _functools C module | |||
| # to allow utilities written in Python to be added | |||
| # to the functools module. | |||
| # Written by Nick Coghlan <ncoghlan at gmail.com>, | |||
| # Raymond Hettinger <python at rcn.com>, | |||
| # and Łukasz Langa <lukasz at langa.pl>. | |||
| # Copyright (C) 2006-2013 Python Software Foundation. | |||
| # See C source code for _functools credits/copyright | |||
| # Modified from | |||
| # https://github.com/python/cpython/blob/3.12/Lib/functools.py | |||
| """ | |||
| Python module wrapper for _functools C module | |||
| to allow utilities written in Python to be added | |||
| to the functools module. | |||
| Written by Nick Coghlan <ncoghlan at gmail.com>, | |||
| Raymond Hettinger <python at rcn.com>, | |||
| and Łukasz Langa <lukasz at langa.pl>. | |||
| Copyright (C) 2006-2013 Python Software Foundation. | |||
| See C source code for _functools credits/copyright | |||
| Modified from | |||
| https://github.com/python/cpython/blob/3.12/Lib/functools.py | |||
| """ | |||
| from typing import Callable, Generic, TypeVar | |||
| @@ -18,30 +19,58 @@ PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields | |||
| class Cache(Generic[K, T]): | |||
| def __init__(self, func: Callable[[K], T]): | |||
| """Create cache | |||
| """ | |||
| A generic caching mechanism that stores the results of a function call and | |||
| retrieves them to avoid repeated calculations. | |||
| :param func: Function this cache evaluates | |||
| :param cache: If true, do in memory caching. | |||
| :param cache_root: If not None, cache to files at the provided path. | |||
| :param key_func: Convert the key into a hashable object if needed | |||
| """ | |||
| This class implements a dictionary-based cache with a circular doubly linked | |||
| list to manage the cache entries efficiently. It is designed to be generic, | |||
| allowing for caching of any callable function. | |||
| Parameters | |||
| ---------- | |||
| func : Callable[[K], T] | |||
| The function to be cached. This function takes an argument of type K and | |||
| returns a value of type T. | |||
| """ | |||
| def __init__(self, func: Callable[[K], T]): | |||
| self.func = func | |||
| self.has_init = False | |||
| self.cache = False | |||
| self.cache_dict = {} | |||
| self.key_func = None | |||
| self.max_size = 0 | |||
| self.hits, self.misses = 0, 0 | |||
| self.full = False | |||
| self.root = [] # root of the circular doubly linked list | |||
| self.root[:] = [self.root, self.root, None, None] | |||
| def __getitem__(self, obj, *args) -> T: | |||
| return self.get_from_dict(obj, *args) | |||
| def clear_cache(self): | |||
| """Invalidate entire cache.""" | |||
| """ | |||
| Invalidate the entire cache. | |||
| """ | |||
| self.cache_dict.clear() | |||
| def _init_cache(self, obj): | |||
| def init_cache(self, obj): | |||
| """ | |||
| Initialize the cache settings. | |||
| Parameters | |||
| ---------- | |||
| obj : Any | |||
| The object containing settings for cache initialization. | |||
| """ | |||
| if self.has_init: | |||
| return | |||
| self.cache = True | |||
| self.cache_dict = dict() | |||
| self.cache_dict = {} | |||
| self.key_func = obj.key_func | |||
| self.max_size = obj.cache_size | |||
| @@ -53,9 +82,23 @@ class Cache(Generic[K, T]): | |||
| self.has_init = True | |||
| def get_from_dict(self, obj, *args) -> T: | |||
| """Implements dict based cache.""" | |||
| """ | |||
| Retrieve a value from the cache or compute it using ``self.func``. | |||
| Parameters | |||
| ---------- | |||
| obj : Any | |||
| The object to which the cached method/function belongs. | |||
| *args : Any | |||
| Arguments used in key generation for cache retrieval or function computation. | |||
| Returns | |||
| ------- | |||
| T | |||
| The value from the cache or computed by the function. | |||
| """ | |||
| # x is not used in cache key | |||
| pred_pseudo_label, y, x, *res_args = args | |||
| pred_pseudo_label, y, _x, *res_args = args | |||
| cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args) | |||
| link = self.cache_dict.get(cache_key) | |||
| if link is not None: | |||
| @@ -96,15 +139,23 @@ class Cache(Generic[K, T]): | |||
| def abl_cache(): | |||
| """ | |||
| Decorator to enable caching for a function. | |||
| Returns | |||
| ------- | |||
| Callable | |||
| The wrapped function with caching capability. | |||
| """ | |||
| def decorator(func): | |||
| cache_instance = Cache(func) | |||
| def wrapper(obj, *args): | |||
| if obj.use_cache: | |||
| cache_instance._init_cache(obj) | |||
| cache_instance.init_cache(obj) | |||
| return cache_instance.get_from_dict(obj, *args) | |||
| else: | |||
| return func(obj, *args) | |||
| return func(obj, *args) | |||
| return wrapper | |||
| @@ -1,6 +1,8 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| # Modified from | |||
| # https://github.com/open-mmlab/mmengine/blob/main/mmengine/logging/logger.py | |||
| """ | |||
| Copyright (c) OpenMMLab. All rights reserved. | |||
| Modified from | |||
| https://github.com/open-mmlab/mmengine/blob/main/mmengine/logging/logger.py | |||
| """ | |||
| import logging | |||
| import os | |||
| @@ -132,13 +134,13 @@ class ABLFormatter(logging.Formatter): | |||
| Formatted result. | |||
| """ | |||
| if record.levelno == logging.ERROR: | |||
| self._style._fmt = self.err_format | |||
| self._style._fmt = self.err_format # pylint: disable=protected-access | |||
| elif record.levelno == logging.WARNING: | |||
| self._style._fmt = self.warn_format | |||
| self._style._fmt = self.warn_format # pylint: disable=protected-access | |||
| elif record.levelno == logging.INFO: | |||
| self._style._fmt = self.info_format | |||
| self._style._fmt = self.info_format # pylint: disable=protected-access | |||
| elif record.levelno == logging.DEBUG: | |||
| self._style._fmt = self.debug_format | |||
| self._style._fmt = self.debug_format # pylint: disable=protected-access | |||
| result = logging.Formatter.format(self, record) | |||
| return result | |||
| @@ -215,7 +217,7 @@ class ABLLogger(Logger, ManagerMixin): | |||
| self.handlers.append(stream_handler) | |||
| if log_file is None: | |||
| import time | |||
| import time # pylint: disable=import-outside-toplevel | |||
| local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) | |||
| @@ -234,10 +236,20 @@ class ABLLogger(Logger, ManagerMixin): | |||
| @property | |||
| def log_file(self): | |||
| """Get the file path of the log. | |||
| Returns: | |||
| str: Path of the log. | |||
| """ | |||
| return self._log_file | |||
| @property | |||
| def log_dir(self): | |||
| """Get the directory where the log is stored. | |||
| Returns: | |||
| str: Directory where the log is stored. | |||
| """ | |||
| return self._log_dir | |||
| @classmethod | |||
| @@ -284,11 +296,11 @@ class ABLLogger(Logger, ManagerMixin): | |||
| level : Union[int, str] | |||
| The logging level to set. | |||
| """ | |||
| self.level = logging._checkLevel(level) | |||
| self.level = logging._checkLevel(level) # pylint: disable=protected-access | |||
| _accquire_lock() | |||
| # The same logic as ``logging.Manager._clear_cache``. | |||
| for logger in ABLLogger._instance_dict.values(): | |||
| logger._cache.clear() | |||
| logger._cache.clear() # pylint: disable=protected-access | |||
| _release_lock() | |||
| @@ -1,4 +1,7 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| """ | |||
| Copyright (c) OpenMMLab. All rights reserved. | |||
| """ | |||
| import inspect | |||
| import threading | |||
| import warnings | |||
| @@ -72,7 +75,7 @@ class ManagerMixin(metaclass=ManagerMeta): | |||
| name (str): Name of the instance. Defaults to ''. | |||
| """ | |||
| def __init__(self, name: str = "", **kwargs): | |||
| def __init__(self, name: str = ""): | |||
| assert isinstance(name, str) and name, "name argument must be an non-empty string." | |||
| self._instance_name = name | |||
| @@ -1,3 +1,9 @@ | |||
| """ | |||
| Implementation of utilities used in ablkit. | |||
| Copyright (c) 2024 LAMDA. All rights reserved. | |||
| """ | |||
| from typing import List, Any, Union, Tuple, Optional | |||
| import numpy as np | |||
| @@ -198,6 +204,6 @@ def tab_data_to_tuple( | |||
| return None | |||
| if len(X) != len(y): | |||
| raise ValueError( | |||
| "The length of X and y should be the same, but got {} and {}.".format(len(X), len(y)) | |||
| f"The length of X and y should be the same, but got {len(X)} and {len(y)}." | |||
| ) | |||
| return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) | |||