| @@ -17,7 +17,7 @@ import math | |||||
| import numpy as np | import numpy as np | ||||
| def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps): | |||||
| def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps, global_step=0): | |||||
| """ | """ | ||||
| Applies three steps decay to generate learning rate array. | Applies three steps decay to generate learning rate array. | ||||
| @@ -45,6 +45,7 @@ def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps): | |||||
| else: | else: | ||||
| lr = lr_max * 0.001 | lr = lr_max * 0.001 | ||||
| lr_each_step.append(lr) | lr_each_step.append(lr) | ||||
| lr_each_step = np.array(lr_each_step).astype(np.float32)[global_step:] | |||||
| return lr_each_step | return lr_each_step | ||||
| @@ -131,7 +132,7 @@ def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps): | |||||
| return lr_each_step | return lr_each_step | ||||
| def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): | |||||
| def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode, global_step=0): | |||||
| """ | """ | ||||
| generate learning rate array | generate learning rate array | ||||
| @@ -150,7 +151,7 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch | |||||
| total_steps = steps_per_epoch * total_epochs | total_steps = steps_per_epoch * total_epochs | ||||
| warmup_steps = steps_per_epoch * warmup_epochs | warmup_steps = steps_per_epoch * warmup_epochs | ||||
| if lr_decay_mode == 'steps': | if lr_decay_mode == 'steps': | ||||
| lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps) | |||||
| lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps, global_step) | |||||
| elif lr_decay_mode == 'steps_decay': | elif lr_decay_mode == 'steps_decay': | ||||
| lr_each_step = _generate_exponential_lr(lr_init, lr_max, total_steps, warmup_steps, steps_per_epoch) | lr_each_step = _generate_exponential_lr(lr_init, lr_max, total_steps, warmup_steps, steps_per_epoch) | ||||
| elif lr_decay_mode == 'cosine': | elif lr_decay_mode == 'cosine': | ||||