You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

thor.py 9.9 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """momentum"""
  16. import mindspore.common.dtype as mstype
  17. from mindspore.common.initializer import initializer
  18. from mindspore.common.parameter import Parameter
  19. from mindspore.common.parameter import ParameterTuple
  20. from mindspore.common.tensor import Tensor
  21. from mindspore.nn.optim.optimizer import Optimizer
  22. from mindspore.ops import functional as F, composite as C, operations as P
  23. from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
  24. from model.grad_reducer_thor import DistributedGradReducerThor
  25. momentum_opt = C.MultitypeFuncGraph("momentum_opt")
  26. @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
  27. def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
  28. """Apply momentum optimizer to the weight parameter using Tensor."""
  29. success = True
  30. success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
  31. return success
  32. op_add = P.AddN()
  33. apply_decay = C.MultitypeFuncGraph("apply_decay")
  34. @apply_decay.register("Number", "Bool", "Tensor", "Tensor")
  35. def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
  36. """Get grad with weight_decay."""
  37. if if_apply:
  38. return op_add((weight * weight_decay, gradient))
  39. return gradient
  40. class THOR(Optimizer):
  41. """THOR"""
  42. def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
  43. loss_scale=1.0,
  44. decay_filter=lambda x: x.name not in []):
  45. super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
  46. if isinstance(momentum, float) and momentum < 0.0:
  47. raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
  48. self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
  49. self.params = self.parameters
  50. self.moments = self.params.clone(prefix="moments", init='zeros')
  51. self.hyper_map = C.HyperMap()
  52. self.opt = P.ApplyMomentum()
  53. self.matrix_A = ParameterTuple(matrix_A)
  54. self.matrix_G = ParameterTuple(matrix_G)
  55. self.A_inv_max = ParameterTuple(A_inv_max)
  56. self.G_inv_max = ParameterTuple(G_inv_max)
  57. self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
  58. self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
  59. self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
  60. self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
  61. self.transpose = P.Transpose()
  62. self.shape = P.Shape()
  63. self.reshape = P.Reshape()
  64. self.mul = P.Mul()
  65. self.weight_idx = []
  66. for i in range(len(self.params)):
  67. if "conv" in self.params[i].name or "end_point" in self.params[i].name:
  68. self.weight_idx.append(i)
  69. self.weight_idx.append(len(self.params))
  70. self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
  71. 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
  72. 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
  73. 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
  74. 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
  75. 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
  76. 1.0 / 196, 1.0 / 196, 1.0 / 196,
  77. 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
  78. 1.0]
  79. mean = _get_mirror_mean()
  80. degree = _get_device_num()
  81. self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree)
  82. self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree)
  83. self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree)
  84. self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree)
  85. self.matrix_A_inv = ()
  86. self.matrix_G_inv = ()
  87. self.matrix_max_inv = ()
  88. for i in range(54):
  89. self.matrix_max_inv = self.matrix_max_inv + (
  90. Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
  91. self.log = P.Log()
  92. self.exp = P.Exp()
  93. self.sqrt = P.Sqrt()
  94. self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
  95. self.assign = P.Assign()
  96. self.cast = P.Cast()
  97. self.thor = True
  98. self.weight_decay = weight_decay * loss_scale
  99. self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
  100. def construct(self, gradients):
  101. params = self.params
  102. moments = self.moments
  103. if self.thor:
  104. matrix_A_allreduce = ()
  105. matrix_G_allreduce = ()
  106. matrix_A_max_allreduce = ()
  107. matrix_G_max_allreduce = ()
  108. for i in range(54):
  109. g = gradients[i * 3]
  110. matrix_A = self.matrix_A[i]
  111. matrix_G = self.matrix_G[i]
  112. A_max = self.A_inv_max[i]
  113. G_max = self.G_inv_max[i]
  114. matrix_A = F.depend(matrix_A, g)
  115. matrix_G = F.depend(matrix_G, g)
  116. A_max = F.depend(A_max, g)
  117. G_max = F.depend(G_max, g)
  118. matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
  119. matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
  120. matrix_A_max_allreduce = matrix_A_max_allreduce + (A_max,)
  121. matrix_G_max_allreduce = matrix_G_max_allreduce + (G_max,)
  122. matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce)
  123. matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce)
  124. matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce)
  125. matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce)
  126. new_grads = ()
  127. for i in range(54):
  128. g = gradients[i * 3]
  129. temp_a = matrix_A_allreduce[i]
  130. temp_g = matrix_G_allreduce[i]
  131. temp_a = self.cast(temp_a, mstype.float32)
  132. temp_g = self.cast(temp_g, mstype.float32)
  133. matrix_A_inv_max = self.log(matrix_A_max_allreduce[i])
  134. matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
  135. matrix_A_inv_max = self.exp(matrix_A_inv_max)
  136. temp_a = self.mul(temp_a, matrix_A_inv_max)
  137. matrix_G_inv_max = self.log(matrix_G_max_allreduce[i])
  138. matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
  139. matrix_G_inv_max = self.exp(matrix_G_inv_max)
  140. temp_g = self.mul(temp_g, matrix_G_inv_max)
  141. temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i])
  142. temp_max = self.mul(temp_max, self.feature_map[i])
  143. if i == 53:
  144. g = self.cube_matmul_left_fc(temp_g, g)
  145. g = self.cube_matmul_right_fc(g, temp_a, temp_max)
  146. else:
  147. g = self.cube_matmul_left(temp_g, g)
  148. g = self.cube_matmul_right_mul(g, temp_a, temp_max)
  149. fake_A = self.assign(self.matrix_A[i], temp_a)
  150. fake_G = self.assign(self.matrix_G[i], temp_g)
  151. fake_max = self.assign(self.matrix_max_inv[i], temp_max)
  152. g = F.depend(g, fake_A)
  153. g = F.depend(g, fake_G)
  154. g = F.depend(g, fake_max)
  155. if i == 53:
  156. new_grads = new_grads + (g,)
  157. else:
  158. new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
  159. gradients = new_grads
  160. else:
  161. new_grads = ()
  162. for i in range(54):
  163. g = gradients[i * 3]
  164. matrix_A = self.matrix_A[i]
  165. matrix_G = self.matrix_G[i]
  166. matrix_max = self.matrix_max_inv[i]
  167. matrix_A = F.depend(matrix_A, g)
  168. matrix_G = F.depend(matrix_G, g)
  169. matrix_max = F.depend(matrix_max, g)
  170. if i == 53:
  171. g = self.cube_matmul_left_fc(matrix_G, g)
  172. g = self.cube_matmul_right_fc(g, matrix_A, matrix_max)
  173. new_grads = new_grads + (g,)
  174. else:
  175. g = self.cube_matmul_left(matrix_G, g)
  176. g = self.cube_matmul_right_mul(g, matrix_A, matrix_max)
  177. new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
  178. gradients = new_grads
  179. if self.weight_decay > 0:
  180. gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
  181. params, gradients)
  182. gradients = self.scale_grad(gradients)
  183. lr = self.get_lr()
  184. success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
  185. return success