You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autocast.py 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import functools
  2. from ..core.tensor import amp
  3. class autocast:
  4. r"""A class to control autocast mode for amp as a context manager or a decorator.
  5. Args:
  6. enabled: Whether autocast mode is enabled.
  7. low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change
  8. the target dtype in tensor casting for better speed and memory. Default: float16.
  9. high_prec_dtype: Set amp autocast mode's higher precision dtype. It will
  10. change the target dtype in tensor casting for better precision. Default: float32.
  11. Examples:
  12. .. code-block::
  13. # used as decorator
  14. @autocast()
  15. def train_step(image, label):
  16. with gm:
  17. logits = model(image)
  18. loss = F.nn.cross_entropy(logits, label)
  19. gm.backward(loss)
  20. opt.step().clear_grad()
  21. return loss
  22. # used as context manager
  23. def train_step(image, label):
  24. with autocast():
  25. with gm:
  26. logits = model(image)
  27. loss = F.nn.cross_entropy(logits, label)
  28. gm.backward(loss)
  29. opt.step().clear_grad()
  30. return loss
  31. """
  32. def __init__(
  33. self,
  34. enabled: bool = True,
  35. low_prec_dtype: str = "float16",
  36. high_prec_dtype: str = "float32",
  37. ):
  38. self.enabled = enabled
  39. self.high_prec_dtype = high_prec_dtype
  40. self.low_prec_dtype = low_prec_dtype
  41. self._origin_enabled = None
  42. self._origin_high = None
  43. self._origin_low = None
  44. def __enter__(self):
  45. self._origin_enabled = amp._enabled
  46. self._origin_high = amp._get_amp_high_prec_dtype()
  47. self._origin_low = amp._get_amp_low_prec_dtype()
  48. amp._enabled = self.enabled
  49. amp._set_amp_dtype_autocast(self.enabled)
  50. amp._set_amp_high_prec_dtype(self.high_prec_dtype)
  51. amp._set_amp_low_prec_dtype(self.low_prec_dtype)
  52. def __exit__(self, *args):
  53. amp._enabled = self._origin_enabled
  54. amp._set_amp_dtype_autocast(self._origin_enabled)
  55. amp._set_amp_high_prec_dtype(self._origin_high)
  56. amp._set_amp_low_prec_dtype(self._origin_low)
  57. def __call__(self, func):
  58. @functools.wraps(func)
  59. def wrapper(*args, **kwargs):
  60. with self:
  61. return func(*args, **kwargs)
  62. return wrapper