Browse Source

!7254 [ME] reused `check_type` function

Merge pull request !7254 from chenzhongming/zomi_master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
044a511726
11 changed files with 25 additions and 28 deletions
  1. +3
    -6
      mindspore/_checkparam.py
  2. +1
    -1
      mindspore/common/dtype.py
  3. +0
    -1
      mindspore/nn/probability/distribution/_utils/__init__.py
  4. +0
    -6
      mindspore/nn/probability/distribution/_utils/utils.py
  5. +3
    -2
      mindspore/nn/probability/distribution/bernoulli.py
  6. +3
    -2
      mindspore/nn/probability/distribution/categorical.py
  7. +3
    -2
      mindspore/nn/probability/distribution/exponential.py
  8. +3
    -2
      mindspore/nn/probability/distribution/geometric.py
  9. +3
    -2
      mindspore/nn/probability/distribution/logistic.py
  10. +3
    -2
      mindspore/nn/probability/distribution/normal.py
  11. +3
    -2
      mindspore/nn/probability/distribution/uniform.py

+ 3
- 6
mindspore/_checkparam.py View File

@@ -375,17 +375,14 @@ class Validator:
"""Type checking.""" """Type checking."""
def raise_error_msg(): def raise_error_msg():
"""func for raising error message when check failed""" """func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.')


if isinstance(arg_value, type(mstype.tensor)): if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type() arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types): if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg() raise_error_msg()
if arg_value in valid_types:
return arg_value
if isinstance(arg_value, tuple(valid_types)): if isinstance(arg_value, tuple(valid_types)):
return arg_value return arg_value
raise_error_msg() raise_error_msg()


+ 1
- 1
mindspore/common/dtype.py View File

@@ -118,7 +118,7 @@ number_type = (int8,
float64,) float64,)


int_type = (int8, int16, int32, int64,) int_type = (int8, int16, int32, int64,)
uint_type = (uint8, uint16, uint32, uint64)
uint_type = (uint8, uint16, uint32, uint64,)
float_type = (float16, float32, float64,) float_type = (float16, float32, float64,)


implicit_conversion_seq = {t: idx for idx, t in enumerate(( implicit_conversion_seq = {t: idx for idx, t in enumerate((


+ 0
- 1
mindspore/nn/probability/distribution/_utils/__init__.py View File

@@ -24,7 +24,6 @@ __all__ = [
'check_greater_equal_zero', 'check_greater_equal_zero',
'check_greater_zero', 'check_greater_zero',
'check_prob', 'check_prob',
'check_type',
'exp_generic', 'exp_generic',
'expm1_generic', 'expm1_generic',
'log_generic', 'log_generic',


+ 0
- 6
mindspore/nn/probability/distribution/_utils/utils.py View File

@@ -206,12 +206,6 @@ def probs_to_logits(probs, is_binary=False):
return P.Log()(ps_clamped) return P.Log()(ps_clamped)




def check_type(data_type, value_type, name):
if not data_type in value_type:
raise TypeError(
f"For {name}, valid type include {value_type}, {data_type} is invalid")


@constexpr @constexpr
def raise_none_error(name): def raise_none_error(name):
raise TypeError(f"the type {name} should be subclass of Tensor." raise TypeError(f"the type {name} should be subclass of Tensor."


+ 3
- 2
mindspore/nn/probability/distribution/bernoulli.py View File

@@ -16,8 +16,9 @@
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import check_prob, check_type, check_distribution_name
from ._utils.utils import check_prob, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic




@@ -118,7 +119,7 @@ class Bernoulli(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'probs': probs} param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Bernoulli, self).__init__(seed, dtype, name, param) super(Bernoulli, self).__init__(seed, dtype, name, param)


self._probs = self._add_parameter(probs, 'probs') self._probs = self._add_parameter(probs, 'probs')


+ 3
- 2
mindspore/nn/probability/distribution/categorical.py View File

@@ -16,10 +16,11 @@
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore._checkparam import Validator
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import check_prob, check_sum_equal_one, check_type, check_rank,\
from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
check_distribution_name, raise_not_implemented_util check_distribution_name, raise_not_implemented_util
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
@@ -107,7 +108,7 @@ class Categorical(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'probs': probs} param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type valid_dtype = mstype.int_type
check_type(dtype, valid_dtype, "Categorical")
Validator.check_type("Categorical", dtype, valid_dtype)
super(Categorical, self).__init__(seed, dtype, name, param) super(Categorical, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs') self._probs = self._add_parameter(probs, 'probs')


+ 3
- 2
mindspore/nn/probability/distribution/exponential.py View File

@@ -16,9 +16,10 @@
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
from ._utils.utils import check_greater_zero, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic




@@ -120,7 +121,7 @@ class Exponential(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'rate': rate} param['param_dict'] = {'rate': rate}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Exponential, self).__init__(seed, dtype, name, param) super(Exponential, self).__init__(seed, dtype, name, param)


self._rate = self._add_parameter(rate, 'rate') self._rate = self._add_parameter(rate, 'rate')


+ 3
- 2
mindspore/nn/probability/distribution/geometric.py View File

@@ -16,9 +16,10 @@
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import check_prob, check_type, check_distribution_name
from ._utils.utils import check_prob, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic




@@ -121,7 +122,7 @@ class Geometric(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'probs': probs} param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Geometric, self).__init__(seed, dtype, name, param) super(Geometric, self).__init__(seed, dtype, name, param)


self._probs = self._add_parameter(probs, 'probs') self._probs = self._add_parameter(probs, 'probs')


+ 3
- 2
mindspore/nn/probability/distribution/logistic.py View File

@@ -16,9 +16,10 @@
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import check_greater_zero, check_type
from ._utils.utils import check_greater_zero
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic




@@ -110,7 +111,7 @@ class Logistic(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'loc': loc, 'scale': scale} param['param_dict'] = {'loc': loc, 'scale': scale}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Logistic, self).__init__(seed, dtype, name, param) super(Logistic, self).__init__(seed, dtype, name, param)


self._loc = self._add_parameter(loc, 'loc') self._loc = self._add_parameter(loc, 'loc')


+ 3
- 2
mindspore/nn/probability/distribution/normal.py View File

@@ -16,9 +16,10 @@
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
from ._utils.utils import check_greater_zero, check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic from ._utils.custom_ops import exp_generic, expm1_generic, log_generic




@@ -126,7 +127,7 @@ class Normal(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'mean': mean, 'sd': sd} param['param_dict'] = {'mean': mean, 'sd': sd}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Normal, self).__init__(seed, dtype, name, param) super(Normal, self).__init__(seed, dtype, name, param)


self._mean_value = self._add_parameter(mean, 'mean') self._mean_value = self._add_parameter(mean, 'mean')


+ 3
- 2
mindspore/nn/probability/distribution/uniform.py View File

@@ -15,9 +15,10 @@
"""Uniform Distribution""" """Uniform Distribution"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import check_greater, check_type, check_distribution_name
from ._utils.utils import check_greater, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic




@@ -125,7 +126,7 @@ class Uniform(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'low': low, 'high': high} param['param_dict'] = {'low': low, 'high': high}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Uniform, self).__init__(seed, dtype, name, param) super(Uniform, self).__init__(seed, dtype, name, param)


self._low = self._add_parameter(low, 'low') self._low = self._add_parameter(low, 'low')


Loading…
Cancel
Save