Merge pull request !7254 from chenzhongming/zomi_mastertags/v1.1.0
| @@ -375,17 +375,14 @@ class Validator: | |||||
| """Type checking.""" | """Type checking.""" | ||||
| def raise_error_msg(): | def raise_error_msg(): | ||||
| """func for raising error message when check failed""" | """func for raising error message when check failed""" | ||||
| type_names = [t.__name__ for t in valid_types] | |||||
| num_types = len(valid_types) | |||||
| raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||||
| raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.') | |||||
| if isinstance(arg_value, type(mstype.tensor)): | if isinstance(arg_value, type(mstype.tensor)): | ||||
| arg_value = arg_value.element_type() | arg_value = arg_value.element_type() | ||||
| # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and | |||||
| # `check_type('x', True, [bool, int])` will check pass | |||||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | if isinstance(arg_value, bool) and bool not in tuple(valid_types): | ||||
| raise_error_msg() | raise_error_msg() | ||||
| if arg_value in valid_types: | |||||
| return arg_value | |||||
| if isinstance(arg_value, tuple(valid_types)): | if isinstance(arg_value, tuple(valid_types)): | ||||
| return arg_value | return arg_value | ||||
| raise_error_msg() | raise_error_msg() | ||||
| @@ -118,7 +118,7 @@ number_type = (int8, | |||||
| float64,) | float64,) | ||||
| int_type = (int8, int16, int32, int64,) | int_type = (int8, int16, int32, int64,) | ||||
| uint_type = (uint8, uint16, uint32, uint64) | |||||
| uint_type = (uint8, uint16, uint32, uint64,) | |||||
| float_type = (float16, float32, float64,) | float_type = (float16, float32, float64,) | ||||
| implicit_conversion_seq = {t: idx for idx, t in enumerate(( | implicit_conversion_seq = {t: idx for idx, t in enumerate(( | ||||
| @@ -24,7 +24,6 @@ __all__ = [ | |||||
| 'check_greater_equal_zero', | 'check_greater_equal_zero', | ||||
| 'check_greater_zero', | 'check_greater_zero', | ||||
| 'check_prob', | 'check_prob', | ||||
| 'check_type', | |||||
| 'exp_generic', | 'exp_generic', | ||||
| 'expm1_generic', | 'expm1_generic', | ||||
| 'log_generic', | 'log_generic', | ||||
| @@ -206,12 +206,6 @@ def probs_to_logits(probs, is_binary=False): | |||||
| return P.Log()(ps_clamped) | return P.Log()(ps_clamped) | ||||
| def check_type(data_type, value_type, name): | |||||
| if not data_type in value_type: | |||||
| raise TypeError( | |||||
| f"For {name}, valid type include {value_type}, {data_type} is invalid") | |||||
| @constexpr | @constexpr | ||||
| def raise_none_error(name): | def raise_none_error(name): | ||||
| raise TypeError(f"the type {name} should be subclass of Tensor." | raise TypeError(f"the type {name} should be subclass of Tensor." | ||||
| @@ -16,8 +16,9 @@ | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore._checkparam import Validator | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import check_prob, check_type, check_distribution_name | |||||
| from ._utils.utils import check_prob, check_distribution_name | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| @@ -118,7 +119,7 @@ class Bernoulli(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'probs': probs} | param['param_dict'] = {'probs': probs} | ||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | |||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| super(Bernoulli, self).__init__(seed, dtype, name, param) | super(Bernoulli, self).__init__(seed, dtype, name, param) | ||||
| self._probs = self._add_parameter(probs, 'probs') | self._probs = self._add_parameter(probs, 'probs') | ||||
| @@ -16,10 +16,11 @@ | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore._checkparam import Validator | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import check_prob, check_sum_equal_one, check_type, check_rank,\ | |||||
| from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\ | |||||
| check_distribution_name, raise_not_implemented_util | check_distribution_name, raise_not_implemented_util | ||||
| from ._utils.custom_ops import exp_generic, log_generic, broadcast_to | from ._utils.custom_ops import exp_generic, log_generic, broadcast_to | ||||
| @@ -107,7 +108,7 @@ class Categorical(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'probs': probs} | param['param_dict'] = {'probs': probs} | ||||
| valid_dtype = mstype.int_type | valid_dtype = mstype.int_type | ||||
| check_type(dtype, valid_dtype, "Categorical") | |||||
| Validator.check_type("Categorical", dtype, valid_dtype) | |||||
| super(Categorical, self).__init__(seed, dtype, name, param) | super(Categorical, self).__init__(seed, dtype, name, param) | ||||
| self._probs = self._add_parameter(probs, 'probs') | self._probs = self._add_parameter(probs, 'probs') | ||||
| @@ -16,9 +16,10 @@ | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore._checkparam import Validator | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import check_greater_zero, check_type, check_distribution_name | |||||
| from ._utils.utils import check_greater_zero, check_distribution_name | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| @@ -120,7 +121,7 @@ class Exponential(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'rate': rate} | param['param_dict'] = {'rate': rate} | ||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | |||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| super(Exponential, self).__init__(seed, dtype, name, param) | super(Exponential, self).__init__(seed, dtype, name, param) | ||||
| self._rate = self._add_parameter(rate, 'rate') | self._rate = self._add_parameter(rate, 'rate') | ||||
| @@ -16,9 +16,10 @@ | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore._checkparam import Validator | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import check_prob, check_type, check_distribution_name | |||||
| from ._utils.utils import check_prob, check_distribution_name | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| @@ -121,7 +122,7 @@ class Geometric(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'probs': probs} | param['param_dict'] = {'probs': probs} | ||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | |||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| super(Geometric, self).__init__(seed, dtype, name, param) | super(Geometric, self).__init__(seed, dtype, name, param) | ||||
| self._probs = self._add_parameter(probs, 'probs') | self._probs = self._add_parameter(probs, 'probs') | ||||
| @@ -16,9 +16,10 @@ | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore._checkparam import Validator | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import check_greater_zero, check_type | |||||
| from ._utils.utils import check_greater_zero | |||||
| from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic | from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic | ||||
| @@ -110,7 +111,7 @@ class Logistic(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'loc': loc, 'scale': scale} | param['param_dict'] = {'loc': loc, 'scale': scale} | ||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | |||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| super(Logistic, self).__init__(seed, dtype, name, param) | super(Logistic, self).__init__(seed, dtype, name, param) | ||||
| self._loc = self._add_parameter(loc, 'loc') | self._loc = self._add_parameter(loc, 'loc') | ||||
| @@ -16,9 +16,10 @@ | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore._checkparam import Validator | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import check_greater_zero, check_type, check_distribution_name | |||||
| from ._utils.utils import check_greater_zero, check_distribution_name | |||||
| from ._utils.custom_ops import exp_generic, expm1_generic, log_generic | from ._utils.custom_ops import exp_generic, expm1_generic, log_generic | ||||
| @@ -126,7 +127,7 @@ class Normal(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'mean': mean, 'sd': sd} | param['param_dict'] = {'mean': mean, 'sd': sd} | ||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | |||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| super(Normal, self).__init__(seed, dtype, name, param) | super(Normal, self).__init__(seed, dtype, name, param) | ||||
| self._mean_value = self._add_parameter(mean, 'mean') | self._mean_value = self._add_parameter(mean, 'mean') | ||||
| @@ -15,9 +15,10 @@ | |||||
| """Uniform Distribution""" | """Uniform Distribution""" | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore._checkparam import Validator | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import check_greater, check_type, check_distribution_name | |||||
| from ._utils.utils import check_greater, check_distribution_name | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| @@ -125,7 +126,7 @@ class Uniform(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'low': low, 'high': high} | param['param_dict'] = {'low': low, 'high': high} | ||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | |||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| super(Uniform, self).__init__(seed, dtype, name, param) | super(Uniform, self).__init__(seed, dtype, name, param) | ||||
| self._low = self._add_parameter(low, 'low') | self._low = self._add_parameter(low, 'low') | ||||