diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index beed7a0186..3826f6a548 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -257,11 +257,11 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e .. math:: decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) * - (1 - tmp\_epoch / decay\_epoch)^{power} + end\_learning\_rate + (1 - tmp\_epoch / tmp\_decay\_epoch)^{power} + end\_learning\_rate - Where :math:`tmp\_epoch=min(current\_epoch, decay\_epoch), current\_epoch=floor(\frac{i}{step\_per\_epoch})`. - If `update_decay_epoch` is true, update the value of `decay_epoch` every epoch. The formula is - :math:`decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)` + Where :math:`tmp\_epoch=min(current\_epoch, decay\_epoch),\ current\_epoch=floor(\frac{i}{step\_per\_epoch})`, and + :math:`tmp\_decay\_epoch = decay\_epoch`. If `update_decay_epoch` is true, update the value of `tmp_decay_epoch` + every epoch. The formula is :math:`tmp\_decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)` Args: learning_rate (float): The initial value of learning rate. @@ -296,9 +296,10 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) + origin_decay_epoch = decay_epoch function = lambda x, y: (x, min(x, y)) if update_decay_epoch: - function = lambda x, y: (x * max(math.ceil(y / x), 1), y) + function = lambda x, y: (origin_decay_epoch * max(math.ceil(y / origin_decay_epoch), 1), y) lr = [] delta = learning_rate - end_learning_rate