|
|
|
@@ -181,6 +181,54 @@ class RKMEStatSpecification(BaseStatSpecification): |
|
|
|
|
|
|
|
self.beta = beta |
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
def _update_z_vectorize(self, alpha: float, X: Any, step_size: float, batch_size=8): |
|
|
|
"""Fix beta and update Z using gradient descent. |
|
|
|
Unlike method _update_z, this method updates z by batches. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
alpha : int |
|
|
|
Normalization factor. |
|
|
|
X : np.ndarray or torch.tensor |
|
|
|
Raw data in np.ndarray format or torch.tensor format. |
|
|
|
step_size : float |
|
|
|
Step size for gradient descent. |
|
|
|
batch_size : int |
|
|
|
To prevent exceeding GPU memory, process no more than batch_size at a time. |
|
|
|
""" |
|
|
|
gamma = self.gamma |
|
|
|
Z = self.z |
|
|
|
beta = self.beta |
|
|
|
|
|
|
|
if not torch.is_tensor(Z): |
|
|
|
Z = torch.from_numpy(Z) |
|
|
|
Z = Z.to(self.device).double() |
|
|
|
|
|
|
|
if not torch.is_tensor(beta): |
|
|
|
beta = torch.from_numpy(beta) |
|
|
|
beta = beta.to(self.device).double() |
|
|
|
|
|
|
|
if not torch.is_tensor(X): |
|
|
|
X = torch.from_numpy(X) |
|
|
|
X = X.to(self.device).double() |
|
|
|
|
|
|
|
grad_Z = torch.zeros_like(Z) |
|
|
|
for i in range(0, Z.shape[0], batch_size): |
|
|
|
Z_ = Z[i: i + batch_size] |
|
|
|
term_1 = torch.bmm(torch.unsqueeze((torch.unsqueeze(beta, dim=0) * torch_rbf_kernel(Z_, Z, gamma)), dim=1), |
|
|
|
torch.unsqueeze(Z_, dim=1) - torch.unsqueeze(Z, dim=0)) |
|
|
|
if alpha is not None: |
|
|
|
term_2 = -2 * torch.bmm(torch.unsqueeze(alpha * torch_rbf_kernel(Z_, X, gamma), dim=1), |
|
|
|
torch.unsqueeze(Z_, dim=1) - torch.unsqueeze(X, dim=0)) |
|
|
|
else: |
|
|
|
term_2 = -2 * torch.bmm(torch.unsqueeze(torch_rbf_kernel(Z_, X, gamma) / self.num_points, dim=1), |
|
|
|
torch.unsqueeze(Z_, dim=1) - torch.unsqueeze(X, dim=0)) |
|
|
|
grad_Z[i: i + batch_size] = -2 * gamma * torch.unsqueeze(beta[i: i + batch_size], dim=1) * torch.squeeze(term_1 + term_2) |
|
|
|
|
|
|
|
Z = Z - step_size * grad_Z |
|
|
|
self.z = Z |
|
|
|
|
|
|
|
def _update_z(self, alpha: float, X: Any, step_size: float): |
|
|
|
"""Fix beta and update Z using gradient descent. |
|
|
|
|
|
|
|
|