| @@ -18,7 +18,7 @@ from tqdm import tqdm | |||||
| from . import cnn_gp | from . import cnn_gp | ||||
| from ..base import RegularStatsSpecification | from ..base import RegularStatsSpecification | ||||
| from ..table.rkme import solve_qp, choose_device, setup_seed | |||||
| from ..table.rkme import rkme_solve_qp, choose_device, setup_seed | |||||
| class RKMEImageSpecification(RegularStatsSpecification): | class RKMEImageSpecification(RegularStatsSpecification): | ||||
| @@ -97,6 +97,11 @@ class RKMEImageSpecification(RegularStatsSpecification): | |||||
| ------- | ------- | ||||
| """ | """ | ||||
| if len(X.shape) != 4: | |||||
| raise ValueError("X should be in shape of [N, C, {0:d}, {0:d}]. ".format( | |||||
| RKMEImageSpecification.IMAGE_WIDTH | |||||
| )) | |||||
| if ( | if ( | ||||
| X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH | X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH | ||||
| ) and not resize: | ) and not resize: | ||||
| @@ -175,7 +180,7 @@ class RKMEImageSpecification(RegularStatsSpecification): | |||||
| C = torch.sum(C, dim=1) / x_features.shape[0] | C = torch.sum(C, dim=1) / x_features.shape[0] | ||||
| if nonnegative_beta: | if nonnegative_beta: | ||||
| beta = solve_qp(K.double(), C.double()).to(self.device) | |||||
| beta = rkme_solve_qp(K.double(), C.double())[0].to(self.device) | |||||
| else: | else: | ||||
| beta = torch.linalg.inv(K + torch.eye(K.shape[0]).to(self.device) * 1e-5) @ C | beta = torch.linalg.inv(K + torch.eye(K.shape[0]).to(self.device) * 1e-5) @ C | ||||