You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Optim.py 1.1 kB

1234567891011121314151617181920212223242526272829303132333435
  1. '''A wrapper class for optimizer '''
  2. import numpy as np
  3. class ScheduledOptim():
  4. '''A simple wrapper class for learning rate scheduling'''
  5. def __init__(self, optimizer, d_model, n_warmup_steps):
  6. self._optimizer = optimizer
  7. self.n_warmup_steps = n_warmup_steps
  8. self.n_current_steps = 0
  9. self.init_lr = np.power(d_model, -0.5)
  10. def step_and_update_lr(self):
  11. "Step with the inner optimizer"
  12. self._update_learning_rate()
  13. self._optimizer.step()
  14. def zero_grad(self):
  15. "Zero out the gradients by the inner optimizer"
  16. self._optimizer.zero_grad()
  17. def _get_lr_scale(self):
  18. return np.min([
  19. np.power(self.n_current_steps, -0.5),
  20. np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])
  21. def _update_learning_rate(self):
  22. ''' Learning rate scheduling per step '''
  23. self.n_current_steps += 1
  24. lr = self.init_lr * self._get_lr_scale()
  25. for param_group in self._optimizer.param_groups:
  26. param_group['lr'] = lr