Browse Source

Fix bug of adam optimizer

tags/v1.0.0
fary86 5 years ago
parent
commit
ef55cefb2d
2 changed files with 6 additions and 5 deletions
  1. +3
    -3
      mindspore/context.py
  2. +3
    -2
      mindspore/nn/optim/adam.py

+ 3
- 3
mindspore/context.py View File

@@ -219,7 +219,7 @@ class _Context:
self.set_param(ms_ctx_param.profiling_options, option)

def set_variable_memory_max_size(self, variable_memory_max_size):
if not check_input_format(variable_memory_max_size):
if not _check_input_format(variable_memory_max_size):
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
@@ -230,7 +230,7 @@ class _Context:
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)

def set_max_device_memory(self, max_device_memory):
if not check_input_format(max_device_memory):
if not _check_input_format(max_device_memory):
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
max_device_memory_value = float(max_device_memory[:-2])
if max_device_memory_value == 0:
@@ -289,7 +289,7 @@ class _Context:
thread_info.debug_runtime = enable


def check_input_format(x):
def _check_input_format(x):
import re
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
result = re.match(pattern, x)


+ 3
- 2
mindspore/nn/optim/adam.py View File

@@ -51,6 +51,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
Returns:
Tensor, the new value of v after updating.
"""
success = True
if optim_filter:
op_mul = P.Mul()
op_square = P.Square()
@@ -80,8 +81,8 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_param
return gradient
success = F.depend(success, next_param)
return success


@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",


Loading…
Cancel
Save