You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dynamic_weight_decay_gpu.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import mindspore.context as context
  16. import mindspore.nn as nn
  17. from .weight_decay_utils import dynamic_weight_decay_cmp, WeightDecaySchdule, Net
  18. def test_momentum_dynamic_weight_decay_pynative():
  19. """
  20. Feature: Dynamic weight decay
  21. Description: Test dynamic weight decay for Momentum
  22. Expectation: The value of decay changes according to preset weight decay schedule
  23. """
  24. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  25. net1, net2 = Net(), Net()
  26. weight_decay_schedule = WeightDecaySchdule()
  27. optimizer1 = nn.Momentum(net1.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=0.001)
  28. optimizer2 = nn.Momentum(net2.trainable_params(), momentum=0.001, learning_rate=0.001,
  29. weight_decay=weight_decay_schedule)
  30. dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
  31. def test_momentum_dynamic_weight_decay_graph():
  32. """
  33. Feature: Dynamic weight decay
  34. Description: Test dynamic weight decay for Momentum
  35. Expectation: The value of decay changes according to preset weight decay schedule
  36. """
  37. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  38. net1, net2 = Net(), Net()
  39. weight_decay_schedule = WeightDecaySchdule()
  40. optimizer1 = nn.Momentum(net1.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=0.001)
  41. optimizer2 = nn.Momentum(net2.trainable_params(), momentum=0.001, learning_rate=0.001,
  42. weight_decay=weight_decay_schedule)
  43. dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
  44. def test_momentum_dynamic_weight_decay_graph_group():
  45. """
  46. Feature: Dynamic weight decay
  47. Description: Test dynamic weight decay for Momentum
  48. Expectation: The value of decay changes according to preset weight decay schedule
  49. """
  50. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  51. weight_decay_schedule = WeightDecaySchdule()
  52. net1, net2 = Net(), Net()
  53. net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
  54. net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
  55. net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
  56. net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
  57. params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
  58. {'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
  59. params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
  60. {'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
  61. optimizer1 = nn.Momentum(params1, momentum=0.001, learning_rate=0.001, weight_decay=0.001)
  62. optimizer2 = nn.Momentum(params2, momentum=0.001, learning_rate=0.001, weight_decay=0.001)
  63. dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
  64. def test_adamweightdecay_dynamic_weight_decay_pynative():
  65. """
  66. Feature: Dynamic weight decay
  67. Description: Test dynamic weight decay for AdamWeightDecay
  68. Expectation: The value of decay changes according to preset weight decay schedule
  69. """
  70. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  71. net1, net2 = Net(), Net()
  72. weight_decay_schedule = WeightDecaySchdule()
  73. optimizer1 = nn.AdamWeightDecay(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
  74. optimizer2 = nn.AdamWeightDecay(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
  75. dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
  76. def test_adamweightdecay_dynamic_weight_decay_graph():
  77. """
  78. Feature: Dynamic weight decay
  79. Description: Test dynamic weight decay for AdamWeightDecay
  80. Expectation: The value of decay changes according to preset weight decay schedule
  81. """
  82. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  83. net1, net2 = Net(), Net()
  84. weight_decay_schedule = WeightDecaySchdule()
  85. optimizer1 = nn.AdamWeightDecay(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
  86. optimizer2 = nn.AdamWeightDecay(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
  87. dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
  88. def test_adamweightdecay_dynamic_weight_decay_graph_group():
  89. """
  90. Feature: Dynamic weight decay
  91. Description: Test dynamic weight decay for Momentum
  92. Expectation: The value of decay changes according to preset weight decay schedule
  93. """
  94. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  95. weight_decay_schedule = WeightDecaySchdule()
  96. net1, net2 = Net(), Net()
  97. net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
  98. net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
  99. net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
  100. net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
  101. params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
  102. {'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
  103. params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
  104. {'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
  105. optimizer1 = nn.AdamWeightDecay(params1, learning_rate=0.001, weight_decay=0.001)
  106. optimizer2 = nn.AdamWeightDecay(params2, learning_rate=0.001, weight_decay=0.001)
  107. dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
  108. def test_lamb_dynamic_weight_decay_pynative():
  109. """
  110. Feature: Dynamic weight decay
  111. Description: Test dynamic weight decay for Lamb
  112. Expectation: The value of decay changes according to preset weight decay schedule
  113. """
  114. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  115. net1, net2 = Net(), Net()
  116. weight_decay_schedule = WeightDecaySchdule()
  117. optimizer1 = nn.Lamb(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
  118. optimizer2 = nn.Lamb(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
  119. dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
  120. def test_lamb_dynamic_weight_decay_graph():
  121. """
  122. Feature: Dynamic weight decay
  123. Description: Test dynamic weight decay for Lamb
  124. Expectation: The value of decay changes according to preset weight decay schedule
  125. """
  126. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  127. net1, net2 = Net(), Net()
  128. weight_decay_schedule = WeightDecaySchdule()
  129. optimizer1 = nn.Lamb(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
  130. optimizer2 = nn.Lamb(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
  131. dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
  132. def test_lamb_dynamic_weight_decay_graph_group():
  133. """
  134. Feature: Dynamic weight decay
  135. Description: Test dynamic weight decay for Momentum
  136. Expectation: The value of decay changes according to preset weight decay schedule
  137. """
  138. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  139. weight_decay_schedule = WeightDecaySchdule()
  140. net1, net2 = Net(), Net()
  141. net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
  142. net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
  143. net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
  144. net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
  145. params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
  146. {'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
  147. params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
  148. {'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
  149. optimizer1 = nn.Lamb(params1, learning_rate=0.001, weight_decay=0.001)
  150. optimizer2 = nn.Lamb(params2, learning_rate=0.001, weight_decay=0.001)
  151. dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)