Merge pull request !4365 from peixu_ren/custom_bijectortags/v0.7.0-beta
| @@ -15,6 +15,7 @@ | |||||
| """Bernoulli Distribution""" | """Bernoulli Distribution""" | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type | from ._utils.utils import cast_to_tensor, check_prob, check_type | ||||
| @@ -116,7 +117,7 @@ class Bernoulli(Distribution): | |||||
| self.select = P.Select() | self.select = P.Select() | ||||
| self.sq = P.Square() | self.sq = P.Square() | ||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.uniform = P.UniformReal(seed=seed) | |||||
| self.uniform = C.uniform | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| @@ -256,7 +257,6 @@ class Bernoulli(Distribution): | |||||
| probs1 = self.probs if probs is None else probs | probs1 = self.probs if probs is None else probs | ||||
| l_zero = self.const(0.0) | l_zero = self.const(0.0) | ||||
| h_one = self.const(1.0) | h_one = self.const(1.0) | ||||
| sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one) | |||||
| sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed) | |||||
| sample = self.less(sample_uniform, probs1) | sample = self.less(sample_uniform, probs1) | ||||
| sample = self.cast(sample, self.dtype) | |||||
| return sample | |||||
| return self.cast(sample, self.dtype) | |||||
| @@ -15,6 +15,7 @@ | |||||
| """Exponential Distribution""" | """Exponential Distribution""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type | from ._utils.utils import cast_to_tensor, check_greater_zero, check_type | ||||
| @@ -107,7 +108,8 @@ class Exponential(Distribution): | |||||
| self.minval = np.finfo(np.float).tiny | self.minval = np.finfo(np.float).tiny | ||||
| # ops needed for the class | |||||
| # ops needed for the class | |||||
| self.cast = P.Cast() | |||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||
| self.dtypeop = P.DType() | self.dtypeop = P.DType() | ||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| @@ -118,7 +120,7 @@ class Exponential(Distribution): | |||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.sq = P.Square() | self.sq = P.Square() | ||||
| self.uniform = P.UniformReal(seed=seed) | |||||
| self.uniform = C.uniform | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| @@ -251,5 +253,6 @@ class Exponential(Distribution): | |||||
| rate = self.rate if rate is None else rate | rate = self.rate if rate is None else rate | ||||
| minval = self.const(self.minval) | minval = self.const(self.minval) | ||||
| maxval = self.const(1.0) | maxval = self.const(1.0) | ||||
| sample = self.uniform(shape + self.shape(rate), minval, maxval) | |||||
| return -self.log(sample) / rate | |||||
| sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed) | |||||
| sample = -self.log(sample_uniform) / rate | |||||
| return self.cast(sample, self.dtype) | |||||
| @@ -15,6 +15,7 @@ | |||||
| """Geometric Distribution""" | """Geometric Distribution""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type | from ._utils.utils import cast_to_tensor, check_prob, check_type | ||||
| @@ -109,6 +110,7 @@ class Geometric(Distribution): | |||||
| self.minval = np.finfo(np.float).tiny | self.minval = np.finfo(np.float).tiny | ||||
| # ops needed for the class | # ops needed for the class | ||||
| self.cast = P.Cast() | |||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||
| self.dtypeop = P.DType() | self.dtypeop = P.DType() | ||||
| self.fill = P.Fill() | self.fill = P.Fill() | ||||
| @@ -121,7 +123,7 @@ class Geometric(Distribution): | |||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.sq = P.Square() | self.sq = P.Square() | ||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.uniform = P.UniformReal(seed=seed) | |||||
| self.uniform = C.uniform | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| @@ -269,5 +271,6 @@ class Geometric(Distribution): | |||||
| probs = self.probs if probs is None else probs | probs = self.probs if probs is None else probs | ||||
| minval = self.const(self.minval) | minval = self.const(self.minval) | ||||
| maxval = self.const(1.0) | maxval = self.const(1.0) | ||||
| sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval) | |||||
| return self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) | |||||
| sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval, self.seed) | |||||
| sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) | |||||
| return self.cast(sample, self.dtype) | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Uniform Distribution""" | """Uniform Distribution""" | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import convert_to_batch, check_greater, check_type | from ._utils.utils import convert_to_batch, check_greater, check_type | ||||
| @@ -108,7 +109,8 @@ class Uniform(Distribution): | |||||
| self._low = low | self._low = low | ||||
| self._high = high | self._high = high | ||||
| # ops needed for the class | |||||
| # ops needed for the class | |||||
| self.cast = P.Cast() | |||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||
| self.dtypeop = P.DType() | self.dtypeop = P.DType() | ||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| @@ -121,8 +123,8 @@ class Uniform(Distribution): | |||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.sq = P.Square() | self.sq = P.Square() | ||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.uniform = P.UniformReal(seed=seed) | |||||
| self.zeroslike = P.ZerosLike() | self.zeroslike = P.ZerosLike() | ||||
| self.uniform = C.uniform | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| @@ -284,6 +286,6 @@ class Uniform(Distribution): | |||||
| broadcast_shape = self.shape(low + high) | broadcast_shape = self.shape(low + high) | ||||
| l_zero = self.const(0.0) | l_zero = self.const(0.0) | ||||
| h_one = self.const(1.0) | h_one = self.const(1.0) | ||||
| sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one) | |||||
| sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one, self.seed) | |||||
| sample = (high - low) * sample_uniform + low | sample = (high - low) * sample_uniform + low | ||||
| return sample | |||||
| return self.cast(sample, self.dtype) | |||||