Merge pull request !2116 from ghzl/change-multitypefungraph-to-internal-interfacetags/v0.5.0-beta
| @@ -26,10 +26,10 @@ from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer | |||
| adam_opt = C.MultitypeFuncGraph("adam_opt") | |||
| _adam_opt = C.MultitypeFuncGraph("adam_opt") | |||
| @adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): | |||
| """ | |||
| Update parameters. | |||
| @@ -101,8 +101,8 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po | |||
| validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) | |||
| @adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", | |||
| "Tensor", "Tensor", "Tensor") | |||
| @_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", | |||
| "Tensor", "Tensor", "Tensor") | |||
| def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||
| moment1, moment2): | |||
| """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" | |||
| @@ -112,8 +112,8 @@ def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2 | |||
| return success | |||
| @adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor") | |||
| @_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor") | |||
| def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||
| moment1, moment2): | |||
| """Apply adam optimizer to the weight parameter using Tensor.""" | |||
| @@ -261,11 +261,11 @@ class Adam(Optimizer): | |||
| beta2_power = self.beta2_power * self.beta2 | |||
| self.beta2_power = beta2_power | |||
| if self.is_group_lr: | |||
| success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power, | |||
| success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power, | |||
| self.beta1, self.beta2, self.eps), | |||
| lr, gradients, params, moment1, moment2) | |||
| else: | |||
| success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power, | |||
| success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power, | |||
| self.beta1, self.beta2, self.eps, lr), | |||
| gradients, params, moment1, moment2) | |||
| return success | |||
| @@ -328,7 +328,7 @@ class AdamWeightDecay(Optimizer): | |||
| def construct(self, gradients): | |||
| lr = self.get_lr() | |||
| updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| @@ -424,7 +424,7 @@ class AdamWeightDecayDynamicLR(Optimizer): | |||
| warmup_lr = self.start_learning_rate * warmup_percent | |||
| is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) | |||
| lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr | |||
| updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| @@ -18,13 +18,13 @@ from mindspore.common import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer, apply_decay, grad_scale | |||
| from .optimizer import Optimizer, _apply_decay, _grad_scale | |||
| ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") | |||
| _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") | |||
| @ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor", | |||
| "Tensor") | |||
| @_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor", | |||
| "Tensor") | |||
| def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): | |||
| """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" | |||
| success = True | |||
| @@ -32,8 +32,8 @@ def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, | |||
| return success | |||
| @ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", | |||
| "Tensor") | |||
| @_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", | |||
| "Tensor") | |||
| def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): | |||
| """Apply ftrl optimizer to the weight parameter.""" | |||
| success = True | |||
| @@ -124,9 +124,9 @@ class FTRL(Optimizer): | |||
| linear = self.linear | |||
| lr = self.learning_rate | |||
| if self.weight_decay > 0.0: | |||
| grads = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, grads) | |||
| grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads) | |||
| grads = self.scale_grad(grads) | |||
| success = self.map_(F.partial(ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power), | |||
| success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power), | |||
| linear, grads, params, moments) | |||
| return success | |||
| @@ -28,10 +28,10 @@ from .. import layer | |||
| num_one = Tensor(np.ones([1]), mstype.float32) | |||
| lamb_opt = C.MultitypeFuncGraph("lamb_opt") | |||
| _lamb_opt = C.MultitypeFuncGraph("lamb_opt") | |||
| @lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Bool") | |||
| @_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Bool") | |||
| def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v, | |||
| gradient, decay_flag): | |||
| """ | |||
| @@ -227,7 +227,7 @@ class Lamb(Optimizer): | |||
| warmup_lr = self.start_learning_rate * warmup_percent | |||
| is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) | |||
| lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr | |||
| updated_velocity = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, lr, | |||
| updated_velocity = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor, self.global_step), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| @@ -22,12 +22,12 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore._checkparam import Validator as validator | |||
| from .optimizer import grad_scale, Optimizer | |||
| from .optimizer import _grad_scale, Optimizer | |||
| lars_opt = C.MultitypeFuncGraph("lars_opt") | |||
| _lars_opt = C.MultitypeFuncGraph("lars_opt") | |||
| @lars_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||
| @_lars_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||
| def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag): | |||
| """Apply lars optimizer to the weight parameter.""" | |||
| if lars_flag: | |||
| @@ -119,9 +119,9 @@ class LARS(Optimizer): | |||
| else: | |||
| lr = self.learning_rate | |||
| if self.reciprocal_scale != 1.0: | |||
| gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) | |||
| gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients) | |||
| grad_t = self.hyper_map(F.partial(lars_opt, self.lars, self.weight_decay, lr), | |||
| grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, self.weight_decay, lr), | |||
| gradients, params, self.decay_flag, self.lars_flag) | |||
| success = self.opt(grad_t) | |||
| @@ -24,11 +24,11 @@ from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer | |||
| lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") | |||
| _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") | |||
| @lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", | |||
| "Tensor", "Tensor", "Tensor") | |||
| @_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", | |||
| "Tensor", "Tensor", "Tensor") | |||
| def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||
| moment1, moment2): | |||
| """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" | |||
| @@ -38,8 +38,8 @@ def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2 | |||
| return success | |||
| @lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor") | |||
| @_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor") | |||
| def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||
| moment1, moment2): | |||
| """Apply adam optimizer to the weight parameter using Tensor.""" | |||
| @@ -189,11 +189,11 @@ class LazyAdam(Optimizer): | |||
| self.beta2_power = self.beta2_power * self.beta2 | |||
| if self.is_group_lr: | |||
| success = self.map_(F.partial(lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power, | |||
| success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power, | |||
| self.beta2_power, self.beta1, self.beta2, self.eps), | |||
| lr, gradients, self.parameters, self.moment1, self.moment2) | |||
| else: | |||
| success = self.map_(F.partial(lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power, | |||
| success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power, | |||
| self.beta2_power, self.beta1, self.beta2, self.eps, lr), | |||
| gradients, self.parameters, self.moment1, self.moment2) | |||
| return success | |||
| @@ -21,10 +21,10 @@ from mindspore._checkparam import check_bool | |||
| from mindspore._checkparam import Validator as validator | |||
| from .optimizer import Optimizer | |||
| momentum_opt = C.MultitypeFuncGraph("momentum_opt") | |||
| _momentum_opt = C.MultitypeFuncGraph("momentum_opt") | |||
| @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| @_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment): | |||
| """Apply momentum optimizer to the weight parameter using Tensor.""" | |||
| success = True | |||
| @@ -129,7 +129,7 @@ class Momentum(Optimizer): | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| if self.is_group_lr: | |||
| success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments) | |||
| success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments) | |||
| else: | |||
| success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) | |||
| success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) | |||
| return success | |||
| @@ -171,11 +171,11 @@ class Optimizer(Cell): | |||
| params = self.parameters | |||
| if self.is_group: | |||
| if self.exec_weight_decay: | |||
| gradients = self.hyper_map(F.partial(apply_decay), self.weight_decay, self.decay_flags, | |||
| gradients = self.hyper_map(F.partial(_apply_decay), self.weight_decay, self.decay_flags, | |||
| params, gradients) | |||
| else: | |||
| if self.weight_decay > 0: | |||
| gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, | |||
| gradients = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_flags, | |||
| params, gradients) | |||
| return gradients | |||
| @@ -196,7 +196,7 @@ class Optimizer(Cell): | |||
| """ | |||
| if self.reciprocal_scale != 1.0: | |||
| gradients = self.map_(F.partial(grad_scale, self.reciprocal_scale), gradients) | |||
| gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients) | |||
| return gradients | |||
| @@ -390,10 +390,10 @@ class Optimizer(Cell): | |||
| op_add = P.AddN() | |||
| apply_decay = C.MultitypeFuncGraph("apply_decay") | |||
| _apply_decay = C.MultitypeFuncGraph("apply_decay") | |||
| @apply_decay.register("Number", "Bool", "Tensor", "Tensor") | |||
| @_apply_decay.register("Number", "Bool", "Tensor", "Tensor") | |||
| def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): | |||
| """Get grad with weight_decay.""" | |||
| if if_apply: | |||
| @@ -401,10 +401,10 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): | |||
| return gradient | |||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| _grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| @grad_scale.register("Number", "Tensor") | |||
| @_grad_scale.register("Number", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| """Get grad with scale.""" | |||
| if scale == 1.0: | |||
| @@ -412,7 +412,7 @@ def tensor_grad_scale(scale, grad): | |||
| return grad * scale | |||
| @grad_scale.register("Number", "Tuple") | |||
| @_grad_scale.register("Number", "Tuple") | |||
| def tensor_grad_scale_with_sparse(scale, grad): | |||
| """Get grad with scale.""" | |||
| if scale == 1.0: | |||
| @@ -20,10 +20,10 @@ from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer | |||
| proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") | |||
| _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") | |||
| @proximal_ada_grad_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| @_proximal_ada_grad_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _tensor_run_opt(opt, learning_rate, l1, l2, gradient, weight, accum): | |||
| """Apply proximal_ada_grad optimizer to the weight parameter.""" | |||
| success = True | |||
| @@ -94,6 +94,6 @@ class ProximalAdagrad(Optimizer): | |||
| grads = self.decay_weight(grads) | |||
| grads = self.scale_grad(grads) | |||
| lr = self.learning_rate | |||
| success = self.hyper_map(F.partial(proximal_ada_grad_opt, self.opt, lr, self.l1, self.l2), | |||
| success = self.hyper_map(F.partial(_proximal_ada_grad_opt, self.opt, lr, self.l1, self.l2), | |||
| grads, params, accum) | |||
| return success | |||
| @@ -18,21 +18,21 @@ from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer | |||
| rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| _rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| _centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| @rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad): | |||
| @_rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _rmsprop_opt_(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad): | |||
| """Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" | |||
| success = True | |||
| success = F.depend(success, opt(weight, ms, mom, learning_rate, grad, decay, momentum, epsilon)) | |||
| return success | |||
| @centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor") | |||
| def _centered_rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, mg, ms, mom, grad): | |||
| @_centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor") | |||
| def _centered_rmsprop_opt_(opt, decay, epsilon, momentum, learning_rate, weight, mg, ms, mom, grad): | |||
| """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate.""" | |||
| success = True | |||
| success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) | |||
| @@ -187,17 +187,17 @@ class RMSProp(Optimizer): | |||
| lr = self.get_lr() | |||
| if self.centered: | |||
| if self.is_group_lr: | |||
| success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| success = self.hyper_map(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| self.momentum), lr, params, self.mg, self.ms, self.moment, gradients) | |||
| else: | |||
| success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| success = self.hyper_map(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| self.momentum, lr), params, self.mg, self.ms, self.moment, gradients) | |||
| else: | |||
| if self.is_group_lr: | |||
| success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| success = self.hyper_map(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| self.momentum), lr, params, self.ms, self.moment, gradients) | |||
| else: | |||
| success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| success = self.hyper_map(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| self.momentum, lr), params, self.ms, self.moment, gradients) | |||
| return success | |||
| @@ -20,10 +20,10 @@ import mindspore.common.dtype as mstype | |||
| from mindspore._checkparam import Validator as validator | |||
| from .optimizer import Optimizer | |||
| sgd_opt = C.MultitypeFuncGraph("sgd_opt") | |||
| _sgd_opt = C.MultitypeFuncGraph("sgd_opt") | |||
| @sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| @_sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat): | |||
| """Apply sgd optimizer to the weight parameter using Tensor.""" | |||
| success = True | |||
| @@ -154,7 +154,7 @@ class SGD(Optimizer): | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| if self.is_group_lr: | |||
| success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) | |||
| success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) | |||
| else: | |||
| success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat) | |||
| success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat) | |||
| return success | |||