Merge pull request !4365 from peixu_ren/custom_bijectortags/v0.7.0-beta
| @@ -15,6 +15,7 @@ | |||
| """Bernoulli Distribution""" | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_prob, check_type | |||
| @@ -116,7 +117,7 @@ class Bernoulli(Distribution): | |||
| self.select = P.Select() | |||
| self.sq = P.Square() | |||
| self.sqrt = P.Sqrt() | |||
| self.uniform = P.UniformReal(seed=seed) | |||
| self.uniform = C.uniform | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| @@ -256,7 +257,6 @@ class Bernoulli(Distribution): | |||
| probs1 = self.probs if probs is None else probs | |||
| l_zero = self.const(0.0) | |||
| h_one = self.const(1.0) | |||
| sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one) | |||
| sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed) | |||
| sample = self.less(sample_uniform, probs1) | |||
| sample = self.cast(sample, self.dtype) | |||
| return sample | |||
| return self.cast(sample, self.dtype) | |||
| @@ -15,6 +15,7 @@ | |||
| """Exponential Distribution""" | |||
| import numpy as np | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type | |||
| @@ -107,7 +108,8 @@ class Exponential(Distribution): | |||
| self.minval = np.finfo(np.float).tiny | |||
| # ops needed for the class | |||
| # ops needed for the class | |||
| self.cast = P.Cast() | |||
| self.const = P.ScalarToArray() | |||
| self.dtypeop = P.DType() | |||
| self.exp = P.Exp() | |||
| @@ -118,7 +120,7 @@ class Exponential(Distribution): | |||
| self.shape = P.Shape() | |||
| self.sqrt = P.Sqrt() | |||
| self.sq = P.Square() | |||
| self.uniform = P.UniformReal(seed=seed) | |||
| self.uniform = C.uniform | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| @@ -251,5 +253,6 @@ class Exponential(Distribution): | |||
| rate = self.rate if rate is None else rate | |||
| minval = self.const(self.minval) | |||
| maxval = self.const(1.0) | |||
| sample = self.uniform(shape + self.shape(rate), minval, maxval) | |||
| return -self.log(sample) / rate | |||
| sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed) | |||
| sample = -self.log(sample_uniform) / rate | |||
| return self.cast(sample, self.dtype) | |||
| @@ -15,6 +15,7 @@ | |||
| """Geometric Distribution""" | |||
| import numpy as np | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_prob, check_type | |||
| @@ -109,6 +110,7 @@ class Geometric(Distribution): | |||
| self.minval = np.finfo(np.float).tiny | |||
| # ops needed for the class | |||
| self.cast = P.Cast() | |||
| self.const = P.ScalarToArray() | |||
| self.dtypeop = P.DType() | |||
| self.fill = P.Fill() | |||
| @@ -121,7 +123,7 @@ class Geometric(Distribution): | |||
| self.shape = P.Shape() | |||
| self.sq = P.Square() | |||
| self.sqrt = P.Sqrt() | |||
| self.uniform = P.UniformReal(seed=seed) | |||
| self.uniform = C.uniform | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| @@ -269,5 +271,6 @@ class Geometric(Distribution): | |||
| probs = self.probs if probs is None else probs | |||
| minval = self.const(self.minval) | |||
| maxval = self.const(1.0) | |||
| sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval) | |||
| return self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) | |||
| sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval, self.seed) | |||
| sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) | |||
| return self.cast(sample, self.dtype) | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """Uniform Distribution""" | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import convert_to_batch, check_greater, check_type | |||
| @@ -108,7 +109,8 @@ class Uniform(Distribution): | |||
| self._low = low | |||
| self._high = high | |||
| # ops needed for the class | |||
| # ops needed for the class | |||
| self.cast = P.Cast() | |||
| self.const = P.ScalarToArray() | |||
| self.dtypeop = P.DType() | |||
| self.exp = P.Exp() | |||
| @@ -121,8 +123,8 @@ class Uniform(Distribution): | |||
| self.shape = P.Shape() | |||
| self.sq = P.Square() | |||
| self.sqrt = P.Sqrt() | |||
| self.uniform = P.UniformReal(seed=seed) | |||
| self.zeroslike = P.ZerosLike() | |||
| self.uniform = C.uniform | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| @@ -284,6 +286,6 @@ class Uniform(Distribution): | |||
| broadcast_shape = self.shape(low + high) | |||
| l_zero = self.const(0.0) | |||
| h_one = self.const(1.0) | |||
| sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one) | |||
| sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one, self.seed) | |||
| sample = (high - low) * sample_uniform + low | |||
| return sample | |||
| return self.cast(sample, self.dtype) | |||