From 70953f8cc36def9e8801449a3e896de6b741dda7 Mon Sep 17 00:00:00 2001 From: liuht-0807 Date: Mon, 4 Dec 2023 20:41:15 +0800 Subject: [PATCH 1/5] [FIX] fix exceed GPU memory generating hetero_spec --- .../market/heterogeneous/organizer/hetero_map/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index 2e31195..6735879 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -11,6 +11,8 @@ from .....specification import HeteroMapTableSpecification, RKMETableSpecificati from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer from .trainer import Trainer, TransTabCollatorForCL +from loguru import logger + class HeteroMap(nn.Module): """ @@ -127,6 +129,7 @@ class HeteroMap(nn.Module): self.base_temperature = base_temperature self.num_partition = num_partition self.overlap_ratio = overlap_ratio + self.max_process_size = 20480 self.to(device) def to(self, device: Union[str, torch.device]): @@ -306,6 +309,10 @@ class HeteroMap(nn.Module): """ self.eval() output_feas_list = [] + + if eval_batch_size * x_test.shape[1] > self.max_process_size: + eval_batch_size = int(self.max_process_size / x_test.shape[1]) + for i in range(0, len(x_test), eval_batch_size): bs_x_test = x_test.iloc[i : i + eval_batch_size] with torch.no_grad(): From d5b45f5d9fcb2742242ce5df5c93cdb217ca0ad9 Mon Sep 17 00:00:00 2001 From: liuht-0807 Date: Mon, 4 Dec 2023 20:44:29 +0800 Subject: [PATCH 2/5] [FIX] delete loguru --- .../market/heterogeneous/organizer/hetero_map/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index 6735879..4ebb0a8 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -9,9 +9,7 @@ from torch import Tensor, nn from .....utils import allocate_cuda_idx, choose_device from .....specification import HeteroMapTableSpecification, RKMETableSpecification from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer -from .trainer import Trainer, TransTabCollatorForCL - -from loguru import logger +from .trainer import TransTabCollatorForCL class HeteroMap(nn.Module): From 2ff7e15248b513702379a870fb389bf006806c8a Mon Sep 17 00:00:00 2001 From: liuht Date: Wed, 6 Dec 2023 15:51:20 +0800 Subject: [PATCH 3/5] [FIX] fix wrong delete --- learnware/market/heterogeneous/organizer/hetero_map/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index 4ebb0a8..e124bcb 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -9,7 +9,7 @@ from torch import Tensor, nn from .....utils import allocate_cuda_idx, choose_device from .....specification import HeteroMapTableSpecification, RKMETableSpecification from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer -from .trainer import TransTabCollatorForCL +from .trainer import TransTabCollatorForCL, Trainer class HeteroMap(nn.Module): From a9c443aefbbc42c34923922e72515bd92ef83450 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 7 Dec 2023 21:04:39 +0800 Subject: [PATCH 4/5] [FIX] fix bugs about batch_size --- learnware/market/heterogeneous/organizer/hetero_map/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index e124bcb..aca3d02 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -309,7 +309,7 @@ class HeteroMap(nn.Module): output_feas_list = [] if eval_batch_size * x_test.shape[1] > self.max_process_size: - eval_batch_size = int(self.max_process_size / x_test.shape[1]) + eval_batch_size = max(1, self.max_process_size // x_test.shape[1]) for i in range(0, len(x_test), eval_batch_size): bs_x_test = x_test.iloc[i : i + eval_batch_size] From 25a8493ec5cffb8711568f580623e01b9b6089e6 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 7 Dec 2023 21:13:05 +0800 Subject: [PATCH 5/5] [FIX] delete extra import --- learnware/market/anchor/organizer.py | 3 +-- learnware/market/anchor/searcher.py | 3 +-- learnware/market/base.py | 27 +++++++++++-------- learnware/market/easy/database_ops.py | 2 +- .../organizer/hetero_map/__init__.py | 4 +-- .../organizer/hetero_map/feature_extractor.py | 2 +- 6 files changed, 22 insertions(+), 19 deletions(-) diff --git a/learnware/market/anchor/organizer.py b/learnware/market/anchor/organizer.py index 405060b..4c3b668 100644 --- a/learnware/market/anchor/organizer.py +++ b/learnware/market/anchor/organizer.py @@ -1,9 +1,8 @@ -from typing import List, Dict, Tuple, Any +from typing import Dict from ..easy.organizer import EasyOrganizer from ...logger import get_module_logger from ...learnware import Learnware -from ...specification import BaseStatSpecification logger = get_module_logger("anchor_organizer") diff --git a/learnware/market/anchor/searcher.py b/learnware/market/anchor/searcher.py index 34d326d..60bca9f 100644 --- a/learnware/market/anchor/searcher.py +++ b/learnware/market/anchor/searcher.py @@ -1,7 +1,6 @@ -from typing import List, Dict, Tuple, Any, Union +from typing import List, Tuple, Any from .user_info import AnchoredUserInfo -from ..base import BaseUserInfo from ..easy.searcher import EasySearcher from ...logger import get_module_logger from ...learnware import Learnware diff --git a/learnware/market/base.py b/learnware/market/base.py index 78d06e6..fce2f18 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -3,7 +3,7 @@ from __future__ import annotations import traceback import zipfile import tempfile -from typing import Tuple, Any, List, Union, Dict, Optional +from typing import Tuple, Any, List, Union, Optional from dataclasses import dataclass from ..learnware import Learnware, get_learnware_from_dirpath from ..logger import get_module_logger @@ -45,7 +45,7 @@ class BaseUserInfo: def update_semantic_spec(self, semantic_spec: dict): self.semantic_spec = semantic_spec - + def update_stat_info(self, name: str, item: Any): """Update stat_info by market @@ -64,28 +64,35 @@ class SingleSearchItem: learnware: Learnware score: Optional[float] = None + @dataclass class MultipleSearchItem: learnwares: List[Learnware] score: float - + + class SearchResults: - def __init__(self, single_results: Optional[List[SingleSearchItem]] = None, multiple_results: Optional[List[MultipleSearchItem]] = None): + def __init__( + self, + single_results: Optional[List[SingleSearchItem]] = None, + multiple_results: Optional[List[MultipleSearchItem]] = None, + ): self.update_single_results([] if single_results is None else single_results) self.update_multiple_results([] if multiple_results is None else multiple_results) - + def get_single_results(self) -> List[SingleSearchItem]: return self.single_results - + def get_multiple_results(self) -> List[MultipleSearchItem]: return self.multiple_results - + def update_single_results(self, single_results: List[SingleSearchItem]): self.single_results = single_results - + def update_multiple_results(self, multiple_results: List[MultipleSearchItem]): self.multiple_results = multiple_results + class LearnwareMarket: """Base interface for market, it provide the interface of search/add/detele/update learnwares""" @@ -179,9 +186,7 @@ class LearnwareMarket: zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs ) - def search_learnware( - self, user_info: BaseUserInfo, check_status: int = None, **kwargs - ) -> SearchResults: + def search_learnware(self, user_info: BaseUserInfo, check_status: int = None, **kwargs) -> SearchResults: """Search learnwares based on user_info from learnwares with check_status Parameters diff --git a/learnware/market/easy/database_ops.py b/learnware/market/easy/database_ops.py index 7f8e87c..e27577c 100644 --- a/learnware/market/easy/database_ops.py +++ b/learnware/market/easy/database_ops.py @@ -1,6 +1,6 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import create_engine, text -from sqlalchemy import Column, Integer, Text, DateTime, String +from sqlalchemy import Column, Text, String import os import json import traceback diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index aca3d02..b2f39fe 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -1,10 +1,10 @@ -from typing import Callable, List, Optional, Union +from typing import Callable, Union import numpy as np import pandas as pd import torch import torch.nn.functional as F -from torch import Tensor, nn +from torch import nn from .....utils import allocate_cuda_idx, choose_device from .....specification import HeteroMapTableSpecification, RKMETableSpecification diff --git a/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py b/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py index d98bd2e..325f74e 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py @@ -1,6 +1,6 @@ import math import os -from typing import Callable, Dict, List, Union +from typing import Dict, List, Union import numpy as np import pandas as pd