You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer_utils.py 11 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore
  17. from mindspore import nn, Tensor
  18. from mindspore.ops import operations as P
  19. from mindspore.nn.optim import ASGD
  20. from mindspore.nn.optim import Rprop
  21. np.random.seed(1024)
  22. fc1_weight = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149,
  23. 0.6942514, 0.39767185, 0.24918061, 0.4548748],
  24. [0.7203382, 0.19086994, 0.76286614, 0.87920564,
  25. 0.3169892, 0.9462494, 0.62827677, 0.27504718],
  26. [0.3544535, 0.2524781, 0.5370583, 0.8313121,
  27. 0.6670143, 0.0488653, 0.62225235, 0.7546456],
  28. [0.17985944, 0.05106374, 0.31064633, 0.4863033,
  29. 0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32")
  30. fc1_bias = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32")
  31. fc2_weight = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32")
  32. fc2_bias = np.array([0.09996348]).astype("float32")
  33. def make_fake_data():
  34. """
  35. make fake data
  36. """
  37. data, label = [], []
  38. for i in range(20):
  39. data.append(mindspore.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32)))
  40. label.append(mindspore.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32)))
  41. return data, label
  42. class NetWithLoss(nn.Cell):
  43. """
  44. build net with loss
  45. """
  46. def __init__(self, network):
  47. super(NetWithLoss, self).__init__()
  48. self.network = network
  49. self.loss = nn.MSELoss(reduction='sum')
  50. def construct(self, x, label):
  51. out = self.network(x)
  52. loss = self.loss(out, label)
  53. return loss
  54. class FakeNet(nn.Cell):
  55. """
  56. build fake net
  57. """
  58. def __init__(self):
  59. super(FakeNet, self).__init__()
  60. self.fc1 = nn.Dense(in_channels=8, out_channels=4, weight_init=Tensor(fc1_weight), bias_init=Tensor(fc1_bias))
  61. self.fc2 = nn.Dense(in_channels=4, out_channels=1, weight_init=Tensor(fc2_weight), bias_init=Tensor(fc2_bias))
  62. self.relu = nn.ReLU()
  63. self.reducemean = P.ReduceMean()
  64. def construct(self, x):
  65. x = self.relu(self.fc1(x))
  66. x = self.fc2(x)
  67. return x
  68. def _initialize_weights(self):
  69. """
  70. parameter initialization
  71. """
  72. self.init_parameters_data()
  73. for name, m in self.cells_and_names():
  74. if name == 'fc1':
  75. m.weight.set_data(Tensor(fc1_weight))
  76. m.bias.set_data(Tensor(fc1_bias))
  77. elif name == 'fc2':
  78. m.weight.set_data(Tensor(fc2_weight))
  79. m.bias.set_data(Tensor(fc2_bias))
  80. def build_network(opt_config, is_group=False):
  81. """
  82. Construct training
  83. """
  84. losses = []
  85. net = FakeNet()
  86. networkwithloss = NetWithLoss(net)
  87. networkwithloss.set_train()
  88. if is_group:
  89. fc1_params = list(filter(lambda x: 'fc1' in x.name, networkwithloss.trainable_params()))
  90. fc2_params = list(filter(lambda x: 'fc1' not in x.name, networkwithloss.trainable_params()))
  91. if opt_config['name'] == 'ASGD':
  92. params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
  93. else:
  94. params = [{'params': fc1_params, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
  95. else:
  96. params = networkwithloss.trainable_params()
  97. if opt_config['name'] == 'ASGD':
  98. net_opt = ASGD(params, learning_rate=opt_config['lr'], lambd=opt_config['lambd'], alpha=opt_config['alpha'],
  99. t0=opt_config['t0'], weight_decay=opt_config['weight_decay'])
  100. elif opt_config['name'] == 'Rprop':
  101. net_opt = Rprop(params, learning_rate=opt_config['lr'], etas=opt_config['etas'],
  102. step_sizes=opt_config['step_sizes'], weight_decay=0.0)
  103. trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt)
  104. data, label = make_fake_data()
  105. for i in range(20):
  106. loss = trainonestepcell(data[i], label[i])
  107. losses.append(loss.asnumpy())
  108. if opt_config['name'] == 'ASGD':
  109. return np.array(losses), net_opt
  110. return np.array(losses)
  111. loss_default_asgd = np.array([3.01246792e-01, 1.20041794e+02, 1.38681079e+03, 2.01250820e+01,
  112. 3.27283554e+01, 4.76963005e+01, 6.47094269e+01, 8.34786530e+01,
  113. 1.03742706e+02, 1.25265739e+02, 1.47835190e+02, 1.71259613e+02,
  114. 1.95367035e+02, 2.20003204e+02, 2.45029831e+02, 2.70323456e+02,
  115. 2.95774048e+02, 3.21283752e+02, 3.46765594e+02, 3.72143097e+02], dtype=np.float32)
  116. loss_not_default_asgd = np.array([3.01246792e-01, 1.26019104e+02, 1.90600449e+02, 9.70605755e+00,
  117. 2.98419113e+01, 3.68430023e+02, 1.06318066e+04, 1.35017746e+02,
  118. 1.68673813e+02, 2.05914215e+02, 2.46694992e+02, 2.90972443e+02,
  119. 3.38703430e+02, 3.89845123e+02, 4.44355103e+02, 5.02191406e+02,
  120. 5.63312500e+02, 6.27676941e+02, 6.95244202e+02, 7.65973816e+02], dtype=np.float32)
  121. loss_group_asgd = np.array([3.01246792e-01, 7.26708527e+01, 2.84905312e+05, 4.17499258e+04,
  122. 1.46797949e+04, 5.07966602e+03, 1.70935132e+03, 5.47094910e+02,
  123. 1.59216995e+02, 3.78818207e+01, 5.18196869e+00, 2.62275129e-03,
  124. 2.09768796e+00, 5.23108435e+00, 7.78943682e+00, 9.57108879e+00,
  125. 1.07310610e+01, 1.14618425e+01, 1.19147835e+01, 1.21936722e+01], dtype=np.float32)
  126. loss_default_rprop = np.array([3.01246792e-01, 1.19871742e+02, 4.13467163e+02, 8.09146179e+02,
  127. 1.22364807e+03, 1.56787573e+03, 1.75733594e+03, 1.72866272e+03,
  128. 1.46183936e+03, 1.00406335e+03, 4.84076874e+02, 9.49734650e+01,
  129. 2.00592804e+01, 1.87920704e+01, 1.53733969e+01, 1.85836582e+01,
  130. 5.21527790e-02, 2.01522671e-02, 7.19913816e+00, 8.52459526e+00], dtype=np.float32)
  131. loss_not_default_rprop = np.array([3.0124679e-01, 1.2600269e+02, 4.7351608e+02, 1.0220379e+03,
  132. 1.7181555e+03, 2.4367019e+03, 2.9170872e+03, 2.7243464e+03,
  133. 1.4999669e+03, 7.5820435e+01, 1.0590715e+03, 5.4336096e+02,
  134. 7.0162407e+01, 8.2754419e+02, 9.6329260e+02, 3.4475109e+01,
  135. 5.3843134e+02, 6.0064526e+02, 1.1046149e+02, 3.5530117e+03], dtype=np.float32)
  136. loss_group_rprop = np.array([3.0124679e-01, 7.1360558e+01, 4.8910957e+01, 2.1730331e+02,
  137. 3.0747052e+02, 5.2734237e+00, 5.6865869e+00, 1.7116127e+02,
  138. 2.0539343e+02, 2.2993685e+01, 2.6194101e+02, 2.8772815e+02,
  139. 2.4236647e+01, 3.9299741e+02, 3.5600668e+02, 1.4759110e+01,
  140. 7.2244568e+02, 8.1952783e+02, 9.8913864e+01, 1.1141744e+03], dtype=np.float32)
  141. default_fc1_weight_asgd = np.array([[-0.9451941, -0.71258026, -1.2602371, -1.4823773,
  142. -0.974408, -1.2709816, -1.4194703, -1.2137808],
  143. [-1.5341775, -2.0636342, -1.4916497, -1.3753126,
  144. -1.9375193, -1.308271, -1.6262367, -1.9794592],
  145. [-1.9886293, -2.0906024, -1.8060291, -1.5117803,
  146. -1.6760755, -2.2942104, -1.7208353, -1.5884445],
  147. [-2.071215, -2.2000103, -1.9404325, -1.7647781,
  148. -1.4022746, -1.6987679, -2.0481179, -1.5297506]], dtype=np.float32)
  149. default_fc1_bias_asgd = np.array([-0.17978168, -1.0764512, -0.578816, -0.2928958], dtype=np.float32)
  150. default_fc2_weight_asgd = np.array([[4.097412, 6.2694297, 5.9203916, 5.3845487]], dtype=np.float32)
  151. default_fc2_bias_asgd = np.array([6.904814], dtype=np.float32)
  152. no_default_fc1_weight_asgd = np.array([[-1.3406217, -1.1080127, -1.655658, -1.8777936,
  153. -1.3698348, -1.6664025, -1.8148884, -1.6092018],
  154. [-1.1475986, -1.6770473, -1.1050745, -0.98873824,
  155. -1.5509329, -0.9216978, -1.2396574, -1.5928726],
  156. [-1.2329121, -1.334883, -1.050313, -0.756071,
  157. -0.92036265, -1.5384867, -0.96512324, -0.8327349],
  158. [-1.0685704, -1.1973612, -0.9377885, -0.7621386,
  159. -0.39964262, -0.69612867, -1.0454736, -0.52711576]], dtype=np.float32)
  160. no_default_fc1_bias_asgd = np.array([0.41264832, -0.19961096, 0.37743938, 0.65807366], dtype=np.float32)
  161. no_default_fc2_weight_asgd = np.array([[-5.660916, -5.9415145, -5.1402636, -4.199707]], dtype=np.float32)
  162. no_default_fc2_bias_asgd = np.array([0.5082278], dtype=np.float32)
  163. no_default_group_fc1_weight_asgd = np.array([[-32.526627, -32.29401, -32.8416, -33.06367, -32.55584,
  164. -32.852345, -33.000767, -32.795143],
  165. [-33.164936, -33.69432, -33.12241, -33.006073, -33.568207,
  166. -32.9391, -33.256996, -33.61015],
  167. [-33.118973, -33.220943, -32.936436, -32.642193, -32.806488,
  168. -33.424484, -32.85125, -32.718857],
  169. [-30.155754, -30.284513, -30.025005, -29.849358, -29.486917,
  170. -29.783375, -30.132658, -29.614393]], dtype=np.float32)
  171. no_default_group_fc1_bias_asgd = np.array([-15.838092, -16.811989, -16.078112, -14.289094], dtype=np.float32)
  172. no_default_group_fc2_weight_asgd = np.array([[1288.7146, 1399.3041, 1292.8445, 1121.4629]], dtype=np.float32)
  173. no_default_group_fc2_bias_asgd = np.array([18.513494], dtype=np.float32)