diff --git a/mindspore/nn/probability/distribution/beta.py b/mindspore/nn/probability/distribution/beta.py index 2a07375bfe..ff9ced01cd 100644 --- a/mindspore/nn/probability/distribution/beta.py +++ b/mindspore/nn/probability/distribution/beta.py @@ -238,7 +238,8 @@ class Beta(Distribution): comp1 = self.greater(concentration1, 1.) comp2 = self.greater(concentration0, 1.) cond = self.logicaland(comp1, comp2) - nan = self.fill(self.dtype, self.broadcast_shape, np.nan) + batch_shape = self.shape(concentration1 + concentration0) + nan = self.fill(self.dtype, batch_shape, np.nan) mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.) return self.select(cond, mode, nan) diff --git a/mindspore/nn/probability/distribution/poisson.py b/mindspore/nn/probability/distribution/poisson.py index b91ffa94f3..a543e5b4c0 100644 --- a/mindspore/nn/probability/distribution/poisson.py +++ b/mindspore/nn/probability/distribution/poisson.py @@ -212,6 +212,7 @@ class Poisson(Distribution): """ value = self._check_value(value, "value") value = self.cast(value, self.dtype) + value = self.floor(value) rate = self._check_param_type(rate) log_rate = self.log(rate) zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) @@ -239,6 +240,7 @@ class Poisson(Distribution): """ value = self._check_value(value, 'value') value = self.cast(value, self.dtype) + value = self.floor(value) rate = self._check_param_type(rate) zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) comp = self.less(value, zeros) @@ -259,6 +261,9 @@ class Poisson(Distribution): """ shape = self.checktuple(shape, 'shape') rate = self._check_param_type(rate) + + # now Poisson sampler supports only fp32 + rate = self.cast(rate, mstype.float32) origin_shape = shape + self.shape(rate) if origin_shape == (): sample_shape = (1,)