You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_learning_rate_schedule.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ Test Dynamic Learning Rate """
  16. import pytest
  17. from mindspore import Tensor
  18. from mindspore.nn import learning_rate_schedule as lr_schedules
  19. from mindspore.common.api import _cell_graph_executor
  20. import mindspore.common.dtype as mstype
  21. learning_rate = 0.1
  22. end_learning_rate = 0.01
  23. decay_rate = 0.9
  24. decay_steps = 4
  25. warmup_steps = 2
  26. min_lr = 0.01
  27. max_lr = 0.1
  28. power = 0.5
  29. global_step = Tensor(2, mstype.int32)
  30. class TestInit:
  31. def test_learning_rate_type(self):
  32. lr = True
  33. with pytest.raises(TypeError):
  34. lr_schedules.ExponentialDecayLR(lr, decay_rate, decay_steps)
  35. with pytest.raises(TypeError):
  36. lr_schedules.PolynomialDecayLR(lr, end_learning_rate, decay_steps, power)
  37. def test_learning_rate_value(self):
  38. lr = -1.0
  39. with pytest.raises(ValueError):
  40. lr_schedules.ExponentialDecayLR(lr, decay_rate, decay_steps)
  41. with pytest.raises(ValueError):
  42. lr_schedules.PolynomialDecayLR(lr, end_learning_rate, decay_steps, power)
  43. def test_end_learning_rate_type(self):
  44. lr = True
  45. with pytest.raises(TypeError):
  46. lr_schedules.PolynomialDecayLR(learning_rate, lr, decay_steps, power)
  47. def test_end_learning_rate_value(self):
  48. lr = -1.0
  49. with pytest.raises(ValueError):
  50. lr_schedules.PolynomialDecayLR(learning_rate, lr, decay_steps, power)
  51. def test_decay_rate_type(self):
  52. rate = 'a'
  53. with pytest.raises(TypeError):
  54. lr_schedules.ExponentialDecayLR(learning_rate, rate, decay_steps)
  55. def test_decay_rate_value(self):
  56. rate = -1.0
  57. with pytest.raises(ValueError):
  58. lr_schedules.ExponentialDecayLR(learning_rate, rate, decay_steps)
  59. def test_decay_steps_type(self):
  60. decay_steps_e = 'm'
  61. with pytest.raises(TypeError):
  62. lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps_e)
  63. with pytest.raises(TypeError):
  64. lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps_e)
  65. with pytest.raises(TypeError):
  66. lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps_e, power)
  67. def test_decay_steps_value(self):
  68. decay_steps_e = -2
  69. with pytest.raises(ValueError):
  70. lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps_e)
  71. with pytest.raises(ValueError):
  72. lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps_e)
  73. with pytest.raises(ValueError):
  74. lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps_e, power)
  75. def test_is_stair(self):
  76. is_stair = 1
  77. with pytest.raises(TypeError):
  78. lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps, is_stair)
  79. def test_min_lr_type(self):
  80. min_lr1 = True
  81. with pytest.raises(TypeError):
  82. lr_schedules.CosineDecayLR(min_lr1, max_lr, decay_steps)
  83. def test_min_lr_value(self):
  84. min_lr1 = -1.0
  85. with pytest.raises(ValueError):
  86. lr_schedules.CosineDecayLR(min_lr1, max_lr, decay_steps)
  87. def test_max_lr_type(self):
  88. max_lr1 = 'a'
  89. with pytest.raises(TypeError):
  90. lr_schedules.CosineDecayLR(min_lr, max_lr1, decay_steps)
  91. def test_max_lr_value(self):
  92. max_lr1 = -1.0
  93. with pytest.raises(ValueError):
  94. lr_schedules.CosineDecayLR(min_lr, max_lr1, decay_steps)
  95. def test_power(self):
  96. power1 = True
  97. with pytest.raises(TypeError):
  98. lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power1)
  99. def test_exponential_decay():
  100. lr_schedule = lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps, True)
  101. _cell_graph_executor.compile(lr_schedule, global_step)
  102. def test_enatural_exp_decay():
  103. lr_schedule = lr_schedules.NaturalExpDecayLR(learning_rate, decay_rate, decay_steps, True)
  104. _cell_graph_executor.compile(lr_schedule, global_step)
  105. def test_inverse_decay():
  106. lr_schedule = lr_schedules.InverseDecayLR(learning_rate, decay_rate, decay_steps, True)
  107. _cell_graph_executor.compile(lr_schedule, global_step)
  108. def test_cosine_decay():
  109. lr_schedule = lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps)
  110. _cell_graph_executor.compile(lr_schedule, global_step)
  111. def test_polynomial_decay():
  112. lr_schedule = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
  113. _cell_graph_executor.compile(lr_schedule, global_step)
  114. def test_polynomial_decay2():
  115. lr_schedule = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power, True)
  116. _cell_graph_executor.compile(lr_schedule, global_step)
  117. def test_warmup():
  118. lr_schedule = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
  119. _cell_graph_executor.compile(lr_schedule, global_step)