Browse Source

Merge pull request #132 from Learnware-LAMDA/fix_hetero_spec_gen

[FIX] fix exceed GPU memory generating hetero_spec
tags/v0.3.2
Gene GitHub 2 years ago
parent
commit
0ff68a1f03
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 28 additions and 20 deletions
  1. +1
    -2
      learnware/market/anchor/organizer.py
  2. +1
    -2
      learnware/market/anchor/searcher.py
  3. +16
    -11
      learnware/market/base.py
  4. +1
    -1
      learnware/market/easy/database_ops.py
  5. +8
    -3
      learnware/market/heterogeneous/organizer/hetero_map/__init__.py
  6. +1
    -1
      learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py

+ 1
- 2
learnware/market/anchor/organizer.py View File

@@ -1,9 +1,8 @@
from typing import List, Dict, Tuple, Any
from typing import Dict


from ..easy.organizer import EasyOrganizer from ..easy.organizer import EasyOrganizer
from ...logger import get_module_logger from ...logger import get_module_logger
from ...learnware import Learnware from ...learnware import Learnware
from ...specification import BaseStatSpecification


logger = get_module_logger("anchor_organizer") logger = get_module_logger("anchor_organizer")




+ 1
- 2
learnware/market/anchor/searcher.py View File

@@ -1,7 +1,6 @@
from typing import List, Dict, Tuple, Any, Union
from typing import List, Tuple, Any


from .user_info import AnchoredUserInfo from .user_info import AnchoredUserInfo
from ..base import BaseUserInfo
from ..easy.searcher import EasySearcher from ..easy.searcher import EasySearcher
from ...logger import get_module_logger from ...logger import get_module_logger
from ...learnware import Learnware from ...learnware import Learnware


+ 16
- 11
learnware/market/base.py View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import traceback import traceback
import zipfile import zipfile
import tempfile import tempfile
from typing import Tuple, Any, List, Union, Dict, Optional
from typing import Tuple, Any, List, Union, Optional
from dataclasses import dataclass from dataclasses import dataclass
from ..learnware import Learnware, get_learnware_from_dirpath from ..learnware import Learnware, get_learnware_from_dirpath
from ..logger import get_module_logger from ..logger import get_module_logger
@@ -45,7 +45,7 @@ class BaseUserInfo:


def update_semantic_spec(self, semantic_spec: dict): def update_semantic_spec(self, semantic_spec: dict):
self.semantic_spec = semantic_spec self.semantic_spec = semantic_spec
def update_stat_info(self, name: str, item: Any): def update_stat_info(self, name: str, item: Any):
"""Update stat_info by market """Update stat_info by market


@@ -64,28 +64,35 @@ class SingleSearchItem:
learnware: Learnware learnware: Learnware
score: Optional[float] = None score: Optional[float] = None



@dataclass @dataclass
class MultipleSearchItem: class MultipleSearchItem:
learnwares: List[Learnware] learnwares: List[Learnware]
score: float score: float


class SearchResults: class SearchResults:
def __init__(self, single_results: Optional[List[SingleSearchItem]] = None, multiple_results: Optional[List[MultipleSearchItem]] = None):
def __init__(
self,
single_results: Optional[List[SingleSearchItem]] = None,
multiple_results: Optional[List[MultipleSearchItem]] = None,
):
self.update_single_results([] if single_results is None else single_results) self.update_single_results([] if single_results is None else single_results)
self.update_multiple_results([] if multiple_results is None else multiple_results) self.update_multiple_results([] if multiple_results is None else multiple_results)
def get_single_results(self) -> List[SingleSearchItem]: def get_single_results(self) -> List[SingleSearchItem]:
return self.single_results return self.single_results
def get_multiple_results(self) -> List[MultipleSearchItem]: def get_multiple_results(self) -> List[MultipleSearchItem]:
return self.multiple_results return self.multiple_results
def update_single_results(self, single_results: List[SingleSearchItem]): def update_single_results(self, single_results: List[SingleSearchItem]):
self.single_results = single_results self.single_results = single_results
def update_multiple_results(self, multiple_results: List[MultipleSearchItem]): def update_multiple_results(self, multiple_results: List[MultipleSearchItem]):
self.multiple_results = multiple_results self.multiple_results = multiple_results



class LearnwareMarket: class LearnwareMarket:
"""Base interface for market, it provide the interface of search/add/detele/update learnwares""" """Base interface for market, it provide the interface of search/add/detele/update learnwares"""


@@ -179,9 +186,7 @@ class LearnwareMarket:
zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
) )


def search_learnware(
self, user_info: BaseUserInfo, check_status: int = None, **kwargs
) -> SearchResults:
def search_learnware(self, user_info: BaseUserInfo, check_status: int = None, **kwargs) -> SearchResults:
"""Search learnwares based on user_info from learnwares with check_status """Search learnwares based on user_info from learnwares with check_status


Parameters Parameters


+ 1
- 1
learnware/market/easy/database_ops.py View File

@@ -1,6 +1,6 @@
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text
from sqlalchemy import Column, Integer, Text, DateTime, String
from sqlalchemy import Column, Text, String
import os import os
import json import json
import traceback import traceback


+ 8
- 3
learnware/market/heterogeneous/organizer/hetero_map/__init__.py View File

@@ -1,15 +1,15 @@
from typing import Callable, List, Optional, Union
from typing import Callable, Union


import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, nn
from torch import nn


from .....utils import allocate_cuda_idx, choose_device from .....utils import allocate_cuda_idx, choose_device
from .....specification import HeteroMapTableSpecification, RKMETableSpecification from .....specification import HeteroMapTableSpecification, RKMETableSpecification
from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer
from .trainer import Trainer, TransTabCollatorForCL
from .trainer import TransTabCollatorForCL, Trainer




class HeteroMap(nn.Module): class HeteroMap(nn.Module):
@@ -127,6 +127,7 @@ class HeteroMap(nn.Module):
self.base_temperature = base_temperature self.base_temperature = base_temperature
self.num_partition = num_partition self.num_partition = num_partition
self.overlap_ratio = overlap_ratio self.overlap_ratio = overlap_ratio
self.max_process_size = 20480
self.to(device) self.to(device)


def to(self, device: Union[str, torch.device]): def to(self, device: Union[str, torch.device]):
@@ -306,6 +307,10 @@ class HeteroMap(nn.Module):
""" """
self.eval() self.eval()
output_feas_list = [] output_feas_list = []

if eval_batch_size * x_test.shape[1] > self.max_process_size:
eval_batch_size = max(1, self.max_process_size // x_test.shape[1])

for i in range(0, len(x_test), eval_batch_size): for i in range(0, len(x_test), eval_batch_size):
bs_x_test = x_test.iloc[i : i + eval_batch_size] bs_x_test = x_test.iloc[i : i + eval_batch_size]
with torch.no_grad(): with torch.no_grad():


+ 1
- 1
learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py View File

@@ -1,6 +1,6 @@
import math import math
import os import os
from typing import Callable, Dict, List, Union
from typing import Dict, List, Union


import numpy as np import numpy as np
import pandas as pd import pandas as pd


Loading…
Cancel
Save