From 8a5c6661e493aaee7e2c25027cec2d67d08ecf26 Mon Sep 17 00:00:00 2001 From: shihy Date: Fri, 28 Jul 2023 13:38:41 +0800 Subject: [PATCH] [ENH] Provide with faster _update_z_vectorize. --- learnware/specification/rkme.py | 48 +++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/learnware/specification/rkme.py b/learnware/specification/rkme.py index 2271b03..541eab3 100644 --- a/learnware/specification/rkme.py +++ b/learnware/specification/rkme.py @@ -181,6 +181,54 @@ class RKMEStatSpecification(BaseStatSpecification): self.beta = beta + @torch.no_grad() + def _update_z_vectorize(self, alpha: float, X: Any, step_size: float, batch_size=8): + """Fix beta and update Z using gradient descent. + Unlike method _update_z, this method updates z by batches. + + Parameters + ---------- + alpha : int + Normalization factor. + X : np.ndarray or torch.tensor + Raw data in np.ndarray format or torch.tensor format. + step_size : float + Step size for gradient descent. + batch_size : int + To prevent exceeding GPU memory, process no more than batch_size at a time. + """ + gamma = self.gamma + Z = self.z + beta = self.beta + + if not torch.is_tensor(Z): + Z = torch.from_numpy(Z) + Z = Z.to(self.device).double() + + if not torch.is_tensor(beta): + beta = torch.from_numpy(beta) + beta = beta.to(self.device).double() + + if not torch.is_tensor(X): + X = torch.from_numpy(X) + X = X.to(self.device).double() + + grad_Z = torch.zeros_like(Z) + for i in range(0, Z.shape[0], batch_size): + Z_ = Z[i: i + batch_size] + term_1 = torch.bmm(torch.unsqueeze((torch.unsqueeze(beta, dim=0) * torch_rbf_kernel(Z_, Z, gamma)), dim=1), + torch.unsqueeze(Z_, dim=1) - torch.unsqueeze(Z, dim=0)) + if alpha is not None: + term_2 = -2 * torch.bmm(torch.unsqueeze(alpha * torch_rbf_kernel(Z_, X, gamma), dim=1), + torch.unsqueeze(Z_, dim=1) - torch.unsqueeze(X, dim=0)) + else: + term_2 = -2 * torch.bmm(torch.unsqueeze(torch_rbf_kernel(Z_, X, gamma) / self.num_points, dim=1), + torch.unsqueeze(Z_, dim=1) - torch.unsqueeze(X, dim=0)) + grad_Z[i: i + batch_size] = -2 * gamma * torch.unsqueeze(beta[i: i + batch_size], dim=1) * torch.squeeze(term_1 + term_2) + + Z = Z - step_size * grad_Z + self.z = Z + def _update_z(self, alpha: float, X: Any, step_size: float): """Fix beta and update Z using gradient descent.