diff --git a/examples/examples1/example_rkme.py b/examples/examples1/example_rkme.py index f7842b3..4d14914 100644 --- a/examples/examples1/example_rkme.py +++ b/examples/examples1/example_rkme.py @@ -3,7 +3,7 @@ import learnware.specification as specification if __name__ == "__main__": - data_X = np.random.randn(10000, 20) + data_X = np.random.randn(10000, 20, 10, 5) for i in range(10): data_X[i, i] = np.nan spec1 = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=-1) @@ -24,3 +24,5 @@ if __name__ == "__main__": print(spec1.inner_prod(spec2)) print(spec1.dist(spec2)) + print(spec1.get_z().shape) + print(spec2.get_z().shape) \ No newline at end of file diff --git a/learnware/specification/rkme.py b/learnware/specification/rkme.py index f1aec90..f1157b5 100644 --- a/learnware/specification/rkme.py +++ b/learnware/specification/rkme.py @@ -89,6 +89,9 @@ class RKMEStatSpecification(BaseStatSpecification): """ alpha = None self.num_points = X.shape[0] + X_shape = X.shape + Z_shape = tuple([K] + list(X_shape)[1:]) + X = X.reshape(self.num_points, -1) # fill np.nan X_nan = np.isnan(X) @@ -98,7 +101,7 @@ class RKMEStatSpecification(BaseStatSpecification): X[:, col] = np.where(X_nan[:, col], col_mean, X[:, col]) if not reduce: - self.z = X + self.z = X.reshape(X_shape) self.beta = 1 / self.num_points * np.ones(self.num_points) self.z = torch.from_numpy(self.z).double().to(self.device) self.beta = torch.from_numpy(self.beta).double().to(self.device) @@ -113,6 +116,9 @@ class RKMEStatSpecification(BaseStatSpecification): self._update_z(alpha, X, step_size) self._update_beta(X, nonnegative_beta) + # Reshape to original dimensions + self.z = self.z.reshape(Z_shape) + def _init_z_by_faiss(self, X: Union[np.ndarray, torch.tensor], K: int): """Intialize Z by faiss clustering. @@ -268,10 +274,10 @@ class RKMEStatSpecification(BaseStatSpecification): """ beta_1 = self.beta.reshape(1, -1).double().to(self.device) beta_2 = Phi2.beta.reshape(1, -1).double().to(self.device) - Z1 = self.z.double().to(self.device) - Z2 = Phi2.z.double().to(self.device) - + Z1 = self.z.double().reshape(self.z.shape[0], -1).to(self.device) + Z2 = Phi2.z.double().reshape(Phi2.z.shape[0], -1).to(self.device) v = torch.sum(torch_rbf_kernel(Z1, Z2, self.gamma) * (beta_1.T @ beta_2)) + return float(v) def dist(self, Phi2: RKMEStatSpecification, omit_term1: bool = False) -> float: @@ -306,6 +312,10 @@ class RKMEStatSpecification(BaseStatSpecification): np.ndarray A collection of examples which approximate the unknown distribution. """ + # Flatten z + Z_shape = self.z.shape + self.z = self.z.reshape(self.z.shape[0], -1) + Nstart = 100 * T Xstart = self._sampling_candidates(Nstart).to(self.device) D = self.z[0].shape[0] @@ -318,6 +328,11 @@ class RKMEStatSpecification(BaseStatSpecification): fs = (i + 1) * fsX - fsS idx = torch.argmax(fs) S[i, :] = Xstart[idx, :] + + # Reshape to orignial dimensions + self.z = self.z.reshape(Z_shape) + S_shape = tuple([S.shape[0]] + list(Z_shape)[1:]) + S = S.reshape(S_shape) return S