Browse Source

[MNT] RKME auto dimension

tags/v0.3.2
Gene 3 years ago
parent
commit
ed9fc5cf01
2 changed files with 22 additions and 5 deletions
  1. +3
    -1
      examples/examples1/example_rkme.py
  2. +19
    -4
      learnware/specification/rkme.py

+ 3
- 1
examples/examples1/example_rkme.py View File

@@ -3,7 +3,7 @@ import learnware.specification as specification


if __name__ == "__main__":
data_X = np.random.randn(10000, 20)
data_X = np.random.randn(10000, 20, 10, 5)
for i in range(10):
data_X[i, i] = np.nan
spec1 = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=-1)
@@ -24,3 +24,5 @@ if __name__ == "__main__":

print(spec1.inner_prod(spec2))
print(spec1.dist(spec2))
print(spec1.get_z().shape)
print(spec2.get_z().shape)

+ 19
- 4
learnware/specification/rkme.py View File

@@ -89,6 +89,9 @@ class RKMEStatSpecification(BaseStatSpecification):
"""
alpha = None
self.num_points = X.shape[0]
X_shape = X.shape
Z_shape = tuple([K] + list(X_shape)[1:])
X = X.reshape(self.num_points, -1)

# fill np.nan
X_nan = np.isnan(X)
@@ -98,7 +101,7 @@ class RKMEStatSpecification(BaseStatSpecification):
X[:, col] = np.where(X_nan[:, col], col_mean, X[:, col])

if not reduce:
self.z = X
self.z = X.reshape(X_shape)
self.beta = 1 / self.num_points * np.ones(self.num_points)
self.z = torch.from_numpy(self.z).double().to(self.device)
self.beta = torch.from_numpy(self.beta).double().to(self.device)
@@ -113,6 +116,9 @@ class RKMEStatSpecification(BaseStatSpecification):
self._update_z(alpha, X, step_size)
self._update_beta(X, nonnegative_beta)

# Reshape to original dimensions
self.z = self.z.reshape(Z_shape)

def _init_z_by_faiss(self, X: Union[np.ndarray, torch.tensor], K: int):
"""Intialize Z by faiss clustering.

@@ -268,10 +274,10 @@ class RKMEStatSpecification(BaseStatSpecification):
"""
beta_1 = self.beta.reshape(1, -1).double().to(self.device)
beta_2 = Phi2.beta.reshape(1, -1).double().to(self.device)
Z1 = self.z.double().to(self.device)
Z2 = Phi2.z.double().to(self.device)

Z1 = self.z.double().reshape(self.z.shape[0], -1).to(self.device)
Z2 = Phi2.z.double().reshape(Phi2.z.shape[0], -1).to(self.device)
v = torch.sum(torch_rbf_kernel(Z1, Z2, self.gamma) * (beta_1.T @ beta_2))

return float(v)

def dist(self, Phi2: RKMEStatSpecification, omit_term1: bool = False) -> float:
@@ -306,6 +312,10 @@ class RKMEStatSpecification(BaseStatSpecification):
np.ndarray
A collection of examples which approximate the unknown distribution.
"""
# Flatten z
Z_shape = self.z.shape
self.z = self.z.reshape(self.z.shape[0], -1)
Nstart = 100 * T
Xstart = self._sampling_candidates(Nstart).to(self.device)
D = self.z[0].shape[0]
@@ -318,6 +328,11 @@ class RKMEStatSpecification(BaseStatSpecification):
fs = (i + 1) * fsX - fsS
idx = torch.argmax(fs)
S[i, :] = Xstart[idx, :]
# Reshape to orignial dimensions
self.z = self.z.reshape(Z_shape)
S_shape = tuple([S.shape[0]] + list(Z_shape)[1:])
S = S.reshape(S_shape)

return S



Loading…
Cancel
Save