Browse Source

Merge pull request #132 from Learnware-LAMDA/fix_hetero_spec_gen

[FIX] fix exceed GPU memory generating hetero_spec
tags/v0.3.2
Gene GitHub 2 years ago
parent
commit
0ff68a1f03
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 28 additions and 20 deletions
  1. +1
    -2
      learnware/market/anchor/organizer.py
  2. +1
    -2
      learnware/market/anchor/searcher.py
  3. +16
    -11
      learnware/market/base.py
  4. +1
    -1
      learnware/market/easy/database_ops.py
  5. +8
    -3
      learnware/market/heterogeneous/organizer/hetero_map/__init__.py
  6. +1
    -1
      learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py

+ 1
- 2
learnware/market/anchor/organizer.py View File

@@ -1,9 +1,8 @@
from typing import List, Dict, Tuple, Any
from typing import Dict

from ..easy.organizer import EasyOrganizer
from ...logger import get_module_logger
from ...learnware import Learnware
from ...specification import BaseStatSpecification

logger = get_module_logger("anchor_organizer")



+ 1
- 2
learnware/market/anchor/searcher.py View File

@@ -1,7 +1,6 @@
from typing import List, Dict, Tuple, Any, Union
from typing import List, Tuple, Any

from .user_info import AnchoredUserInfo
from ..base import BaseUserInfo
from ..easy.searcher import EasySearcher
from ...logger import get_module_logger
from ...learnware import Learnware


+ 16
- 11
learnware/market/base.py View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import traceback
import zipfile
import tempfile
from typing import Tuple, Any, List, Union, Dict, Optional
from typing import Tuple, Any, List, Union, Optional
from dataclasses import dataclass
from ..learnware import Learnware, get_learnware_from_dirpath
from ..logger import get_module_logger
@@ -45,7 +45,7 @@ class BaseUserInfo:

def update_semantic_spec(self, semantic_spec: dict):
self.semantic_spec = semantic_spec
def update_stat_info(self, name: str, item: Any):
"""Update stat_info by market

@@ -64,28 +64,35 @@ class SingleSearchItem:
learnware: Learnware
score: Optional[float] = None


@dataclass
class MultipleSearchItem:
learnwares: List[Learnware]
score: float


class SearchResults:
def __init__(self, single_results: Optional[List[SingleSearchItem]] = None, multiple_results: Optional[List[MultipleSearchItem]] = None):
def __init__(
self,
single_results: Optional[List[SingleSearchItem]] = None,
multiple_results: Optional[List[MultipleSearchItem]] = None,
):
self.update_single_results([] if single_results is None else single_results)
self.update_multiple_results([] if multiple_results is None else multiple_results)
def get_single_results(self) -> List[SingleSearchItem]:
return self.single_results
def get_multiple_results(self) -> List[MultipleSearchItem]:
return self.multiple_results
def update_single_results(self, single_results: List[SingleSearchItem]):
self.single_results = single_results
def update_multiple_results(self, multiple_results: List[MultipleSearchItem]):
self.multiple_results = multiple_results


class LearnwareMarket:
"""Base interface for market, it provide the interface of search/add/detele/update learnwares"""

@@ -179,9 +186,7 @@ class LearnwareMarket:
zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
)

def search_learnware(
self, user_info: BaseUserInfo, check_status: int = None, **kwargs
) -> SearchResults:
def search_learnware(self, user_info: BaseUserInfo, check_status: int = None, **kwargs) -> SearchResults:
"""Search learnwares based on user_info from learnwares with check_status

Parameters


+ 1
- 1
learnware/market/easy/database_ops.py View File

@@ -1,6 +1,6 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import create_engine, text
from sqlalchemy import Column, Integer, Text, DateTime, String
from sqlalchemy import Column, Text, String
import os
import json
import traceback


+ 8
- 3
learnware/market/heterogeneous/organizer/hetero_map/__init__.py View File

@@ -1,15 +1,15 @@
from typing import Callable, List, Optional, Union
from typing import Callable, Union

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch import nn

from .....utils import allocate_cuda_idx, choose_device
from .....specification import HeteroMapTableSpecification, RKMETableSpecification
from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer
from .trainer import Trainer, TransTabCollatorForCL
from .trainer import TransTabCollatorForCL, Trainer


class HeteroMap(nn.Module):
@@ -127,6 +127,7 @@ class HeteroMap(nn.Module):
self.base_temperature = base_temperature
self.num_partition = num_partition
self.overlap_ratio = overlap_ratio
self.max_process_size = 20480
self.to(device)

def to(self, device: Union[str, torch.device]):
@@ -306,6 +307,10 @@ class HeteroMap(nn.Module):
"""
self.eval()
output_feas_list = []

if eval_batch_size * x_test.shape[1] > self.max_process_size:
eval_batch_size = max(1, self.max_process_size // x_test.shape[1])

for i in range(0, len(x_test), eval_batch_size):
bs_x_test = x_test.iloc[i : i + eval_batch_size]
with torch.no_grad():


+ 1
- 1
learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py View File

@@ -1,6 +1,6 @@
import math
import os
from typing import Callable, Dict, List, Union
from typing import Dict, List, Union

import numpy as np
import pandas as pd


Loading…
Cancel
Save