From e029674ca8caaebe3f79bb94f87a27334c8bced5 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 2 Nov 2023 10:10:59 +0800 Subject: [PATCH] [MNT] format code by black --- learnware/market/easy2/checker.py | 4 ++-- learnware/specification/__init__.py | 8 +++++++- learnware/specification/regular/text/rkme.py | 3 ++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index 1fc0fef..fa8f26c 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -91,7 +91,7 @@ class EasyStatisticalChecker(BaseChecker): if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") return self.INVALID_LEARNWARE - + def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): text_list = [] for i in range(num): @@ -106,7 +106,7 @@ class EasyStatisticalChecker(BaseChecker): else: raise ValueError("Type should be en or zh") return text_list - + if is_text: inputs = generate_random_text_list(10) else: diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 7fbf500..b27ef5b 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,9 @@ from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec from .base import Specification, BaseStatSpecification -from .regular import RegularStatsSpecification, RKMEStatSpecification, RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification +from .regular import ( + RegularStatsSpecification, + RKMEStatSpecification, + RKMETableSpecification, + RKMEImageSpecification, + RKMETextSpecification, +) diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index cc8659e..117b032 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -10,6 +10,7 @@ logger = get_module_logger("RKMETextSpecification", "INFO") class RKMETextSpecification(RKMETableSpecification): """Reduced Kernel Mean Embedding (RKME) Specification for Text""" + def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): RKMETableSpecification.__init__(self, gamma, cuda_idx) self.language = [] @@ -59,7 +60,7 @@ class RKMETextSpecification(RKMETableSpecification): @staticmethod def get_language_ids(X): try: - text = ' '.join(X) + text = " ".join(X) lang = langdetect.detect(text) langs = langdetect.detect_langs(text) return [l.lang for l in langs]