Browse Source

Add the function of checking nan or inf

tags/v0.2.0-alpha
leilei_snow 5 years ago
parent
commit
834a407103
2 changed files with 31 additions and 3 deletions
  1. +11
    -0
      mindspore/_checkparam.py
  2. +20
    -3
      mindspore/nn/dynamic_lr.py

+ 11
- 0
mindspore/_checkparam.py View File

@@ -15,6 +15,7 @@
"""Check parameters."""
import re
import inspect
import math
from enum import Enum
from functools import reduce, wraps
from itertools import repeat
@@ -318,6 +319,16 @@ class Validator:
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')

@staticmethod
def check_float_legal_value(arg_name, arg_value, prim_name):
"""Checks whether a legal value of float type"""
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
if isinstance(arg_value, float):
if math.isinf(arg_value) or math.isnan(arg_value):
raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.")
return arg_value
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")


class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""


+ 20
- 3
mindspore/nn/dynamic_lr.py View File

@@ -28,7 +28,7 @@ def piecewise_constant_lr(milestone, learning_rates):
`milestone`. Let the output learning rate be `y`.

.. math::
y[i] = x_t for i \in [M_{t-1}, M_t)
y[i] = x_t,\ for\ i \in [M_{t-1}, M_t)

Args:
milestone (list[int]): A list of milestone. This list is a monotone increasing list.
@@ -52,7 +52,7 @@ def piecewise_constant_lr(milestone, learning_rates):
last_item = 0
for i, item in enumerate(milestone):
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None)
validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None)
validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None)
if item < last_item:
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
lr += [learning_rates[i]] * (item - last_item)
@@ -66,7 +66,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_float_positive('decay_rate', decay_rate, None)
validator.check_float_legal_value('decay_rate', decay_rate, None)
validator.check_value_type('is_stair', is_stair, [bool], None)


@@ -229,7 +231,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
"""
validator.check_float_positive('min_lr', min_lr, None)
validator.check_float_legal_value('min_lr', min_lr, None)
validator.check_float_positive('max_lr', max_lr, None)
validator.check_float_legal_value('max_lr', max_lr, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
@@ -280,11 +284,14 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
"""
validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_float_positive('end_learning_rate', end_learning_rate, None)
validator.check_float_legal_value('end_learning_rate', end_learning_rate, None)
validator.check_float_positive('power', power, None)
validator.check_float_legal_value('power', power, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
validator.check_value_type('power', power, [float], None)
validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None)

function = lambda x, y: (x, min(x, y))
@@ -298,3 +305,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
decay_epoch, tmp_epoch = function(decay_epoch, current_epoch)
lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate)
return lr


__all__ = [
'piecewise_constant_lr',
'exponential_decay_lr',
'natural_exp_decay_lr',
'inverse_decay_lr',
'cosine_decay_lr',
'polynomial_decay_lr'
]

Loading…
Cancel
Save