|
|
|
@@ -7,10 +7,10 @@ from tqdm import trange |
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
from ..align import AlignLearnware |
|
|
|
from ...utils import choose_device |
|
|
|
from ...logger import get_module_logger |
|
|
|
from ...learnware import Learnware |
|
|
|
from ...specification import RKMETableSpecification |
|
|
|
from ...specification.regular.table.rkme import choose_device |
|
|
|
|
|
|
|
logger = get_module_logger("feature_align") |
|
|
|
|
|
|
|
@@ -60,7 +60,7 @@ class FeatureAlignLearnware(AlignLearnware): |
|
|
|
user_rkme : RKMETableSpecification |
|
|
|
The RKME specification from the user dataset. |
|
|
|
""" |
|
|
|
target_rkme = self.learnware.specification.get_stat_spec()["RKMETableSpecification"] |
|
|
|
target_rkme = self.specification.get_stat_spec()["RKMETableSpecification"] |
|
|
|
trainer = FeatureAlignTrainer( |
|
|
|
target_rkme=target_rkme, user_rkme=user_rkme, cuda_idx=self.cuda_idx, **self.align_arguments |
|
|
|
) |
|
|
|
@@ -86,7 +86,7 @@ class FeatureAlignLearnware(AlignLearnware): |
|
|
|
transformed_user_data = ( |
|
|
|
self.align_model(torch.tensor(user_data, device=self.device).float()).detach().cpu().numpy() |
|
|
|
) |
|
|
|
y_pred = self.learnware.predict(transformed_user_data) |
|
|
|
y_pred = super(FeatureAlignLearnware, self).predict(transformed_user_data) |
|
|
|
return y_pred |
|
|
|
|
|
|
|
def _fill_data(self, X: np.ndarray): |
|
|
|
|