|
|
|
@@ -18,7 +18,7 @@ from tqdm import tqdm |
|
|
|
|
|
|
|
from . import cnn_gp |
|
|
|
from ..base import RegularStatsSpecification |
|
|
|
from ..table.rkme import solve_qp, choose_device, setup_seed |
|
|
|
from ..table.rkme import rkme_solve_qp, choose_device, setup_seed |
|
|
|
|
|
|
|
|
|
|
|
class RKMEImageSpecification(RegularStatsSpecification): |
|
|
|
@@ -97,6 +97,11 @@ class RKMEImageSpecification(RegularStatsSpecification): |
|
|
|
------- |
|
|
|
|
|
|
|
""" |
|
|
|
if len(X.shape) != 4: |
|
|
|
raise ValueError("X should be in shape of [N, C, {0:d}, {0:d}]. ".format( |
|
|
|
RKMEImageSpecification.IMAGE_WIDTH |
|
|
|
)) |
|
|
|
|
|
|
|
if ( |
|
|
|
X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH |
|
|
|
) and not resize: |
|
|
|
@@ -175,7 +180,7 @@ class RKMEImageSpecification(RegularStatsSpecification): |
|
|
|
C = torch.sum(C, dim=1) / x_features.shape[0] |
|
|
|
|
|
|
|
if nonnegative_beta: |
|
|
|
beta = solve_qp(K.double(), C.double()).to(self.device) |
|
|
|
beta = rkme_solve_qp(K.double(), C.double())[0].to(self.device) |
|
|
|
else: |
|
|
|
beta = torch.linalg.inv(K + torch.eye(K.shape[0]).to(self.device) * 1e-5) @ C |
|
|
|
|
|
|
|
|