Browse Source

[Fix] Adapting to function rkme_solve_qp updates

tags/v0.3.2
shihy 2 years ago
parent
commit
c7bbef5388
1 changed files with 7 additions and 2 deletions
  1. +7
    -2
      learnware/specification/regular/image/rkme.py

+ 7
- 2
learnware/specification/regular/image/rkme.py View File

@@ -18,7 +18,7 @@ from tqdm import tqdm

from . import cnn_gp
from ..base import RegularStatsSpecification
from ..table.rkme import solve_qp, choose_device, setup_seed
from ..table.rkme import rkme_solve_qp, choose_device, setup_seed


class RKMEImageSpecification(RegularStatsSpecification):
@@ -97,6 +97,11 @@ class RKMEImageSpecification(RegularStatsSpecification):
-------

"""
if len(X.shape) != 4:
raise ValueError("X should be in shape of [N, C, {0:d}, {0:d}]. ".format(
RKMEImageSpecification.IMAGE_WIDTH
))

if (
X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH
) and not resize:
@@ -175,7 +180,7 @@ class RKMEImageSpecification(RegularStatsSpecification):
C = torch.sum(C, dim=1) / x_features.shape[0]

if nonnegative_beta:
beta = solve_qp(K.double(), C.double()).to(self.device)
beta = rkme_solve_qp(K.double(), C.double())[0].to(self.device)
else:
beta = torch.linalg.inv(K + torch.eye(K.shape[0]).to(self.device) * 1e-5) @ C



Loading…
Cancel
Save