Browse Source

Change const tensor dtype to fp16

tags/v1.1.0
caifubi 5 years ago
parent
commit
702ab2bac2
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      mindspore/nn/optim/optimizer.py

+ 3
- 3
mindspore/nn/optim/optimizer.py View File

@@ -138,14 +138,14 @@ class Optimizer(Cell):
if self.is_group:
self.parameters = ParameterTuple(self.group_params)
self.weight_decay = tuple(self.group_weight_decay)
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay)
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float16) for x in self.group_weight_decay)
decay_filter = lambda x: x > 0
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
self.exec_weight_decay = any(self.decay_flags)
else:
self.parameters = ParameterTuple(parameters)
self.weight_decay = weight_decay * loss_scale
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32)
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float16)
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.exec_weight_decay = self.weight_decay > 0
@@ -156,7 +156,7 @@ class Optimizer(Cell):
break
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float16)
self.need_scale = loss_scale != 1.0
self.param_length = len(self.parameters)
self.map_ = C.Map()


Loading…
Cancel
Save