Browse Source

Merge pull request #37 from Learnware-LAMDA/fit_many_specification

Fit Many Specification
tags/v0.3.2
bxdd GitHub 2 years ago
parent
commit
5c5d460e78
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 110 additions and 80 deletions
  1. +0
    -0
      examples/dataset_text_workflow/example_files/example_init.py
  2. +0
    -0
      examples/dataset_text_workflow/example_files/example_yaml.yaml
  3. +0
    -0
      examples/dataset_text_workflow/get_data.py
  4. +1
    -1
      examples/dataset_text_workflow/main.py
  5. +0
    -0
      examples/dataset_text_workflow/utils.py
  6. +41
    -32
      learnware/market/easy2/checker.py
  7. +32
    -25
      learnware/market/easy2/searcher.py
  8. +11
    -0
      learnware/market/utils.py
  9. +17
    -18
      learnware/reuse/job_selector.py
  10. +3
    -3
      learnware/specification/regular/text/rkme.py
  11. +5
    -1
      tests/test_workflow/test_workflow.py

tests/test_text_workflow/example_files/example_init.py → examples/dataset_text_workflow/example_files/example_init.py View File


tests/test_text_workflow/example_files/example_yaml.yaml → examples/dataset_text_workflow/example_files/example_yaml.yaml View File


tests/test_text_workflow/get_data.py → examples/dataset_text_workflow/get_data.py View File


tests/test_text_workflow/main.py → examples/dataset_text_workflow/main.py View File

@@ -1,6 +1,6 @@
import numpy as np
import torch
from get_data import *
from get_data import get_sst2
import os
import random
from utils import generate_uploader, generate_user, TextDataLoader, train, eval_prediction

tests/test_text_workflow/utils.py → examples/dataset_text_workflow/utils.py View File


+ 41
- 32
learnware/market/easy2/checker.py View File

@@ -5,6 +5,7 @@ import random
import string

from ..base import BaseChecker
from ..utils import parse_specification_type
from ...config import C
from ...logger import get_module_logger

@@ -61,6 +62,22 @@ class EasySemanticChecker(BaseChecker):


class EasyStatChecker(BaseChecker):
@staticmethod
def _generate_random_text_list(num, text_type="en", min_len=10, max_len=1000):
text_list = []
for i in range(num):
length = random.randint(min_len, max_len)
if text_type == "en":
characters = string.ascii_letters + string.digits + string.punctuation
result_str = "".join(random.choice(characters) for i in range(length))
text_list.append(result_str)
elif text_type == "zh":
result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length))
text_list.append(result_str)
else:
raise ValueError("Type should be en or zh")
return text_list

def __call__(self, learnware):
semantic_spec = learnware.get_specification().get_semantic_spec()

@@ -76,41 +93,32 @@ class EasyStatChecker(BaseChecker):
try:
learnware_model = learnware.get_model()
# Check input shape
if semantic_spec["Data"]["Values"][0] == "Table":
input_shape = (semantic_spec["Input"]["Dimension"],)
else:
input_shape = learnware_model.input_shape
input_shape = learnware_model.input_shape

# Check rkme dimension
is_text = "RKMETextSpecification" in learnware.get_specification().stat_spec
if is_text:
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETextSpecification")
else:
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification")
if stat_spec is not None and not is_text:
## WHY: why write this?
if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != (
int(semantic_spec["Input"]["Dimension"]),
):
logger.warning("input shapes of model and semantic specifications are different")
return self.INVALID_LEARNWARE

spec_type = parse_specification_type(learnware.get_specification())
if spec_type is None:
logger.warning(f"No valid specification is found in stat spec {spec_type}")
return self.INVALID_LEARNWARE

if spec_type == "RKMETableSpecification":
stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type)
if stat_spec.get_z().shape[1:] != input_shape:
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.")
return self.INVALID_LEARNWARE

def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000):
text_list = []
for i in range(num):
length = random.randint(min_len, max_len)
if text_type == "en":
characters = string.ascii_letters + string.digits + string.punctuation
result_str = "".join(random.choice(characters) for i in range(length))
text_list.append(result_str)
elif text_type == "zh":
result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length))
text_list.append(result_str)
else:
raise ValueError("Type should be en or zh")
return text_list

if is_text:
inputs = generate_random_text_list(10)
else:
inputs = np.random.randn(10, *input_shape)
elif spec_type == "RKMETextSpecification":
inputs = EasyStatChecker._generate_random_text_list(10)
elif spec_type == "RKMEImageSpecification":
inputs = np.random.randint(0, 255, size=(10, *input_shape))
else:
raise ValueError(f"not supported spec type for spec_type = {spec_type}")
outputs = learnware.predict(inputs)
# Check output
if outputs.ndim == 1:
@@ -129,8 +137,9 @@ class EasyStatChecker(BaseChecker):
return self.INVALID_LEARNWARE

# Check output shape
output_dim = int(semantic_spec["Output"]["Dimension"])
if outputs[0].shape[0] != output_dim:
if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != int(
semantic_spec["Output"]["Dimension"]
):
logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!")
return self.INVALID_LEARNWARE



+ 32
- 25
learnware/market/easy2/searcher.py View File

@@ -5,9 +5,10 @@ from cvxopt import solvers, matrix
from typing import Tuple, List, Union

from .organizer import EasyOrganizer
from ..utils import parse_specification_type
from ..base import BaseUserInfo, BaseSearcher
from ...learnware import Learnware
from ...specification import RKMETableSpecification, RKMEImageSpecification
from ...specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification
from ...logger import get_module_logger

logger = get_module_logger("easy_seacher")
@@ -251,7 +252,7 @@ class EasyStatSearcher(BaseSearcher):
The second is the mmd dist between the mixture of learnware rkmes and the user's rkme
"""
learnware_num = len(learnware_list)
RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list]
RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list]

if type(intermediate_K) == np.ndarray:
K = intermediate_K
@@ -318,7 +319,7 @@ class EasyStatSearcher(BaseSearcher):
The second is the intermediate value of C
"""
num = intermediate_K.shape[0] - 1
RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list]
RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list]
for i in range(intermediate_K.shape[0]):
intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i])
intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1])
@@ -373,7 +374,7 @@ class EasyStatSearcher(BaseSearcher):
if len(mixture_list) <= 1:
mixture_list = [learnware_list[sort_by_weight_idx_list[0]]]
mixture_weight = [1]
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_info_name))
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_spec_type))
else:
if len(mixture_list) > max_search_num:
mixture_list = mixture_list[:max_search_num]
@@ -414,16 +415,18 @@ class EasyStatSearcher(BaseSearcher):
idx = idx + 1
return sorted_score_list[:idx], learnware_list[:idx]

def _filter_by_rkme_spec_dimension(
self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification]
def _filter_by_rkme_spec_metadata(
self,
learnware_list: List[Learnware],
user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification],
) -> List[Learnware]:
"""Filter learnwares whose rkme dimension different from user_rkme
"""Filter learnwares whose rkme metadata different from user_rkme

Parameters
----------
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
user_rkme : Union[RKMETableSpecification, RKMEImageSpecification]
user_rkme : Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification]
User RKME statistical specification

Returns
@@ -435,12 +438,15 @@ class EasyStatSearcher(BaseSearcher):
user_rkme_dim = str(list(user_rkme.get_z().shape)[1:])

for learnware in learnware_list:
if self.stat_info_name not in learnware.specification.stat_spec:
if self.stat_spec_type not in learnware.specification.stat_spec:
continue
rkme = learnware.specification.get_stat_spec_by_name(self.stat_spec_type)
if self.stat_spec_type == "RKMETextSpecification" and not set(user_rkme.language).issubset(
set(rkme.language)
):
continue
rkme = learnware.specification.get_stat_spec_by_name(self.stat_info_name)
if self.stat_info_name == "RKMETextSpecification":
if not set(user_rkme.language).issubset(set(rkme.language)):
continue

# TODO: must we check dim for Text and Image specification?
rkme_dim = str(list(rkme.get_z().shape)[1:])
if rkme_dim == user_rkme_dim:
filtered_learnware_list.append(learnware)
@@ -520,7 +526,9 @@ class EasyStatSearcher(BaseSearcher):
return mmd_dist, weight_min, mixture_list

def _search_by_rkme_spec_single(
self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification]
self,
learnware_list: List[Learnware],
user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification],
) -> Tuple[List[float], List[Learnware]]:
"""Calculate the distances between learnwares in the given learnware_list and user_rkme

@@ -528,7 +536,7 @@ class EasyStatSearcher(BaseSearcher):
----------
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
user_rkme : Union[RKMETableSpecification, RKMEImageSpecification]
user_rkme : Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification]
user RKME statistical specification

Returns
@@ -538,7 +546,7 @@ class EasyStatSearcher(BaseSearcher):
the second is the list of Learnware
both lists are sorted by mmd dist
"""
RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list]
RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list]
mmd_dist_list = []
for RKME in RKME_list:
mmd_dist = RKME.dist(user_rkme)
@@ -557,12 +565,12 @@ class EasyStatSearcher(BaseSearcher):
max_search_num: int = 5,
search_method: str = "greedy",
) -> Tuple[List[float], List[Learnware], float, List[Learnware]]:
if "RKMETextSpecification" in user_info.stat_info:
self.stat_info_name = "RKMETextSpecification"
else:
self.stat_info_name = "RKMETableSpecification"
user_rkme = user_info.stat_info[self.stat_info_name]
learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme)
self.stat_spec_type = parse_specification_type(stat_spec=user_info.stat_info)
if self.stat_spec_type is None:
raise KeyError("No supported stat specification is given in the user info")
user_rkme = user_info.stat_info[self.stat_spec_type]
learnware_list = self._filter_by_rkme_spec_metadata(learnware_list, user_rkme)
logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}")

sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme)
@@ -637,9 +645,8 @@ class EasySearcher(BaseSearcher):

if len(learnware_list) == 0:
return [], [], 0.0, []
elif "RKMETableSpecification" in user_info.stat_info:
return self.stat_searcher(learnware_list, user_info, max_search_num, search_method)
elif "RKMETextSpecification" in user_info.stat_info:

if parse_specification_type(stat_spec=user_info.stat_info) is not None:
return self.stat_searcher(learnware_list, user_info, max_search_num, search_method)
else:
return None, learnware_list, 0.0, None

+ 11
- 0
learnware/market/utils.py View File

@@ -0,0 +1,11 @@
from ..specification import Specification


def parse_specification_type(
stat_spec: Specification, spec_list=["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"]
):
stat_specs = stat_spec.stat_spec
for spec in spec_list:
if spec in stat_specs:
return spec
return None

+ 17
- 18
learnware/reuse/job_selector.py View File

@@ -1,15 +1,17 @@
import torch
import numpy as np

from typing import List
from typing import List, Union
from cvxopt import matrix, solvers
from lightgbm import LGBMClassifier, early_stopping
from sklearn.metrics import accuracy_score

from learnware.learnware import Learnware
import learnware.specification as specification

from .base import BaseReuser
from ..market.utils import parse_specification_type
from ..learnware import Learnware
from ..specification import RKMETableSpecification, RKMETextSpecification
from ..specification.utils import generate_rkme_spec
from ..logger import get_module_logger

logger = get_module_logger("job_selector_reuse")
@@ -32,7 +34,7 @@ class JobSelectorReuser(BaseReuser):
self.herding_num = herding_num
self.use_herding = use_herding

def predict(self, user_data: np.ndarray) -> np.ndarray:
def predict(self, user_data: Union[np.ndarray, List[str]]) -> np.ndarray:
"""Give prediction for user data using baseline job-selector method

Parameters
@@ -41,12 +43,16 @@ class JobSelectorReuser(BaseReuser):
User's unlabeled raw data.

Returns
-------
------
np.ndarray
Prediction given by job-selector method
"""
ori_user_data = user_data
raw_user_data = user_data
if isinstance(user_data[0], str):
stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification())
assert (
stat_spec_type == "RKMETextSpecification"
), "stat_spec_type must be 'RKMETextSpecification' when user data is the List of string."
user_data = RKMETextSpecification.get_sentence_embedding(user_data)

select_result = self.job_selector(user_data)
@@ -56,8 +62,8 @@ class JobSelectorReuser(BaseReuser):
for idx in range(len(self.learnware_list)):
data_idx_list = np.where(select_result == idx)[0]
if len(data_idx_list) > 0:
# pred_y = self.learnware_list[idx].predict(ori_user_data[data_idx_list])
pred_y = self.learnware_list[idx].predict([ori_user_data[i] for i in data_idx_list])
# pred_y = self.learnware_list[idx].predict(raw_user_data[data_idx_list])
pred_y = self.learnware_list[idx].predict([raw_user_data[i] for i in data_idx_list])
if isinstance(pred_y, torch.Tensor):
pred_y = pred_y.detach().cpu().numpy()
# elif isinstance(pred_y, tf.Tensor):
@@ -91,14 +97,9 @@ class JobSelectorReuser(BaseReuser):
user_data_num = len(user_data)
return np.array([0] * user_data_num)
else:
ori_user_data = user_data
if isinstance(user_data[0], str):
user_data = RKMETextSpecification.get_sentence_embedding(user_data)
spec_name = "RKMETableSpecification"
if len(self.learnware_list) and "RKMETextSpecification" in self.learnware_list[0].specification.stat_spec:
spec_name = "RKMETextSpecification"
stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification())
learnware_rkme_spec_list = [
learnware.specification.get_stat_spec_by_name(spec_name) for learnware in self.learnware_list
learnware.specification.get_stat_spec_by_name(stat_spec_type) for learnware in self.learnware_list
]

if self.use_herding:
@@ -179,9 +180,7 @@ class JobSelectorReuser(BaseReuser):
Inner product matrix calculated from task_rkme_list.
"""
task_num = len(task_rkme_list)
if isinstance(user_data[0], str):
user_data = RKMETextSpecification.get_sentence_embedding(user_data)
user_rkme_spec = specification.utils.generate_rkme_spec(X=user_data, reduce=False)
user_rkme_spec = generate_rkme_spec(X=user_data, reduce=False)
K = task_rkme_matrix
v = np.array([user_rkme_spec.inner_prod(task_rkme) for task_rkme in task_rkme_list])



+ 3
- 3
learnware/specification/regular/text/rkme.py View File

@@ -1,8 +1,8 @@
from sentence_transformers import SentenceTransformer
from ..table import RKMETableSpecification
import numpy as np
import os
import langdetect
import numpy as np
from sentence_transformers import SentenceTransformer
from ..table import RKMETableSpecification
from ....logger import get_module_logger

logger = get_module_logger("RKMETextSpecification", "INFO")


+ 5
- 1
tests/test_workflow/test_workflow.py View File

@@ -18,7 +18,7 @@ import learnware.specification as specification
curr_root = os.path.dirname(os.path.abspath(__file__))

user_semantic = {
"Data": {"Values": ["Image"], "Type": "Class"},
"Data": {"Values": ["Table"], "Type": "Class"},
"Task": {
"Values": ["Classification"],
"Type": "Class",
@@ -96,6 +96,10 @@ class TestAllWorkflow(unittest.TestCase):
semantic_spec = copy.deepcopy(user_semantic)
semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
semantic_spec["Input"] = {"Dimension": 64}
semantic_spec["Input"].update(
{f"{i}": f"The value in the digit image with row is {i // 8} and col is {i % 8}." for i in range(64)}
)
semantic_spec["Output"] = {"Dimension": 1, "Description": {"0": "The label of the hand-written digit."}}
easy_market.add_learnware(zip_path, semantic_spec)



Loading…
Cancel
Save