Browse Source

!4365 Update random uniform op invocation

Merge pull request !4365 from peixu_ren/custom_bijector
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
ddec0c0f96
4 changed files with 23 additions and 15 deletions
  1. +4
    -4
      mindspore/nn/probability/distribution/bernoulli.py
  2. +7
    -4
      mindspore/nn/probability/distribution/exponential.py
  3. +6
    -3
      mindspore/nn/probability/distribution/geometric.py
  4. +6
    -4
      mindspore/nn/probability/distribution/uniform.py

+ 4
- 4
mindspore/nn/probability/distribution/bernoulli.py View File

@@ -15,6 +15,7 @@
"""Bernoulli Distribution"""
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type

@@ -116,7 +117,7 @@ class Bernoulli(Distribution):
self.select = P.Select()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed)
self.uniform = C.uniform

def extend_repr(self):
if self.is_scalar_batch:
@@ -256,7 +257,6 @@ class Bernoulli(Distribution):
probs1 = self.probs if probs is None else probs
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed)
sample = self.less(sample_uniform, probs1)
sample = self.cast(sample, self.dtype)
return sample
return self.cast(sample, self.dtype)

+ 7
- 4
mindspore/nn/probability/distribution/exponential.py View File

@@ -15,6 +15,7 @@
"""Exponential Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type
@@ -107,7 +108,8 @@ class Exponential(Distribution):

self.minval = np.finfo(np.float).tiny

# ops needed for the class
# ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = P.Exp()
@@ -118,7 +120,7 @@ class Exponential(Distribution):
self.shape = P.Shape()
self.sqrt = P.Sqrt()
self.sq = P.Square()
self.uniform = P.UniformReal(seed=seed)
self.uniform = C.uniform

def extend_repr(self):
if self.is_scalar_batch:
@@ -251,5 +253,6 @@ class Exponential(Distribution):
rate = self.rate if rate is None else rate
minval = self.const(self.minval)
maxval = self.const(1.0)
sample = self.uniform(shape + self.shape(rate), minval, maxval)
return -self.log(sample) / rate
sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed)
sample = -self.log(sample_uniform) / rate
return self.cast(sample, self.dtype)

+ 6
- 3
mindspore/nn/probability/distribution/geometric.py View File

@@ -15,6 +15,7 @@
"""Geometric Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type
@@ -109,6 +110,7 @@ class Geometric(Distribution):
self.minval = np.finfo(np.float).tiny

# ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.fill = P.Fill()
@@ -121,7 +123,7 @@ class Geometric(Distribution):
self.shape = P.Shape()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed)
self.uniform = C.uniform

def extend_repr(self):
if self.is_scalar_batch:
@@ -269,5 +271,6 @@ class Geometric(Distribution):
probs = self.probs if probs is None else probs
minval = self.const(self.minval)
maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval)
return self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval, self.seed)
sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
return self.cast(sample, self.dtype)

+ 6
- 4
mindspore/nn/probability/distribution/uniform.py View File

@@ -14,6 +14,7 @@
# ============================================================================
"""Uniform Distribution"""
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type
@@ -108,7 +109,8 @@ class Uniform(Distribution):
self._low = low
self._high = high

# ops needed for the class
# ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = P.Exp()
@@ -121,8 +123,8 @@ class Uniform(Distribution):
self.shape = P.Shape()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed)
self.zeroslike = P.ZerosLike()
self.uniform = C.uniform

def extend_repr(self):
if self.is_scalar_batch:
@@ -284,6 +286,6 @@ class Uniform(Distribution):
broadcast_shape = self.shape(low + high)
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one)
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one, self.seed)
sample = (high - low) * sample_uniform + low
return sample
return self.cast(sample, self.dtype)

Loading…
Cancel
Save