| @@ -138,14 +138,14 @@ class Optimizer(Cell): | |||||
| if self.is_group: | if self.is_group: | ||||
| self.parameters = ParameterTuple(self.group_params) | self.parameters = ParameterTuple(self.group_params) | ||||
| self.weight_decay = tuple(self.group_weight_decay) | self.weight_decay = tuple(self.group_weight_decay) | ||||
| self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay) | |||||
| self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float16) for x in self.group_weight_decay) | |||||
| decay_filter = lambda x: x > 0 | decay_filter = lambda x: x > 0 | ||||
| self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) | self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) | ||||
| self.exec_weight_decay = any(self.decay_flags) | self.exec_weight_decay = any(self.decay_flags) | ||||
| else: | else: | ||||
| self.parameters = ParameterTuple(parameters) | self.parameters = ParameterTuple(parameters) | ||||
| self.weight_decay = weight_decay * loss_scale | self.weight_decay = weight_decay * loss_scale | ||||
| self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32) | |||||
| self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float16) | |||||
| decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name | decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name | ||||
| self.decay_flags = tuple(decay_filter(x) for x in self.parameters) | self.decay_flags = tuple(decay_filter(x) for x in self.parameters) | ||||
| self.exec_weight_decay = self.weight_decay > 0 | self.exec_weight_decay = self.weight_decay > 0 | ||||
| @@ -156,7 +156,7 @@ class Optimizer(Cell): | |||||
| break | break | ||||
| ps_filter = lambda x: x.is_param_ps | ps_filter = lambda x: x.is_param_ps | ||||
| self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) | self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) | ||||
| self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) | |||||
| self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float16) | |||||
| self.need_scale = loss_scale != 1.0 | self.need_scale = loss_scale != 1.0 | ||||
| self.param_length = len(self.parameters) | self.param_length = len(self.parameters) | ||||
| self.map_ = C.Map() | self.map_ = C.Map() | ||||