Merge pull request !3098 from wangnan39/sparse_optimizer_adapter_indexedslicetags/v0.6.0-beta
| @@ -108,24 +108,26 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po | |||
| validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) | |||
| @_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", | |||
| @_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "IndexedSlices", | |||
| "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||
| moment1, moment2, ps_parameter): | |||
| """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" | |||
| success = True | |||
| indices = gradient.indices() | |||
| values = gradient.values() | |||
| if ps_parameter: | |||
| op_shape = P.Shape() | |||
| _ps_pull = P.Pull() | |||
| _ps_push = P.Push("Adam", [0, 1, 2]) | |||
| shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), | |||
| op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), | |||
| op_shape(beta2), op_shape(eps), op_shape(gradient[1]), op_shape(gradient[0])) | |||
| op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) | |||
| success = F.depend(success, _ps_pull(_ps_push((beta1_power, beta2_power, lr, beta1, beta2, | |||
| eps, gradient[1], gradient[0]), shapes), params)) | |||
| eps, values, indices), shapes), params)) | |||
| else: | |||
| success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | |||
| eps, gradient[1], gradient[0])) | |||
| eps, values, indices)) | |||
| return success | |||
| @@ -149,17 +151,19 @@ def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, b | |||
| @_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tuple", "Tensor", "Tensor", "Tensor") | |||
| "Tensor", "IndexedSlices", "Tensor", "Tensor", "Tensor") | |||
| def _run_push_pull_opt_with_sparse(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||
| moment1, moment2): | |||
| """Apply sparse adam optimizer by push and pull to the weight parameter when the gradient is sparse.""" | |||
| success = True | |||
| op_shape = P.Shape() | |||
| values = gradient.values() | |||
| indices = gradient.indices() | |||
| shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), | |||
| op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), | |||
| op_shape(beta2), op_shape(eps), op_shape(gradient[1]), op_shape(gradient[0])) | |||
| op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) | |||
| success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, | |||
| eps, gradient[1], gradient[0]), shapes), params)) | |||
| eps, values, indices), shapes), params)) | |||
| return success | |||
| @@ -25,20 +25,22 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") | |||
| _ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt") | |||
| @_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor", | |||
| @_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "IndexedSlices", "Tensor", | |||
| "Tensor", "Bool") | |||
| def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment, | |||
| ps_parameter): | |||
| """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" | |||
| success = True | |||
| indices = gradient.indices() | |||
| values = gradient.values() | |||
| if ps_parameter: | |||
| op_shape = P.Shape() | |||
| _ps_pull = P.Pull() | |||
| _ps_push = P.Push("Ftrl", [0, 1, 2]) | |||
| shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(gradient[1]), op_shape(gradient[0])) | |||
| success = F.depend(success, _ps_pull(_ps_push((gradient[1], gradient[0]), shapes), weight)) | |||
| shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) | |||
| success = F.depend(success, _ps_pull(_ps_push((values, indices), shapes), weight)) | |||
| else: | |||
| success = F.depend(success, spars_opt(weight, moment, linear, gradient[1], gradient[0])) | |||
| success = F.depend(success, spars_opt(weight, moment, linear, values, indices)) | |||
| return success | |||
| @@ -58,14 +60,16 @@ def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gra | |||
| return success | |||
| @_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", | |||
| @_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "IndexedSlices", | |||
| "Tensor", "Tensor") | |||
| def _tensor_run_push_pull_opt_with_sparse(push, pull, learning_rate, l1, l2, lr_power, linear, gradient, | |||
| weight, moment): | |||
| success = True | |||
| op_shape = P.Shape() | |||
| shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(gradient[1]), op_shape(gradient[0])) | |||
| success = F.depend(success, pull(push((gradient[1], gradient[0]), shapes), weight)) | |||
| values = gradient.values() | |||
| indices = gradient.indices() | |||
| shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) | |||
| success = F.depend(success, pull(push((values, indices), shapes), weight)) | |||
| return success | |||
| @@ -27,14 +27,14 @@ from .optimizer import Optimizer | |||
| _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") | |||
| @_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", | |||
| "Tensor", "Tensor", "Tensor") | |||
| @_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", | |||
| "IndexedSlices", "Tensor", "Tensor", "Tensor") | |||
| def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||
| moment1, moment2): | |||
| """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" | |||
| success = True | |||
| success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | |||
| eps, gradient[1], gradient[0])) | |||
| eps, gradient.values(), gradient.indices())) | |||
| return success | |||
| @@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.tensor import Tensor, IndexedSlices | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| @@ -490,12 +490,14 @@ op_gather = P.GatherV2() | |||
| _apply_decay = C.MultitypeFuncGraph("apply_decay") | |||
| @_apply_decay.register("Number", "Bool", "Tensor", "Tuple") | |||
| @_apply_decay.register("Number", "Bool", "Tensor", "IndexedSlices") | |||
| def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): | |||
| """Get grad with weight_decay.""" | |||
| if if_apply: | |||
| weight = op_gather(weight, gradient[0], 0) | |||
| return gradient[0], op_add((weight * weight_decay, gradient[1])), gradient[2] | |||
| indices = gradient.indices() | |||
| values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values())) | |||
| shape = gradient.dense_shape() | |||
| return IndexedSlices(indices, values, shape) | |||
| return gradient | |||
| @@ -518,9 +520,9 @@ def tensor_grad_scale(scale, grad): | |||
| return grad * scale | |||
| @_grad_scale.register("Number", "Tuple") | |||
| @_grad_scale.register("Number", "IndexedSlices") | |||
| def tensor_grad_scale_with_sparse(scale, grad): | |||
| """Get grad with scale.""" | |||
| if scale == 1.0: | |||
| return grad | |||
| return grad[0], grad[1] * scale, grad[2] | |||
| return IndexedSlices(grad.indices(), grad.values() * scale, grad.dense_shape()) | |||
| @@ -23,11 +23,12 @@ from .optimizer import Optimizer | |||
| _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") | |||
| @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tuple", "Tensor", "Tensor") | |||
| @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor", | |||
| "Tensor") | |||
| def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum): | |||
| """Apply sparse proximal_ada_grad optimizer to the weight parameter.""" | |||
| success = True | |||
| success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient[1], gradient[0])) | |||
| success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices())) | |||
| return success | |||
| @@ -16,6 +16,7 @@ | |||
| from mindspore import context | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.communication.management import GlobalComm, get_group_size | |||
| from mindspore.common.tensor import IndexedSlices | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.ops.operations.comm_ops import AllReduce, AllGather | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| @@ -77,7 +78,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc | |||
| return grad | |||
| @reduce_opt.register("Number", "Bool", "Function", "Bool", "Tuple", "Function") | |||
| @reduce_opt.register("Number", "Bool", "Function", "Bool", "IndexedSlices", "Function") | |||
| def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce): | |||
| """ | |||
| Apply allgather on gradient instead of allreduce for sparse feature. | |||
| @@ -88,21 +89,21 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, gr | |||
| mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. | |||
| allgather (Primitive): The communication operator for sparse gradients. | |||
| allreduce_filter (bool): When it is true, allgather would apply. | |||
| grad (tuple): The indices, gradient tensor and tensor_shape before operation. | |||
| grad (IndexedSlices): The gradient before operation. | |||
| allreduce (Primitive): The communication operator for gradients. | |||
| Returns: | |||
| Tuple, include indices, the gradient tensor and tensor_shape after operation. | |||
| IndexedSlices, the gradient after operation. | |||
| """ | |||
| if allreduce_filter: | |||
| indices = allgather(grad[0]) | |||
| dout = allgather(grad[1]) | |||
| indices = allgather(grad.indices()) | |||
| dout = allgather(grad.values()) | |||
| if mean: | |||
| degree = F.scalar_cast(degree, F.dtype(grad[1])) | |||
| degree = F.scalar_cast(degree, F.dtype(grad.values())) | |||
| cast_op = P.Cast() | |||
| mul_op = P.Mul() | |||
| dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) | |||
| grad = (indices, dout, grad[2]) | |||
| grad = IndexedSlices(indices, dout, grad.dense_shape()) | |||
| return grad | |||
| @@ -123,18 +124,18 @@ def _tensors_get_datatype(grad): | |||
| return F.dtype(grad) | |||
| @_get_datatype.register("Tuple") | |||
| @_get_datatype.register("IndexedSlices") | |||
| def _tensors_get_datatype_with_sparse(grad): | |||
| """ | |||
| Acquire gradient datatype. | |||
| Args: | |||
| grad (Tuple): The gradient tensor before operation. | |||
| grad (IndexedSlices): The gradient before operation. | |||
| Returns: | |||
| mstype, the datatype of gradient. | |||
| """ | |||
| return F.dtype(grad[1]) | |||
| return F.dtype(grad.values()) | |||
| _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") | |||
| @@ -155,20 +156,20 @@ def _tensors_cast_datatype(datatype, grad): | |||
| return F.cast(grad, datatype) | |||
| @_cast_datatype.register("TypeType", "Tuple") | |||
| @_cast_datatype.register("TypeType", "IndexedSlices") | |||
| def _tensors_cast_datatype_with_sparse(datatype, grad): | |||
| """ | |||
| Cast gradient to datatype. | |||
| Args: | |||
| datatype (mstype): the destination datatype of gradient. | |||
| grad (Tuple): The gradient tensor before operation. | |||
| grad (IndexedSlices): The gradient before operation. | |||
| Returns: | |||
| Tuple, the gradient tuple after operation. | |||
| IndexedSlices, the gradient after operation. | |||
| """ | |||
| dout = F.cast(grad[1], datatype) | |||
| return (grad[0], dout, grad[2]) | |||
| dout = F.cast(grad.values(), datatype) | |||
| return IndexedSlices(grad.indices(), dout, grad.dense_shape()) | |||
| class DistributedGradReducer(Cell): | |||
| @@ -25,6 +25,7 @@ from .grad_base import bprop_getters | |||
| from ..primitive import constexpr | |||
| from ... import context | |||
| from ...common import dtype as mstype | |||
| from ...common.tensor import IndexedSlices | |||
| reduce_sum = P.ReduceSum() | |||
| unsorted_segment_sum = P.UnsortedSegmentSum() | |||
| @@ -206,7 +207,7 @@ def get_bprop_embedding_lookup(self): | |||
| actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail | |||
| # Reshape the 'actual_dout' on device | |||
| actual_dout = reshape_op(dout, actual_dout_shape_changed) | |||
| return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) | |||
| return IndexedSlices(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) | |||
| return bprop_sparse | |||
| @@ -335,7 +336,7 @@ def get_bprop_sparse_gather_v2(self): | |||
| values_shape = indices_size + x_tail_shp | |||
| values = reshape(dout, values_shape) | |||
| indices = reshape(indices, indices_size) | |||
| return (indices, values, x_shp), zeros_like(indices), zeros_like(axis) | |||
| return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis) | |||
| if F.rank(dout) == 0: | |||
| dout = P.ExpandDims()(dout, -1) | |||
| if F.rank(indices) == 0: | |||
| @@ -17,6 +17,7 @@ | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import functional as F | |||
| from .. import operations as P | |||
| from ...common.tensor import IndexedSlices | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||
| _GetTensorSlice, _MirrorOperator, ReduceOp, | |||
| @@ -46,9 +47,9 @@ def get_bprop_all_reduce(self): | |||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||
| dx = all_reduce_grad(dout) | |||
| else: | |||
| indices = all_gather(dout[0]) | |||
| grad = all_gather(dout[1]) | |||
| dx = (indices, grad, dout[2]) | |||
| indices = all_gather(dout.indices()) | |||
| grad = all_gather(dout.values()) | |||
| dx = IndexedSlices(indices, grad, dout.dense_shape()) | |||
| return (dx,) | |||
| else: | |||
| @@ -59,12 +60,12 @@ def get_bprop_all_reduce(self): | |||
| z = cast(z, dtype(dx)) | |||
| dx = mul(dx, z) | |||
| else: | |||
| indices = all_gather(dout[0]) | |||
| grad = all_gather(dout[1]) | |||
| indices = all_gather(dout.indices()) | |||
| grad = all_gather(dout.values()) | |||
| z = equal(x, out) | |||
| z = cast(z, dtype(grad)) | |||
| grad = mul(grad, z) | |||
| dx = (indices, grad, dout[2]) | |||
| dx = IndexedSlices(indices, grad, dout.dense_shape()) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -194,19 +195,19 @@ def get_bprop_mirror_operator(self): | |||
| num = F.scalar_cast(dev_num, F.dtype(dx)) | |||
| dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) | |||
| else: | |||
| indices = all_gather(dout[0]) | |||
| grad = all_gather(dout[1]) | |||
| indices = all_gather(dout.indices()) | |||
| grad = all_gather(dout.values()) | |||
| float_one = F.scalar_cast(1.0, F.dtype(grad)) | |||
| num = F.scalar_cast(dev_num, F.dtype(grad)) | |||
| grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) | |||
| dx = (indices, grad, dout[2]) | |||
| dx = (indices, grad, dout.dense_shape()) | |||
| else: | |||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||
| dx = all_reduce(dout) | |||
| else: | |||
| indices = all_gather(dout[0]) | |||
| grad = all_gather(dout[1]) | |||
| dx = (indices, grad, dout[2]) | |||
| indices = all_gather(dout.indices()) | |||
| grad = all_gather(dout.values()) | |||
| dx = (indices, grad, dout.dense_shape()) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -1,174 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test adam """ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter, context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Optimizer | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| context.set_context(enable_sparse=True) | |||
| adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") | |||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): | |||
| op_mul = P.Mul() | |||
| op_square = P.Square() | |||
| op_sqrt = P.Sqrt() | |||
| op_cast = P.Cast() | |||
| op_reshape = P.Reshape() | |||
| op_shape = P.Shape() | |||
| param_fp32 = op_cast(param, mstype.float32) | |||
| m_fp32 = op_cast(m, mstype.float32) | |||
| v_fp32 = op_cast(v, mstype.float32) | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) | |||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) | |||
| - beta2, op_square(gradient_fp32)) | |||
| update = next_m / (op_sqrt(next_v) + eps) | |||
| if decay_flag: | |||
| update = update + op_mul(weight_decay_tensor, param_fp32) | |||
| update_with_lr = op_mul(lr, update) | |||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||
| next_v = F.depend(next_v, F.assign(param, next_param)) | |||
| next_v = F.depend(next_v, F.assign(m, next_m)) | |||
| next_v = F.depend(next_v, F.assign(v, next_v)) | |||
| return next_v | |||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tuple", "Bool") | |||
| def _update_run_op_sparse_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): | |||
| return gradient[2][2] | |||
| def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||
| """Check the type of inputs.""" | |||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||
| validator.check_value_type("eps", eps, [float], prim_name) | |||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | |||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||
| class AdamWeightDecaySparse(Optimizer): | |||
| """ | |||
| Implements Adam algorithm weight decay fix. | |||
| Args: | |||
| params (list[Parameter]): A list of parameter, which will be updated. The element in `params` | |||
| should be class mindspore.Parameter. | |||
| learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is | |||
| Iterable or a Tensor and the dims of the Tensor is 1, | |||
| use dynamic learning rate, then the i-th step will | |||
| take the i-th value as the learning rate. | |||
| When the learning_rate is float or learning_rate is a Tensor | |||
| but the dims of the Tensor is 0, use fixed learning rate. | |||
| Other cases are not supported. Default: 1e-3. | |||
| beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. | |||
| Should be in range (0.0, 1.0). | |||
| beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. | |||
| Should be in range (0.0, 1.0). | |||
| eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. | |||
| Should be greater than 0. | |||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | |||
| lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. | |||
| Inputs: | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`, | |||
| and might be in sparse format. | |||
| Outputs: | |||
| tuple[Parameter], the updated velocity value, the shape is the same as `params`. | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> optim = nn.AdamWeightDecay(params=net.trainable_params()) | |||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | |||
| """ | |||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(AdamWeightDecaySparse, self).__init__(learning_rate, params) | |||
| if self.is_group: | |||
| raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") | |||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||
| self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | |||
| self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | |||
| self.eps = Tensor(np.array([eps]).astype(np.float32)) | |||
| self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) | |||
| self.params = self.parameters | |||
| self.moments1 = self.params.clone(prefix="adam_m", init='zeros') | |||
| self.moments2 = self.params.clone(prefix="adam_v", init='zeros') | |||
| self.decay_flag = tuple(decay_filter(x) for x in self.params) | |||
| self.map = C.Map() | |||
| def construct(self, gradients): | |||
| lr = self.get_lr() | |||
| updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| return updated_velocity | |||
| def test_AdamWeightDecaySparse(): | |||
| """ test_AdamWeightDecaySparse """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class Loss(nn.Cell): | |||
| def __init__(self): | |||
| super(Loss, self).__init__() | |||
| def construct(self, base, target): | |||
| return base | |||
| class NetWithSparseGatherV2(nn.Cell): | |||
| def __init__(self): | |||
| super(NetWithSparseGatherV2, self).__init__() | |||
| self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") | |||
| self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") | |||
| self.gatherv2 = P.SparseGatherV2() | |||
| self.axis = 0 | |||
| def construct(self, indices): | |||
| return self.gatherv2(self.w1, indices, self.axis) * self.w2 | |||
| inputs = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| loss = Loss() | |||
| optimizer = AdamWeightDecaySparse(net.trainable_params()) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||
| _executor.compile(train_network, inputs, label) | |||
| @@ -19,8 +19,8 @@ import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common.tensor import Tensor, IndexedSlices | |||
| from mindspore.ops import composite as C, operations as P | |||
| from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator | |||
| from mindspore.ops._grad.grad_base import bprop_getters | |||
| from mindspore._checkparam import Validator as validator | |||
| @@ -65,7 +65,7 @@ def get_bprop_gather_v2(self): | |||
| """Generate bprop for GatherV2""" | |||
| def bprop(x, indices, axis, out, dout): | |||
| return (indices, dout, x), axis, out | |||
| return IndexedSlices(indices, dout, x), axis, out | |||
| return bprop | |||
| @@ -78,7 +78,7 @@ def test_bprop_with_sparse_feature_allreduce(): | |||
| if shape is None: | |||
| shape = [8, 8] | |||
| self.all_reduce = AllReduce() | |||
| self.gatherv2 = VirtualGatherV2() | |||
| self.gatherv2 = P.GatherV2() | |||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | |||
| self.axis = axis | |||
| @@ -102,7 +102,7 @@ def test_bprop_with_sparse_feature_mirror(): | |||
| if shape is None: | |||
| shape = [8, 8] | |||
| self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP) | |||
| self.gatherv2 = VirtualGatherV2() | |||
| self.gatherv2 = P.GatherV2() | |||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | |||
| self.axis = axis | |||