| @@ -166,7 +166,8 @@ class Adam(Optimizer): | |||||
| """ | """ | ||||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, | def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, | ||||
| use_nesterov=False, weight_decay=0.0, loss_scale=1.0): | |||||
| use_nesterov=False, weight_decay=0.0, loss_scale=1.0, | |||||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||||
| super(Adam, self).__init__(learning_rate, params) | super(Adam, self).__init__(learning_rate, params) | ||||
| _check_param_value(beta1, beta2, eps, weight_decay) | _check_param_value(beta1, beta2, eps, weight_decay) | ||||
| validator.check_type("use_locking", use_locking, [bool]) | validator.check_type("use_locking", use_locking, [bool]) | ||||
| @@ -192,6 +193,7 @@ class Adam(Optimizer): | |||||
| self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') | self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') | ||||
| self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') | self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') | ||||
| self.decay_tf = tuple(decay_filter(x) for x in self.parameters) | |||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.opt = P.Adam(use_locking, use_nesterov) | self.opt = P.Adam(use_locking, use_nesterov) | ||||
| self.weight_decay = weight_decay * loss_scale | self.weight_decay = weight_decay * loss_scale | ||||