|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """momentum"""
- import mindspore.common.dtype as mstype
- from mindspore.common.initializer import initializer
- from mindspore.common.parameter import Parameter
- from mindspore.common.parameter import ParameterTuple
- from mindspore.common.tensor import Tensor
- from mindspore.nn.optim.optimizer import Optimizer
- from mindspore.ops import functional as F, composite as C, operations as P
- from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
- from model.grad_reducer_thor import DistributedGradReducerThor
-
- momentum_opt = C.MultitypeFuncGraph("momentum_opt")
-
-
- @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
- def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
- """Apply momentum optimizer to the weight parameter using Tensor."""
- success = True
- success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
- return success
-
-
- op_add = P.AddN()
- apply_decay = C.MultitypeFuncGraph("apply_decay")
-
-
- @apply_decay.register("Number", "Bool", "Tensor", "Tensor")
- def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
- """Get grad with weight_decay."""
- if if_apply:
- return op_add((weight * weight_decay, gradient))
- return gradient
-
-
- class THOR(Optimizer):
- """THOR"""
- def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
- loss_scale=1.0,
- decay_filter=lambda x: x.name not in []):
- super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
- if isinstance(momentum, float) and momentum < 0.0:
- raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
- self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
- self.params = self.parameters
- self.moments = self.params.clone(prefix="moments", init='zeros')
- self.hyper_map = C.HyperMap()
- self.opt = P.ApplyMomentum()
- self.matrix_A = ParameterTuple(matrix_A)
- self.matrix_G = ParameterTuple(matrix_G)
- self.A_inv_max = ParameterTuple(A_inv_max)
- self.G_inv_max = ParameterTuple(G_inv_max)
- self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
- self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
- self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
- self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
- self.transpose = P.Transpose()
- self.shape = P.Shape()
- self.reshape = P.Reshape()
- self.mul = P.Mul()
- self.weight_idx = []
- for i in range(len(self.params)):
- if "conv" in self.params[i].name or "end_point" in self.params[i].name:
- self.weight_idx.append(i)
- self.weight_idx.append(len(self.params))
- self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
- 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
- 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
- 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
- 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
- 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
- 1.0 / 196, 1.0 / 196, 1.0 / 196,
- 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
- 1.0]
- mean = _get_mirror_mean()
- degree = _get_device_num()
- self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree)
- self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree)
- self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree)
- self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree)
- self.matrix_A_inv = ()
- self.matrix_G_inv = ()
- self.matrix_max_inv = ()
-
- for i in range(54):
- self.matrix_max_inv = self.matrix_max_inv + (
- Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
- self.log = P.Log()
- self.exp = P.Exp()
- self.sqrt = P.Sqrt()
- self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
- self.assign = P.Assign()
- self.cast = P.Cast()
- self.thor = True
- self.weight_decay = weight_decay * loss_scale
- self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
-
- def construct(self, gradients):
- params = self.params
- moments = self.moments
- if self.thor:
- matrix_A_allreduce = ()
- matrix_G_allreduce = ()
- matrix_A_max_allreduce = ()
- matrix_G_max_allreduce = ()
- for i in range(54):
- g = gradients[i * 3]
- matrix_A = self.matrix_A[i]
- matrix_G = self.matrix_G[i]
- A_max = self.A_inv_max[i]
- G_max = self.G_inv_max[i]
- matrix_A = F.depend(matrix_A, g)
- matrix_G = F.depend(matrix_G, g)
- A_max = F.depend(A_max, g)
- G_max = F.depend(G_max, g)
- matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
- matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
- matrix_A_max_allreduce = matrix_A_max_allreduce + (A_max,)
- matrix_G_max_allreduce = matrix_G_max_allreduce + (G_max,)
- matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce)
- matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce)
- matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce)
- matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce)
- new_grads = ()
- for i in range(54):
- g = gradients[i * 3]
- temp_a = matrix_A_allreduce[i]
- temp_g = matrix_G_allreduce[i]
- temp_a = self.cast(temp_a, mstype.float32)
- temp_g = self.cast(temp_g, mstype.float32)
- matrix_A_inv_max = self.log(matrix_A_max_allreduce[i])
- matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
- matrix_A_inv_max = self.exp(matrix_A_inv_max)
- temp_a = self.mul(temp_a, matrix_A_inv_max)
- matrix_G_inv_max = self.log(matrix_G_max_allreduce[i])
- matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
- matrix_G_inv_max = self.exp(matrix_G_inv_max)
- temp_g = self.mul(temp_g, matrix_G_inv_max)
- temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i])
- temp_max = self.mul(temp_max, self.feature_map[i])
- temp_a = self.cast(temp_a, mstype.float16)
- temp_g = self.cast(temp_g, mstype.float16)
- if i == 53:
- g = self.cube_matmul_left_fc(temp_g, g)
- g = self.cube_matmul_right_fc(g, temp_a, temp_max)
- else:
- g = self.cube_matmul_left(temp_g, g)
- g = self.cube_matmul_right_mul(g, temp_a, temp_max)
- fake_A = self.assign(self.matrix_A[i], temp_a)
- fake_G = self.assign(self.matrix_G[i], temp_g)
- fake_max = self.assign(self.matrix_max_inv[i], temp_max)
- g = F.depend(g, fake_A)
- g = F.depend(g, fake_G)
- g = F.depend(g, fake_max)
- if i == 53:
- new_grads = new_grads + (g,)
- else:
- new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
- gradients = new_grads
- else:
- new_grads = ()
- for i in range(54):
- g = gradients[i * 3]
- matrix_A = self.matrix_A[i]
- matrix_G = self.matrix_G[i]
- matrix_max = self.matrix_max_inv[i]
- matrix_A = F.depend(matrix_A, g)
- matrix_G = F.depend(matrix_G, g)
- matrix_max = F.depend(matrix_max, g)
- if i == 53:
- g = self.cube_matmul_left_fc(matrix_G, g)
- g = self.cube_matmul_right_fc(g, matrix_A, matrix_max)
- new_grads = new_grads + (g,)
- else:
- g = self.cube_matmul_left(matrix_G, g)
- g = self.cube_matmul_right_mul(g, matrix_A, matrix_max)
- new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
- gradients = new_grads
-
- if self.weight_decay > 0:
- gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
- params, gradients)
- gradients = self.scale_grad(gradients)
- lr = self.get_lr()
- success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
- return success
|