You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dynamic_lr.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Dynamic Learning Rate"""
  16. import math
  17. from mindspore._checkparam import Validator as validator
  18. def piecewise_constant_lr(milestone, learning_rates):
  19. r"""
  20. Get piecewise constant learning rate.
  21. Calculate learning rate by given `milestone` and `learning_rates`. Let the value of `milestone` be
  22. :math:`(M_1, M_2, ..., M_N)` and the value of `learning_rates` be :math:`(x_1, x_2, ..., x_N)`. N is the length of
  23. `milestone`. Let the output learning rate be `y`.
  24. .. math::
  25. y[i] = x_t,\ for\ i \in [M_{t-1}, M_t)
  26. Args:
  27. milestone (Union[list[int], tuple[int]]): A list of milestone. This list is a monotone increasing list.
  28. Every element is a milestone step, and must be greater than 0.
  29. learning_rates (Union[list[float], tuple[float]]): A list of learning rates.
  30. Returns:
  31. list[float]. The size of list is :math:`M_N`.
  32. Examples:
  33. >>> milestone = [2, 5, 10]
  34. >>> learning_rates = [0.1, 0.05, 0.01]
  35. >>> piecewise_constant_lr(milestone, learning_rates)
  36. [0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01]
  37. """
  38. validator.check_value_type('milestone', milestone, (tuple, list))
  39. validator.check_value_type('learning_rates', learning_rates, (tuple, list))
  40. if len(milestone) != len(learning_rates):
  41. raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.')
  42. lr = []
  43. last_item = 0
  44. for i, item in enumerate(milestone):
  45. validator.check_positive_int(item, f'milestone[{i}]')
  46. validator.check_is_float(learning_rates[i], f'learning_rates[{i}]')
  47. if item < last_item:
  48. raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
  49. lr += [learning_rates[i]] * (item - last_item)
  50. last_item = item
  51. return lr
  52. def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair):
  53. validator.check_positive_int(total_step, 'total_step')
  54. validator.check_positive_int(step_per_epoch, 'step_per_epoch')
  55. validator.check_positive_int(decay_epoch, 'decay_epoch')
  56. validator.check_positive_float(learning_rate, 'learning_rate')
  57. validator.check_is_float(learning_rate, 'learning_rate')
  58. validator.check_positive_float(decay_rate, 'decay_rate')
  59. validator.check_is_float(decay_rate, 'decay_rate')
  60. validator.check_value_type('is_stair', is_stair, [bool])
  61. def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
  62. r"""
  63. Calculate learning rate base on exponential decay function.
  64. For the i-th step, the formula of computing decayed_learning_rate[i] is:
  65. .. math::
  66. decayed\_learning\_rate[i] = learning\_rate * decay\_rate^{\frac{current\_epoch}{decay\_epoch}}
  67. Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
  68. Args:
  69. learning_rate (float): The initial value of learning rate.
  70. decay_rate (float): The decay rate.
  71. total_step (int): The total number of steps.
  72. step_per_epoch (int): The number of steps in per epoch.
  73. decay_epoch (int): A value used to calculate decayed learning rate.
  74. is_stair (bool): If true, learning rate is decayed once every `decay_epoch` times. Default: False.
  75. Returns:
  76. list[float]. The size of list is `total_step`.
  77. Examples:
  78. >>> learning_rate = 0.1
  79. >>> decay_rate = 0.9
  80. >>> total_step = 6
  81. >>> step_per_epoch = 2
  82. >>> decay_epoch = 1
  83. >>> exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
  84. [0.1, 0.1, 0.09000000000000001, 0.09000000000000001, 0.08100000000000002, 0.08100000000000002]
  85. """
  86. _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
  87. lr = []
  88. for i in range(total_step):
  89. if is_stair:
  90. lr.append(learning_rate * decay_rate ** math.floor(math.floor(i / step_per_epoch) / decay_epoch))
  91. else:
  92. lr.append(learning_rate * decay_rate ** (math.floor(i / step_per_epoch) / decay_epoch))
  93. return lr
  94. def natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
  95. r"""
  96. Calculate learning rate base on natural exponential decay function.
  97. For the i-th step, the formula of computing decayed_learning_rate[i] is:
  98. .. math::
  99. decayed\_learning\_rate[i] = learning\_rate * e^{-decay\_rate * current\_epoch}
  100. Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
  101. Args:
  102. learning_rate (float): The initial value of learning rate.
  103. decay_rate (float): The decay rate.
  104. total_step (int): The total number of steps.
  105. step_per_epoch (int): The number of steps in per epoch.
  106. decay_epoch (int): A value used to calculate decayed learning rate.
  107. is_stair (bool): If true, learning rate is decayed once every `decay_epoch` times. Default: False.
  108. Returns:
  109. list[float]. The size of list is `total_step`.
  110. Examples:
  111. >>> learning_rate = 0.1
  112. >>> decay_rate = 0.9
  113. >>> total_step = 6
  114. >>> step_per_epoch = 2
  115. >>> decay_epoch = 2
  116. >>> natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
  117. [0.1, 0.1, 0.1, 0.1, 0.016529888822158657, 0.016529888822158657]
  118. """
  119. _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
  120. function = lambda x, y: x
  121. if is_stair:
  122. function = lambda x, y: math.floor(x / y) * y
  123. lr = []
  124. for i in range(total_step):
  125. lr.append(learning_rate * math.e ** (-decay_rate * function(math.floor(i / step_per_epoch), decay_epoch)))
  126. return lr
  127. def inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
  128. r"""
  129. Calculate learning rate base on inverse-time decay function.
  130. For the i-th step, the formula of computing decayed_learning_rate[i] is:
  131. .. math::
  132. decayed\_learning\_rate[i] = learning\_rate / (1 + decay\_rate * current\_epoch / decay\_epoch)
  133. Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
  134. Args:
  135. learning_rate (float): The initial value of learning rate.
  136. decay_rate (float): The decay rate.
  137. total_step (int): The total number of steps.
  138. step_per_epoch (int): The number of steps in per epoch.
  139. decay_epoch (int): A value used to calculate decayed learning rate.
  140. is_stair (bool): If true, learning rate is decayed once every `decay_epoch` times. Default: False.
  141. Returns:
  142. list[float]. The size of list is `total_step`.
  143. Examples:
  144. >>> learning_rate = 0.1
  145. >>> decay_rate = 0.5
  146. >>> total_step = 6
  147. >>> step_per_epoch = 1
  148. >>> decay_epoch = 1
  149. >>> inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
  150. [0.1, 0.06666666666666667, 0.05, 0.04, 0.03333333333333333, 0.028571428571428574]
  151. """
  152. _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
  153. lr = []
  154. for i in range(total_step):
  155. if is_stair:
  156. lr.append(learning_rate / (1 + decay_rate * math.floor(math.floor(i / step_per_epoch) / decay_epoch)))
  157. else:
  158. lr.append(learning_rate / (1 + decay_rate * math.floor(i / step_per_epoch) / decay_epoch))
  159. return lr
  160. def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
  161. r"""
  162. Calculate learning rate base on cosine decay function.
  163. For the i-th step, the formula of computing decayed_learning_rate[i] is:
  164. .. math::
  165. decayed\_learning\_rate[i] = min\_learning\_rate + 0.5 * (max\_learning\_rate - min\_learning\_rate) *
  166. (1 + cos(\frac{current\_epoch}{decay\_epoch}\pi))
  167. Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
  168. Args:
  169. min_lr (float): The minimum value of learning rate.
  170. max_lr (float): The maximum value of learning rate.
  171. total_step (int): The total number of steps.
  172. step_per_epoch (int): The number of steps in per epoch.
  173. decay_epoch (int): A value used to calculate decayed learning rate.
  174. Returns:
  175. list[float]. The size of list is `total_step`.
  176. Examples:
  177. >>> min_lr = 0.01
  178. >>> max_lr = 0.1
  179. >>> total_step = 6
  180. >>> step_per_epoch = 2
  181. >>> decay_epoch = 2
  182. >>> cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
  183. [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
  184. """
  185. if not isinstance(min_lr, float):
  186. raise TypeError("min_lr must be float.")
  187. validator.check_non_negative_float(min_lr, "min_lr", None)
  188. validator.check_positive_float(max_lr, 'max_lr')
  189. validator.check_is_float(max_lr, 'max_lr')
  190. validator.check_positive_int(total_step, 'total_step')
  191. validator.check_positive_int(step_per_epoch, 'step_per_epoch')
  192. validator.check_positive_int(decay_epoch, 'decay_epoch')
  193. if min_lr >= max_lr:
  194. raise ValueError('`max_lr` should be greater than `min_lr`.')
  195. delta = 0.5 * (max_lr - min_lr)
  196. lr = []
  197. for i in range(total_step):
  198. tmp_epoch = min(math.floor(i / step_per_epoch), decay_epoch)
  199. lr.append(min_lr + delta * (1 + math.cos(math.pi * tmp_epoch / decay_epoch)))
  200. return lr
  201. def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power,
  202. update_decay_epoch=False):
  203. r"""
  204. Calculate learning rate base on polynomial decay function.
  205. For the i-th step, the formula of computing decayed_learning_rate[i] is:
  206. .. math::
  207. decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) *
  208. (1 - tmp\_epoch / tmp\_decay\_epoch)^{power} + end\_learning\_rate
  209. Where:
  210. .. math::
  211. tmp\_epoch = min(current\_epoch, decay\_epoch)
  212. .. math::
  213. current\_epoch=floor(\frac{i}{step\_per\_epoch})
  214. .. math::
  215. tmp\_decay\_epoch = decay\_epoch
  216. If `update_decay_epoch` is true, update the value of `tmp_decay_epoch` every epoch. The formula is:
  217. .. math::
  218. tmp\_decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)
  219. Args:
  220. learning_rate (float): The initial value of learning rate.
  221. end_learning_rate (float): The end value of learning rate.
  222. total_step (int): The total number of steps.
  223. step_per_epoch (int): The number of steps in per epoch.
  224. decay_epoch (int): A value used to calculate decayed learning rate.
  225. power (float): A value used to calculate decayed learning rate. This parameter must be greater than 0.
  226. update_decay_epoch (bool): If true, update `decay_epoch`. Default: False.
  227. Returns:
  228. list[float]. The size of list is `total_step`.
  229. Examples:
  230. >>> learning_rate = 0.1
  231. >>> end_learning_rate = 0.01
  232. >>> total_step = 6
  233. >>> step_per_epoch = 2
  234. >>> decay_epoch = 2
  235. >>> power = 0.5
  236. >>> polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
  237. [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
  238. """
  239. validator.check_positive_float(learning_rate, 'learning_rate')
  240. validator.check_is_float(learning_rate, 'learning_rate')
  241. if not isinstance(end_learning_rate, float):
  242. raise TypeError("end_learning_rate must be float.")
  243. validator.check_non_negative_float(end_learning_rate, "end_learning_rate", None)
  244. validator.check_positive_float(power, 'power')
  245. validator.check_is_float(power, 'power')
  246. validator.check_positive_int(total_step, 'total_step')
  247. validator.check_positive_int(step_per_epoch, 'step_per_epoch')
  248. validator.check_positive_int(decay_epoch, 'decay_epoch')
  249. validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool])
  250. origin_decay_epoch = decay_epoch
  251. function = lambda x, y: (x, min(x, y))
  252. if update_decay_epoch:
  253. function = lambda x, y: (origin_decay_epoch * max(math.ceil(y / origin_decay_epoch), 1), y)
  254. lr = []
  255. delta = learning_rate - end_learning_rate
  256. for i in range(total_step):
  257. current_epoch = math.floor(i / step_per_epoch)
  258. decay_epoch, tmp_epoch = function(decay_epoch, current_epoch)
  259. lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate)
  260. return lr
  261. def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch):
  262. r"""
  263. Get learning rate warming up.
  264. For the i-th step, the formula of computing warmup_learning_rate[i] is:
  265. .. math::
  266. warmup\_learning\_rate[i] = learning\_rate * tmp\_epoch / tmp\_warmup\_epoch
  267. Where :math:`tmp\_epoch=min(current\_epoch, warmup\_epoch),\ current\_epoch=floor(\frac{i}{step\_per\_epoch})`
  268. Args:
  269. learning_rate (float): The initial value of learning rate.
  270. warmup_steps (int): The warm up steps of learning rate.
  271. Inputs:
  272. Tensor. The current step number.
  273. Returns:
  274. Tensor. The learning rate value for the current step.
  275. Examples:
  276. >>> learning_rate = 0.1
  277. >>> total_step = 6
  278. >>> step_per_epoch = 2
  279. >>> warmup_epoch = 2
  280. >>> warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch)
  281. [0.0, 0.0, 0.05, 0.05, 0.1, 0.1]
  282. """
  283. if not isinstance(learning_rate, float):
  284. raise TypeError("learning_rate must be float.")
  285. validator.check_non_negative_float(learning_rate, "learning_rate", None)
  286. validator.check_positive_int(warmup_epoch, 'warmup_epoch')
  287. validator.check_positive_int(total_step, 'total_step')
  288. validator.check_positive_int(step_per_epoch, 'step_per_epoch')
  289. function = lambda x, y: (x, min(x, y))
  290. lr = []
  291. for i in range(total_step):
  292. current_epoch = math.floor(i / step_per_epoch)
  293. warmup_epoch, tmp_epoch = function(warmup_epoch, current_epoch)
  294. lr.append(learning_rate * tmp_epoch/ warmup_epoch)
  295. return lr
  296. __all__ = [
  297. 'piecewise_constant_lr',
  298. 'exponential_decay_lr',
  299. 'natural_exp_decay_lr',
  300. 'inverse_decay_lr',
  301. 'cosine_decay_lr',
  302. 'polynomial_decay_lr',
  303. 'warmup_lr'
  304. ]