|
|
|
@@ -15,13 +15,13 @@ from typing import Tuple, Any, List, Union, Dict |
|
|
|
import scipy |
|
|
|
from sklearn.cluster import MiniBatchKMeans |
|
|
|
|
|
|
|
try: |
|
|
|
import faiss |
|
|
|
# try: |
|
|
|
# import faiss |
|
|
|
|
|
|
|
ver = faiss.__version__ |
|
|
|
_FAISS_INSTALLED = ver >= "1.7.1" |
|
|
|
except ImportError: |
|
|
|
_FAISS_INSTALLED = False |
|
|
|
# ver = faiss.__version__ |
|
|
|
# _FAISS_INSTALLED = ver >= "1.7.1" |
|
|
|
# except ImportError: |
|
|
|
# _FAISS_INSTALLED = False |
|
|
|
|
|
|
|
from ..base import RegularStatsSpecification |
|
|
|
from ....logger import get_module_logger |
|
|
|
@@ -129,7 +129,7 @@ class RKMETableSpecification(RegularStatsSpecification): |
|
|
|
return |
|
|
|
|
|
|
|
# Initialize Z by clustering, utiliing kmeans or faiss to speed up the process. |
|
|
|
self._init_z_by_faiss(X, K) |
|
|
|
self._init_z_by_kmeans(X, K) |
|
|
|
self._update_beta(X, nonnegative_beta) |
|
|
|
|
|
|
|
# Alternating optimize Z and beta |
|
|
|
@@ -140,22 +140,22 @@ class RKMETableSpecification(RegularStatsSpecification): |
|
|
|
# Reshape to original dimensions |
|
|
|
self.z = self.z.reshape(Z_shape) |
|
|
|
|
|
|
|
def _init_z_by_faiss(self, X: Union[np.ndarray, torch.tensor], K: int): |
|
|
|
"""Intialize Z by faiss clustering. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
X : np.ndarray or torch.tensor |
|
|
|
Raw data in np.ndarray format or torch.tensor format. |
|
|
|
K : int |
|
|
|
Size of the construced reduced set. |
|
|
|
""" |
|
|
|
X = X.astype("float32") |
|
|
|
numDim = X.shape[1] |
|
|
|
kmeans = faiss.Kmeans(numDim, K, niter=100, verbose=False) |
|
|
|
kmeans.train(X) |
|
|
|
center = torch.from_numpy(kmeans.centroids).double() |
|
|
|
self.z = center |
|
|
|
# def _init_z_by_faiss(self, X: Union[np.ndarray, torch.tensor], K: int): |
|
|
|
# """Intialize Z by faiss clustering. |
|
|
|
|
|
|
|
# Parameters |
|
|
|
# ---------- |
|
|
|
# X : np.ndarray or torch.tensor |
|
|
|
# Raw data in np.ndarray format or torch.tensor format. |
|
|
|
# K : int |
|
|
|
# Size of the construced reduced set. |
|
|
|
# """ |
|
|
|
# X = X.astype("float32") |
|
|
|
# numDim = X.shape[1] |
|
|
|
# kmeans = faiss.Kmeans(numDim, K, niter=100, verbose=False) |
|
|
|
# kmeans.train(X) |
|
|
|
# center = torch.from_numpy(kmeans.centroids).double() |
|
|
|
# self.z = center |
|
|
|
|
|
|
|
def _init_z_by_kmeans(self, X: Union[np.ndarray, torch.tensor], K: int): |
|
|
|
"""Intialize Z by kmeans clustering. |
|
|
|
@@ -168,7 +168,7 @@ class RKMETableSpecification(RegularStatsSpecification): |
|
|
|
Size of the construced reduced set. |
|
|
|
""" |
|
|
|
X = X.astype("float32") |
|
|
|
kmeans = MiniBatchKMeans(n_clusters=K, max_iter=100, verbose=False, n_init="auto") |
|
|
|
kmeans = MiniBatchKMeans(n_clusters=K, max_iter=100, verbose=False, n_init='auto') |
|
|
|
kmeans.fit(X) |
|
|
|
center = torch.from_numpy(kmeans.cluster_centers_).double() |
|
|
|
self.z = center |
|
|
|
@@ -578,6 +578,7 @@ def rkme_solve_qp(K: np.ndarray, C: np.ndarray): |
|
|
|
K = K.cpu().numpy() |
|
|
|
if torch.is_tensor(C): |
|
|
|
C = C.cpu().numpy() |
|
|
|
|
|
|
|
n = K.shape[0] |
|
|
|
P = np.array(K) |
|
|
|
P = scipy.sparse.csc_matrix(P) |
|
|
|
@@ -588,11 +589,25 @@ def rkme_solve_qp(K: np.ndarray, C: np.ndarray): |
|
|
|
A = np.array(np.ones((1, n))) |
|
|
|
A = scipy.sparse.csc_matrix(A) |
|
|
|
b = np.array(np.ones((1, 1))) |
|
|
|
|
|
|
|
# sol = solve_qp(P, q, G, h, A, b, solver="clarabel") # Requires the sum of x to be 1 |
|
|
|
# sol = solver_qp(P, q, G, h, solver="clarabel") # Otherwise |
|
|
|
problem = Problem(P, q, G, h, A, b) |
|
|
|
sol = solve_problem(problem, solver="clarabel") |
|
|
|
w = sol.x |
|
|
|
solution = solve_problem(problem, solver="clarabel") |
|
|
|
w = solution.x |
|
|
|
w = torch.from_numpy(w).reshape(-1) |
|
|
|
return w, sol.obj |
|
|
|
return w, solution.obj |
|
|
|
|
|
|
|
# from cvxopt import solvers, matrix |
|
|
|
# n = K.shape[0] |
|
|
|
# P = matrix(K) |
|
|
|
# q = matrix(-C) |
|
|
|
# G = matrix(-np.eye(n)) |
|
|
|
# h = matrix(np.zeros((n, 1))) |
|
|
|
# A = matrix(np.ones((1, n))) |
|
|
|
# b = matrix(np.ones((1, 1))) |
|
|
|
# solvers.options["show_progress"] = False |
|
|
|
# sol = solvers.qp(P, q, G, h, A, b) # Requires the sum of x to be 1 |
|
|
|
# # sol = solvers.qp(P, q, G, h) # Otherwise |
|
|
|
# w = np.array(sol["x"]) |
|
|
|
# w = torch.from_numpy(w).reshape(-1) |
|
|
|
# return w, sol["primal objective"] |