diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 31253259e5..028ee175f1 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -15,6 +15,7 @@ """Bernoulli Distribution""" from mindspore.common import dtype as mstype from mindspore.ops import operations as P +from mindspore.ops import composite as C from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type @@ -116,7 +117,7 @@ class Bernoulli(Distribution): self.select = P.Select() self.sq = P.Square() self.sqrt = P.Sqrt() - self.uniform = P.UniformReal(seed=seed) + self.uniform = C.uniform def extend_repr(self): if self.is_scalar_batch: @@ -256,7 +257,6 @@ class Bernoulli(Distribution): probs1 = self.probs if probs is None else probs l_zero = self.const(0.0) h_one = self.const(1.0) - sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one) + sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed) sample = self.less(sample_uniform, probs1) - sample = self.cast(sample, self.dtype) - return sample + return self.cast(sample, self.dtype) diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 6f15958f95..8564935e09 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -15,6 +15,7 @@ """Exponential Distribution""" import numpy as np from mindspore.ops import operations as P +from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_greater_zero, check_type @@ -107,7 +108,8 @@ class Exponential(Distribution): self.minval = np.finfo(np.float).tiny - # ops needed for the class + # ops needed for the class + self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() self.exp = P.Exp() @@ -118,7 +120,7 @@ class Exponential(Distribution): self.shape = P.Shape() self.sqrt = P.Sqrt() self.sq = P.Square() - self.uniform = P.UniformReal(seed=seed) + self.uniform = C.uniform def extend_repr(self): if self.is_scalar_batch: @@ -251,5 +253,6 @@ class Exponential(Distribution): rate = self.rate if rate is None else rate minval = self.const(self.minval) maxval = self.const(1.0) - sample = self.uniform(shape + self.shape(rate), minval, maxval) - return -self.log(sample) / rate + sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed) + sample = -self.log(sample_uniform) / rate + return self.cast(sample, self.dtype) diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 87ad7ad8a4..2c67bb5588 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -15,6 +15,7 @@ """Geometric Distribution""" import numpy as np from mindspore.ops import operations as P +from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type @@ -109,6 +110,7 @@ class Geometric(Distribution): self.minval = np.finfo(np.float).tiny # ops needed for the class + self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() self.fill = P.Fill() @@ -121,7 +123,7 @@ class Geometric(Distribution): self.shape = P.Shape() self.sq = P.Square() self.sqrt = P.Sqrt() - self.uniform = P.UniformReal(seed=seed) + self.uniform = C.uniform def extend_repr(self): if self.is_scalar_batch: @@ -269,5 +271,6 @@ class Geometric(Distribution): probs = self.probs if probs is None else probs minval = self.const(self.minval) maxval = self.const(1.0) - sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval) - return self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) + sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval, self.seed) + sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) + return self.cast(sample, self.dtype) diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index b1f2aba90a..6aff1ef775 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -14,6 +14,7 @@ # ============================================================================ """Uniform Distribution""" from mindspore.ops import operations as P +from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import convert_to_batch, check_greater, check_type @@ -108,7 +109,8 @@ class Uniform(Distribution): self._low = low self._high = high - # ops needed for the class + # ops needed for the class + self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() self.exp = P.Exp() @@ -121,8 +123,8 @@ class Uniform(Distribution): self.shape = P.Shape() self.sq = P.Square() self.sqrt = P.Sqrt() - self.uniform = P.UniformReal(seed=seed) self.zeroslike = P.ZerosLike() + self.uniform = C.uniform def extend_repr(self): if self.is_scalar_batch: @@ -284,6 +286,6 @@ class Uniform(Distribution): broadcast_shape = self.shape(low + high) l_zero = self.const(0.0) h_one = self.const(1.0) - sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one) + sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one, self.seed) sample = (high - low) * sample_uniform + low - return sample + return self.cast(sample, self.dtype)