|
|
|
@@ -89,6 +89,9 @@ class RKMEStatSpecification(BaseStatSpecification): |
|
|
|
""" |
|
|
|
alpha = None |
|
|
|
self.num_points = X.shape[0] |
|
|
|
X_shape = X.shape |
|
|
|
Z_shape = tuple([K] + list(X_shape)[1:]) |
|
|
|
X = X.reshape(self.num_points, -1) |
|
|
|
|
|
|
|
# fill np.nan |
|
|
|
X_nan = np.isnan(X) |
|
|
|
@@ -98,7 +101,7 @@ class RKMEStatSpecification(BaseStatSpecification): |
|
|
|
X[:, col] = np.where(X_nan[:, col], col_mean, X[:, col]) |
|
|
|
|
|
|
|
if not reduce: |
|
|
|
self.z = X |
|
|
|
self.z = X.reshape(X_shape) |
|
|
|
self.beta = 1 / self.num_points * np.ones(self.num_points) |
|
|
|
self.z = torch.from_numpy(self.z).double().to(self.device) |
|
|
|
self.beta = torch.from_numpy(self.beta).double().to(self.device) |
|
|
|
@@ -113,6 +116,9 @@ class RKMEStatSpecification(BaseStatSpecification): |
|
|
|
self._update_z(alpha, X, step_size) |
|
|
|
self._update_beta(X, nonnegative_beta) |
|
|
|
|
|
|
|
# Reshape to original dimensions |
|
|
|
self.z = self.z.reshape(Z_shape) |
|
|
|
|
|
|
|
def _init_z_by_faiss(self, X: Union[np.ndarray, torch.tensor], K: int): |
|
|
|
"""Intialize Z by faiss clustering. |
|
|
|
|
|
|
|
@@ -268,10 +274,10 @@ class RKMEStatSpecification(BaseStatSpecification): |
|
|
|
""" |
|
|
|
beta_1 = self.beta.reshape(1, -1).double().to(self.device) |
|
|
|
beta_2 = Phi2.beta.reshape(1, -1).double().to(self.device) |
|
|
|
Z1 = self.z.double().to(self.device) |
|
|
|
Z2 = Phi2.z.double().to(self.device) |
|
|
|
|
|
|
|
Z1 = self.z.double().reshape(self.z.shape[0], -1).to(self.device) |
|
|
|
Z2 = Phi2.z.double().reshape(Phi2.z.shape[0], -1).to(self.device) |
|
|
|
v = torch.sum(torch_rbf_kernel(Z1, Z2, self.gamma) * (beta_1.T @ beta_2)) |
|
|
|
|
|
|
|
return float(v) |
|
|
|
|
|
|
|
def dist(self, Phi2: RKMEStatSpecification, omit_term1: bool = False) -> float: |
|
|
|
@@ -306,6 +312,10 @@ class RKMEStatSpecification(BaseStatSpecification): |
|
|
|
np.ndarray |
|
|
|
A collection of examples which approximate the unknown distribution. |
|
|
|
""" |
|
|
|
# Flatten z |
|
|
|
Z_shape = self.z.shape |
|
|
|
self.z = self.z.reshape(self.z.shape[0], -1) |
|
|
|
|
|
|
|
Nstart = 100 * T |
|
|
|
Xstart = self._sampling_candidates(Nstart).to(self.device) |
|
|
|
D = self.z[0].shape[0] |
|
|
|
@@ -318,6 +328,11 @@ class RKMEStatSpecification(BaseStatSpecification): |
|
|
|
fs = (i + 1) * fsX - fsS |
|
|
|
idx = torch.argmax(fs) |
|
|
|
S[i, :] = Xstart[idx, :] |
|
|
|
|
|
|
|
# Reshape to orignial dimensions |
|
|
|
self.z = self.z.reshape(Z_shape) |
|
|
|
S_shape = tuple([S.shape[0]] + list(Z_shape)[1:]) |
|
|
|
S = S.reshape(S_shape) |
|
|
|
|
|
|
|
return S |
|
|
|
|
|
|
|
|