Browse Source

!7254 [ME] reused `check_type` function

Merge pull request !7254 from chenzhongming/zomi_master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
044a511726
11 changed files with 25 additions and 28 deletions
  1. +3
    -6
      mindspore/_checkparam.py
  2. +1
    -1
      mindspore/common/dtype.py
  3. +0
    -1
      mindspore/nn/probability/distribution/_utils/__init__.py
  4. +0
    -6
      mindspore/nn/probability/distribution/_utils/utils.py
  5. +3
    -2
      mindspore/nn/probability/distribution/bernoulli.py
  6. +3
    -2
      mindspore/nn/probability/distribution/categorical.py
  7. +3
    -2
      mindspore/nn/probability/distribution/exponential.py
  8. +3
    -2
      mindspore/nn/probability/distribution/geometric.py
  9. +3
    -2
      mindspore/nn/probability/distribution/logistic.py
  10. +3
    -2
      mindspore/nn/probability/distribution/normal.py
  11. +3
    -2
      mindspore/nn/probability/distribution/uniform.py

+ 3
- 6
mindspore/_checkparam.py View File

@@ -375,17 +375,14 @@ class Validator:
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.')

if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if arg_value in valid_types:
return arg_value
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()


+ 1
- 1
mindspore/common/dtype.py View File

@@ -118,7 +118,7 @@ number_type = (int8,
float64,)

int_type = (int8, int16, int32, int64,)
uint_type = (uint8, uint16, uint32, uint64)
uint_type = (uint8, uint16, uint32, uint64,)
float_type = (float16, float32, float64,)

implicit_conversion_seq = {t: idx for idx, t in enumerate((


+ 0
- 1
mindspore/nn/probability/distribution/_utils/__init__.py View File

@@ -24,7 +24,6 @@ __all__ = [
'check_greater_equal_zero',
'check_greater_zero',
'check_prob',
'check_type',
'exp_generic',
'expm1_generic',
'log_generic',


+ 0
- 6
mindspore/nn/probability/distribution/_utils/utils.py View File

@@ -206,12 +206,6 @@ def probs_to_logits(probs, is_binary=False):
return P.Log()(ps_clamped)


def check_type(data_type, value_type, name):
if not data_type in value_type:
raise TypeError(
f"For {name}, valid type include {value_type}, {data_type} is invalid")


@constexpr
def raise_none_error(name):
raise TypeError(f"the type {name} should be subclass of Tensor."


+ 3
- 2
mindspore/nn/probability/distribution/bernoulli.py View File

@@ -16,8 +16,9 @@
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from .distribution import Distribution
from ._utils.utils import check_prob, check_type, check_distribution_name
from ._utils.utils import check_prob, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic


@@ -118,7 +119,7 @@ class Bernoulli(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Bernoulli, self).__init__(seed, dtype, name, param)

self._probs = self._add_parameter(probs, 'probs')


+ 3
- 2
mindspore/nn/probability/distribution/categorical.py View File

@@ -16,10 +16,11 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_prob, check_sum_equal_one, check_type, check_rank,\
from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
check_distribution_name, raise_not_implemented_util
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
@@ -107,7 +108,7 @@ class Categorical(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type
check_type(dtype, valid_dtype, "Categorical")
Validator.check_type("Categorical", dtype, valid_dtype)
super(Categorical, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs')


+ 3
- 2
mindspore/nn/probability/distribution/exponential.py View File

@@ -16,9 +16,10 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
from ._utils.utils import check_greater_zero, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic


@@ -120,7 +121,7 @@ class Exponential(Distribution):
param = dict(locals())
param['param_dict'] = {'rate': rate}
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Exponential, self).__init__(seed, dtype, name, param)

self._rate = self._add_parameter(rate, 'rate')


+ 3
- 2
mindspore/nn/probability/distribution/geometric.py View File

@@ -16,9 +16,10 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_prob, check_type, check_distribution_name
from ._utils.utils import check_prob, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic


@@ -121,7 +122,7 @@ class Geometric(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Geometric, self).__init__(seed, dtype, name, param)

self._probs = self._add_parameter(probs, 'probs')


+ 3
- 2
mindspore/nn/probability/distribution/logistic.py View File

@@ -16,9 +16,10 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater_zero, check_type
from ._utils.utils import check_greater_zero
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic


@@ -110,7 +111,7 @@ class Logistic(Distribution):
param = dict(locals())
param['param_dict'] = {'loc': loc, 'scale': scale}
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Logistic, self).__init__(seed, dtype, name, param)

self._loc = self._add_parameter(loc, 'loc')


+ 3
- 2
mindspore/nn/probability/distribution/normal.py View File

@@ -16,9 +16,10 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
from ._utils.utils import check_greater_zero, check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic


@@ -126,7 +127,7 @@ class Normal(Distribution):
param = dict(locals())
param['param_dict'] = {'mean': mean, 'sd': sd}
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Normal, self).__init__(seed, dtype, name, param)

self._mean_value = self._add_parameter(mean, 'mean')


+ 3
- 2
mindspore/nn/probability/distribution/uniform.py View File

@@ -15,9 +15,10 @@
"""Uniform Distribution"""
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater, check_type, check_distribution_name
from ._utils.utils import check_greater, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic


@@ -125,7 +126,7 @@ class Uniform(Distribution):
param = dict(locals())
param['param_dict'] = {'low': low, 'high': high}
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Uniform, self).__init__(seed, dtype, name, param)

self._low = self._add_parameter(low, 'low')


Loading…
Cancel
Save