|
|
|
@@ -212,6 +212,7 @@ class Poisson(Distribution): |
|
|
|
""" |
|
|
|
value = self._check_value(value, "value") |
|
|
|
value = self.cast(value, self.dtype) |
|
|
|
value = self.floor(value) |
|
|
|
rate = self._check_param_type(rate) |
|
|
|
log_rate = self.log(rate) |
|
|
|
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) |
|
|
|
@@ -239,6 +240,7 @@ class Poisson(Distribution): |
|
|
|
""" |
|
|
|
value = self._check_value(value, 'value') |
|
|
|
value = self.cast(value, self.dtype) |
|
|
|
value = self.floor(value) |
|
|
|
rate = self._check_param_type(rate) |
|
|
|
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) |
|
|
|
comp = self.less(value, zeros) |
|
|
|
@@ -259,6 +261,9 @@ class Poisson(Distribution): |
|
|
|
""" |
|
|
|
shape = self.checktuple(shape, 'shape') |
|
|
|
rate = self._check_param_type(rate) |
|
|
|
|
|
|
|
# now Poisson sampler supports only fp32 |
|
|
|
rate = self.cast(rate, mstype.float32) |
|
|
|
origin_shape = shape + self.shape(rate) |
|
|
|
if origin_shape == (): |
|
|
|
sample_shape = (1,) |
|
|
|
|