Browse Source

added add_parameter function into distribution base class

tags/v1.1.0
Xun Deng 5 years ago
parent
commit
2d8be64696
10 changed files with 146 additions and 192 deletions
  1. +4
    -3
      mindspore/nn/probability/distribution/_utils/__init__.py
  2. +9
    -0
      mindspore/nn/probability/distribution/_utils/custom_ops.py
  3. +6
    -83
      mindspore/nn/probability/distribution/_utils/utils.py
  4. +12
    -20
      mindspore/nn/probability/distribution/bernoulli.py
  5. +87
    -24
      mindspore/nn/probability/distribution/distribution.py
  6. +5
    -11
      mindspore/nn/probability/distribution/exponential.py
  7. +8
    -14
      mindspore/nn/probability/distribution/geometric.py
  8. +6
    -19
      mindspore/nn/probability/distribution/normal.py
  9. +2
    -0
      mindspore/nn/probability/distribution/transformed_distribution.py
  10. +7
    -18
      mindspore/nn/probability/distribution/uniform.py

+ 4
- 3
mindspore/nn/probability/distribution/_utils/__init__.py View File

@@ -19,17 +19,18 @@ from .utils import *
from .custom_ops import * from .custom_ops import *


__all__ = [ __all__ = [
'convert_to_batch',
'cast_to_tensor', 'cast_to_tensor',
'check_greater', 'check_greater',
'check_greater_equal_zero', 'check_greater_equal_zero',
'check_greater_zero', 'check_greater_zero',
'calc_broadcast_shape_from_param',
'check_scalar_from_param',
'check_prob', 'check_prob',
'check_type', 'check_type',
'exp_generic', 'exp_generic',
'expm1_generic', 'expm1_generic',
'log_generic', 'log_generic',
'log1p_generic', 'log1p_generic',
'broadcast_to',
'set_param_type',
'CheckTensor',
'CheckTuple',
] ]

+ 9
- 0
mindspore/nn/probability/distribution/_utils/custom_ops.py View File

@@ -72,3 +72,12 @@ def log1p_generic(x):
Log1p ops on GPU device or when device_target == GPU. Log1p ops on GPU device or when device_target == GPU.
""" """
return log_generic(x + 1.0) return log_generic(x + 1.0)

def broadcast_to(x, target):
"""
Broadcast x to the shape of target.
"""
shape = P.Shape()
if shape(x) == shape(target):
return x
return P.BroadcastTo(shape(target))(x)

+ 6
- 83
mindspore/nn/probability/distribution/_utils/utils.py View File

@@ -19,13 +19,10 @@ from mindspore._checkparam import Validator as validator
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import _utils as utils
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability as msp



def cast_to_tensor(t, hint_type=mstype.float32): def cast_to_tensor(t, hint_type=mstype.float32):
""" """
@@ -46,41 +43,13 @@ def cast_to_tensor(t, hint_type=mstype.float32):
raise ValueError(f'Input cannot be None in cast_to_tensor') raise ValueError(f'Input cannot be None in cast_to_tensor')
if isinstance(t, Parameter): if isinstance(t, Parameter):
return t return t
t_type = hint_type
if isinstance(t, Tensor):
# convert the type of tensor to dtype
return Tensor(t.asnumpy(), dtype=t_type)
if isinstance(t, (list, np.ndarray)):
return Tensor(t, dtype=t_type)
if isinstance(t, bool): if isinstance(t, bool):
raise TypeError(f'Input cannot be Type Bool') raise TypeError(f'Input cannot be Type Bool')
if isinstance(t, (int, float)):
return Tensor(t, dtype=t_type)
if isinstance(t, (Tensor, np.ndarray, list, int, float)):
return Tensor(t, dtype=hint_type)
invalid_type = type(t) invalid_type = type(t)
raise TypeError( raise TypeError(
f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}")


def convert_to_batch(t, batch_shape, required_type):
"""
Convert a Tensor to a given batch shape.

Args:
t (int, float, list, numpy.ndarray, Tensor, Parameter): Tensor to be converted.
batch_shape (tuple): desired batch shape.
dtype (mindspore.dtype): desired dtype.

Raises:
RuntimeError: if the converison cannot be done.

Returns:
Tensor, with shape of batch_shape.
"""
if isinstance(t, Parameter):
return t
t = cast_to_tensor(t, required_type)
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)

f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}")


def cast_type_for_device(dtype): def cast_type_for_device(dtype):
""" """
@@ -100,54 +69,6 @@ def cast_type_for_device(dtype):
return dtype return dtype




def check_scalar_from_param(params):
"""
Check if params are all scalars.

Args:
params (dict): parameters used to initialize distribution.

Notes: String parameters are excluded.
"""
for value in params.values():
if value is None:
continue
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].is_scalar_batch
if isinstance(value, Parameter):
return False
if not isinstance(value, (int, float, str, type(params['dtype']))):
return False
return True


def calc_broadcast_shape_from_param(params):
"""
Calculate the broadcast shape from params.

Args:
params (dict): parameters used to initialize distribution.

Returns:
tuple.
"""
broadcast_shape = []
for value in params.values():
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].broadcast_shape
if isinstance(value, (str, type(params['dtype']))):
continue
if value is None:
return None
if isinstance(value, Parameter):
value_t = value.data
else:
value_t = cast_to_tensor(value, mstype.float32)
broadcast_shape = utils.get_broadcast_shape(
broadcast_shape, list(value_t.shape), params['name'])
return tuple(broadcast_shape)


def check_greater_equal_zero(value, name): def check_greater_equal_zero(value, name):
""" """
Check if the given Tensor is greater zero. Check if the given Tensor is greater zero.
@@ -371,6 +292,9 @@ def set_param_type(args, hint_type):
Raises: Raises:
TypeError: if tensors in args are not the same dtype. TypeError: if tensors in args are not the same dtype.
""" """
int_type = mstype.int_type + mstype.uint_type
if hint_type in int_type:
hint_type = mstype.float32
common_dtype = None common_dtype = None
for name, arg in args.items(): for name, arg in args.items():
if hasattr(arg, 'dtype'): if hasattr(arg, 'dtype'):
@@ -382,7 +306,6 @@ def set_param_type(args, hint_type):
common_dtype = cur_dtype common_dtype = cur_dtype
elif cur_dtype != common_dtype: elif cur_dtype != common_dtype:
raise TypeError(f"{name} should have the same dtype as other arguments.") raise TypeError(f"{name} should have the same dtype as other arguments.")
int_type = mstype.int_type + mstype.uint_type
if common_dtype in int_type or common_dtype == mstype.float64: if common_dtype in int_type or common_dtype == mstype.float64:
return mstype.float32 return mstype.float32
return hint_type if common_dtype is None else common_dtype return hint_type if common_dtype is None else common_dtype

+ 12
- 20
mindspore/nn/probability/distribution/bernoulli.py View File

@@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type
from ._utils.utils import check_prob, check_type, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic




@@ -116,18 +116,14 @@ class Bernoulli(Distribution):
Constructor of Bernoulli. Constructor of Bernoulli.
""" """
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param) super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None:
self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self.probs)
else:
self._probs = probs


self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
self._probs = self._add_parameter(probs, 'probs')
if self._probs is not None:
check_prob(self.probs)


# ops needed for the class # ops needed for the class
self.exp = exp_generic self.exp = exp_generic
@@ -135,14 +131,11 @@ class Bernoulli(Distribution):
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.floor = P.Floor() self.floor = P.Floor()
self.fill = P.Fill() self.fill = P.Fill()
self.less = P.Less() self.less = P.Less()
self.shape = P.Shape() self.shape = P.Shape()
self.select = P.Select() self.select = P.Select()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = C.uniform self.uniform = C.uniform


def extend_repr(self): def extend_repr(self):
@@ -173,9 +166,8 @@ class Bernoulli(Distribution):
MODE(B) = 1 if probs1 > 0.5 else = 0 MODE(B) = 1 if probs1 > 0.5 else = 0
""" """
probs1 = self._check_param_type(probs1) probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0)
zeros = self.fill(self.dtype, self.shape(probs1), 0.0)
ones = self.fill(self.dtype, self.shape(probs1), 1.0)
comp = self.less(0.5, probs1) comp = self.less(0.5, probs1)
return self.select(comp, ones, zeros) return self.select(comp, ones, zeros)


@@ -244,13 +236,13 @@ class Bernoulli(Distribution):
value = self.cast(value, self.parameter_type) value = self.cast(value, self.parameter_type)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param_type(probs1) probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
broadcast_shape_tensor = value * probs1
value = self.broadcast(value, broadcast_shape_tensor)
probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor)
comp_zero = self.less(value, 0.0) comp_zero = self.less(value, 0.0)
comp_one = self.less(value, 1.0) comp_one = self.less(value, 1.0)
zeros = self.fill(prob_type, self.shape(value), 0.0)
ones = self.fill(prob_type, self.shape(value), 1.0)
zeros = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 0.0)
ones = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 1.0)
less_than_zero = self.select(comp_zero, zeros, probs0) less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones) return self.select(comp_one, less_than_zero, ones)




+ 87
- 24
mindspore/nn/probability/distribution/distribution.py View File

@@ -14,13 +14,14 @@
# ============================================================================ # ============================================================================
"""basic""" """basic"""
from mindspore import context from mindspore import context
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore.common import get_seed from mindspore.common import get_seed
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\
raise_none_error
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device
from ._utils.utils import CheckTuple, CheckTensor from ._utils.utils import CheckTuple, CheckTensor
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic




class Distribution(Cell): class Distribution(Cell):
@@ -68,14 +69,16 @@ class Distribution(Cell):
self._seed = seed self._seed = seed
self._dtype = cast_type_for_device(dtype) self._dtype = cast_type_for_device(dtype)
self._parameters = {} self._parameters = {}

# parsing parameters # parsing parameters
for k in param.keys(): for k in param.keys():
if not(k == 'self' or k.startswith('_')): if not(k == 'self' or k.startswith('_')):
self._parameters[k] = param[k] self._parameters[k] = param[k]

# some attributes # some attributes
self._broadcast_shape = calc_broadcast_shape_from_param(
self.parameters)
self._is_scalar_batch = check_scalar_from_param(self.parameters)
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype)
self._broadcast_shape = self._calc_broadcast_shape()
self._is_scalar_batch = self._check_is_scalar_batch()


# set the function to call according to the derived class's attributes # set the function to call according to the derived class's attributes
self._set_prob() self._set_prob()
@@ -91,6 +94,18 @@ class Distribution(Cell):
self.context_mode = context.get_context('mode') self.context_mode = context.get_context('mode')
self.checktuple = CheckTuple() self.checktuple = CheckTuple()
self.checktensor = CheckTensor() self.checktensor = CheckTensor()
self.broadcast = broadcast_to

# ops needed for the base class
self.cast_base = P.Cast()
self.dtype_base = P.DType()
self.exp_base = exp_generic
self.fill_base = P.Fill()
self.log_base = log_generic
self.sametypeshape_base = P.SameTypeShape()
self.sq_base = P.Square()
self.sqrt_base = P.Sqrt()
self.shape_base = P.Shape()


@property @property
def name(self): def name(self):
@@ -116,6 +131,21 @@ class Distribution(Cell):
def broadcast_shape(self): def broadcast_shape(self):
return self._broadcast_shape return self._broadcast_shape


def _add_parameter(self, value, name):
"""
Cast `value` to a tensor and add it to `self.default_parameters`.
Add `name` into and `self.parameter_names`.
"""
# initialize the attributes if they do not exist yet
if not hasattr(self, 'default_parameters'):
self.default_parameters = []
self.parameter_names = []
# cast value to a tensor if it is not None
value_t = None if value is None else cast_to_tensor(value, self.parameter_type)
self.default_parameters += [value_t,]
self.parameter_names += [name,]
return value_t

def _check_param_type(self, *args): def _check_param_type(self, *args):
""" """
Check the availability and validity of default parameters and `dist_spec_args`. Check the availability and validity of default parameters and `dist_spec_args`.
@@ -123,6 +153,7 @@ class Distribution(Cell):
are None, the parameters must be passed in through `args`. are None, the parameters must be passed in through `args`.
""" """
broadcast_shape = None broadcast_shape = None
broadcast_shape_tensor = None
common_dtype = None common_dtype = None
out = [] out = []


@@ -139,17 +170,17 @@ class Distribution(Cell):


# broadcast if the number of args > 1 # broadcast if the number of args > 1
if broadcast_shape is None: if broadcast_shape is None:
broadcast_shape = self.shape(arg)
common_dtype = self.dtypeop(arg)
broadcast_shape = self.shape_base(arg)
common_dtype = self.dtype_base(arg)
broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0)
else: else:
ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
broadcast_shape = self.shape(arg + ones)
broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0)
arg = self.broadcast(arg, broadcast_shape_tensor)
# check if the arguments have the same dtype # check if the arguments have the same dtype
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0)
self.sametypeshape(arg, dtype_tensor)
arg = self.cast(arg, self.parameter_type)
self.sametypeshape_base(arg, broadcast_shape_tensor)

arg = self.cast_base(arg, self.parameter_type)
out.append(arg) out.append(arg)


if len(out) == 1: if len(out) == 1:
@@ -158,7 +189,7 @@ class Distribution(Cell):
# broadcast all args to broadcast_shape # broadcast all args to broadcast_shape
result = () result = ()
for arg in out: for arg in out:
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
arg = self.broadcast(arg, broadcast_shape_tensor)
result = result + (arg,) result = result + (arg,)
return result return result


@@ -171,6 +202,38 @@ class Distribution(Cell):
return value return value
return self.checktensor(value, name) return self.checktensor(value, name)


def _check_is_scalar_batch(self):
"""
Check if the parameters used during initialization are scalars.
"""
if hasattr(self, 'distribution'):
return self._distribution.is_scalar_batch
param_dict = self.parameters['param_dict']
for value in param_dict.values():
if value is None:
continue
if not isinstance(value, (int, float)):
return False
return True

def _calc_broadcast_shape(self):
"""
Calculate the broadcast shape of the parameters used during initialization.
"""
if hasattr(self, 'distribution'):
return self._distribution.broadcast_shape
param_dict = self.parameters['param_dict']
broadcast_shape_tensor = None
for value in param_dict.values():
if value is None:
return None
if broadcast_shape_tensor is None:
broadcast_shape_tensor = cast_to_tensor(value)
else:
value = cast_to_tensor(value)
broadcast_shape_tensor = (value + broadcast_shape_tensor)
return broadcast_shape_tensor.shape

def _set_prob(self): def _set_prob(self):
""" """
Set probability funtion based on the availability of `_prob` and `_log_likehood`. Set probability funtion based on the availability of `_prob` and `_log_likehood`.
@@ -280,7 +343,7 @@ class Distribution(Cell):
.. math:: .. math::
probability(x) = \exp(log_likehood(x)) probability(x) = \exp(log_likehood(x))
""" """
return self.exp(self._log_prob(value, *args, **kwargs))
return self.exp_base(self._log_prob(value, *args, **kwargs))


def prob(self, value, *args, **kwargs): def prob(self, value, *args, **kwargs):
""" """
@@ -304,7 +367,7 @@ class Distribution(Cell):
.. math:: .. math::
log_prob(x) = \log(prob(x)) log_prob(x) = \log(prob(x))
""" """
return self.log(self._prob(value, *args, **kwargs))
return self.log_base(self._prob(value, *args, **kwargs))


def cdf(self, value, *args, **kwargs): def cdf(self, value, *args, **kwargs):
""" """
@@ -328,7 +391,7 @@ class Distribution(Cell):
.. math:: .. math::
cdf(x) = \exp(log_cdf(x)) cdf(x) = \exp(log_cdf(x))
""" """
return self.exp(self._log_cdf(value, *args, **kwargs))
return self.exp_base(self._log_cdf(value, *args, **kwargs))


def _calc_cdf_from_survival(self, value, *args, **kwargs): def _calc_cdf_from_survival(self, value, *args, **kwargs):
r""" r"""
@@ -346,7 +409,7 @@ class Distribution(Cell):
.. math:: .. math::
cdf(x) = 1 - (\exp(log_survival(x))) cdf(x) = 1 - (\exp(log_survival(x)))
""" """
return 1.0 - self.exp(self._log_survival(value, *args, **kwargs))
return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs))


def log_cdf(self, value, *args, **kwargs): def log_cdf(self, value, *args, **kwargs):
""" """
@@ -370,7 +433,7 @@ class Distribution(Cell):
.. math:: .. math::
log_cdf(x) = \log(cdf(x)) log_cdf(x) = \log(cdf(x))
""" """
return self.log(self._call_cdf(value, *args, **kwargs))
return self.log_base(self._call_cdf(value, *args, **kwargs))


def survival_function(self, value, *args, **kwargs): def survival_function(self, value, *args, **kwargs):
""" """
@@ -403,7 +466,7 @@ class Distribution(Cell):
.. math:: .. math::
survival(x) = \exp(survival_function(x)) survival(x) = \exp(survival_function(x))
""" """
return self.exp(self._log_survival(value, *args, **kwargs))
return self.exp_base(self._log_survival(value, *args, **kwargs))


def log_survival(self, value, *args, **kwargs): def log_survival(self, value, *args, **kwargs):
""" """
@@ -427,7 +490,7 @@ class Distribution(Cell):
.. math:: .. math::
log_survival(x) = \log(survival_function(x)) log_survival(x) = \log(survival_function(x))
""" """
return self.log(self._call_survival(value, *args, **kwargs))
return self.log_base(self._call_survival(value, *args, **kwargs))


def kl_loss(self, dist, *args, **kwargs): def kl_loss(self, dist, *args, **kwargs):
""" """
@@ -507,7 +570,7 @@ class Distribution(Cell):
.. math:: .. math::
STD(x) = \sqrt(VAR(x)) STD(x) = \sqrt(VAR(x))
""" """
return self.sqrt(self._var(*args, **kwargs))
return self.sqrt_base(self._var(*args, **kwargs))


def _calc_var_from_sd(self, *args, **kwargs): def _calc_var_from_sd(self, *args, **kwargs):
r""" r"""
@@ -516,7 +579,7 @@ class Distribution(Cell):
.. math:: .. math::
VAR(x) = STD(x) ^ 2 VAR(x) = STD(x) ^ 2
""" """
return self.sq(self._sd(*args, **kwargs))
return self.sq_base(self._sd(*args, **kwargs))


def entropy(self, *args, **kwargs): def entropy(self, *args, **kwargs):
""" """


+ 5
- 11
mindspore/nn/probability/distribution/exponential.py View File

@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic




@@ -118,18 +118,14 @@ class Exponential(Distribution):
Constructor of Exponential. Constructor of Exponential.
""" """
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'rate': rate}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param) super(Exponential, self).__init__(seed, dtype, name, param)
self.parameter_type = set_param_type({'rate': rate}, self.dtype)
if rate is not None:
self._rate = cast_to_tensor(rate, self.parameter_type)
check_greater_zero(self._rate, "rate")
else:
self._rate = rate


self.default_parameters = [self.rate]
self.parameter_names = ['rate']
self._rate = self._add_parameter(rate, 'rate')
if self.rate is not None:
check_greater_zero(self.rate, 'rate')


self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny


@@ -144,8 +140,6 @@ class Exponential(Distribution):
self.less = P.Less() self.less = P.Less()
self.select = P.Select() self.select = P.Select()
self.shape = P.Shape() self.shape = P.Shape()
self.sqrt = P.Sqrt()
self.sq = P.Square()
self.uniform = C.uniform self.uniform = C.uniform


def extend_repr(self): def extend_repr(self):


+ 8
- 14
mindspore/nn/probability/distribution/geometric.py View File

@@ -18,8 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
set_param_type
from ._utils.utils import check_prob, check_type, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic




@@ -120,18 +119,14 @@ class Geometric(Distribution):
Constructor of Geometric distribution. Constructor of Geometric distribution.
""" """
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param) super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None:
self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self._probs)
else:
self._probs = probs


self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
self._probs = self._add_parameter(probs, 'probs')
if self._probs is not None:
check_prob(self.probs)


self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny


@@ -150,7 +145,6 @@ class Geometric(Distribution):
self.select = P.Select() self.select = P.Select()
self.shape = P.Shape() self.shape = P.Shape()
self.sq = P.Square() self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = C.uniform self.uniform = C.uniform


def extend_repr(self): def extend_repr(self):
@@ -181,7 +175,7 @@ class Geometric(Distribution):
MODE(Geo) = 0 MODE(Geo) = 0
""" """
probs1 = self._check_param_type(probs1) probs1 = self._check_param_type(probs1)
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
return self.fill(self.dtype, self.shape(probs1), 0.)


def _var(self, probs1=None): def _var(self, probs1=None):
r""" r"""
@@ -229,7 +223,7 @@ class Geometric(Distribution):
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param_type(probs1) probs1 = self._check_param_type(probs1)
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
zeros = self.fill(self.dtypeop(pmf), self.shape(pmf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
return self.select(comp, zeros, pmf) return self.select(comp, zeros, pmf)


@@ -252,7 +246,7 @@ class Geometric(Distribution):
probs1 = self._check_param_type(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
cdf = 1.0 - self.pow(probs0, value + 1.0) cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
return self.select(comp, zeros, cdf) return self.select(comp, zeros, cdf)




+ 6
- 19
mindspore/nn/probability/distribution/normal.py View File

@@ -18,8 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
set_param_type
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic from ._utils.custom_ops import exp_generic, expm1_generic, log_generic




@@ -125,23 +124,15 @@ class Normal(Distribution):
Constructor of Normal. Constructor of Normal.
""" """
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'mean': mean, 'sd': sd}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param) super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = set_param_type(
{'mean': mean, 'sd': sd}, self.dtype)
if mean is not None and sd is not None:
self._mean_value = cast_to_tensor(mean, self.parameter_type)
self._sd_value = cast_to_tensor(sd, self.parameter_type)
check_greater_zero(self._sd_value, "Standard deviation")
else:
self._mean_value = mean if mean is None else cast_to_tensor(
mean, self.parameter_type)
self._sd_value = sd if sd is None else cast_to_tensor(
sd, self.parameter_type)


self.default_parameters = [self._mean_value, self._sd_value]
self.parameter_names = ['mean', 'sd']
self._mean_value = self._add_parameter(mean, 'mean')
self._sd_value = self._add_parameter(sd, 'sd')
if self._sd_value is not None:
check_greater_zero(self._sd_value, "Standard deviation")


# ops needed for the class # ops needed for the class
self.exp = exp_generic self.exp = exp_generic
@@ -151,13 +142,9 @@ class Normal(Distribution):
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.fill = P.Fill()
self.shape = P.Shape() self.shape = P.Shape()
self.sq = P.Square() self.sq = P.Square()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.zeroslike = P.ZerosLike()
self.dtypeop = P.DType()
self.sametypeshape = P.SameTypeShape()


def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:


+ 2
- 0
mindspore/nn/probability/distribution/transformed_distribution.py View File

@@ -81,6 +81,8 @@ class TransformedDistribution(Distribution):
self._bijector = bijector self._bijector = bijector
self._distribution = distribution self._distribution = distribution
self._is_linear_transformation = bijector.is_constant_jacobian self._is_linear_transformation = bijector.is_constant_jacobian
self.default_parameters = distribution.default_parameters
self.parameter_names = distribution.parameter_names
self.exp = exp_generic self.exp = exp_generic
self.log = log_generic self.log = log_generic




+ 7
- 18
mindspore/nn/probability/distribution/uniform.py View File

@@ -17,8 +17,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
set_param_type
from ._utils.utils import check_greater, check_type, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic




@@ -124,23 +123,16 @@ class Uniform(Distribution):
Constructor of Uniform distribution. Constructor of Uniform distribution.
""" """
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'low': low, 'high': high}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param) super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = set_param_type(
{'low': low, 'high': high}, self.dtype)
if low is not None and high is not None:
self._low = cast_to_tensor(low, self.parameter_type)
self._high = cast_to_tensor(high, self.parameter_type)
check_greater(self.low, self.high, "low value", "high value")
else:
self._low = low if low is None else cast_to_tensor(
low, self.parameter_type)
self._high = high if high is None else cast_to_tensor(
high, self.parameter_type)


self.default_parameters = [self.low, self.high]
self.parameter_names = ['low', 'high']
self._low = self._add_parameter(low, 'low')
self._high = self._add_parameter(high, 'high')
if self.low is not None and self.high is not None:
check_greater(self.low, self.high, 'low', 'high')



# ops needed for the class # ops needed for the class
self.exp = exp_generic self.exp = exp_generic
@@ -156,12 +148,9 @@ class Uniform(Distribution):
self.select = P.Select() self.select = P.Select()
self.shape = P.Shape() self.shape = P.Shape()
self.sq = P.Square() self.sq = P.Square()
self.sqrt = P.Sqrt()
self.zeroslike = P.ZerosLike() self.zeroslike = P.ZerosLike()
self.uniform = C.uniform self.uniform = C.uniform


self.sametypeshape = P.SameTypeShape()

def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'low = {self.low}, high = {self.high}' str_info = f'low = {self.low}, high = {self.high}'


Loading…
Cancel
Save