diff --git a/mindspore/context.py b/mindspore/context.py index 6126bd0473..1d4b76e57c 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -219,7 +219,7 @@ class _Context: self.set_param(ms_ctx_param.profiling_options, option) def set_variable_memory_max_size(self, variable_memory_max_size): - if not check_input_format(variable_memory_max_size): + if not _check_input_format(variable_memory_max_size): raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: raise ValueError("Context param variable_memory_max_size should be less than 31GB.") @@ -230,7 +230,7 @@ class _Context: self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_) def set_max_device_memory(self, max_device_memory): - if not check_input_format(max_device_memory): + if not _check_input_format(max_device_memory): raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") max_device_memory_value = float(max_device_memory[:-2]) if max_device_memory_value == 0: @@ -289,7 +289,7 @@ class _Context: thread_info.debug_runtime = enable -def check_input_format(x): +def _check_input_format(x): import re pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' result = re.match(pattern, x) diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 1654322738..365fe8f1c5 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -51,6 +51,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d Returns: Tensor, the new value of v after updating. """ + success = True if optim_filter: op_mul = P.Mul() op_square = P.Square() @@ -80,8 +81,8 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) - return next_param - return gradient + success = F.depend(success, next_param) + return success @_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",