Browse Source

Add the function of checking nan or inf

tags/v0.2.0-alpha
leilei_snow 5 years ago
parent
commit
834a407103
2 changed files with 31 additions and 3 deletions
  1. +11
    -0
      mindspore/_checkparam.py
  2. +20
    -3
      mindspore/nn/dynamic_lr.py

+ 11
- 0
mindspore/_checkparam.py View File

@@ -15,6 +15,7 @@
"""Check parameters.""" """Check parameters."""
import re import re
import inspect import inspect
import math
from enum import Enum from enum import Enum
from functools import reduce, wraps from functools import reduce, wraps
from itertools import repeat from itertools import repeat
@@ -318,6 +319,16 @@ class Validator:
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.') f' but got {get_typename(arg_type)}.')


@staticmethod
def check_float_legal_value(arg_name, arg_value, prim_name):
"""Checks whether a legal value of float type"""
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
if isinstance(arg_value, float):
if math.isinf(arg_value) or math.isnan(arg_value):
raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.")
return arg_value
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")



class ParamValidator: class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`""" """Parameter validator. NOTICE: this class will be replaced by `class Validator`"""


+ 20
- 3
mindspore/nn/dynamic_lr.py View File

@@ -28,7 +28,7 @@ def piecewise_constant_lr(milestone, learning_rates):
`milestone`. Let the output learning rate be `y`. `milestone`. Let the output learning rate be `y`.


.. math:: .. math::
y[i] = x_t for i \in [M_{t-1}, M_t)
y[i] = x_t,\ for\ i \in [M_{t-1}, M_t)


Args: Args:
milestone (list[int]): A list of milestone. This list is a monotone increasing list. milestone (list[int]): A list of milestone. This list is a monotone increasing list.
@@ -52,7 +52,7 @@ def piecewise_constant_lr(milestone, learning_rates):
last_item = 0 last_item = 0
for i, item in enumerate(milestone): for i, item in enumerate(milestone):
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None) validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None)
validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None)
validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None)
if item < last_item: if item < last_item:
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
lr += [learning_rates[i]] * (item - last_item) lr += [learning_rates[i]] * (item - last_item)
@@ -66,7 +66,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
validator.check_float_positive('learning_rate', learning_rate, None) validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_float_positive('decay_rate', decay_rate, None) validator.check_float_positive('decay_rate', decay_rate, None)
validator.check_float_legal_value('decay_rate', decay_rate, None)
validator.check_value_type('is_stair', is_stair, [bool], None) validator.check_value_type('is_stair', is_stair, [bool], None)




@@ -229,7 +231,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
""" """
validator.check_float_positive('min_lr', min_lr, None) validator.check_float_positive('min_lr', min_lr, None)
validator.check_float_legal_value('min_lr', min_lr, None)
validator.check_float_positive('max_lr', max_lr, None) validator.check_float_positive('max_lr', max_lr, None)
validator.check_float_legal_value('max_lr', max_lr, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
@@ -280,11 +284,14 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
""" """
validator.check_float_positive('learning_rate', learning_rate, None) validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_float_positive('end_learning_rate', end_learning_rate, None) validator.check_float_positive('end_learning_rate', end_learning_rate, None)
validator.check_float_legal_value('end_learning_rate', end_learning_rate, None)
validator.check_float_positive('power', power, None)
validator.check_float_legal_value('power', power, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
validator.check_value_type('power', power, [float], None)
validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None)


function = lambda x, y: (x, min(x, y)) function = lambda x, y: (x, min(x, y))
@@ -298,3 +305,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
decay_epoch, tmp_epoch = function(decay_epoch, current_epoch) decay_epoch, tmp_epoch = function(decay_epoch, current_epoch)
lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate) lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate)
return lr return lr


__all__ = [
'piecewise_constant_lr',
'exponential_decay_lr',
'natural_exp_decay_lr',
'inverse_decay_lr',
'cosine_decay_lr',
'polynomial_decay_lr'
]

Loading…
Cancel
Save