diff --git a/learnware/market/anchor/organizer.py b/learnware/market/anchor/organizer.py index 405060b..4c3b668 100644 --- a/learnware/market/anchor/organizer.py +++ b/learnware/market/anchor/organizer.py @@ -1,9 +1,8 @@ -from typing import List, Dict, Tuple, Any +from typing import Dict from ..easy.organizer import EasyOrganizer from ...logger import get_module_logger from ...learnware import Learnware -from ...specification import BaseStatSpecification logger = get_module_logger("anchor_organizer") diff --git a/learnware/market/anchor/searcher.py b/learnware/market/anchor/searcher.py index 34d326d..60bca9f 100644 --- a/learnware/market/anchor/searcher.py +++ b/learnware/market/anchor/searcher.py @@ -1,7 +1,6 @@ -from typing import List, Dict, Tuple, Any, Union +from typing import List, Tuple, Any from .user_info import AnchoredUserInfo -from ..base import BaseUserInfo from ..easy.searcher import EasySearcher from ...logger import get_module_logger from ...learnware import Learnware diff --git a/learnware/market/base.py b/learnware/market/base.py index 78d06e6..fce2f18 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -3,7 +3,7 @@ from __future__ import annotations import traceback import zipfile import tempfile -from typing import Tuple, Any, List, Union, Dict, Optional +from typing import Tuple, Any, List, Union, Optional from dataclasses import dataclass from ..learnware import Learnware, get_learnware_from_dirpath from ..logger import get_module_logger @@ -45,7 +45,7 @@ class BaseUserInfo: def update_semantic_spec(self, semantic_spec: dict): self.semantic_spec = semantic_spec - + def update_stat_info(self, name: str, item: Any): """Update stat_info by market @@ -64,28 +64,35 @@ class SingleSearchItem: learnware: Learnware score: Optional[float] = None + @dataclass class MultipleSearchItem: learnwares: List[Learnware] score: float - + + class SearchResults: - def __init__(self, single_results: Optional[List[SingleSearchItem]] = None, multiple_results: Optional[List[MultipleSearchItem]] = None): + def __init__( + self, + single_results: Optional[List[SingleSearchItem]] = None, + multiple_results: Optional[List[MultipleSearchItem]] = None, + ): self.update_single_results([] if single_results is None else single_results) self.update_multiple_results([] if multiple_results is None else multiple_results) - + def get_single_results(self) -> List[SingleSearchItem]: return self.single_results - + def get_multiple_results(self) -> List[MultipleSearchItem]: return self.multiple_results - + def update_single_results(self, single_results: List[SingleSearchItem]): self.single_results = single_results - + def update_multiple_results(self, multiple_results: List[MultipleSearchItem]): self.multiple_results = multiple_results + class LearnwareMarket: """Base interface for market, it provide the interface of search/add/detele/update learnwares""" @@ -179,9 +186,7 @@ class LearnwareMarket: zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs ) - def search_learnware( - self, user_info: BaseUserInfo, check_status: int = None, **kwargs - ) -> SearchResults: + def search_learnware(self, user_info: BaseUserInfo, check_status: int = None, **kwargs) -> SearchResults: """Search learnwares based on user_info from learnwares with check_status Parameters diff --git a/learnware/market/easy/database_ops.py b/learnware/market/easy/database_ops.py index 7f8e87c..e27577c 100644 --- a/learnware/market/easy/database_ops.py +++ b/learnware/market/easy/database_ops.py @@ -1,6 +1,6 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import create_engine, text -from sqlalchemy import Column, Integer, Text, DateTime, String +from sqlalchemy import Column, Text, String import os import json import traceback diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index 2e31195..b2f39fe 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -1,15 +1,15 @@ -from typing import Callable, List, Optional, Union +from typing import Callable, Union import numpy as np import pandas as pd import torch import torch.nn.functional as F -from torch import Tensor, nn +from torch import nn from .....utils import allocate_cuda_idx, choose_device from .....specification import HeteroMapTableSpecification, RKMETableSpecification from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer -from .trainer import Trainer, TransTabCollatorForCL +from .trainer import TransTabCollatorForCL, Trainer class HeteroMap(nn.Module): @@ -127,6 +127,7 @@ class HeteroMap(nn.Module): self.base_temperature = base_temperature self.num_partition = num_partition self.overlap_ratio = overlap_ratio + self.max_process_size = 20480 self.to(device) def to(self, device: Union[str, torch.device]): @@ -306,6 +307,10 @@ class HeteroMap(nn.Module): """ self.eval() output_feas_list = [] + + if eval_batch_size * x_test.shape[1] > self.max_process_size: + eval_batch_size = max(1, self.max_process_size // x_test.shape[1]) + for i in range(0, len(x_test), eval_batch_size): bs_x_test = x_test.iloc[i : i + eval_batch_size] with torch.no_grad(): diff --git a/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py b/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py index d98bd2e..325f74e 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py @@ -1,6 +1,6 @@ import math import os -from typing import Callable, Dict, List, Union +from typing import Dict, List, Union import numpy as np import pandas as pd