From 0c347b4575a61626f37af1c173b57df2c3045689 Mon Sep 17 00:00:00 2001 From: nju-xy <1582857295@qq.com> Date: Sun, 29 Oct 2023 18:44:55 +0800 Subject: [PATCH 01/10] [ENH] Add text specification and test --- .../example_files/example_init.py | 54 +++++ .../example_files/example_yaml.yaml | 8 + examples/dataset_text_workflow/get_data.py | 14 ++ examples/dataset_text_workflow/main.py | 229 ++++++++++++++++++ examples/dataset_text_workflow/utils.py | 186 ++++++++++++++ learnware/reuse/job_selector.py | 3 +- learnware/specification/__init__.py | 1 + learnware/specification/text/__init__.py | 1 + learnware/specification/text/rkme.py | 43 ++++ 9 files changed, 538 insertions(+), 1 deletion(-) create mode 100644 examples/dataset_text_workflow/example_files/example_init.py create mode 100644 examples/dataset_text_workflow/example_files/example_yaml.yaml create mode 100644 examples/dataset_text_workflow/get_data.py create mode 100644 examples/dataset_text_workflow/main.py create mode 100644 examples/dataset_text_workflow/utils.py create mode 100644 learnware/specification/text/__init__.py create mode 100644 learnware/specification/text/rkme.py diff --git a/examples/dataset_text_workflow/example_files/example_init.py b/examples/dataset_text_workflow/example_files/example_init.py new file mode 100644 index 0000000..99839b2 --- /dev/null +++ b/examples/dataset_text_workflow/example_files/example_init.py @@ -0,0 +1,54 @@ +import os +import joblib +import numpy as np +from learnware.model import BaseModel +import torch +from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER +import torchtext.functional as F +import torchtext.transforms as T +from torch.hub import load_state_dict_from_url + + +class Model(BaseModel): + def __init__(self): + super().__init__(input_shape=None, output_shape=(2,)) + dir_path = os.path.dirname(os.path.abspath(__file__)) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + num_classes = 2 + input_dim = 768 + classifier_head = RobertaClassificationHead(num_classes=num_classes, input_dim=input_dim) + self.model = XLMR_BASE_ENCODER.get_model(head=classifier_head).to(self.device) + self.model.load_state_dict(torch.load(os.path.join(dir_path, "model.pth"))) + + def fit(self, X: np.ndarray, y: np.ndarray): + pass + + def predict(self, X: np.ndarray) -> np.ndarray: + X = sentence_preprocess(X) + X = F.to_tensor(X, padding_value=1).to(self.device) + return self.model(X) + + def finetune(self, X: np.ndarray, y: np.ndarray): + pass + + +def sentence_preprocess(x_datapipe): + padding_idx = 1 + bos_idx = 0 + eos_idx = 2 + max_seq_len = 256 + xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt" + xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model" + + text_transform = T.Sequential( + T.SentencePieceTokenizer(xlmr_spm_model_path), + T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)), + T.Truncate(max_seq_len - 2), + T.AddToken(token=bos_idx, begin=True), + T.AddToken(token=eos_idx, begin=False), + ) + + x_datapipe = [text_transform(x) for x in x_datapipe] + # x_datapipe = x_datapipe.map(text_transform) + return x_datapipe \ No newline at end of file diff --git a/examples/dataset_text_workflow/example_files/example_yaml.yaml b/examples/dataset_text_workflow/example_files/example_yaml.yaml new file mode 100644 index 0000000..6ca01c9 --- /dev/null +++ b/examples/dataset_text_workflow/example_files/example_yaml.yaml @@ -0,0 +1,8 @@ +model: + class_name: Model + kwargs: {} +stat_specifications: + - module_path: learnware.specification + class_name: RKMEStatSpecification + file_name: rkme.json + kwargs: {} \ No newline at end of file diff --git a/examples/dataset_text_workflow/get_data.py b/examples/dataset_text_workflow/get_data.py new file mode 100644 index 0000000..03c413b --- /dev/null +++ b/examples/dataset_text_workflow/get_data.py @@ -0,0 +1,14 @@ +import torch +from torchtext.datasets import SST2 + +def get_sst2(data_root="./data"): + train_datapipe = SST2(root='./data', split="train") + + X_train = [x[0] for x in train_datapipe] + y_train = [x[1] for x in train_datapipe] + + dev_datapipe = SST2(root='./data', split="dev") + + X_test = [x[0] for x in dev_datapipe] + y_test = [x[1] for x in dev_datapipe] + return X_train, y_train, X_test, y_test diff --git a/examples/dataset_text_workflow/main.py b/examples/dataset_text_workflow/main.py new file mode 100644 index 0000000..9933301 --- /dev/null +++ b/examples/dataset_text_workflow/main.py @@ -0,0 +1,229 @@ +import numpy as np +import torch +from get_data import * +import os +import random +from utils import generate_uploader, generate_user, TextDataLoader, train, eval_prediction +from learnware.learnware import Learnware +from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser +import time +import pickle + +from learnware.market import EasyMarket, BaseUserInfo +from learnware.market import database_ops +from learnware.learnware import Learnware +import learnware.specification as specification +from learnware.logger import get_module_logger + +from shutil import copyfile, rmtree +import zipfile + +logger = get_module_logger("text_test", level="INFO") +origin_data_root = "./data/origin_data" +processed_data_root = "./data/processed_data" +tmp_dir = "./data/tmp" +learnware_pool_dir = "./data/learnware_pool" +dataset = "sst2" +n_uploaders = 50 +n_users = 5 +n_classes = 2 +data_root = os.path.join(origin_data_root, dataset) +data_save_root = os.path.join(processed_data_root, dataset) +user_save_root = os.path.join(data_save_root, "user") +uploader_save_root = os.path.join(data_save_root, "uploader") +model_save_root = os.path.join(data_save_root, "uploader_model") +os.makedirs(data_root, exist_ok=True) +os.makedirs(user_save_root, exist_ok=True) +os.makedirs(uploader_save_root, exist_ok=True) +os.makedirs(model_save_root, exist_ok=True) + + +semantic_specs = [ + { + "Data": {"Values": ["Text"], "Type": "Class"}, + "Task": {"Values": ["Classification"], "Type": "Class"}, + "Library": {"Values": ["Pytorch"], "Type": "Class"}, + "Scenario": {"Values": ["Business"], "Type": "Tag"}, + "Description": {"Values": "", "Type": "String"}, + "Name": {"Values": "learnware_1", "Type": "String"}, + } +] + +user_semantic = { + "Data": {"Values": ["Text"], "Type": "Class"}, + "Task": {"Values": ["Classification"], "Type": "Class"}, + "Library": {"Values": ["Pytorch"], "Type": "Class"}, + "Scenario": {"Values": ["Business"], "Type": "Tag"}, + "Description": {"Values": "", "Type": "String"}, + "Name": {"Values": "", "Type": "String"}, +} + + +def prepare_data(): + if dataset == "sst2": + X_train, y_train, X_test, y_test = get_sst2(data_root) + else: + return + generate_uploader(X_train, y_train, n_uploaders=n_uploaders, data_save_root=uploader_save_root) + generate_user(X_test, y_test, n_users=n_users, data_save_root=user_save_root) + + +def prepare_model(): + dataloader = TextDataLoader(data_save_root, train=True) + for i in range(n_uploaders): + logger.info("Train on uploader: %d" % (i)) + X, y = dataloader.get_idx_data(i) + model = train(X, y, out_classes=n_classes) + model_save_path = os.path.join(model_save_root, "uploader_%d.pth" % (i)) + torch.save(model.state_dict(), model_save_path) + logger.info("Model saved to '%s'" % (model_save_path)) + + +def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_root, zip_name): + os.makedirs(save_root, exist_ok=True) + tmp_spec_path = os.path.join(save_root, "rkme.json") + tmp_model_path = os.path.join(save_root, "model.pth") + tmp_yaml_path = os.path.join(save_root, "learnware.yaml") + tmp_init_path = os.path.join(save_root, "__init__.py") + + with open(data_path, 'rb') as f: + X = pickle.load(f) + semantic_spec = semantic_specs[0] + + st = time.time() + # user_spec = specification.utils.generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0) + user_spec = specification.TextRKMEStatSpecification() + user_spec.generate_stat_spec_from_text(X=X) + ed = time.time() + logger.info("Stat spec generated in %.3f s" % (ed - st)) + user_spec.save(tmp_spec_path) + copyfile(model_path, tmp_model_path) + copyfile(yaml_path, tmp_yaml_path) + copyfile(init_file_path, tmp_init_path) + zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name)) + with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj: + zip_obj.write(tmp_spec_path, "rkme.json") + zip_obj.write(tmp_model_path, "model.pth") + zip_obj.write(tmp_yaml_path, "learnware.yaml") + zip_obj.write(tmp_init_path, "__init__.py") + rmtree(save_root) + logger.info("New Learnware Saved to %s" % (zip_file_name)) + return zip_file_name + + +def prepare_market(): + text_market = EasyMarket(market_id="sst2", rebuild=True) + try: + rmtree(learnware_pool_dir) + except: + pass + os.makedirs(learnware_pool_dir, exist_ok=True) + for i in range(n_uploaders): + data_path = os.path.join(uploader_save_root, "uploader_%d_X.pkl" % (i)) + model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i)) + init_file_path = "./example_files/example_init.py" + yaml_file_path = "./example_files/example_yaml.yaml" + new_learnware_path = prepare_learnware( + data_path, model_path, init_file_path, yaml_file_path, tmp_dir, "%s_%d" % (dataset, i) + ) + semantic_spec = semantic_specs[0] + semantic_spec["Name"]["Values"] = "learnware_%d" % (i) + semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (i) + text_market.add_learnware(new_learnware_path, semantic_spec) + + logger.info("Total Item: %d" % (len(text_market))) + curr_inds = text_market._get_ids() + logger.info("Available ids: " + str(curr_inds)) + + +def test_search(gamma=0.1, load_market=True): + if load_market: + text_market = EasyMarket(market_id="sst2") + else: + prepare_market() + text_market = EasyMarket(market_id="sst2") + logger.info("Number of items in the market: %d" % len(text_market)) + + select_list = [] + avg_list = [] + improve_list = [] + job_selector_score_list = [] + ensemble_score_list = [] + pruning_score_list = [] + for i in range(n_users): + user_data_path = os.path.join(user_save_root, "user_%d_X.pkl" % (i)) + user_label_path = os.path.join(user_save_root, "user_%d_y.pkl" % (i)) + with open(user_data_path, 'rb') as f: + user_data = pickle.load(f) + with open(user_label_path, 'rb') as f: + user_label = pickle.load(f) + # user_data = np.load(user_data_path) + # user_label = np.load(user_label_path) + # user_stat_spec = specification.utils.generate_rkme_spec(X=user_data, gamma=gamma, cuda_idx=0) + user_stat_spec = specification.TextRKMEStatSpecification() + user_stat_spec.generate_stat_spec_from_text(X=user_data) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_stat_spec}) + logger.info("Searching Market for user: %d" % (i)) + sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = text_market.search_learnware( + user_info + ) + l = len(sorted_score_list) + acc_list = [] + for idx in range(l): + learnware = single_learnware_list[idx] + score = sorted_score_list[idx] + pred_y = learnware.predict(user_data) + acc = eval_prediction(pred_y, user_label) + acc_list.append(acc) + logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) + + # # test reuse (job selector) + # reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) + # reuse_predict = reuse_baseline.predict(user_data=user_data) + # reuse_score = eval_prediction(reuse_predict, user_label) + # job_selector_score_list.append(reuse_score) + # print(f"mixture reuse loss: {reuse_score}") + + # test reuse (ensemble) + reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_label") + ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) + ensemble_score = eval_prediction(ensemble_predict_y, user_label) + ensemble_score_list.append(ensemble_score) + print(f"mixture reuse accuracy (ensemble): {ensemble_score}\n") + + select_list.append(acc_list[0]) + avg_list.append(np.mean(acc_list)) + improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list)) + + # test reuse (ensemblePruning) + reuse_pruning = EnsemblePruningReuser(learnware_list=mixture_learnware_list) + pruning_predict_y = reuse_pruning.predict(user_data=user_data) + pruning_score = eval_prediction(pruning_predict_y, user_label) + pruning_score_list.append(pruning_score) + print(f"mixture reuse accuracy (ensemble): {pruning_score}\n") + + select_list.append(acc_list[0]) + avg_list.append(np.mean(acc_list)) + improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list)) + + logger.info( + "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f" + % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list)) + ) + logger.info("Average performance improvement: %.3f" % (np.mean(improve_list))) + logger.info( + "Average Job Selector Reuse Performance: %.3f +/- %.3f" + % (np.mean(job_selector_score_list), np.std(job_selector_score_list)) + ) + logger.info( + "Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) + ) + logger.info( + "Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(pruning_score_list), np.std(pruning_score_list)) + ) + + +if __name__ == "__main__": + # prepare_data() + # prepare_model() + test_search(load_market=True) diff --git a/examples/dataset_text_workflow/utils.py b/examples/dataset_text_workflow/utils.py new file mode 100644 index 0000000..d3acf24 --- /dev/null +++ b/examples/dataset_text_workflow/utils.py @@ -0,0 +1,186 @@ +import os +import numpy as np +import random +import math + +import torch +import torch.nn as nn +import torch.optim as optim + +import pickle +import torchtext.transforms as T +from torch.hub import load_state_dict_from_url +from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER +import torchtext.functional as F +from torch.optim import AdamW +from torch.utils.data import DataLoader + + + +class TextDataLoader: + def __init__(self, data_root, train: bool = True): + self.data_root = data_root + self.train = train + + def get_idx_data(self, idx=0): + if self.train: + X_path = os.path.join(self.data_root, "uploader", "uploader_%d_X.pkl" % (idx)) + y_path = os.path.join(self.data_root, "uploader", "uploader_%d_y.pkl" % (idx)) + if not (os.path.exists(X_path) and os.path.exists(y_path)): + raise Exception("Index Error") + with open(X_path, 'rb') as f: + X = pickle.load(f) + with open(y_path, 'rb') as f: + y = pickle.load(f) + else: + X_path = os.path.join(self.data_root, "user", "user_%d_X.pkl" % (idx)) + y_path = os.path.join(self.data_root, "user", "user_%d_y.pkl" % (idx)) + if not (os.path.exists(X_path) and os.path.exists(y_path)): + raise Exception("Index Error") + with open(X_path, 'rb') as f: + X = pickle.load(f) + with open(y_path, 'rb') as f: + y = pickle.load(f) + return X, y + + +def generate_uploader(data_x, data_y, n_uploaders=50, data_save_root=None): + if data_save_root is None: + return + os.makedirs(data_save_root, exist_ok=True) + n = len(data_x) + for i in range(n_uploaders): + selected_X = data_x[i * (n // n_uploaders): (i + 1) * (n // n_uploaders)] + selected_y = data_y[i * (n // n_uploaders): (i + 1) * (n // n_uploaders)] + X_save_dir = os.path.join(data_save_root, "uploader_%d_X.pkl" % (i)) + y_save_dir = os.path.join(data_save_root, "uploader_%d_y.pkl" % (i)) + with open(X_save_dir, 'wb') as f: + pickle.dump(selected_X, f) + with open(y_save_dir, 'wb') as f: + pickle.dump(selected_y, f) + print("Saving to %s" % (X_save_dir)) + + +def generate_user(data_x, data_y, n_users=50, data_save_root=None): + if data_save_root is None: + return + os.makedirs(data_save_root, exist_ok=True) + n = len(data_x) + for i in range(n_users): + selected_X = data_x[i * (n // n_users): (i + 1) * (n // n_users)] + selected_y = data_y[i * (n // n_users): (i + 1) * (n // n_users)] + X_save_dir = os.path.join(data_save_root, "user_%d_X.pkl" % (i)) + y_save_dir = os.path.join(data_save_root, "user_%d_y.pkl" % (i)) + with open(X_save_dir, 'wb') as f: + pickle.dump(selected_X, f) + with open(y_save_dir, 'wb') as f: + pickle.dump(selected_y, f) + print("Saving to %s" % (X_save_dir)) + + +def sentence_preprocess(x_datapipe): + padding_idx = 1 + bos_idx = 0 + eos_idx = 2 + max_seq_len = 256 + xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt" + xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model" + + text_transform = T.Sequential( + T.SentencePieceTokenizer(xlmr_spm_model_path), + T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)), + T.Truncate(max_seq_len - 2), + T.AddToken(token=bos_idx, begin=True), + T.AddToken(token=eos_idx, begin=False), + ) + + x_datapipe = [text_transform(x) for x in x_datapipe] + # x_datapipe = x_datapipe.map(text_transform) + return x_datapipe + + +def train_step(model, criteria, optim, input, target): + output = model(input) + loss = criteria(output, target) + optim.zero_grad() + loss.backward() + optim.step() + + +def eval_step(model, criteria, input, target): + output = model(input) + loss = criteria(output, target).item() + return float(loss), (output.argmax(1) == target).type(torch.float).sum().item() + + +def evaluate(model, criteria, dev_dataloader): + model.eval() + total_loss = 0 + correct_predictions = 0 + total_predictions = 0 + counter = 0 + with torch.no_grad(): + for batch in dev_dataloader: + input = F.to_tensor(batch["token_ids"], padding_value=1).to(DEVICE) + target = torch.tensor(batch["target"]).to(DEVICE) + loss, predictions = eval_step(model, criteria, input, target) + total_loss += loss + correct_predictions += predictions + total_predictions += len(target) + counter += 1 + + return total_loss / counter, correct_predictions / total_predictions + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Train Uploaders' models +def train(X, y, out_classes, epochs=35, batch_size=128): + # print(X.shape, y.shape) + from torchdata.datapipes.iter import IterableWrapper + X = sentence_preprocess(X) + data_size = len(X) + train_datapipe = list(zip(X, y)) + train_datapipe = IterableWrapper(train_datapipe) + train_datapipe = train_datapipe.batch(batch_size) + train_datapipe = train_datapipe.rows2columnar(["token_ids", "target"]) + train_dataloader = DataLoader(train_datapipe, batch_size=None) + + num_classes = 2 + input_dim = 768 + classifier_head = RobertaClassificationHead(num_classes=num_classes, input_dim=input_dim) + model = XLMR_BASE_ENCODER.get_model(head=classifier_head) + learning_rate = 1e-5 + optim = AdamW(model.parameters(), lr=learning_rate) + criteria = nn.CrossEntropyLoss() + + model.to(DEVICE) + + num_epochs = 10 + + for e in range(num_epochs): + for batch in train_dataloader: + input = F.to_tensor(batch["token_ids"], padding_value=1).to(DEVICE) + target = torch.tensor(batch["target"]).to(DEVICE) + train_step(model, criteria, optim, input, target) + + loss, accuracy = evaluate(model, criteria, train_dataloader) + print("Epoch = [{}], loss = [{}], accuracy = [{}]".format(e, loss, accuracy)) + return model + + +def eval_prediction(pred_y, target_y): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if not isinstance(pred_y, np.ndarray): + pred_y = pred_y.detach().cpu().numpy() + if len(pred_y.shape) == 1: + predicted = np.array(pred_y) + else: + predicted = np.argmax(pred_y, 1) + annos = np.array(target_y) + # print(predicted, annos) + # annos = target_y + total = predicted.shape[0] + correct = (predicted == annos).sum().item() + criterion = nn.CrossEntropyLoss() + return correct / total diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index e786e15..7768adf 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -82,7 +82,8 @@ class JobSelectorReuser(BaseReuser): User's labeled raw data. """ if len(self.learnware_list) == 1: - user_data_num = user_data.shape[0] + # user_data_num = user_data.shape[0] + user_data_num = len(user_data) return np.array([0] * user_data_num) else: learnware_rkme_spec_list = [ diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 556aefb..17f07bf 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,4 @@ from .utils import generate_stat_spec from .base import Specification, BaseStatSpecification from .rkme import RKMEStatSpecification +from .text import TextRKMEStatSpecification diff --git a/learnware/specification/text/__init__.py b/learnware/specification/text/__init__.py new file mode 100644 index 0000000..8b17514 --- /dev/null +++ b/learnware/specification/text/__init__.py @@ -0,0 +1 @@ +from .rkme import TextRKMEStatSpecification diff --git a/learnware/specification/text/rkme.py b/learnware/specification/text/rkme.py new file mode 100644 index 0000000..e541f58 --- /dev/null +++ b/learnware/specification/text/rkme.py @@ -0,0 +1,43 @@ +from sentence_transformers import SentenceTransformer +from ..rkme import RKMEStatSpecification + +class TextRKMEStatSpecification(RKMEStatSpecification): + """Reduced Kernel Mean Embedding (RKME) Specification for Text""" + + def generate_stat_spec_from_text( + self, + X: list, + K: int = 100, + step_size: float = 0.1, + steps: int = 3, + nonnegative_beta: bool = True, + reduce: bool = True, + ): + """Construct reduced set from raw dataset using iterative optimization. + + Parameters + ---------- + X : np.ndarray or torch.tensor + Raw data in np.ndarray format. + K : int + Size of the construced reduced set. + step_size : float + Step size for gradient descent in the iterative optimization. + steps : int + Total rounds in the iterative optimization. + nonnegative_beta : bool, optional + True if weights for the reduced set are intended to be kept non-negative, by default False. + reduce : bool, optional + Whether shrink original data to a smaller set, by default True + """ + + # Sentence embedding for Text + model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') + X = model.encode(X) + + return self.generate_stat_spec_from_data( + X, K, step_size,steps, + nonnegative_beta, + reduce, + ) + From 02e3b4a7cf52fb4163dd62bc33389771178e3064 Mon Sep 17 00:00:00 2001 From: nju-xy <1582857295@qq.com> Date: Mon, 30 Oct 2023 15:48:45 +0800 Subject: [PATCH 02/10] [MNT] job selector for Text --- examples/dataset_pfs_workflow/main.py | 2 +- learnware/reuse/job_selector.py | 15 +++++++++++++-- learnware/specification/__init__.py | 2 +- learnware/specification/text/__init__.py | 2 +- learnware/specification/text/rkme.py | 13 +++++++++++-- .../example_files/example_init.py | 0 .../example_files/example_yaml.yaml | 0 .../test_text_workflow}/get_data.py | 0 .../test_text_workflow}/main.py | 18 +++++++++--------- .../test_text_workflow}/utils.py | 2 +- 10 files changed, 37 insertions(+), 17 deletions(-) rename {examples/dataset_text_workflow => tests/test_text_workflow}/example_files/example_init.py (100%) rename {examples/dataset_text_workflow => tests/test_text_workflow}/example_files/example_yaml.yaml (100%) rename {examples/dataset_text_workflow => tests/test_text_workflow}/get_data.py (100%) rename {examples/dataset_text_workflow => tests/test_text_workflow}/main.py (94%) rename {examples/dataset_text_workflow => tests/test_text_workflow}/utils.py (99%) diff --git a/examples/dataset_pfs_workflow/main.py b/examples/dataset_pfs_workflow/main.py index a465241..b3a7d36 100644 --- a/examples/dataset_pfs_workflow/main.py +++ b/examples/dataset_pfs_workflow/main.py @@ -15,7 +15,7 @@ import learnware.specification as specification from pfs import Dataloader from learnware.logger import get_module_logger -logger = get_module_logger("m5_test", level="INFO") +logger = get_module_logger("pfs_test", level="INFO") semantic_specs = [ diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index 7768adf..adbc328 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -9,7 +9,7 @@ from sklearn.metrics import accuracy_score from learnware.learnware import Learnware import learnware.specification as specification from .base import BaseReuser -from ..specification import RKMEStatSpecification +from ..specification import RKMEStatSpecification, TextRKMEStatSpecification, sentence_embedding from ..logger import get_module_logger logger = get_module_logger("job_selector_reuse") @@ -45,6 +45,10 @@ class JobSelectorReuser(BaseReuser): np.ndarray Prediction given by job-selector method """ + ori_user_data = user_data + if isinstance(user_data[0], str): + user_data = sentence_embedding(user_data) + select_result = self.job_selector(user_data) pred_y_list = [] data_idxs_list = [] @@ -52,7 +56,8 @@ class JobSelectorReuser(BaseReuser): for idx in range(len(self.learnware_list)): data_idx_list = np.where(select_result == idx)[0] if len(data_idx_list) > 0: - pred_y = self.learnware_list[idx].predict(user_data[data_idx_list]) + # pred_y = self.learnware_list[idx].predict(ori_user_data[data_idx_list]) + pred_y = self.learnware_list[idx].predict([ori_user_data[i] for i in data_idx_list]) if isinstance(pred_y, torch.Tensor): pred_y = pred_y.detach().cpu().numpy() # elif isinstance(pred_y, tf.Tensor): @@ -86,6 +91,9 @@ class JobSelectorReuser(BaseReuser): user_data_num = len(user_data) return np.array([0] * user_data_num) else: + ori_user_data = user_data + if isinstance(user_data[0], str): + user_data = sentence_embedding(user_data) learnware_rkme_spec_list = [ learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in self.learnware_list @@ -169,6 +177,8 @@ class JobSelectorReuser(BaseReuser): Inner product matrix calculated from task_rkme_list. """ task_num = len(task_rkme_list) + if isinstance(user_data[0], str): + user_data = sentence_embedding(user_data) user_rkme_spec = specification.utils.generate_rkme_spec(X=user_data, reduce=False) K = task_rkme_matrix v = np.array([user_rkme_spec.inner_prod(task_rkme) for task_rkme in task_rkme_list]) @@ -229,6 +239,7 @@ class JobSelectorReuser(BaseReuser): "boosting_type": "gbdt", "n_estimators": 2000, "boost_from_average": False, + "verbose": -1 } if num_class == 2: diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 17f07bf..62c6b7a 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,4 +1,4 @@ from .utils import generate_stat_spec from .base import Specification, BaseStatSpecification from .rkme import RKMEStatSpecification -from .text import TextRKMEStatSpecification +from .text import TextRKMEStatSpecification, sentence_embedding diff --git a/learnware/specification/text/__init__.py b/learnware/specification/text/__init__.py index 8b17514..956907d 100644 --- a/learnware/specification/text/__init__.py +++ b/learnware/specification/text/__init__.py @@ -1 +1 @@ -from .rkme import TextRKMEStatSpecification +from .rkme import TextRKMEStatSpecification, sentence_embedding diff --git a/learnware/specification/text/rkme.py b/learnware/specification/text/rkme.py index e541f58..90fe4f2 100644 --- a/learnware/specification/text/rkme.py +++ b/learnware/specification/text/rkme.py @@ -1,5 +1,7 @@ from sentence_transformers import SentenceTransformer from ..rkme import RKMEStatSpecification +import numpy as np +import os class TextRKMEStatSpecification(RKMEStatSpecification): """Reduced Kernel Mean Embedding (RKME) Specification for Text""" @@ -32,8 +34,7 @@ class TextRKMEStatSpecification(RKMEStatSpecification): """ # Sentence embedding for Text - model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') - X = model.encode(X) + X = sentence_embedding(X) return self.generate_stat_spec_from_data( X, K, step_size,steps, @@ -41,3 +42,11 @@ class TextRKMEStatSpecification(RKMEStatSpecification): reduce, ) +def sentence_embedding(X): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') + X = model.encode(X) + X = np.array(X) + # print(X.shape, np.mean(X, axis=1).shape, np.std(X, axis=1).reshape(X.shape[0], 1).shape) + # X = (X - np.mean(X, axis=1).reshape(X.shape[0], 1)) / np.std(X, axis=1).reshape(X.shape[0], 1) + return X diff --git a/examples/dataset_text_workflow/example_files/example_init.py b/tests/test_text_workflow/example_files/example_init.py similarity index 100% rename from examples/dataset_text_workflow/example_files/example_init.py rename to tests/test_text_workflow/example_files/example_init.py diff --git a/examples/dataset_text_workflow/example_files/example_yaml.yaml b/tests/test_text_workflow/example_files/example_yaml.yaml similarity index 100% rename from examples/dataset_text_workflow/example_files/example_yaml.yaml rename to tests/test_text_workflow/example_files/example_yaml.yaml diff --git a/examples/dataset_text_workflow/get_data.py b/tests/test_text_workflow/get_data.py similarity index 100% rename from examples/dataset_text_workflow/get_data.py rename to tests/test_text_workflow/get_data.py diff --git a/examples/dataset_text_workflow/main.py b/tests/test_text_workflow/main.py similarity index 94% rename from examples/dataset_text_workflow/main.py rename to tests/test_text_workflow/main.py index 9933301..8eb85fc 100644 --- a/examples/dataset_text_workflow/main.py +++ b/tests/test_text_workflow/main.py @@ -177,19 +177,19 @@ def test_search(gamma=0.1, load_market=True): acc_list.append(acc) logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) - # # test reuse (job selector) - # reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) - # reuse_predict = reuse_baseline.predict(user_data=user_data) - # reuse_score = eval_prediction(reuse_predict, user_label) - # job_selector_score_list.append(reuse_score) - # print(f"mixture reuse loss: {reuse_score}") + # test reuse (job selector) + reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) + reuse_predict = reuse_baseline.predict(user_data=user_data) + reuse_score = eval_prediction(reuse_predict, user_label) + job_selector_score_list.append(reuse_score) + print(f"mixture reuse loss(job selector): {reuse_score}") # test reuse (ensemble) reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_label") ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) ensemble_score = eval_prediction(ensemble_predict_y, user_label) ensemble_score_list.append(ensemble_score) - print(f"mixture reuse accuracy (ensemble): {ensemble_score}\n") + print(f"mixture reuse accuracy (ensemble): {ensemble_score}") select_list.append(acc_list[0]) avg_list.append(np.mean(acc_list)) @@ -200,7 +200,7 @@ def test_search(gamma=0.1, load_market=True): pruning_predict_y = reuse_pruning.predict(user_data=user_data) pruning_score = eval_prediction(pruning_predict_y, user_label) pruning_score_list.append(pruning_score) - print(f"mixture reuse accuracy (ensemble): {pruning_score}\n") + print(f"mixture reuse accuracy (ensemble Pruning): {pruning_score}\n") select_list.append(acc_list[0]) avg_list.append(np.mean(acc_list)) @@ -226,4 +226,4 @@ def test_search(gamma=0.1, load_market=True): if __name__ == "__main__": # prepare_data() # prepare_model() - test_search(load_market=True) + test_search(load_market=False) diff --git a/examples/dataset_text_workflow/utils.py b/tests/test_text_workflow/utils.py similarity index 99% rename from examples/dataset_text_workflow/utils.py rename to tests/test_text_workflow/utils.py index d3acf24..3be6a6a 100644 --- a/examples/dataset_text_workflow/utils.py +++ b/tests/test_text_workflow/utils.py @@ -156,7 +156,7 @@ def train(X, y, out_classes, epochs=35, batch_size=128): model.to(DEVICE) - num_epochs = 10 + num_epochs = 30 for e in range(num_epochs): for batch in train_dataloader: From 93e734fa4c608b7d3b63fc340c1ee7b68fd87065 Mon Sep 17 00:00:00 2001 From: nju-xy <1582857295@qq.com> Date: Mon, 30 Oct 2023 20:06:45 +0800 Subject: [PATCH 03/10] [MNT] text learnware for easy2 market --- learnware/market/easy2/checker.py | 14 +++++++---- learnware/market/easy2/searcher.py | 20 +++++++++++----- learnware/specification/text/rkme.py | 2 +- setup.py | 2 ++ .../example_files/example_yaml.yaml | 2 +- tests/test_text_workflow/main.py | 24 ++++++++++++------- 6 files changed, 44 insertions(+), 20 deletions(-) diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index 7f26b91..a18dcf4 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -69,7 +69,6 @@ class EasyStatisticalChecker(BaseChecker): try: learnware_model = learnware.get_model() - # Check input shape if semantic_spec["Data"]["Values"][0] == "Table": input_shape = (semantic_spec["Input"]["Dimension"],) @@ -77,13 +76,20 @@ class EasyStatisticalChecker(BaseChecker): input_shape = learnware_model.input_shape # Check rkme dimension - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") - if stat_spec is not None: + is_text = "TextRKMEStatSpecification" in learnware.get_specification().stat_spec + if is_text: + stat_spec = learnware.get_specification().get_stat_spec_by_name("TextRKMEStatSpecification") + else: + stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") + if stat_spec is not None and not is_text: if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") return self.INVALID_LEARNWARE - inputs = np.random.randn(10, *input_shape) + if is_text: + inputs = ["This is an example sentence"] + else: + inputs = np.random.randn(10, *input_shape) outputs = learnware.predict(inputs) # Check output diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 9d758fc..80e8963 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -252,7 +252,7 @@ class EasyTableSearcher(BaseSearcher): """ learnware_num = len(learnware_list) RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list ] if type(intermediate_K) == np.ndarray: @@ -321,7 +321,7 @@ class EasyTableSearcher(BaseSearcher): """ num = intermediate_K.shape[0] - 1 RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list ] for i in range(intermediate_K.shape[0]): intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) @@ -377,7 +377,7 @@ class EasyTableSearcher(BaseSearcher): if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] - mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification")) + mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_info_name)) else: if len(mixture_list) > max_search_num: mixture_list = mixture_list[:max_search_num] @@ -439,7 +439,9 @@ class EasyTableSearcher(BaseSearcher): user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) for learnware in learnware_list: - rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") + if self.stat_info_name not in learnware.specification.stat_spec: + continue + rkme = learnware.specification.get_stat_spec_by_name(self.stat_info_name) rkme_dim = str(list(rkme.get_z().shape)[1:]) if rkme_dim == user_rkme_dim: filtered_learnware_list.append(learnware) @@ -538,7 +540,7 @@ class EasyTableSearcher(BaseSearcher): both lists are sorted by mmd dist """ RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list ] mmd_dist_list = [] for RKME in RKME_list: @@ -558,7 +560,11 @@ class EasyTableSearcher(BaseSearcher): max_search_num: int = 5, search_method: str = "greedy", ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: - user_rkme = user_info.stat_info["RKMEStatSpecification"] + if "TextRKMEStatSpecification" in user_info.stat_info: + self.stat_info_name = "TextRKMEStatSpecification" + else: + self.stat_info_name = "RKMEStatSpecification" + user_rkme = user_info.stat_info[self.stat_info_name] learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") @@ -633,5 +639,7 @@ class EasySearcher(BaseSearcher): return [], [], 0.0, [] elif "RKMEStatSpecification" in user_info.stat_info: return self.table_searcher(learnware_list, user_info, max_search_num, search_method) + elif "TextRKMEStatSpecification" in user_info.stat_info: + return self.table_searcher(learnware_list, user_info, max_search_num, search_method) else: return None, learnware_list, 0.0, None diff --git a/learnware/specification/text/rkme.py b/learnware/specification/text/rkme.py index 90fe4f2..ae3ab36 100644 --- a/learnware/specification/text/rkme.py +++ b/learnware/specification/text/rkme.py @@ -1,5 +1,5 @@ from sentence_transformers import SentenceTransformer -from ..rkme import RKMEStatSpecification +from ..table import RKMEStatSpecification import numpy as np import os diff --git a/setup.py b/setup.py index 52a4299..6f2861e 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,8 @@ REQUIRED = [ "geatpy>=2.7.0", "docker>=6.1.3", "rapidfuzz>=3.4.0", + "torchtext>=0.16.0", + "sentence_transformers>=2.2.2" ] if get_platform() != MACOS: diff --git a/tests/test_text_workflow/example_files/example_yaml.yaml b/tests/test_text_workflow/example_files/example_yaml.yaml index 6ca01c9..1df399b 100644 --- a/tests/test_text_workflow/example_files/example_yaml.yaml +++ b/tests/test_text_workflow/example_files/example_yaml.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: TextRKMEStatSpecification file_name: rkme.json kwargs: {} \ No newline at end of file diff --git a/tests/test_text_workflow/main.py b/tests/test_text_workflow/main.py index 8eb85fc..ff1c760 100644 --- a/tests/test_text_workflow/main.py +++ b/tests/test_text_workflow/main.py @@ -9,7 +9,7 @@ from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningR import time import pickle -from learnware.market import EasyMarket, BaseUserInfo +from learnware.market import instatiate_learnware_market, BaseUserInfo from learnware.market import database_ops from learnware.learnware import Learnware import learnware.specification as specification @@ -37,25 +37,33 @@ os.makedirs(user_save_root, exist_ok=True) os.makedirs(uploader_save_root, exist_ok=True) os.makedirs(model_save_root, exist_ok=True) - +output_description = { + "Dimension": 2, + "Description": { + "0": "the probability of being negative", + "1": "the probability of being positive", + }, +} semantic_specs = [ { "Data": {"Values": ["Text"], "Type": "Class"}, "Task": {"Values": ["Classification"], "Type": "Class"}, - "Library": {"Values": ["Pytorch"], "Type": "Class"}, + "Library": {"Values": ["PyTorch"], "Type": "Class"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, "Name": {"Values": "learnware_1", "Type": "String"}, + "Output": output_description } ] user_semantic = { "Data": {"Values": ["Text"], "Type": "Class"}, "Task": {"Values": ["Classification"], "Type": "Class"}, - "Library": {"Values": ["Pytorch"], "Type": "Class"}, + "Library": {"Values": ["PyTorch"], "Type": "Class"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, "Name": {"Values": "", "Type": "String"}, + "Output": output_description } @@ -112,7 +120,7 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo def prepare_market(): - text_market = EasyMarket(market_id="sst2", rebuild=True) + text_market = instatiate_learnware_market(market_id="sst2", rebuild=True) try: rmtree(learnware_pool_dir) except: @@ -138,10 +146,10 @@ def prepare_market(): def test_search(gamma=0.1, load_market=True): if load_market: - text_market = EasyMarket(market_id="sst2") + text_market = instatiate_learnware_market(market_id="sst2") else: prepare_market() - text_market = EasyMarket(market_id="sst2") + text_market = instatiate_learnware_market(market_id="sst2") logger.info("Number of items in the market: %d" % len(text_market)) select_list = [] @@ -162,7 +170,7 @@ def test_search(gamma=0.1, load_market=True): # user_stat_spec = specification.utils.generate_rkme_spec(X=user_data, gamma=gamma, cuda_idx=0) user_stat_spec = specification.TextRKMEStatSpecification() user_stat_spec.generate_stat_spec_from_text(X=user_data) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_stat_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"TextRKMEStatSpecification": user_stat_spec}) logger.info("Searching Market for user: %d" % (i)) sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = text_market.search_learnware( user_info From b9d92234f021acc544c1afafa88d1bacde1bb5cf Mon Sep 17 00:00:00 2001 From: nju-xy <1582857295@qq.com> Date: Mon, 30 Oct 2023 21:47:14 +0800 Subject: [PATCH 04/10] [FIX] job selector for text --- learnware/reuse/job_selector.py | 5 ++++- tests/test_text_workflow/main.py | 6 ++---- tests/test_text_workflow/utils.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index adbc328..c81bc12 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -94,8 +94,11 @@ class JobSelectorReuser(BaseReuser): ori_user_data = user_data if isinstance(user_data[0], str): user_data = sentence_embedding(user_data) + spec_name = "RKMEStatSpecification" + if "TextRKMEStatSpecification" in self.learnware_list[0].specification.stat_spec: + spec_name = "TextRKMEStatSpecification" learnware_rkme_spec_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") + learnware.specification.get_stat_spec_by_name(spec_name) for learnware in self.learnware_list ] diff --git a/tests/test_text_workflow/main.py b/tests/test_text_workflow/main.py index ff1c760..b75d2d0 100644 --- a/tests/test_text_workflow/main.py +++ b/tests/test_text_workflow/main.py @@ -24,7 +24,7 @@ processed_data_root = "./data/processed_data" tmp_dir = "./data/tmp" learnware_pool_dir = "./data/learnware_pool" dataset = "sst2" -n_uploaders = 50 +n_uploaders = 10 n_users = 5 n_classes = 2 data_root = os.path.join(origin_data_root, dataset) @@ -140,8 +140,6 @@ def prepare_market(): text_market.add_learnware(new_learnware_path, semantic_spec) logger.info("Total Item: %d" % (len(text_market))) - curr_inds = text_market._get_ids() - logger.info("Available ids: " + str(curr_inds)) def test_search(gamma=0.1, load_market=True): @@ -234,4 +232,4 @@ def test_search(gamma=0.1, load_market=True): if __name__ == "__main__": # prepare_data() # prepare_model() - test_search(load_market=False) + test_search(load_market=True) diff --git a/tests/test_text_workflow/utils.py b/tests/test_text_workflow/utils.py index 3be6a6a..d3acf24 100644 --- a/tests/test_text_workflow/utils.py +++ b/tests/test_text_workflow/utils.py @@ -156,7 +156,7 @@ def train(X, y, out_classes, epochs=35, batch_size=128): model.to(DEVICE) - num_epochs = 30 + num_epochs = 10 for e in range(num_epochs): for batch in train_dataloader: From c2313848112fea83cdd20e390671027614a3f2eb Mon Sep 17 00:00:00 2001 From: Gene Date: Tue, 31 Oct 2023 14:28:04 +0800 Subject: [PATCH 05/10] [MNT] format code by black --- learnware/market/easy2/searcher.py | 12 +++------ learnware/reuse/job_selector.py | 10 ++----- learnware/specification/regular/text/rkme.py | 10 ++++--- setup.py | 2 +- tests/test_specification/test_rkme.py | 10 +++---- .../example_files/example_init.py | 2 +- tests/test_text_workflow/get_data.py | 7 ++--- tests/test_text_workflow/main.py | 18 +++++++------ tests/test_text_workflow/utils.py | 27 ++++++++++--------- 9 files changed, 47 insertions(+), 51 deletions(-) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index d95584f..93307e7 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -251,9 +251,7 @@ class EasyTableSearcher(BaseSearcher): The second is the mmd dist between the mixture of learnware rkmes and the user's rkme """ learnware_num = len(learnware_list) - RKME_list = [ - learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list - ] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] if type(intermediate_K) == np.ndarray: K = intermediate_K @@ -320,9 +318,7 @@ class EasyTableSearcher(BaseSearcher): The second is the intermediate value of C """ num = intermediate_K.shape[0] - 1 - RKME_list = [ - learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list - ] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] for i in range(intermediate_K.shape[0]): intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1]) @@ -539,9 +535,7 @@ class EasyTableSearcher(BaseSearcher): the second is the list of Learnware both lists are sorted by mmd dist """ - RKME_list = [ - learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list - ] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] mmd_dist_list = [] for RKME in RKME_list: mmd_dist = RKME.dist(user_rkme) diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index 72b662e..c1acab0 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -98,8 +98,7 @@ class JobSelectorReuser(BaseReuser): if "RKMETextStatSpecification" in self.learnware_list[0].specification.stat_spec: spec_name = "RKMETextStatSpecification" learnware_rkme_spec_list = [ - learnware.specification.get_stat_spec_by_name(spec_name) - for learnware in self.learnware_list + learnware.specification.get_stat_spec_by_name(spec_name) for learnware in self.learnware_list ] if self.use_herding: @@ -238,12 +237,7 @@ class JobSelectorReuser(BaseReuser): max_depth = [66] params = (0, 0) - lgb_params = { - "boosting_type": "gbdt", - "n_estimators": 2000, - "boost_from_average": False, - "verbose": -1 - } + lgb_params = {"boosting_type": "gbdt", "n_estimators": 2000, "boost_from_average": False, "verbose": -1} if num_class == 2: lgb_params["objective"] = "binary" diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index a137f93..8cb9dbe 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -3,6 +3,7 @@ from ..table import RKMEStatSpecification import numpy as np import os + class RKMETextStatSpecification(RKMEStatSpecification): """Reduced Kernel Mean Embedding (RKME) Specification for Text""" @@ -32,13 +33,16 @@ class RKMETextStatSpecification(RKMEStatSpecification): reduce : bool, optional Whether shrink original data to a smaller set, by default True """ - + # Sentence embedding for Text X = self.get_sentence_embedding(X) - + # Generate specification return super().generate_stat_spec_from_data( - X, K, step_size,steps, + X, + K, + step_size, + steps, nonnegative_beta, reduce, ) diff --git a/setup.py b/setup.py index 6f2861e..552c4f5 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ REQUIRED = [ "docker>=6.1.3", "rapidfuzz>=3.4.0", "torchtext>=0.16.0", - "sentence_transformers>=2.2.2" + "sentence_transformers>=2.2.2", ] if get_platform() != MACOS: diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index c0f8850..bedd64e 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -27,7 +27,7 @@ class TestRKME(unittest.TestCase): rkme2 = RKMEStatSpecification() rkme2.load(rkme_path) assert rkme2.type == "RKMEStatSpecification" - + def test_text_rkme(self): def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): text_list = [] @@ -38,12 +38,12 @@ class TestRKME(unittest.TestCase): result_str = "".join(random.choice(characters) for i in range(length)) text_list.append(result_str) elif text_type == "zh": - result_str = "".join(chr(random.randint(0x4e00, 0x9fff)) for i in range(length)) + result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) text_list.append(result_str) else: raise ValueError("Type should be en or zh") return text_list - + def _test_text_rkme(X): rkme = specification.utils.generate_rkme_text_spec(X) @@ -58,14 +58,14 @@ class TestRKME(unittest.TestCase): rkme2 = RKMETextStatSpecification() rkme2.load(rkme_path) assert rkme2.type == "RKMETextStatSpecification" - + return rkme2.get_z().shape[1] dim1 = _test_text_rkme(generate_random_text_list(3000, "en")) dim2 = _test_text_rkme(generate_random_text_list(4000, "en")) dim3 = _test_text_rkme(generate_random_text_list(2000, "zh")) dim4 = _test_text_rkme(generate_random_text_list(5000, "zh")) - + assert dim1 == dim2 and dim2 == dim3 and dim3 == dim4 diff --git a/tests/test_text_workflow/example_files/example_init.py b/tests/test_text_workflow/example_files/example_init.py index 99839b2..a98b757 100644 --- a/tests/test_text_workflow/example_files/example_init.py +++ b/tests/test_text_workflow/example_files/example_init.py @@ -51,4 +51,4 @@ def sentence_preprocess(x_datapipe): x_datapipe = [text_transform(x) for x in x_datapipe] # x_datapipe = x_datapipe.map(text_transform) - return x_datapipe \ No newline at end of file + return x_datapipe diff --git a/tests/test_text_workflow/get_data.py b/tests/test_text_workflow/get_data.py index 03c413b..0d1412d 100644 --- a/tests/test_text_workflow/get_data.py +++ b/tests/test_text_workflow/get_data.py @@ -1,13 +1,14 @@ import torch from torchtext.datasets import SST2 + def get_sst2(data_root="./data"): - train_datapipe = SST2(root='./data', split="train") - + train_datapipe = SST2(root="./data", split="train") + X_train = [x[0] for x in train_datapipe] y_train = [x[1] for x in train_datapipe] - dev_datapipe = SST2(root='./data', split="dev") + dev_datapipe = SST2(root="./data", split="dev") X_test = [x[0] for x in dev_datapipe] y_test = [x[1] for x in dev_datapipe] diff --git a/tests/test_text_workflow/main.py b/tests/test_text_workflow/main.py index b76907c..ba97c6b 100644 --- a/tests/test_text_workflow/main.py +++ b/tests/test_text_workflow/main.py @@ -52,7 +52,7 @@ semantic_specs = [ "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, "Name": {"Values": "learnware_1", "Type": "String"}, - "Output": output_description + "Output": output_description, } ] @@ -63,7 +63,7 @@ user_semantic = { "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, "Name": {"Values": "", "Type": "String"}, - "Output": output_description + "Output": output_description, } @@ -93,8 +93,8 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo tmp_model_path = os.path.join(save_root, "model.pth") tmp_yaml_path = os.path.join(save_root, "learnware.yaml") tmp_init_path = os.path.join(save_root, "__init__.py") - - with open(data_path, 'rb') as f: + + with open(data_path, "rb") as f: X = pickle.load(f) semantic_spec = semantic_specs[0] @@ -159,9 +159,9 @@ def test_search(gamma=0.1, load_market=True): for i in range(n_users): user_data_path = os.path.join(user_save_root, "user_%d_X.pkl" % (i)) user_label_path = os.path.join(user_save_root, "user_%d_y.pkl" % (i)) - with open(user_data_path, 'rb') as f: + with open(user_data_path, "rb") as f: user_data = pickle.load(f) - with open(user_label_path, 'rb') as f: + with open(user_label_path, "rb") as f: user_label = pickle.load(f) # user_data = np.load(user_data_path) # user_label = np.load(user_label_path) @@ -222,10 +222,12 @@ def test_search(gamma=0.1, load_market=True): % (np.mean(job_selector_score_list), np.std(job_selector_score_list)) ) logger.info( - "Averaging Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) + "Averaging Ensemble Reuse Performance: %.3f +/- %.3f" + % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) ) logger.info( - "Selective Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(pruning_score_list), np.std(pruning_score_list)) + "Selective Ensemble Reuse Performance: %.3f +/- %.3f" + % (np.mean(pruning_score_list), np.std(pruning_score_list)) ) diff --git a/tests/test_text_workflow/utils.py b/tests/test_text_workflow/utils.py index d3acf24..c438d08 100644 --- a/tests/test_text_workflow/utils.py +++ b/tests/test_text_workflow/utils.py @@ -16,7 +16,6 @@ from torch.optim import AdamW from torch.utils.data import DataLoader - class TextDataLoader: def __init__(self, data_root, train: bool = True): self.data_root = data_root @@ -28,18 +27,18 @@ class TextDataLoader: y_path = os.path.join(self.data_root, "uploader", "uploader_%d_y.pkl" % (idx)) if not (os.path.exists(X_path) and os.path.exists(y_path)): raise Exception("Index Error") - with open(X_path, 'rb') as f: + with open(X_path, "rb") as f: X = pickle.load(f) - with open(y_path, 'rb') as f: + with open(y_path, "rb") as f: y = pickle.load(f) else: X_path = os.path.join(self.data_root, "user", "user_%d_X.pkl" % (idx)) y_path = os.path.join(self.data_root, "user", "user_%d_y.pkl" % (idx)) if not (os.path.exists(X_path) and os.path.exists(y_path)): raise Exception("Index Error") - with open(X_path, 'rb') as f: + with open(X_path, "rb") as f: X = pickle.load(f) - with open(y_path, 'rb') as f: + with open(y_path, "rb") as f: y = pickle.load(f) return X, y @@ -50,13 +49,13 @@ def generate_uploader(data_x, data_y, n_uploaders=50, data_save_root=None): os.makedirs(data_save_root, exist_ok=True) n = len(data_x) for i in range(n_uploaders): - selected_X = data_x[i * (n // n_uploaders): (i + 1) * (n // n_uploaders)] - selected_y = data_y[i * (n // n_uploaders): (i + 1) * (n // n_uploaders)] + selected_X = data_x[i * (n // n_uploaders) : (i + 1) * (n // n_uploaders)] + selected_y = data_y[i * (n // n_uploaders) : (i + 1) * (n // n_uploaders)] X_save_dir = os.path.join(data_save_root, "uploader_%d_X.pkl" % (i)) y_save_dir = os.path.join(data_save_root, "uploader_%d_y.pkl" % (i)) - with open(X_save_dir, 'wb') as f: + with open(X_save_dir, "wb") as f: pickle.dump(selected_X, f) - with open(y_save_dir, 'wb') as f: + with open(y_save_dir, "wb") as f: pickle.dump(selected_y, f) print("Saving to %s" % (X_save_dir)) @@ -67,13 +66,13 @@ def generate_user(data_x, data_y, n_users=50, data_save_root=None): os.makedirs(data_save_root, exist_ok=True) n = len(data_x) for i in range(n_users): - selected_X = data_x[i * (n // n_users): (i + 1) * (n // n_users)] - selected_y = data_y[i * (n // n_users): (i + 1) * (n // n_users)] + selected_X = data_x[i * (n // n_users) : (i + 1) * (n // n_users)] + selected_y = data_y[i * (n // n_users) : (i + 1) * (n // n_users)] X_save_dir = os.path.join(data_save_root, "user_%d_X.pkl" % (i)) y_save_dir = os.path.join(data_save_root, "user_%d_y.pkl" % (i)) - with open(X_save_dir, 'wb') as f: + with open(X_save_dir, "wb") as f: pickle.dump(selected_X, f) - with open(y_save_dir, 'wb') as f: + with open(y_save_dir, "wb") as f: pickle.dump(selected_y, f) print("Saving to %s" % (X_save_dir)) @@ -134,10 +133,12 @@ def evaluate(model, criteria, dev_dataloader): DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Train Uploaders' models def train(X, y, out_classes, epochs=35, batch_size=128): # print(X.shape, y.shape) from torchdata.datapipes.iter import IterableWrapper + X = sentence_preprocess(X) data_size = len(X) train_datapipe = list(zip(X, y)) From e0c9b6186b1a6b3b8256d32f1b8b2c864f237a9f Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 31 Oct 2023 21:31:45 +0800 Subject: [PATCH 06/10] [MNT] rename RKMEStatSpecification to RKMETableSpecification --- learnware/market/easy2/searcher.py | 7 +------ learnware/reuse/job_selector.py | 2 +- learnware/specification/regular/text/rkme.py | 4 ++-- tests/test_specification/test_rkme.py | 3 +-- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 13038d2..0308ae8 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -631,14 +631,9 @@ class EasySearcher(BaseSearcher): if len(learnware_list) == 0: return [], [], 0.0, [] -<<<<<<< HEAD - elif "RKMEStatSpecification" in user_info.stat_info: + elif "RKMETableSpecification" in user_info.stat_info: return self.table_searcher(learnware_list, user_info, max_search_num, search_method) elif "RKMETextStatSpecification" in user_info.stat_info: return self.table_searcher(learnware_list, user_info, max_search_num, search_method) -======= - elif "RKMETableSpecification" in user_info.stat_info: - return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) ->>>>>>> b0aaae48e77fb5d49d2b7a1c31a2023580ea2115 else: return None, learnware_list, 0.0, None diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index 84d0bb7..46c8951 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -94,7 +94,7 @@ class JobSelectorReuser(BaseReuser): ori_user_data = user_data if isinstance(user_data[0], str): user_data = RKMETextStatSpecification.get_sentence_embedding(user_data) - spec_name = "RKMEStatSpecification" + spec_name = "RKMETableSpecification" if "RKMETextStatSpecification" in self.learnware_list[0].specification.stat_spec: spec_name = "RKMETextStatSpecification" learnware_rkme_spec_list = [ diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index 8cb9dbe..1483fb4 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -1,10 +1,10 @@ from sentence_transformers import SentenceTransformer -from ..table import RKMEStatSpecification +from ..table import RKMETableSpecification import numpy as np import os -class RKMETextStatSpecification(RKMEStatSpecification): +class RKMETextStatSpecification(RKMETableSpecification): """Reduced Kernel Mean Embedding (RKME) Specification for Text""" def generate_stat_spec_from_data( diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index 680e668..f46cbf7 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -8,8 +8,7 @@ import tempfile import numpy as np import learnware.specification as specification -from learnware.specification import RKMEStatSpecification, RKMETextStatSpecification -from learnware.specification import RKMETableSpecification, RKMEImageSpecification +from learnware.specification import RKMETableSpecification, RKMEImageSpecification, RKMETextStatSpecification from learnware.specification import generate_rkme_image_spec, generate_rkme_spec From e02bd2fd294a21b48c0230bd051ab36147c26cd0 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 31 Oct 2023 21:33:14 +0800 Subject: [PATCH 07/10] [MNT] fix rkme base class --- learnware/specification/__init__.py | 7 ++++++- learnware/specification/regular/image/rkme.py | 4 ++-- learnware/specification/utils.py | 4 ++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index bd94261..fcdfaa7 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,8 @@ from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec from .base import Specification, BaseStatSpecification -from .regular import RegularStatsSpecification, RKMETableSpecification, RKMEImageSpecification, RKMETextStatSpecification +from .regular import ( + RegularStatsSpecification, + RKMETableSpecification, + RKMEImageSpecification, + RKMETextStatSpecification, +) diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index 1f05382..4421f91 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -17,11 +17,11 @@ from torchvision.transforms import Resize from tqdm import tqdm from . import cnn_gp -from ..base import BaseStatSpecification +from ..base import RegularStatsSpecification from ..table.rkme import solve_qp, choose_device, setup_seed -class RKMEImageSpecification(BaseStatSpecification): +class RKMEImageSpecification(RegularStatsSpecification): # INNER_PRODUCT_COUNT = 0 IMAGE_WIDTH = 32 diff --git a/learnware/specification/utils.py b/learnware/specification/utils.py index dcafb69..085d9e8 100644 --- a/learnware/specification/utils.py +++ b/learnware/specification/utils.py @@ -98,6 +98,7 @@ def generate_rkme_spec( rkme_spec.generate_stat_spec_from_data(X, reduced_set_size, step_size, steps, nonnegative_beta, reduce) return rkme_spec + def generate_rkme_image_spec( X: Union[np.ndarray, torch.Tensor], reduced_set_size: int = 50, @@ -162,6 +163,7 @@ def generate_rkme_image_spec( ) return rkme_image_spec + def generate_rkme_text_spec( X: List[str], gamma: float = 0.1, @@ -219,8 +221,6 @@ def generate_rkme_text_spec( return rkme_text_spec - - def generate_stat_spec(X: np.ndarray) -> BaseStatSpecification: """ Interface for users to generate statistical specification. From aabf125a2c4b6f473b0daa80b929c46938fde7ea Mon Sep 17 00:00:00 2001 From: nju-xy <1582857295@qq.com> Date: Wed, 1 Nov 2023 09:40:52 +0800 Subject: [PATCH 08/10] [MNT] a more complex text checker --- learnware/market/easy2/checker.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index 70a7c15..1ed811b 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -1,6 +1,8 @@ import traceback import numpy as np import torch +import random +import string from ..base import BaseChecker from ...config import C @@ -85,13 +87,27 @@ class EasyStatisticalChecker(BaseChecker): if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") return self.INVALID_LEARNWARE - + + def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): + text_list = [] + for i in range(num): + length = random.randint(min_len, max_len) + if text_type == "en": + characters = string.ascii_letters + string.digits + string.punctuation + result_str = "".join(random.choice(characters) for i in range(length)) + text_list.append(result_str) + elif text_type == "zh": + result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) + text_list.append(result_str) + else: + raise ValueError("Type should be en or zh") + return text_list + if is_text: - inputs = ["This is an example sentence"] + inputs = generate_random_text_list(10) else: inputs = np.random.randn(10, *input_shape) outputs = learnware.predict(inputs) - # Check output if outputs.ndim == 1: outputs = outputs.reshape(-1, 1) From 49b6b475d2afb4dfe6e66b72099cd950cba91a9f Mon Sep 17 00:00:00 2001 From: nju-xy <1582857295@qq.com> Date: Wed, 1 Nov 2023 19:55:53 +0800 Subject: [PATCH 09/10] [MNT] language settings for text --- learnware/market/easy2/searcher.py | 7 ++++-- learnware/reuse/job_selector.py | 2 +- learnware/specification/regular/text/rkme.py | 23 ++++++++++++++++++-- setup.py | 1 + 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 0308ae8..5ef4ae3 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -438,6 +438,9 @@ class EasyStatSearcher(BaseSearcher): if self.stat_info_name not in learnware.specification.stat_spec: continue rkme = learnware.specification.get_stat_spec_by_name(self.stat_info_name) + if self.stat_info_name == "RKMETextStatSpecification": + if not set(user_rkme.language).issubset(set(rkme.language)): + continue rkme_dim = str(list(rkme.get_z().shape)[1:]) if rkme_dim == user_rkme_dim: filtered_learnware_list.append(learnware) @@ -632,8 +635,8 @@ class EasySearcher(BaseSearcher): if len(learnware_list) == 0: return [], [], 0.0, [] elif "RKMETableSpecification" in user_info.stat_info: - return self.table_searcher(learnware_list, user_info, max_search_num, search_method) + return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) elif "RKMETextStatSpecification" in user_info.stat_info: - return self.table_searcher(learnware_list, user_info, max_search_num, search_method) + return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: return None, learnware_list, 0.0, None diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index 46c8951..ee58399 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -95,7 +95,7 @@ class JobSelectorReuser(BaseReuser): if isinstance(user_data[0], str): user_data = RKMETextStatSpecification.get_sentence_embedding(user_data) spec_name = "RKMETableSpecification" - if "RKMETextStatSpecification" in self.learnware_list[0].specification.stat_spec: + if len(self.learnware_list) and "RKMETextStatSpecification" in self.learnware_list[0].specification.stat_spec: spec_name = "RKMETextStatSpecification" learnware_rkme_spec_list = [ learnware.specification.get_stat_spec_by_name(spec_name) for learnware in self.learnware_list diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index 1483fb4..7691bb9 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -2,10 +2,17 @@ from sentence_transformers import SentenceTransformer from ..table import RKMETableSpecification import numpy as np import os +import langdetect +from ....logger import get_module_logger + +logger = get_module_logger("RKMETextStatSpecification", "INFO") class RKMETextStatSpecification(RKMETableSpecification): """Reduced Kernel Mean Embedding (RKME) Specification for Text""" + def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): + RKMETableSpecification.__init__(self, gamma, cuda_idx) + self.language = [] def generate_stat_spec_from_data( self, @@ -35,6 +42,8 @@ class RKMETextStatSpecification(RKMETableSpecification): """ # Sentence embedding for Text + self.language = self.get_language_ids(X) + logger.info("The text learnware's language: %s" % (self.language)) X = self.get_sentence_embedding(X) # Generate specification @@ -47,12 +56,22 @@ class RKMETextStatSpecification(RKMETableSpecification): reduce, ) + @staticmethod + def get_language_ids(X): + try: + text = ' '.join(X) + lang = langdetect.detect(text) + langs = langdetect.detect_langs(text) + return [l.lang for l in langs] + except Exception as e: + logger.warning("Language detection failed.") + return [] + @staticmethod def get_sentence_embedding(X): os.environ["TOKENIZERS_PARALLELISM"] = "false" model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") X = model.encode(X) X = np.array(X) - # print(X.shape, np.mean(X, axis=1).shape, np.std(X, axis=1).reshape(X.shape[0], 1).shape) - # X = (X - np.mean(X, axis=1).reshape(X.shape[0], 1)) / np.std(X, axis=1).reshape(X.shape[0], 1) + # X /= np.sqrt(np.sum(X ** 2, axis=1)).reshape((-1, 1)) return X diff --git a/setup.py b/setup.py index 8119fd9..1e8ea72 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,7 @@ REQUIRED = [ "torchtext>=0.16.0", "sentence_transformers>=2.2.2", "torch-optimizer>=0.3.0", + "langdetect>=1.0.9", ] if get_platform() != MACOS: From 64f59c92b10b0ec64e4d5cc14c607c6ef4cf01b0 Mon Sep 17 00:00:00 2001 From: nju-xy <1582857295@qq.com> Date: Wed, 1 Nov 2023 20:52:11 +0800 Subject: [PATCH 10/10] [MNT] rename RKMETextSpecification --- learnware/market/easy2/checker.py | 4 ++-- learnware/market/easy2/searcher.py | 8 ++++---- learnware/reuse/job_selector.py | 12 ++++++------ learnware/specification/__init__.py | 2 +- learnware/specification/regular/__init__.py | 2 +- learnware/specification/regular/text/__init__.py | 2 +- learnware/specification/regular/text/rkme.py | 4 ++-- learnware/specification/utils.py | 12 ++++++------ tests/test_specification/test_rkme.py | 8 ++++---- .../example_files/example_yaml.yaml | 2 +- tests/test_text_workflow/main.py | 12 ++++++------ 11 files changed, 34 insertions(+), 34 deletions(-) diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index c142029..1fc0fef 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -82,9 +82,9 @@ class EasyStatisticalChecker(BaseChecker): input_shape = learnware_model.input_shape # Check rkme dimension - is_text = "RKMETextStatSpecification" in learnware.get_specification().stat_spec + is_text = "RKMETextSpecification" in learnware.get_specification().stat_spec if is_text: - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETextStatSpecification") + stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETextSpecification") else: stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") if stat_spec is not None and not is_text: diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 5ef4ae3..8934fda 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -438,7 +438,7 @@ class EasyStatSearcher(BaseSearcher): if self.stat_info_name not in learnware.specification.stat_spec: continue rkme = learnware.specification.get_stat_spec_by_name(self.stat_info_name) - if self.stat_info_name == "RKMETextStatSpecification": + if self.stat_info_name == "RKMETextSpecification": if not set(user_rkme.language).issubset(set(rkme.language)): continue rkme_dim = str(list(rkme.get_z().shape)[1:]) @@ -557,8 +557,8 @@ class EasyStatSearcher(BaseSearcher): max_search_num: int = 5, search_method: str = "greedy", ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: - if "RKMETextStatSpecification" in user_info.stat_info: - self.stat_info_name = "RKMETextStatSpecification" + if "RKMETextSpecification" in user_info.stat_info: + self.stat_info_name = "RKMETextSpecification" else: self.stat_info_name = "RKMETableSpecification" user_rkme = user_info.stat_info[self.stat_info_name] @@ -636,7 +636,7 @@ class EasySearcher(BaseSearcher): return [], [], 0.0, [] elif "RKMETableSpecification" in user_info.stat_info: return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) - elif "RKMETextStatSpecification" in user_info.stat_info: + elif "RKMETextSpecification" in user_info.stat_info: return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: return None, learnware_list, 0.0, None diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index ee58399..9de299b 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -9,7 +9,7 @@ from sklearn.metrics import accuracy_score from learnware.learnware import Learnware import learnware.specification as specification from .base import BaseReuser -from ..specification import RKMETableSpecification, RKMETextStatSpecification +from ..specification import RKMETableSpecification, RKMETextSpecification from ..logger import get_module_logger logger = get_module_logger("job_selector_reuse") @@ -47,7 +47,7 @@ class JobSelectorReuser(BaseReuser): """ ori_user_data = user_data if isinstance(user_data[0], str): - user_data = RKMETextStatSpecification.get_sentence_embedding(user_data) + user_data = RKMETextSpecification.get_sentence_embedding(user_data) select_result = self.job_selector(user_data) pred_y_list = [] @@ -93,10 +93,10 @@ class JobSelectorReuser(BaseReuser): else: ori_user_data = user_data if isinstance(user_data[0], str): - user_data = RKMETextStatSpecification.get_sentence_embedding(user_data) + user_data = RKMETextSpecification.get_sentence_embedding(user_data) spec_name = "RKMETableSpecification" - if len(self.learnware_list) and "RKMETextStatSpecification" in self.learnware_list[0].specification.stat_spec: - spec_name = "RKMETextStatSpecification" + if len(self.learnware_list) and "RKMETextSpecification" in self.learnware_list[0].specification.stat_spec: + spec_name = "RKMETextSpecification" learnware_rkme_spec_list = [ learnware.specification.get_stat_spec_by_name(spec_name) for learnware in self.learnware_list ] @@ -180,7 +180,7 @@ class JobSelectorReuser(BaseReuser): """ task_num = len(task_rkme_list) if isinstance(user_data[0], str): - user_data = RKMETextStatSpecification.get_sentence_embedding(user_data) + user_data = RKMETextSpecification.get_sentence_embedding(user_data) user_rkme_spec = specification.utils.generate_rkme_spec(X=user_data, reduce=False) K = task_rkme_matrix v = np.array([user_rkme_spec.inner_prod(task_rkme) for task_rkme in task_rkme_list]) diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 54dae1f..7fbf500 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,3 @@ from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec from .base import Specification, BaseStatSpecification -from .regular import RegularStatsSpecification, RKMEStatSpecification, RKMETableSpecification, RKMEImageSpecification +from .regular import RegularStatsSpecification, RKMEStatSpecification, RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification diff --git a/learnware/specification/regular/__init__.py b/learnware/specification/regular/__init__.py index 1731a2f..9007e4d 100644 --- a/learnware/specification/regular/__init__.py +++ b/learnware/specification/regular/__init__.py @@ -1,4 +1,4 @@ -from .text import RKMETextStatSpecification +from .text import RKMETextSpecification from .table import RKMETableSpecification, RKMEStatSpecification from .image import RKMEImageSpecification from .base import RegularStatsSpecification diff --git a/learnware/specification/regular/text/__init__.py b/learnware/specification/regular/text/__init__.py index fe9abd0..35b8b0a 100644 --- a/learnware/specification/regular/text/__init__.py +++ b/learnware/specification/regular/text/__init__.py @@ -1 +1 @@ -from .rkme import RKMETextStatSpecification +from .rkme import RKMETextSpecification diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index 7691bb9..cc8659e 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -5,10 +5,10 @@ import os import langdetect from ....logger import get_module_logger -logger = get_module_logger("RKMETextStatSpecification", "INFO") +logger = get_module_logger("RKMETextSpecification", "INFO") -class RKMETextStatSpecification(RKMETableSpecification): +class RKMETextSpecification(RKMETableSpecification): """Reduced Kernel Mean Embedding (RKME) Specification for Text""" def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): RKMETableSpecification.__init__(self, gamma, cuda_idx) diff --git a/learnware/specification/utils.py b/learnware/specification/utils.py index 085d9e8..09d66c1 100644 --- a/learnware/specification/utils.py +++ b/learnware/specification/utils.py @@ -4,7 +4,7 @@ import pandas as pd from typing import Union, List from .base import BaseStatSpecification -from .regular import RKMETableSpecification, RKMEImageSpecification, RKMETextStatSpecification +from .regular import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification from ..config import C @@ -173,10 +173,10 @@ def generate_rkme_text_spec( nonnegative_beta: bool = True, reduce: bool = True, cuda_idx: int = None, -) -> RKMETextStatSpecification: +) -> RKMETextSpecification: """ Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification for Text. - Return a RKMETextStatSpecification object, use .save() method to save as json file. + Return a RKMETextSpecification object, use .save() method to save as json file. Parameters ---------- @@ -200,8 +200,8 @@ def generate_rkme_text_spec( Returns ------- - RKMETextStatSpecification - A RKMETextStatSpecification object + RKMETextSpecification + A RKMETextSpecification object """ # Check input type if not isinstance(X, list) or not all(isinstance(item, str) for item in X): @@ -216,7 +216,7 @@ def generate_rkme_text_spec( cuda_idx = 0 # Generate rkme text spec - rkme_text_spec = RKMETextStatSpecification(gamma=gamma, cuda_idx=cuda_idx) + rkme_text_spec = RKMETextSpecification(gamma=gamma, cuda_idx=cuda_idx) rkme_text_spec.generate_stat_spec_from_data(X, reduced_set_size, step_size, steps, nonnegative_beta, reduce) return rkme_text_spec diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index f46cbf7..143bf22 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -8,7 +8,7 @@ import tempfile import numpy as np import learnware.specification as specification -from learnware.specification import RKMETableSpecification, RKMEImageSpecification, RKMETextStatSpecification +from learnware.specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification from learnware.specification import generate_rkme_image_spec, generate_rkme_spec @@ -79,11 +79,11 @@ class TestRKME(unittest.TestCase): with open(rkme_path, "r") as f: data = json.load(f) - assert data["type"] == "RKMETextStatSpecification" + assert data["type"] == "RKMETextSpecification" - rkme2 = RKMETextStatSpecification() + rkme2 = RKMETextSpecification() rkme2.load(rkme_path) - assert rkme2.type == "RKMETextStatSpecification" + assert rkme2.type == "RKMETextSpecification" return rkme2.get_z().shape[1] diff --git a/tests/test_text_workflow/example_files/example_yaml.yaml b/tests/test_text_workflow/example_files/example_yaml.yaml index 73474a2..f9817c7 100644 --- a/tests/test_text_workflow/example_files/example_yaml.yaml +++ b/tests/test_text_workflow/example_files/example_yaml.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMETextStatSpecification + class_name: RKMETextSpecification file_name: rkme.json kwargs: {} \ No newline at end of file diff --git a/tests/test_text_workflow/main.py b/tests/test_text_workflow/main.py index 9ed6fb2..baa54f4 100644 --- a/tests/test_text_workflow/main.py +++ b/tests/test_text_workflow/main.py @@ -100,7 +100,7 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo st = time.time() # user_spec = specification.utils.generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0) - user_spec = specification.RKMETextStatSpecification() + user_spec = specification.RKMETextSpecification() user_spec.generate_stat_spec_from_data(X=X) ed = time.time() logger.info("Stat spec generated in %.3f s" % (ed - st)) @@ -166,9 +166,9 @@ def test_search(gamma=0.1, load_market=True): # user_data = np.load(user_data_path) # user_label = np.load(user_label_path) # user_stat_spec = specification.utils.generate_rkme_spec(X=user_data, gamma=gamma, cuda_idx=0) - user_stat_spec = specification.RKMETextStatSpecification() + user_stat_spec = specification.RKMETextSpecification() user_stat_spec.generate_stat_spec_from_data(X=user_data) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETextStatSpecification": user_stat_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETextSpecification": user_stat_spec}) logger.info("Searching Market for user: %d" % (i)) sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = text_market.search_learnware( user_info @@ -232,6 +232,6 @@ def test_search(gamma=0.1, load_market=True): if __name__ == "__main__": - # prepare_data() - # prepare_model() - test_search(load_market=True) + prepare_data() + prepare_model() + test_search(load_market=False)