diff --git a/examples/dataset_pfs_workflow/main.py b/examples/dataset_pfs_workflow/main.py index 48ff7d0..b5cbdd8 100644 --- a/examples/dataset_pfs_workflow/main.py +++ b/examples/dataset_pfs_workflow/main.py @@ -15,7 +15,7 @@ import learnware.specification as specification from pfs import Dataloader from learnware.logger import get_module_logger -logger = get_module_logger("m5_test", level="INFO") +logger = get_module_logger("pfs_test", level="INFO") semantic_specs = [ diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index fea3e9c..1fc0fef 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -1,6 +1,8 @@ import traceback import numpy as np import torch +import random +import string from ..base import BaseChecker from ...config import C @@ -73,7 +75,6 @@ class EasyStatisticalChecker(BaseChecker): try: learnware_model = learnware.get_model() - # Check input shape if semantic_spec["Data"]["Values"][0] == "Table": input_shape = (semantic_spec["Input"]["Dimension"],) @@ -81,15 +82,36 @@ class EasyStatisticalChecker(BaseChecker): input_shape = learnware_model.input_shape # Check rkme dimension - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") - if stat_spec is not None: + is_text = "RKMETextSpecification" in learnware.get_specification().stat_spec + if is_text: + stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETextSpecification") + else: + stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") + if stat_spec is not None and not is_text: if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") return self.INVALID_LEARNWARE - - inputs = np.random.randn(10, *input_shape) + + def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): + text_list = [] + for i in range(num): + length = random.randint(min_len, max_len) + if text_type == "en": + characters = string.ascii_letters + string.digits + string.punctuation + result_str = "".join(random.choice(characters) for i in range(length)) + text_list.append(result_str) + elif text_type == "zh": + result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) + text_list.append(result_str) + else: + raise ValueError("Type should be en or zh") + return text_list + + if is_text: + inputs = generate_random_text_list(10) + else: + inputs = np.random.randn(10, *input_shape) outputs = learnware.predict(inputs) - # Check output if outputs.ndim == 1: outputs = outputs.reshape(-1, 1) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index aa741e3..8934fda 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -251,9 +251,7 @@ class EasyStatSearcher(BaseSearcher): The second is the mmd dist between the mixture of learnware rkmes and the user's rkme """ learnware_num = len(learnware_list) - RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list - ] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] if type(intermediate_K) == np.ndarray: K = intermediate_K @@ -320,9 +318,7 @@ class EasyStatSearcher(BaseSearcher): The second is the intermediate value of C """ num = intermediate_K.shape[0] - 1 - RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list - ] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] for i in range(intermediate_K.shape[0]): intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1]) @@ -377,7 +373,7 @@ class EasyStatSearcher(BaseSearcher): if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] - mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification")) + mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_info_name)) else: if len(mixture_list) > max_search_num: mixture_list = mixture_list[:max_search_num] @@ -439,7 +435,12 @@ class EasyStatSearcher(BaseSearcher): user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) for learnware in learnware_list: - rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification") + if self.stat_info_name not in learnware.specification.stat_spec: + continue + rkme = learnware.specification.get_stat_spec_by_name(self.stat_info_name) + if self.stat_info_name == "RKMETextSpecification": + if not set(user_rkme.language).issubset(set(rkme.language)): + continue rkme_dim = str(list(rkme.get_z().shape)[1:]) if rkme_dim == user_rkme_dim: filtered_learnware_list.append(learnware) @@ -537,9 +538,7 @@ class EasyStatSearcher(BaseSearcher): the second is the list of Learnware both lists are sorted by mmd dist """ - RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list - ] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] mmd_dist_list = [] for RKME in RKME_list: mmd_dist = RKME.dist(user_rkme) @@ -558,7 +557,11 @@ class EasyStatSearcher(BaseSearcher): max_search_num: int = 5, search_method: str = "greedy", ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: - user_rkme = user_info.stat_info["RKMETableSpecification"] + if "RKMETextSpecification" in user_info.stat_info: + self.stat_info_name = "RKMETextSpecification" + else: + self.stat_info_name = "RKMETableSpecification" + user_rkme = user_info.stat_info[self.stat_info_name] learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") @@ -633,5 +636,7 @@ class EasySearcher(BaseSearcher): return [], [], 0.0, [] elif "RKMETableSpecification" in user_info.stat_info: return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) + elif "RKMETextSpecification" in user_info.stat_info: + return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: return None, learnware_list, 0.0, None diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index 21745fe..9de299b 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -9,7 +9,7 @@ from sklearn.metrics import accuracy_score from learnware.learnware import Learnware import learnware.specification as specification from .base import BaseReuser -from ..specification import RKMETableSpecification +from ..specification import RKMETableSpecification, RKMETextSpecification from ..logger import get_module_logger logger = get_module_logger("job_selector_reuse") @@ -45,6 +45,10 @@ class JobSelectorReuser(BaseReuser): np.ndarray Prediction given by job-selector method """ + ori_user_data = user_data + if isinstance(user_data[0], str): + user_data = RKMETextSpecification.get_sentence_embedding(user_data) + select_result = self.job_selector(user_data) pred_y_list = [] data_idxs_list = [] @@ -52,7 +56,8 @@ class JobSelectorReuser(BaseReuser): for idx in range(len(self.learnware_list)): data_idx_list = np.where(select_result == idx)[0] if len(data_idx_list) > 0: - pred_y = self.learnware_list[idx].predict(user_data[data_idx_list]) + # pred_y = self.learnware_list[idx].predict(ori_user_data[data_idx_list]) + pred_y = self.learnware_list[idx].predict([ori_user_data[i] for i in data_idx_list]) if isinstance(pred_y, torch.Tensor): pred_y = pred_y.detach().cpu().numpy() # elif isinstance(pred_y, tf.Tensor): @@ -82,12 +87,18 @@ class JobSelectorReuser(BaseReuser): User's labeled raw data. """ if len(self.learnware_list) == 1: - user_data_num = user_data.shape[0] + # user_data_num = user_data.shape[0] + user_data_num = len(user_data) return np.array([0] * user_data_num) else: + ori_user_data = user_data + if isinstance(user_data[0], str): + user_data = RKMETextSpecification.get_sentence_embedding(user_data) + spec_name = "RKMETableSpecification" + if len(self.learnware_list) and "RKMETextSpecification" in self.learnware_list[0].specification.stat_spec: + spec_name = "RKMETextSpecification" learnware_rkme_spec_list = [ - learnware.specification.get_stat_spec_by_name("RKMETableSpecification") - for learnware in self.learnware_list + learnware.specification.get_stat_spec_by_name(spec_name) for learnware in self.learnware_list ] if self.use_herding: @@ -168,6 +179,8 @@ class JobSelectorReuser(BaseReuser): Inner product matrix calculated from task_rkme_list. """ task_num = len(task_rkme_list) + if isinstance(user_data[0], str): + user_data = RKMETextSpecification.get_sentence_embedding(user_data) user_rkme_spec = specification.utils.generate_rkme_spec(X=user_data, reduce=False) K = task_rkme_matrix v = np.array([user_rkme_spec.inner_prod(task_rkme) for task_rkme in task_rkme_list]) @@ -224,11 +237,7 @@ class JobSelectorReuser(BaseReuser): max_depth = [66] params = (0, 0) - lgb_params = { - "boosting_type": "gbdt", - "n_estimators": 2000, - "boost_from_average": False, - } + lgb_params = {"boosting_type": "gbdt", "n_estimators": 2000, "boost_from_average": False, "verbose": -1} if num_class == 2: lgb_params["objective"] = "binary" diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 54dae1f..7fbf500 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,3 @@ from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec from .base import Specification, BaseStatSpecification -from .regular import RegularStatsSpecification, RKMEStatSpecification, RKMETableSpecification, RKMEImageSpecification +from .regular import RegularStatsSpecification, RKMEStatSpecification, RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification diff --git a/learnware/specification/regular/__init__.py b/learnware/specification/regular/__init__.py index eeb4b3f..9007e4d 100644 --- a/learnware/specification/regular/__init__.py +++ b/learnware/specification/regular/__init__.py @@ -1,3 +1,4 @@ +from .text import RKMETextSpecification from .table import RKMETableSpecification, RKMEStatSpecification from .image import RKMEImageSpecification from .base import RegularStatsSpecification diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index 1f05382..4421f91 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -17,11 +17,11 @@ from torchvision.transforms import Resize from tqdm import tqdm from . import cnn_gp -from ..base import BaseStatSpecification +from ..base import RegularStatsSpecification from ..table.rkme import solve_qp, choose_device, setup_seed -class RKMEImageSpecification(BaseStatSpecification): +class RKMEImageSpecification(RegularStatsSpecification): # INNER_PRODUCT_COUNT = 0 IMAGE_WIDTH = 32 diff --git a/learnware/specification/regular/text/__init__.py b/learnware/specification/regular/text/__init__.py new file mode 100644 index 0000000..35b8b0a --- /dev/null +++ b/learnware/specification/regular/text/__init__.py @@ -0,0 +1 @@ +from .rkme import RKMETextSpecification diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py new file mode 100644 index 0000000..cc8659e --- /dev/null +++ b/learnware/specification/regular/text/rkme.py @@ -0,0 +1,77 @@ +from sentence_transformers import SentenceTransformer +from ..table import RKMETableSpecification +import numpy as np +import os +import langdetect +from ....logger import get_module_logger + +logger = get_module_logger("RKMETextSpecification", "INFO") + + +class RKMETextSpecification(RKMETableSpecification): + """Reduced Kernel Mean Embedding (RKME) Specification for Text""" + def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): + RKMETableSpecification.__init__(self, gamma, cuda_idx) + self.language = [] + + def generate_stat_spec_from_data( + self, + X: list, + K: int = 100, + step_size: float = 0.1, + steps: int = 3, + nonnegative_beta: bool = True, + reduce: bool = True, + ): + """Construct reduced set from raw dataset using iterative optimization. + + Parameters + ---------- + X : np.ndarray or torch.tensor + Raw data in np.ndarray format. + K : int + Size of the construced reduced set. + step_size : float + Step size for gradient descent in the iterative optimization. + steps : int + Total rounds in the iterative optimization. + nonnegative_beta : bool, optional + True if weights for the reduced set are intended to be kept non-negative, by default False. + reduce : bool, optional + Whether shrink original data to a smaller set, by default True + """ + + # Sentence embedding for Text + self.language = self.get_language_ids(X) + logger.info("The text learnware's language: %s" % (self.language)) + X = self.get_sentence_embedding(X) + + # Generate specification + return super().generate_stat_spec_from_data( + X, + K, + step_size, + steps, + nonnegative_beta, + reduce, + ) + + @staticmethod + def get_language_ids(X): + try: + text = ' '.join(X) + lang = langdetect.detect(text) + langs = langdetect.detect_langs(text) + return [l.lang for l in langs] + except Exception as e: + logger.warning("Language detection failed.") + return [] + + @staticmethod + def get_sentence_embedding(X): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") + X = model.encode(X) + X = np.array(X) + # X /= np.sqrt(np.sum(X ** 2, axis=1)).reshape((-1, 1)) + return X diff --git a/learnware/specification/utils.py b/learnware/specification/utils.py index 91fe226..09d66c1 100644 --- a/learnware/specification/utils.py +++ b/learnware/specification/utils.py @@ -1,10 +1,10 @@ import torch import numpy as np import pandas as pd -from typing import Union +from typing import Union, List from .base import BaseStatSpecification -from .regular import RKMETableSpecification, RKMEImageSpecification +from .regular import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification from ..config import C @@ -164,6 +164,63 @@ def generate_rkme_image_spec( return rkme_image_spec +def generate_rkme_text_spec( + X: List[str], + gamma: float = 0.1, + reduced_set_size: int = 100, + step_size: float = 0.1, + steps: int = 3, + nonnegative_beta: bool = True, + reduce: bool = True, + cuda_idx: int = None, +) -> RKMETextSpecification: + """ + Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification for Text. + Return a RKMETextSpecification object, use .save() method to save as json file. + + Parameters + ---------- + X : List[str] + Raw data of text. + gamma : float + Bandwidth in gaussian kernel, by default 0.1. + reduced_set_size : int + Size of the construced reduced set. + step_size : float + Step size for gradient descent in the iterative optimization. + steps : int + Total rounds in the iterative optimization. + nonnegative_beta : bool, optional + True if weights for the reduced set are intended to be kept non-negative, by default False. + reduce : bool, optional + Whether shrink original data to a smaller set, by default True + cuda_idx : int + A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used. + None indicates that CUDA is automatically selected. + + Returns + ------- + RKMETextSpecification + A RKMETextSpecification object + """ + # Check input type + if not isinstance(X, list) or not all(isinstance(item, str) for item in X): + raise TypeError("Input data must be a list of strings.") + + # Check cuda_idx + if not torch.cuda.is_available() or cuda_idx == -1: + cuda_idx = -1 + else: + num_cuda_devices = torch.cuda.device_count() + if cuda_idx is None or not (cuda_idx >= 0 and cuda_idx < num_cuda_devices): + cuda_idx = 0 + + # Generate rkme text spec + rkme_text_spec = RKMETextSpecification(gamma=gamma, cuda_idx=cuda_idx) + rkme_text_spec.generate_stat_spec_from_data(X, reduced_set_size, step_size, steps, nonnegative_beta, reduce) + return rkme_text_spec + + def generate_stat_spec(X: np.ndarray) -> BaseStatSpecification: """ Interface for users to generate statistical specification. diff --git a/setup.py b/setup.py index 80e48bd..1e8ea72 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,10 @@ REQUIRED = [ "geatpy>=2.7.0", "docker>=6.1.3", "rapidfuzz>=3.4.0", + "torchtext>=0.16.0", + "sentence_transformers>=2.2.2", "torch-optimizer>=0.3.0", + "langdetect>=1.0.9", ] if get_platform() != MACOS: diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index c77e654..143bf22 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -1,11 +1,14 @@ import os import json +import string +import random import torch import unittest import tempfile import numpy as np -from learnware.specification import RKMETableSpecification, RKMEImageSpecification +import learnware.specification as specification +from learnware.specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification from learnware.specification import generate_rkme_image_spec, generate_rkme_spec @@ -51,6 +54,46 @@ class TestRKME(unittest.TestCase): _test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128))) _test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128)) / 255) + def test_text_rkme(self): + def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): + text_list = [] + for i in range(num): + length = random.randint(min_len, max_len) + if text_type == "en": + characters = string.ascii_letters + string.digits + string.punctuation + result_str = "".join(random.choice(characters) for i in range(length)) + text_list.append(result_str) + elif text_type == "zh": + result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) + text_list.append(result_str) + else: + raise ValueError("Type should be en or zh") + return text_list + + def _test_text_rkme(X): + rkme = specification.utils.generate_rkme_text_spec(X) + + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + rkme_path = os.path.join(tempdir, "rkme.json") + rkme.save(rkme_path) + + with open(rkme_path, "r") as f: + data = json.load(f) + assert data["type"] == "RKMETextSpecification" + + rkme2 = RKMETextSpecification() + rkme2.load(rkme_path) + assert rkme2.type == "RKMETextSpecification" + + return rkme2.get_z().shape[1] + + dim1 = _test_text_rkme(generate_random_text_list(3000, "en")) + dim2 = _test_text_rkme(generate_random_text_list(4000, "en")) + dim3 = _test_text_rkme(generate_random_text_list(2000, "zh")) + dim4 = _test_text_rkme(generate_random_text_list(5000, "zh")) + + assert dim1 == dim2 and dim2 == dim3 and dim3 == dim4 + if __name__ == "__main__": unittest.main() diff --git a/tests/test_text_workflow/example_files/example_init.py b/tests/test_text_workflow/example_files/example_init.py new file mode 100644 index 0000000..a98b757 --- /dev/null +++ b/tests/test_text_workflow/example_files/example_init.py @@ -0,0 +1,54 @@ +import os +import joblib +import numpy as np +from learnware.model import BaseModel +import torch +from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER +import torchtext.functional as F +import torchtext.transforms as T +from torch.hub import load_state_dict_from_url + + +class Model(BaseModel): + def __init__(self): + super().__init__(input_shape=None, output_shape=(2,)) + dir_path = os.path.dirname(os.path.abspath(__file__)) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + num_classes = 2 + input_dim = 768 + classifier_head = RobertaClassificationHead(num_classes=num_classes, input_dim=input_dim) + self.model = XLMR_BASE_ENCODER.get_model(head=classifier_head).to(self.device) + self.model.load_state_dict(torch.load(os.path.join(dir_path, "model.pth"))) + + def fit(self, X: np.ndarray, y: np.ndarray): + pass + + def predict(self, X: np.ndarray) -> np.ndarray: + X = sentence_preprocess(X) + X = F.to_tensor(X, padding_value=1).to(self.device) + return self.model(X) + + def finetune(self, X: np.ndarray, y: np.ndarray): + pass + + +def sentence_preprocess(x_datapipe): + padding_idx = 1 + bos_idx = 0 + eos_idx = 2 + max_seq_len = 256 + xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt" + xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model" + + text_transform = T.Sequential( + T.SentencePieceTokenizer(xlmr_spm_model_path), + T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)), + T.Truncate(max_seq_len - 2), + T.AddToken(token=bos_idx, begin=True), + T.AddToken(token=eos_idx, begin=False), + ) + + x_datapipe = [text_transform(x) for x in x_datapipe] + # x_datapipe = x_datapipe.map(text_transform) + return x_datapipe diff --git a/tests/test_text_workflow/example_files/example_yaml.yaml b/tests/test_text_workflow/example_files/example_yaml.yaml new file mode 100644 index 0000000..f9817c7 --- /dev/null +++ b/tests/test_text_workflow/example_files/example_yaml.yaml @@ -0,0 +1,8 @@ +model: + class_name: Model + kwargs: {} +stat_specifications: + - module_path: learnware.specification + class_name: RKMETextSpecification + file_name: rkme.json + kwargs: {} \ No newline at end of file diff --git a/tests/test_text_workflow/get_data.py b/tests/test_text_workflow/get_data.py new file mode 100644 index 0000000..0d1412d --- /dev/null +++ b/tests/test_text_workflow/get_data.py @@ -0,0 +1,15 @@ +import torch +from torchtext.datasets import SST2 + + +def get_sst2(data_root="./data"): + train_datapipe = SST2(root="./data", split="train") + + X_train = [x[0] for x in train_datapipe] + y_train = [x[1] for x in train_datapipe] + + dev_datapipe = SST2(root="./data", split="dev") + + X_test = [x[0] for x in dev_datapipe] + y_test = [x[1] for x in dev_datapipe] + return X_train, y_train, X_test, y_test diff --git a/tests/test_text_workflow/main.py b/tests/test_text_workflow/main.py new file mode 100644 index 0000000..baa54f4 --- /dev/null +++ b/tests/test_text_workflow/main.py @@ -0,0 +1,237 @@ +import numpy as np +import torch +from get_data import * +import os +import random +from utils import generate_uploader, generate_user, TextDataLoader, train, eval_prediction +from learnware.learnware import Learnware +from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser +import time +import pickle + +from learnware.market import instatiate_learnware_market, BaseUserInfo +from learnware.market import database_ops +from learnware.learnware import Learnware +import learnware.specification as specification +from learnware.logger import get_module_logger + +from shutil import copyfile, rmtree +import zipfile + +logger = get_module_logger("text_test", level="INFO") +origin_data_root = "./data/origin_data" +processed_data_root = "./data/processed_data" +tmp_dir = "./data/tmp" +learnware_pool_dir = "./data/learnware_pool" +dataset = "sst2" +n_uploaders = 10 +n_users = 5 +n_classes = 2 +data_root = os.path.join(origin_data_root, dataset) +data_save_root = os.path.join(processed_data_root, dataset) +user_save_root = os.path.join(data_save_root, "user") +uploader_save_root = os.path.join(data_save_root, "uploader") +model_save_root = os.path.join(data_save_root, "uploader_model") +os.makedirs(data_root, exist_ok=True) +os.makedirs(user_save_root, exist_ok=True) +os.makedirs(uploader_save_root, exist_ok=True) +os.makedirs(model_save_root, exist_ok=True) + +output_description = { + "Dimension": 2, + "Description": { + "0": "the probability of being negative", + "1": "the probability of being positive", + }, +} +semantic_specs = [ + { + "Data": {"Values": ["Text"], "Type": "Class"}, + "Task": {"Values": ["Classification"], "Type": "Class"}, + "Library": {"Values": ["PyTorch"], "Type": "Class"}, + "Scenario": {"Values": ["Business"], "Type": "Tag"}, + "Description": {"Values": "", "Type": "String"}, + "Name": {"Values": "learnware_1", "Type": "String"}, + "Output": output_description, + } +] + +user_semantic = { + "Data": {"Values": ["Text"], "Type": "Class"}, + "Task": {"Values": ["Classification"], "Type": "Class"}, + "Library": {"Values": ["PyTorch"], "Type": "Class"}, + "Scenario": {"Values": ["Business"], "Type": "Tag"}, + "Description": {"Values": "", "Type": "String"}, + "Name": {"Values": "", "Type": "String"}, + "Output": output_description, +} + + +def prepare_data(): + if dataset == "sst2": + X_train, y_train, X_test, y_test = get_sst2(data_root) + else: + return + generate_uploader(X_train, y_train, n_uploaders=n_uploaders, data_save_root=uploader_save_root) + generate_user(X_test, y_test, n_users=n_users, data_save_root=user_save_root) + + +def prepare_model(): + dataloader = TextDataLoader(data_save_root, train=True) + for i in range(n_uploaders): + logger.info("Train on uploader: %d" % (i)) + X, y = dataloader.get_idx_data(i) + model = train(X, y, out_classes=n_classes) + model_save_path = os.path.join(model_save_root, "uploader_%d.pth" % (i)) + torch.save(model.state_dict(), model_save_path) + logger.info("Model saved to '%s'" % (model_save_path)) + + +def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_root, zip_name): + os.makedirs(save_root, exist_ok=True) + tmp_spec_path = os.path.join(save_root, "rkme.json") + tmp_model_path = os.path.join(save_root, "model.pth") + tmp_yaml_path = os.path.join(save_root, "learnware.yaml") + tmp_init_path = os.path.join(save_root, "__init__.py") + + with open(data_path, "rb") as f: + X = pickle.load(f) + semantic_spec = semantic_specs[0] + + st = time.time() + # user_spec = specification.utils.generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0) + user_spec = specification.RKMETextSpecification() + user_spec.generate_stat_spec_from_data(X=X) + ed = time.time() + logger.info("Stat spec generated in %.3f s" % (ed - st)) + user_spec.save(tmp_spec_path) + copyfile(model_path, tmp_model_path) + copyfile(yaml_path, tmp_yaml_path) + copyfile(init_file_path, tmp_init_path) + zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name)) + with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj: + zip_obj.write(tmp_spec_path, "rkme.json") + zip_obj.write(tmp_model_path, "model.pth") + zip_obj.write(tmp_yaml_path, "learnware.yaml") + zip_obj.write(tmp_init_path, "__init__.py") + rmtree(save_root) + logger.info("New Learnware Saved to %s" % (zip_file_name)) + return zip_file_name + + +def prepare_market(): + text_market = instatiate_learnware_market(market_id="sst2", rebuild=True) + try: + rmtree(learnware_pool_dir) + except: + pass + os.makedirs(learnware_pool_dir, exist_ok=True) + for i in range(n_uploaders): + data_path = os.path.join(uploader_save_root, "uploader_%d_X.pkl" % (i)) + model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i)) + init_file_path = "./example_files/example_init.py" + yaml_file_path = "./example_files/example_yaml.yaml" + new_learnware_path = prepare_learnware( + data_path, model_path, init_file_path, yaml_file_path, tmp_dir, "%s_%d" % (dataset, i) + ) + semantic_spec = semantic_specs[0] + semantic_spec["Name"]["Values"] = "learnware_%d" % (i) + semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (i) + text_market.add_learnware(new_learnware_path, semantic_spec) + + logger.info("Total Item: %d" % (len(text_market))) + + +def test_search(gamma=0.1, load_market=True): + if load_market: + text_market = instatiate_learnware_market(market_id="sst2") + else: + prepare_market() + text_market = instatiate_learnware_market(market_id="sst2") + logger.info("Number of items in the market: %d" % len(text_market)) + + select_list = [] + avg_list = [] + improve_list = [] + job_selector_score_list = [] + ensemble_score_list = [] + pruning_score_list = [] + for i in range(n_users): + user_data_path = os.path.join(user_save_root, "user_%d_X.pkl" % (i)) + user_label_path = os.path.join(user_save_root, "user_%d_y.pkl" % (i)) + with open(user_data_path, "rb") as f: + user_data = pickle.load(f) + with open(user_label_path, "rb") as f: + user_label = pickle.load(f) + # user_data = np.load(user_data_path) + # user_label = np.load(user_label_path) + # user_stat_spec = specification.utils.generate_rkme_spec(X=user_data, gamma=gamma, cuda_idx=0) + user_stat_spec = specification.RKMETextSpecification() + user_stat_spec.generate_stat_spec_from_data(X=user_data) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETextSpecification": user_stat_spec}) + logger.info("Searching Market for user: %d" % (i)) + sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = text_market.search_learnware( + user_info + ) + l = len(sorted_score_list) + acc_list = [] + for idx in range(l): + learnware = single_learnware_list[idx] + score = sorted_score_list[idx] + pred_y = learnware.predict(user_data) + acc = eval_prediction(pred_y, user_label) + acc_list.append(acc) + logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) + + # test reuse (job selector) + reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) + reuse_predict = reuse_baseline.predict(user_data=user_data) + reuse_score = eval_prediction(reuse_predict, user_label) + job_selector_score_list.append(reuse_score) + print(f"mixture reuse loss(job selector): {reuse_score}") + + # test reuse (ensemble) + reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_label") + ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) + ensemble_score = eval_prediction(ensemble_predict_y, user_label) + ensemble_score_list.append(ensemble_score) + print(f"mixture reuse accuracy (ensemble): {ensemble_score}") + + select_list.append(acc_list[0]) + avg_list.append(np.mean(acc_list)) + improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list)) + + # test reuse (ensemblePruning) + reuse_pruning = EnsemblePruningReuser(learnware_list=mixture_learnware_list) + pruning_predict_y = reuse_pruning.predict(user_data=user_data) + pruning_score = eval_prediction(pruning_predict_y, user_label) + pruning_score_list.append(pruning_score) + print(f"mixture reuse accuracy (ensemble Pruning): {pruning_score}\n") + + select_list.append(acc_list[0]) + avg_list.append(np.mean(acc_list)) + improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list)) + + logger.info( + "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f" + % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list)) + ) + logger.info("Average performance improvement: %.3f" % (np.mean(improve_list))) + logger.info( + "Average Job Selector Reuse Performance: %.3f +/- %.3f" + % (np.mean(job_selector_score_list), np.std(job_selector_score_list)) + ) + logger.info( + "Averaging Ensemble Reuse Performance: %.3f +/- %.3f" + % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) + ) + logger.info( + "Selective Ensemble Reuse Performance: %.3f +/- %.3f" + % (np.mean(pruning_score_list), np.std(pruning_score_list)) + ) + + +if __name__ == "__main__": + prepare_data() + prepare_model() + test_search(load_market=False) diff --git a/tests/test_text_workflow/utils.py b/tests/test_text_workflow/utils.py new file mode 100644 index 0000000..c438d08 --- /dev/null +++ b/tests/test_text_workflow/utils.py @@ -0,0 +1,187 @@ +import os +import numpy as np +import random +import math + +import torch +import torch.nn as nn +import torch.optim as optim + +import pickle +import torchtext.transforms as T +from torch.hub import load_state_dict_from_url +from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER +import torchtext.functional as F +from torch.optim import AdamW +from torch.utils.data import DataLoader + + +class TextDataLoader: + def __init__(self, data_root, train: bool = True): + self.data_root = data_root + self.train = train + + def get_idx_data(self, idx=0): + if self.train: + X_path = os.path.join(self.data_root, "uploader", "uploader_%d_X.pkl" % (idx)) + y_path = os.path.join(self.data_root, "uploader", "uploader_%d_y.pkl" % (idx)) + if not (os.path.exists(X_path) and os.path.exists(y_path)): + raise Exception("Index Error") + with open(X_path, "rb") as f: + X = pickle.load(f) + with open(y_path, "rb") as f: + y = pickle.load(f) + else: + X_path = os.path.join(self.data_root, "user", "user_%d_X.pkl" % (idx)) + y_path = os.path.join(self.data_root, "user", "user_%d_y.pkl" % (idx)) + if not (os.path.exists(X_path) and os.path.exists(y_path)): + raise Exception("Index Error") + with open(X_path, "rb") as f: + X = pickle.load(f) + with open(y_path, "rb") as f: + y = pickle.load(f) + return X, y + + +def generate_uploader(data_x, data_y, n_uploaders=50, data_save_root=None): + if data_save_root is None: + return + os.makedirs(data_save_root, exist_ok=True) + n = len(data_x) + for i in range(n_uploaders): + selected_X = data_x[i * (n // n_uploaders) : (i + 1) * (n // n_uploaders)] + selected_y = data_y[i * (n // n_uploaders) : (i + 1) * (n // n_uploaders)] + X_save_dir = os.path.join(data_save_root, "uploader_%d_X.pkl" % (i)) + y_save_dir = os.path.join(data_save_root, "uploader_%d_y.pkl" % (i)) + with open(X_save_dir, "wb") as f: + pickle.dump(selected_X, f) + with open(y_save_dir, "wb") as f: + pickle.dump(selected_y, f) + print("Saving to %s" % (X_save_dir)) + + +def generate_user(data_x, data_y, n_users=50, data_save_root=None): + if data_save_root is None: + return + os.makedirs(data_save_root, exist_ok=True) + n = len(data_x) + for i in range(n_users): + selected_X = data_x[i * (n // n_users) : (i + 1) * (n // n_users)] + selected_y = data_y[i * (n // n_users) : (i + 1) * (n // n_users)] + X_save_dir = os.path.join(data_save_root, "user_%d_X.pkl" % (i)) + y_save_dir = os.path.join(data_save_root, "user_%d_y.pkl" % (i)) + with open(X_save_dir, "wb") as f: + pickle.dump(selected_X, f) + with open(y_save_dir, "wb") as f: + pickle.dump(selected_y, f) + print("Saving to %s" % (X_save_dir)) + + +def sentence_preprocess(x_datapipe): + padding_idx = 1 + bos_idx = 0 + eos_idx = 2 + max_seq_len = 256 + xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt" + xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model" + + text_transform = T.Sequential( + T.SentencePieceTokenizer(xlmr_spm_model_path), + T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)), + T.Truncate(max_seq_len - 2), + T.AddToken(token=bos_idx, begin=True), + T.AddToken(token=eos_idx, begin=False), + ) + + x_datapipe = [text_transform(x) for x in x_datapipe] + # x_datapipe = x_datapipe.map(text_transform) + return x_datapipe + + +def train_step(model, criteria, optim, input, target): + output = model(input) + loss = criteria(output, target) + optim.zero_grad() + loss.backward() + optim.step() + + +def eval_step(model, criteria, input, target): + output = model(input) + loss = criteria(output, target).item() + return float(loss), (output.argmax(1) == target).type(torch.float).sum().item() + + +def evaluate(model, criteria, dev_dataloader): + model.eval() + total_loss = 0 + correct_predictions = 0 + total_predictions = 0 + counter = 0 + with torch.no_grad(): + for batch in dev_dataloader: + input = F.to_tensor(batch["token_ids"], padding_value=1).to(DEVICE) + target = torch.tensor(batch["target"]).to(DEVICE) + loss, predictions = eval_step(model, criteria, input, target) + total_loss += loss + correct_predictions += predictions + total_predictions += len(target) + counter += 1 + + return total_loss / counter, correct_predictions / total_predictions + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# Train Uploaders' models +def train(X, y, out_classes, epochs=35, batch_size=128): + # print(X.shape, y.shape) + from torchdata.datapipes.iter import IterableWrapper + + X = sentence_preprocess(X) + data_size = len(X) + train_datapipe = list(zip(X, y)) + train_datapipe = IterableWrapper(train_datapipe) + train_datapipe = train_datapipe.batch(batch_size) + train_datapipe = train_datapipe.rows2columnar(["token_ids", "target"]) + train_dataloader = DataLoader(train_datapipe, batch_size=None) + + num_classes = 2 + input_dim = 768 + classifier_head = RobertaClassificationHead(num_classes=num_classes, input_dim=input_dim) + model = XLMR_BASE_ENCODER.get_model(head=classifier_head) + learning_rate = 1e-5 + optim = AdamW(model.parameters(), lr=learning_rate) + criteria = nn.CrossEntropyLoss() + + model.to(DEVICE) + + num_epochs = 10 + + for e in range(num_epochs): + for batch in train_dataloader: + input = F.to_tensor(batch["token_ids"], padding_value=1).to(DEVICE) + target = torch.tensor(batch["target"]).to(DEVICE) + train_step(model, criteria, optim, input, target) + + loss, accuracy = evaluate(model, criteria, train_dataloader) + print("Epoch = [{}], loss = [{}], accuracy = [{}]".format(e, loss, accuracy)) + return model + + +def eval_prediction(pred_y, target_y): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if not isinstance(pred_y, np.ndarray): + pred_y = pred_y.detach().cpu().numpy() + if len(pred_y.shape) == 1: + predicted = np.array(pred_y) + else: + predicted = np.argmax(pred_y, 1) + annos = np.array(target_y) + # print(predicted, annos) + # annos = target_y + total = predicted.shape[0] + correct = (predicted == annos).sum().item() + criterion = nn.CrossEntropyLoss() + return correct / total