Browse Source

fix minor issues

tags/v1.1.0
Xun Deng 5 years ago
parent
commit
855cb855af
2 changed files with 7 additions and 1 deletions
  1. +2
    -1
      mindspore/nn/probability/distribution/beta.py
  2. +5
    -0
      mindspore/nn/probability/distribution/poisson.py

+ 2
- 1
mindspore/nn/probability/distribution/beta.py View File

@@ -238,7 +238,8 @@ class Beta(Distribution):
comp1 = self.greater(concentration1, 1.) comp1 = self.greater(concentration1, 1.)
comp2 = self.greater(concentration0, 1.) comp2 = self.greater(concentration0, 1.)
cond = self.logicaland(comp1, comp2) cond = self.logicaland(comp1, comp2)
nan = self.fill(self.dtype, self.broadcast_shape, np.nan)
batch_shape = self.shape(concentration1 + concentration0)
nan = self.fill(self.dtype, batch_shape, np.nan)
mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.) mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.)
return self.select(cond, mode, nan) return self.select(cond, mode, nan)




+ 5
- 0
mindspore/nn/probability/distribution/poisson.py View File

@@ -212,6 +212,7 @@ class Poisson(Distribution):
""" """
value = self._check_value(value, "value") value = self._check_value(value, "value")
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
value = self.floor(value)
rate = self._check_param_type(rate) rate = self._check_param_type(rate)
log_rate = self.log(rate) log_rate = self.log(rate)
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
@@ -239,6 +240,7 @@ class Poisson(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
value = self.floor(value)
rate = self._check_param_type(rate) rate = self._check_param_type(rate)
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@@ -259,6 +261,9 @@ class Poisson(Distribution):
""" """
shape = self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
rate = self._check_param_type(rate) rate = self._check_param_type(rate)

# now Poisson sampler supports only fp32
rate = self.cast(rate, mstype.float32)
origin_shape = shape + self.shape(rate) origin_shape = shape + self.shape(rate)
if origin_shape == (): if origin_shape == ():
sample_shape = (1,) sample_shape = (1,)


Loading…
Cancel
Save