diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index 4421f91..df7a2a3 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -18,7 +18,7 @@ from tqdm import tqdm from . import cnn_gp from ..base import RegularStatsSpecification -from ..table.rkme import solve_qp, choose_device, setup_seed +from ..table.rkme import rkme_solve_qp, choose_device, setup_seed class RKMEImageSpecification(RegularStatsSpecification): @@ -97,6 +97,11 @@ class RKMEImageSpecification(RegularStatsSpecification): ------- """ + if len(X.shape) != 4: + raise ValueError("X should be in shape of [N, C, {0:d}, {0:d}]. ".format( + RKMEImageSpecification.IMAGE_WIDTH + )) + if ( X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH ) and not resize: @@ -175,7 +180,7 @@ class RKMEImageSpecification(RegularStatsSpecification): C = torch.sum(C, dim=1) / x_features.shape[0] if nonnegative_beta: - beta = solve_qp(K.double(), C.double()).to(self.device) + beta = rkme_solve_qp(K.double(), C.double())[0].to(self.device) else: beta = torch.linalg.inv(K + torch.eye(K.shape[0]).to(self.device) * 1e-5) @ C