From 400282a9cd6f1f1604e6852fe9a563d2a2c6dcab Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Mon, 8 Jan 2024 22:49:21 +0800 Subject: [PATCH] [ENH] resolve main warnings of pylint and add copyright --- ablkit/bridge/base_bridge.py | 10 +- ablkit/bridge/simple_bridge.py | 22 ++-- ablkit/data/evaluation/base_metric.py | 6 + ablkit/data/evaluation/reasoning_metric.py | 8 ++ ablkit/data/evaluation/symbol_accuracy.py | 9 +- ablkit/data/structures/base_data_element.py | 8 +- ablkit/data/structures/list_data.py | 12 +- ablkit/learning/abl_model.py | 34 +++--- ablkit/learning/basic_nn.py | 8 +- .../torch_dataset/classification_dataset.py | 6 + .../torch_dataset/prediction_dataset.py | 6 + .../torch_dataset/regression_dataset.py | 6 + ablkit/reasoning/kb.py | 27 +++-- ablkit/reasoning/reasoner.py | 35 +++--- ablkit/utils/cache.py | 103 +++++++++++++----- ablkit/utils/logger.py | 32 ++++-- ablkit/utils/manager.py | 7 +- ablkit/utils/utils.py | 8 +- 18 files changed, 247 insertions(+), 100 deletions(-) diff --git a/ablkit/bridge/base_bridge.py b/ablkit/bridge/base_bridge.py index 8ec0ebd..9b10be3 100644 --- a/ablkit/bridge/base_bridge.py +++ b/ablkit/bridge/base_bridge.py @@ -1,3 +1,9 @@ +""" +This module contains the base class for the Bridge part. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + from abc import ABCMeta, abstractmethod from typing import Any, List, Optional, Tuple, Union @@ -31,11 +37,11 @@ class BaseBridge(metaclass=ABCMeta): def __init__(self, model: ABLModel, reasoner: Reasoner) -> None: if not isinstance(model, ABLModel): raise TypeError( - "Expected an instance of ABLModel, but received type: {}".format(type(model)) + f"Expected an instance of ABLModel, but received type: {type(model)}" ) if not isinstance(reasoner, Reasoner): raise TypeError( - "Expected an instance of Reasoner, but received type: {}".format(type(reasoner)) + f"Expected an instance of Reasoner, but received type: {type(reasoner)}" ) self.model = model diff --git a/ablkit/bridge/simple_bridge.py b/ablkit/bridge/simple_bridge.py index 845fdbd..b053902 100644 --- a/ablkit/bridge/simple_bridge.py +++ b/ablkit/bridge/simple_bridge.py @@ -1,3 +1,9 @@ +""" +This module contains a simple implementation of the Bridge part. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + import os.path as osp from typing import Any, List, Optional, Tuple, Union @@ -221,7 +227,7 @@ class SimpleBridge(BaseBridge): Labeled data should be in the same format as ``train_data``. The only difference is that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be utilized to train the model. Defaults to None. - val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 + val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 pylint: disable=line-too-long Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label`` and ``Y`` can be either None or not, which depends on the evaluation metircs in ``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate @@ -327,10 +333,11 @@ class SimpleBridge(BaseBridge): Parameters ---------- - val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 - Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object - with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be - either None or not, which depends on the evaluation metircs in ``self.metric_list``. + val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long + Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` + object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` + and ``Y`` can be either None or not, which depends on the evaluation metircs in + ``self.metric_list``. """ val_data_examples = self.data_preprocess("val", val_data) self._valid(val_data_examples) @@ -346,10 +353,11 @@ class SimpleBridge(BaseBridge): Parameters ---------- - test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 + test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` - can be either None or not, which depends on the evaluation metircs in ``self.metric_list``. + can be either None or not, which depends on the evaluation metircs in + ``self.metric_list``. """ print_log("Test start:", logger="current") test_data_examples = self.data_preprocess("test", test_data) diff --git a/ablkit/data/evaluation/base_metric.py b/ablkit/data/evaluation/base_metric.py index 59824d3..9edf4d8 100644 --- a/ablkit/data/evaluation/base_metric.py +++ b/ablkit/data/evaluation/base_metric.py @@ -1,3 +1,9 @@ +""" +This module contains the base class used for evaluation. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + import logging from abc import ABCMeta, abstractmethod from typing import Any, List, Optional diff --git a/ablkit/data/evaluation/reasoning_metric.py b/ablkit/data/evaluation/reasoning_metric.py index 9a010bb..7fb8855 100644 --- a/ablkit/data/evaluation/reasoning_metric.py +++ b/ablkit/data/evaluation/reasoning_metric.py @@ -1,3 +1,10 @@ +""" +This module contains the ReasoningMetric, which is used for evaluating the model performance +on tasks need reasoning. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + from typing import Optional from ...reasoning import KBBase @@ -34,6 +41,7 @@ class ReasoningMetric(BaseMetric): super().__init__(prefix) self.kb = kb + # pylint: disable=protected-access def process(self, data_examples: ListData) -> None: """ Process a batch of data examples. diff --git a/ablkit/data/evaluation/symbol_accuracy.py b/ablkit/data/evaluation/symbol_accuracy.py index 0eac6ca..3c2fa01 100644 --- a/ablkit/data/evaluation/symbol_accuracy.py +++ b/ablkit/data/evaluation/symbol_accuracy.py @@ -1,4 +1,8 @@ -from typing import Optional +""" +This module contains the class SymbolAccuracy, which is used for evaluating symbol-level accuracy. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" import numpy as np @@ -20,9 +24,6 @@ class SymbolAccuracy(BaseMetric): metrics of different tasks. Inherits from BaseMetric. Default to None. """ - def __init__(self, prefix: Optional[str] = None) -> None: - super().__init__(prefix) - def process(self, data_examples: ListData) -> None: """ Processes a batch of data examples. diff --git a/ablkit/data/structures/base_data_element.py b/ablkit/data/structures/base_data_element.py index 4129de5..2e9b000 100644 --- a/ablkit/data/structures/base_data_element.py +++ b/ablkit/data/structures/base_data_element.py @@ -1,6 +1,8 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Modified from -# https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py +""" +Copyright (c) OpenMMLab. All rights reserved. +Modified from +https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py # noqa: E501 pylint: disable=line-too-long +""" import copy from typing import Any, Iterator, Optional, Tuple, Type, Union diff --git a/ablkit/data/structures/list_data.py b/ablkit/data/structures/list_data.py index b937e18..f90474c 100644 --- a/ablkit/data/structures/list_data.py +++ b/ablkit/data/structures/list_data.py @@ -1,6 +1,8 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Modified from -# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa +""" +Copyright (c) OpenMMLab. All rights reserved. +Modified from +https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa: E501 pylint: disable=line-too-long +""" from typing import List, Union @@ -54,7 +56,7 @@ class ListData(BaseDataElement): ``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``. This design is inspired by and extends the functionalities of the ``BaseDataElement`` - class implemented in `MMEngine `_. # noqa: E501 + class implemented in `MMEngine `_. # noqa: E501 pylint: disable=line-too-long Examples: >>> from ablkit.data.structures import ListData @@ -72,7 +74,7 @@ class ListData(BaseDataElement): DATA FIELDS Y: [1, 2, 3] gt_pseudo_label: [[1, 2], [3, 4], [5, 6]] - X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 + X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 pylint: disable=line-too-long ) at 0x7f3bbf1991c0> >>> print(data_examples[:1]) None: """ diff --git a/ablkit/learning/basic_nn.py b/ablkit/learning/basic_nn.py index 3693bb7..8782b1d 100644 --- a/ablkit/learning/basic_nn.py +++ b/ablkit/learning/basic_nn.py @@ -1,3 +1,9 @@ +""" +This module contains the class BasicNN, which servers as a wrapper for PyTorch NN models. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + from __future__ import annotations import logging @@ -474,7 +480,7 @@ class BasicNN: raise ValueError("X should not be None.") if y is None: y = [0] * len(X) - if not (len(y) == len(X)): + if not len(y) == len(X): raise ValueError("X and y should have equal length.") dataset = ClassificationDataset(X, y, transform=self.train_transform) diff --git a/ablkit/learning/torch_dataset/classification_dataset.py b/ablkit/learning/torch_dataset/classification_dataset.py index a45acd3..a1c3e50 100644 --- a/ablkit/learning/torch_dataset/classification_dataset.py +++ b/ablkit/learning/torch_dataset/classification_dataset.py @@ -1,3 +1,9 @@ +""" +Implementation of PyTorch dataset class used for classification. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + from typing import Any, Callable, List, Tuple, Optional import torch diff --git a/ablkit/learning/torch_dataset/prediction_dataset.py b/ablkit/learning/torch_dataset/prediction_dataset.py index 1abf12e..14dce7b 100644 --- a/ablkit/learning/torch_dataset/prediction_dataset.py +++ b/ablkit/learning/torch_dataset/prediction_dataset.py @@ -1,3 +1,9 @@ +""" +Implementation of PyTorch dataset class used for Prediction. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + from typing import Any, Callable, List, Tuple, Optional import torch diff --git a/ablkit/learning/torch_dataset/regression_dataset.py b/ablkit/learning/torch_dataset/regression_dataset.py index 956b38c..978836e 100644 --- a/ablkit/learning/torch_dataset/regression_dataset.py +++ b/ablkit/learning/torch_dataset/regression_dataset.py @@ -1,3 +1,9 @@ +""" +Implementation of PyTorch dataset class used for regression. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + from typing import Any, List, Tuple from torch.utils.data import Dataset diff --git a/ablkit/reasoning/kb.py b/ablkit/reasoning/kb.py index 0954cd6..3202d8d 100644 --- a/ablkit/reasoning/kb.py +++ b/ablkit/reasoning/kb.py @@ -1,3 +1,10 @@ +""" +This module contains the classes KBBase, GroundKB, and PrologKB, which provide wrappers +for different kinds of knowledge bases. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + import bisect import inspect import logging @@ -394,7 +401,7 @@ class GroundKB(KBBase): base. The second element is a list of reasoning results corresponding to each candidate, i.e., the outcome of the ``logic_forward`` function. """ - if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list: + if not self.GKB or len(pseudo_label) not in self.GKB_len_list: return [], [] all_candidates, all_reasoning_results = self._find_candidate_GKB(pseudo_label, y) @@ -478,7 +485,7 @@ class PrologKB(KBBase): super().__init__(pseudo_label_list) try: - import pyswip + import pyswip # pylint: disable=import-outside-toplevel except (IndexError, ImportError): print( "A Prolog-based knowledge base is in use. Please install SWI-Prolog using the" @@ -493,7 +500,7 @@ class PrologKB(KBBase): raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.") self.prolog.consult(self.pl_file) - def logic_forward(self, pseudo_label: List[Any]) -> Any: + def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any: """ Consult prolog with the query ``logic_forward(pseudo_labels, Res).``, and set the returned ``Res`` as the reasoning results. To use this default function, there must be @@ -504,11 +511,15 @@ class PrologKB(KBBase): ---------- pseudo_label : List[Any] Pseudo-labels of an example. + x : List[Any] + The corresponding input example. If the information from the input + is not required in the reasoning process, then this parameter will not have + any effect. """ - result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_label))[0]["Res"] + result = list(self.prolog.query(f"logic_forward({pseudo_label}, Res)."))[0]["Res"] if result == "true": return True - elif result == "false": + if result == "false": return False return result @@ -517,7 +528,7 @@ class PrologKB(KBBase): pseudo_label: List[Any], revision_idx: List[int], ) -> List[Any]: - import re + import re # pylint: disable=import-outside-toplevel revision_pseudo_label = pseudo_label.copy() revision_pseudo_label = flatten(revision_pseudo_label) @@ -533,7 +544,7 @@ class PrologKB(KBBase): self, pseudo_label: List[Any], y: Any, - x: List[Any], + x: List[Any], # pylint: disable=unused-argument revision_idx: List[int], ) -> str: """ @@ -563,7 +574,7 @@ class PrologKB(KBBase): query_string = "logic_forward(" query_string += self._revision_pseudo_label(pseudo_label, revision_idx) key_is_none_flag = y is None or (isinstance(y, list) and y[0] is None) - query_string += ",%s)." % y if not key_is_none_flag else ")." + query_string += f",{y})." if not key_is_none_flag else ")." return query_string def revise_at_idx( diff --git a/ablkit/reasoning/reasoner.py b/ablkit/reasoning/reasoner.py index 7e3af8a..a016c3c 100644 --- a/ablkit/reasoning/reasoner.py +++ b/ablkit/reasoning/reasoner.py @@ -1,3 +1,10 @@ +""" +This module contains the class Reasoner, which is used for minimizing the inconsistency +between the knowledge base and learning models. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + import inspect from typing import Any, Callable, List, Optional, Union @@ -252,24 +259,19 @@ class Reasoner: def zoopt_budget(self, symbol_num: int) -> int: """ Set the budget for ZOOpt optimization. The function, in its default implementation, - returns a fixed budget value of 100. However, it can be adjusted to return other fixed - values, or a dynamic budget based on the number of symbols, if desired. For example, - one might choose to set the budget as 100 times ``symbol_num``. + returns a budget value of 10 * ``symbol_num``. Parameters ---------- symbol_num : int - The number of symbols to be considered in the ZOOpt optimization process. Although this - parameter can be used to compute a dynamic optimization budget, by default it is not - utilized in the calculation. + The number of symbols to be considered in the ZOOpt optimization process. Returns ------- int - The budget for ZOOpt optimization. By default, this is a fixed value of 100, - irrespective of the symbol_num value. + The budget for ZOOpt optimization. By default, this is 10 * ``symbol_num``. """ - return 100 + return 10 * symbol_num def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int: """ @@ -288,19 +290,18 @@ class Reasoner: if max_revision == -1: return symbol_num - elif isinstance(max_revision, float): - if not (0 <= max_revision <= 1): + if isinstance(max_revision, float): + if not 0 <= max_revision <= 1: raise ValueError( "If max_revision is a float, it must be between 0 and 1, " + f"but got {max_revision}" ) return round(symbol_num * max_revision) - else: - if max_revision < 0: - raise ValueError( - f"If max_revision is an int, it must be non-negative, but got {max_revision}" - ) - return max_revision + if max_revision < 0: + raise ValueError( + f"If max_revision is an int, it must be non-negative, but got {max_revision}" + ) + return max_revision def abduce(self, data_example: ListData) -> List[Any]: """ diff --git a/ablkit/utils/cache.py b/ablkit/utils/cache.py index 5b9e348..9f22102 100644 --- a/ablkit/utils/cache.py +++ b/ablkit/utils/cache.py @@ -1,14 +1,15 @@ -# Python module wrapper for _functools C module -# to allow utilities written in Python to be added -# to the functools module. -# Written by Nick Coghlan , -# Raymond Hettinger , -# and Łukasz Langa . -# Copyright (C) 2006-2013 Python Software Foundation. -# See C source code for _functools credits/copyright -# Modified from -# https://github.com/python/cpython/blob/3.12/Lib/functools.py - +""" +Python module wrapper for _functools C module +to allow utilities written in Python to be added +to the functools module. +Written by Nick Coghlan , +Raymond Hettinger , +and Łukasz Langa . + Copyright (C) 2006-2013 Python Software Foundation. +See C source code for _functools credits/copyright +Modified from +https://github.com/python/cpython/blob/3.12/Lib/functools.py +""" from typing import Callable, Generic, TypeVar @@ -18,30 +19,58 @@ PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields class Cache(Generic[K, T]): - def __init__(self, func: Callable[[K], T]): - """Create cache + """ + A generic caching mechanism that stores the results of a function call and + retrieves them to avoid repeated calculations. - :param func: Function this cache evaluates - :param cache: If true, do in memory caching. - :param cache_root: If not None, cache to files at the provided path. - :param key_func: Convert the key into a hashable object if needed - """ + This class implements a dictionary-based cache with a circular doubly linked + list to manage the cache entries efficiently. It is designed to be generic, + allowing for caching of any callable function. + + Parameters + ---------- + func : Callable[[K], T] + The function to be cached. This function takes an argument of type K and + returns a value of type T. + """ + + def __init__(self, func: Callable[[K], T]): self.func = func self.has_init = False + self.cache = False + self.cache_dict = {} + self.key_func = None + self.max_size = 0 + + self.hits, self.misses = 0, 0 + self.full = False + self.root = [] # root of the circular doubly linked list + self.root[:] = [self.root, self.root, None, None] + def __getitem__(self, obj, *args) -> T: return self.get_from_dict(obj, *args) def clear_cache(self): - """Invalidate entire cache.""" + """ + Invalidate the entire cache. + """ self.cache_dict.clear() - def _init_cache(self, obj): + def init_cache(self, obj): + """ + Initialize the cache settings. + + Parameters + ---------- + obj : Any + The object containing settings for cache initialization. + """ if self.has_init: return self.cache = True - self.cache_dict = dict() + self.cache_dict = {} self.key_func = obj.key_func self.max_size = obj.cache_size @@ -53,9 +82,23 @@ class Cache(Generic[K, T]): self.has_init = True def get_from_dict(self, obj, *args) -> T: - """Implements dict based cache.""" + """ + Retrieve a value from the cache or compute it using ``self.func``. + + Parameters + ---------- + obj : Any + The object to which the cached method/function belongs. + *args : Any + Arguments used in key generation for cache retrieval or function computation. + + Returns + ------- + T + The value from the cache or computed by the function. + """ # x is not used in cache key - pred_pseudo_label, y, x, *res_args = args + pred_pseudo_label, y, _x, *res_args = args cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args) link = self.cache_dict.get(cache_key) if link is not None: @@ -96,15 +139,23 @@ class Cache(Generic[K, T]): def abl_cache(): + """ + Decorator to enable caching for a function. + + Returns + ------- + Callable + The wrapped function with caching capability. + """ + def decorator(func): cache_instance = Cache(func) def wrapper(obj, *args): if obj.use_cache: - cache_instance._init_cache(obj) + cache_instance.init_cache(obj) return cache_instance.get_from_dict(obj, *args) - else: - return func(obj, *args) + return func(obj, *args) return wrapper diff --git a/ablkit/utils/logger.py b/ablkit/utils/logger.py index 0256cda..2f234ce 100644 --- a/ablkit/utils/logger.py +++ b/ablkit/utils/logger.py @@ -1,6 +1,8 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Modified from -# https://github.com/open-mmlab/mmengine/blob/main/mmengine/logging/logger.py +""" +Copyright (c) OpenMMLab. All rights reserved. +Modified from +https://github.com/open-mmlab/mmengine/blob/main/mmengine/logging/logger.py +""" import logging import os @@ -132,13 +134,13 @@ class ABLFormatter(logging.Formatter): Formatted result. """ if record.levelno == logging.ERROR: - self._style._fmt = self.err_format + self._style._fmt = self.err_format # pylint: disable=protected-access elif record.levelno == logging.WARNING: - self._style._fmt = self.warn_format + self._style._fmt = self.warn_format # pylint: disable=protected-access elif record.levelno == logging.INFO: - self._style._fmt = self.info_format + self._style._fmt = self.info_format # pylint: disable=protected-access elif record.levelno == logging.DEBUG: - self._style._fmt = self.debug_format + self._style._fmt = self.debug_format # pylint: disable=protected-access result = logging.Formatter.format(self, record) return result @@ -215,7 +217,7 @@ class ABLLogger(Logger, ManagerMixin): self.handlers.append(stream_handler) if log_file is None: - import time + import time # pylint: disable=import-outside-toplevel local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) @@ -234,10 +236,20 @@ class ABLLogger(Logger, ManagerMixin): @property def log_file(self): + """Get the file path of the log. + + Returns: + str: Path of the log. + """ return self._log_file @property def log_dir(self): + """Get the directory where the log is stored. + + Returns: + str: Directory where the log is stored. + """ return self._log_dir @classmethod @@ -284,11 +296,11 @@ class ABLLogger(Logger, ManagerMixin): level : Union[int, str] The logging level to set. """ - self.level = logging._checkLevel(level) + self.level = logging._checkLevel(level) # pylint: disable=protected-access _accquire_lock() # The same logic as ``logging.Manager._clear_cache``. for logger in ABLLogger._instance_dict.values(): - logger._cache.clear() + logger._cache.clear() # pylint: disable=protected-access _release_lock() diff --git a/ablkit/utils/manager.py b/ablkit/utils/manager.py index d93b784..c763041 100644 --- a/ablkit/utils/manager.py +++ b/ablkit/utils/manager.py @@ -1,4 +1,7 @@ -# Copyright (c) OpenMMLab. All rights reserved. +""" +Copyright (c) OpenMMLab. All rights reserved. +""" + import inspect import threading import warnings @@ -72,7 +75,7 @@ class ManagerMixin(metaclass=ManagerMeta): name (str): Name of the instance. Defaults to ''. """ - def __init__(self, name: str = "", **kwargs): + def __init__(self, name: str = ""): assert isinstance(name, str) and name, "name argument must be an non-empty string." self._instance_name = name diff --git a/ablkit/utils/utils.py b/ablkit/utils/utils.py index 2872747..efa2bf2 100644 --- a/ablkit/utils/utils.py +++ b/ablkit/utils/utils.py @@ -1,3 +1,9 @@ +""" +Implementation of utilities used in ablkit. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + from typing import List, Any, Union, Tuple, Optional import numpy as np @@ -198,6 +204,6 @@ def tab_data_to_tuple( return None if len(X) != len(y): raise ValueError( - "The length of X and y should be the same, but got {} and {}.".format(len(X), len(y)) + f"The length of X and y should be the same, but got {len(X)} and {len(y)}." ) return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y))