You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

thor.py 10 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """momentum"""
  16. import mindspore.common.dtype as mstype
  17. from mindspore.common.initializer import initializer
  18. from mindspore.common.parameter import Parameter
  19. from mindspore.common.parameter import ParameterTuple
  20. from mindspore.common.tensor import Tensor
  21. from mindspore.nn.optim.optimizer import Optimizer
  22. from mindspore.ops import functional as F, composite as C, operations as P
  23. from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
  24. from cus_ops.cus_matmul_cube_dense_right import CusMatMulCubeDenseRight
  25. from cus_ops.cus_matmul_cube_fracz_left_cast import CusMatMulCubeFraczLeftCast
  26. from cus_ops.cus_matmul_cube_dense_left import CusMatMulCubeDenseLeft
  27. from cus_ops.cus_matmul_cube_fracz_right_mul import CusMatMulCubeFraczRightMul
  28. from model.grad_reducer_thor import DistributedGradReducerThor
  29. momentum_opt = C.MultitypeFuncGraph("momentum_opt")
  30. @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
  31. def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
  32. """Apply momentum optimizer to the weight parameter using Tensor."""
  33. success = True
  34. success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
  35. return success
  36. op_add = P.AddN()
  37. apply_decay = C.MultitypeFuncGraph("apply_decay")
  38. @apply_decay.register("Number", "Bool", "Tensor", "Tensor")
  39. def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
  40. """Get grad with weight_decay."""
  41. if if_apply:
  42. return op_add((weight * weight_decay, gradient))
  43. return gradient
  44. class THOR(Optimizer):
  45. """THOR"""
  46. def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
  47. loss_scale=1.0,
  48. decay_filter=lambda x: x.name not in []):
  49. super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
  50. if isinstance(momentum, float) and momentum < 0.0:
  51. raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
  52. self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
  53. self.params = self.parameters
  54. self.moments = self.params.clone(prefix="moments", init='zeros')
  55. self.hyper_map = C.HyperMap()
  56. self.opt = P.ApplyMomentum()
  57. self.matrix_A = ParameterTuple(matrix_A)
  58. self.matrix_G = ParameterTuple(matrix_G)
  59. self.A_inv_max = ParameterTuple(A_inv_max)
  60. self.G_inv_max = ParameterTuple(G_inv_max)
  61. self.cube_matmul_left = CusMatMulCubeFraczLeftCast()
  62. self.cube_matmul_left_fc = CusMatMulCubeDenseLeft()
  63. self.cube_matmul_right_fc = CusMatMulCubeDenseRight()
  64. self.cube_matmul_right_mul = CusMatMulCubeFraczRightMul()
  65. self.transpose = P.Transpose()
  66. self.shape = P.Shape()
  67. self.reshape = P.Reshape()
  68. self.mul = P.Mul()
  69. self.weight_idx = []
  70. for i in range(len(self.params)):
  71. if "conv" in self.params[i].name or "end_point" in self.params[i].name:
  72. self.weight_idx.append(i)
  73. self.weight_idx.append(len(self.params))
  74. self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
  75. 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
  76. 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
  77. 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
  78. 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
  79. 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
  80. 1.0 / 196, 1.0 / 196, 1.0 / 196,
  81. 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
  82. 1.0]
  83. mean = _get_mirror_mean()
  84. degree = _get_device_num()
  85. self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree)
  86. self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree)
  87. self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree)
  88. self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree)
  89. self.matrix_A_inv = ()
  90. self.matrix_G_inv = ()
  91. self.matrix_max_inv = ()
  92. for i in range(54):
  93. self.matrix_max_inv = self.matrix_max_inv + (
  94. Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
  95. self.log = P.Log()
  96. self.exp = P.Exp()
  97. self.sqrt = P.Sqrt()
  98. self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
  99. self.assign = P.Assign()
  100. self.cast = P.Cast()
  101. self.thor = True
  102. self.weight_decay = weight_decay * loss_scale
  103. self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
  104. def construct(self, gradients):
  105. params = self.params
  106. moments = self.moments
  107. if self.thor:
  108. matrix_A_allreduce = ()
  109. matrix_G_allreduce = ()
  110. matrix_A_max_allreduce = ()
  111. matrix_G_max_allreduce = ()
  112. for i in range(54):
  113. g = gradients[i * 3]
  114. matrix_A = self.matrix_A[i]
  115. matrix_G = self.matrix_G[i]
  116. A_max = self.A_inv_max[i]
  117. G_max = self.G_inv_max[i]
  118. matrix_A = F.depend(matrix_A, g)
  119. matrix_G = F.depend(matrix_G, g)
  120. A_max = F.depend(A_max, g)
  121. G_max = F.depend(G_max, g)
  122. matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
  123. matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
  124. matrix_A_max_allreduce = matrix_A_max_allreduce + (A_max,)
  125. matrix_G_max_allreduce = matrix_G_max_allreduce + (G_max,)
  126. matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce)
  127. matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce)
  128. matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce)
  129. matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce)
  130. new_grads = ()
  131. for i in range(54):
  132. g = gradients[i * 3]
  133. temp_a = matrix_A_allreduce[i]
  134. temp_g = matrix_G_allreduce[i]
  135. temp_a = self.cast(temp_a, mstype.float32)
  136. temp_g = self.cast(temp_g, mstype.float32)
  137. matrix_A_inv_max = self.log(matrix_A_max_allreduce[i])
  138. matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
  139. matrix_A_inv_max = self.exp(matrix_A_inv_max)
  140. temp_a = self.mul(temp_a, matrix_A_inv_max)
  141. matrix_G_inv_max = self.log(matrix_G_max_allreduce[i])
  142. matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
  143. matrix_G_inv_max = self.exp(matrix_G_inv_max)
  144. temp_g = self.mul(temp_g, matrix_G_inv_max)
  145. temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i])
  146. temp_max = self.mul(temp_max, self.feature_map[i])
  147. if i == 53:
  148. g = self.cube_matmul_left_fc(temp_g, g)
  149. g = self.cube_matmul_right_fc(g, temp_a, temp_max)
  150. else:
  151. g = self.cube_matmul_left(temp_g, g)
  152. g = self.cube_matmul_right_mul(g, temp_a, temp_max)
  153. fake_A = self.assign(self.matrix_A[i], temp_a)
  154. fake_G = self.assign(self.matrix_G[i], temp_g)
  155. fake_max = self.assign(self.matrix_max_inv[i], temp_max)
  156. g = F.depend(g, fake_A)
  157. g = F.depend(g, fake_G)
  158. g = F.depend(g, fake_max)
  159. if i == 53:
  160. new_grads = new_grads + (g,)
  161. else:
  162. new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
  163. gradients = new_grads
  164. else:
  165. new_grads = ()
  166. for i in range(54):
  167. g = gradients[i * 3]
  168. matrix_A = self.matrix_A[i]
  169. matrix_G = self.matrix_G[i]
  170. matrix_max = self.matrix_max_inv[i]
  171. matrix_A = F.depend(matrix_A, g)
  172. matrix_G = F.depend(matrix_G, g)
  173. matrix_max = F.depend(matrix_max, g)
  174. if i == 53:
  175. g = self.cube_matmul_left_fc(matrix_G, g)
  176. g = self.cube_matmul_right_fc(g, matrix_A, matrix_max)
  177. new_grads = new_grads + (g,)
  178. else:
  179. g = self.cube_matmul_left(matrix_G, g)
  180. g = self.cube_matmul_right_mul(g, matrix_A, matrix_max)
  181. new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
  182. gradients = new_grads
  183. if self.weight_decay > 0:
  184. gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
  185. params, gradients)
  186. gradients = self.scale_grad(gradients)
  187. lr = self.get_lr()
  188. success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
  189. return success