Browse Source

[ENH] Provide with faster _update_z_vectorize.

tags/v0.3.2
shihy 2 years ago
parent
commit
8a5c6661e4
1 changed files with 48 additions and 0 deletions
  1. +48
    -0
      learnware/specification/rkme.py

+ 48
- 0
learnware/specification/rkme.py View File

@@ -181,6 +181,54 @@ class RKMEStatSpecification(BaseStatSpecification):

self.beta = beta

@torch.no_grad()
def _update_z_vectorize(self, alpha: float, X: Any, step_size: float, batch_size=8):
"""Fix beta and update Z using gradient descent.
Unlike method _update_z, this method updates z by batches.

Parameters
----------
alpha : int
Normalization factor.
X : np.ndarray or torch.tensor
Raw data in np.ndarray format or torch.tensor format.
step_size : float
Step size for gradient descent.
batch_size : int
To prevent exceeding GPU memory, process no more than batch_size at a time.
"""
gamma = self.gamma
Z = self.z
beta = self.beta

if not torch.is_tensor(Z):
Z = torch.from_numpy(Z)
Z = Z.to(self.device).double()

if not torch.is_tensor(beta):
beta = torch.from_numpy(beta)
beta = beta.to(self.device).double()

if not torch.is_tensor(X):
X = torch.from_numpy(X)
X = X.to(self.device).double()

grad_Z = torch.zeros_like(Z)
for i in range(0, Z.shape[0], batch_size):
Z_ = Z[i: i + batch_size]
term_1 = torch.bmm(torch.unsqueeze((torch.unsqueeze(beta, dim=0) * torch_rbf_kernel(Z_, Z, gamma)), dim=1),
torch.unsqueeze(Z_, dim=1) - torch.unsqueeze(Z, dim=0))
if alpha is not None:
term_2 = -2 * torch.bmm(torch.unsqueeze(alpha * torch_rbf_kernel(Z_, X, gamma), dim=1),
torch.unsqueeze(Z_, dim=1) - torch.unsqueeze(X, dim=0))
else:
term_2 = -2 * torch.bmm(torch.unsqueeze(torch_rbf_kernel(Z_, X, gamma) / self.num_points, dim=1),
torch.unsqueeze(Z_, dim=1) - torch.unsqueeze(X, dim=0))
grad_Z[i: i + batch_size] = -2 * gamma * torch.unsqueeze(beta[i: i + batch_size], dim=1) * torch.squeeze(term_1 + term_2)

Z = Z - step_size * grad_Z
self.z = Z

def _update_z(self, alpha: float, X: Any, step_size: float):
"""Fix beta and update Z using gradient descent.



Loading…
Cancel
Save