Merge pull request !4956 from XunDeng/pp_issue_branchtags/v0.7.0-beta
| @@ -22,6 +22,7 @@ from mindspore.common.parameter import Parameter | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore import context | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability as msp | |||
| @@ -273,7 +274,8 @@ def check_type(data_type, value_type, name): | |||
| @constexpr | |||
| def raise_none_error(name): | |||
| raise ValueError(f"{name} should be specified. Value cannot be None") | |||
| raise TypeError(f"the type {name} should be subclass of Tensor." | |||
| f" It should not be None since it is not specified during initialization.") | |||
| @constexpr | |||
| def raise_not_impl_error(name): | |||
| @@ -298,15 +300,20 @@ class CheckTuple(PrimitiveWithInfer): | |||
| def __infer__(self, x, name): | |||
| if not isinstance(x['dtype'], tuple): | |||
| raise TypeError("Input type should be a tuple: " + name["value"]) | |||
| raise TypeError(f"For {name['value']}, Input type should b a tuple.") | |||
| out = {'shape': None, | |||
| 'dtype': None, | |||
| 'value': None} | |||
| 'value': x["value"]} | |||
| return out | |||
| def __call__(self, *args): | |||
| return | |||
| def __call__(self, x, name): | |||
| if context.get_context("mode") == 0: | |||
| return x["value"] | |||
| #Pynative mode | |||
| if isinstance(x, tuple): | |||
| return x | |||
| raise TypeError(f"For {name['value']}, Input type should b a tuple.") | |||
| class CheckTensor(PrimitiveWithInfer): | |||
| """ | |||
| @@ -327,5 +334,5 @@ class CheckTensor(PrimitiveWithInfer): | |||
| 'value': None} | |||
| return out | |||
| def __call__(self, *args): | |||
| def __call__(self, x, name): | |||
| return | |||
| @@ -18,7 +18,6 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error | |||
| from ._utils.utils import CheckTensor, CheckTuple | |||
| from ._utils.custom_ops import log_by_step | |||
| class Bernoulli(Distribution): | |||
| @@ -125,9 +124,6 @@ class Bernoulli(Distribution): | |||
| self.sqrt = P.Sqrt() | |||
| self.uniform = C.uniform | |||
| self.checktensor = CheckTensor() | |||
| self.checktuple = CheckTuple() | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| str_info = f'probs = {self.probs}' | |||
| @@ -279,7 +275,7 @@ class Bernoulli(Distribution): | |||
| Returns: | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| self.checktuple(shape, 'shape') | |||
| shape = self.checktuple(shape, 'shape') | |||
| probs1 = self._check_param(probs1) | |||
| origin_shape = shape + self.shape(probs1) | |||
| if origin_shape == (): | |||
| @@ -17,6 +17,7 @@ from mindspore.nn.cell import Cell | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param | |||
| from ._utils.utils import CheckTuple, CheckTensor | |||
| class Distribution(Cell): | |||
| """ | |||
| @@ -79,6 +80,9 @@ class Distribution(Cell): | |||
| self._set_log_survival() | |||
| self._set_cross_entropy() | |||
| self.checktuple = CheckTuple() | |||
| self.checktensor = CheckTensor() | |||
| @property | |||
| def name(self): | |||
| return self._name | |||
| @@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | |||
| raise_none_error | |||
| from ._utils.utils import CheckTensor, CheckTuple | |||
| from ._utils.custom_ops import log_by_step | |||
| class Exponential(Distribution): | |||
| @@ -127,8 +126,6 @@ class Exponential(Distribution): | |||
| self.sq = P.Square() | |||
| self.uniform = C.uniform | |||
| self.checktensor = CheckTensor() | |||
| self.checktuple = CheckTuple() | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| @@ -270,7 +267,7 @@ class Exponential(Distribution): | |||
| Returns: | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| self.checktuple(shape, 'shape') | |||
| shape = self.checktuple(shape, 'shape') | |||
| rate = self._check_param(rate) | |||
| origin_shape = shape + self.shape(rate) | |||
| if origin_shape == (): | |||
| @@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | |||
| raise_none_error | |||
| from ._utils.utils import CheckTensor, CheckTuple | |||
| from ._utils.custom_ops import log_by_step | |||
| class Geometric(Distribution): | |||
| @@ -131,8 +130,6 @@ class Geometric(Distribution): | |||
| self.sqrt = P.Sqrt() | |||
| self.uniform = C.uniform | |||
| self.checktensor = CheckTensor() | |||
| self.checktuple = CheckTuple() | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| @@ -278,7 +275,7 @@ class Geometric(Distribution): | |||
| Returns: | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| self.checktuple(shape, 'shape') | |||
| shape = self.checktuple(shape, 'shape') | |||
| probs1 = self._check_param(probs1) | |||
| origin_shape = shape + self.shape(probs1) | |||
| if origin_shape == (): | |||
| @@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ | |||
| raise_none_error | |||
| from ._utils.utils import CheckTensor, CheckTuple | |||
| from ._utils.custom_ops import log_by_step, expm1_by_step | |||
| class Normal(Distribution): | |||
| @@ -128,9 +127,6 @@ class Normal(Distribution): | |||
| self.sqrt = P.Sqrt() | |||
| self.zeroslike = P.ZerosLike() | |||
| self.checktensor = CheckTensor() | |||
| self.checktuple = CheckTuple() | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' | |||
| @@ -277,7 +273,7 @@ class Normal(Distribution): | |||
| Returns: | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| self.checktuple(shape, 'shape') | |||
| shape = self.checktuple(shape, 'shape') | |||
| mean, sd = self._check_param(mean, sd) | |||
| batch_shape = self.shape(mean + sd) | |||
| origin_shape = shape + batch_shape | |||
| @@ -116,4 +116,4 @@ class TransformedDistribution(Distribution): | |||
| if not self.is_linear_transformation: | |||
| raise_not_impl_error("mean") | |||
| return self.bijector("forward", self.distribution("mean")) | |||
| return self.bijector("forward", self.distribution("mean", *args, **kwargs)) | |||
| @@ -19,7 +19,6 @@ from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ | |||
| raise_none_error | |||
| from ._utils.utils import CheckTensor, CheckTuple | |||
| from ._utils.custom_ops import log_by_step | |||
| class Uniform(Distribution): | |||
| @@ -131,9 +130,6 @@ class Uniform(Distribution): | |||
| self.zeroslike = P.ZerosLike() | |||
| self.uniform = C.uniform | |||
| self.checktensor = CheckTensor() | |||
| self.checktuple = CheckTuple() | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| str_info = f'low = {self.low}, high = {self.high}' | |||
| @@ -306,7 +302,7 @@ class Uniform(Distribution): | |||
| Returns: | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| self.checktuple(shape, 'shape') | |||
| shape = self.checktuple(shape, 'shape') | |||
| low, high = self._check_param(low, high) | |||
| broadcast_shape = self.shape(low + high) | |||
| origin_shape = shape + broadcast_shape | |||