Browse Source

!9751 Fix minor issues in beta and categorical distribution

From: @shallydeng
Reviewed-by: @zichun_ye,@sunnybeike
Signed-off-by: @sunnybeike
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
91114c6b26
2 changed files with 11 additions and 4 deletions
  1. +2
    -4
      mindspore/nn/probability/distribution/beta.py
  2. +9
    -0
      mindspore/nn/probability/distribution/categorical.py

+ 2
- 4
mindspore/nn/probability/distribution/beta.py View File

@@ -109,13 +109,11 @@ class Beta(Distribution):
>>> ans = b1.kl_loss('Beta', concentration1_b, concentration0_b)
>>> print(ans)
[0.34434414 0.24721336 0.26786423]
>>> ans = b1.kl_loss('Beta', concentration1_b, concentration0_b,
>>> concentration1_a, concentration0_a)
>>> ans = b1.kl_loss('Beta', concentration1_b, concentration0_b, concentration1_a, concentration0_a)
>>> print(ans)
[0.12509346 0.13629508 0.26527953]
>>> # Additional `concentration1` and `concentration0` must be passed in.
>>> ans = b2.kl_loss('Beta', concentration1_b, concentration0_b,
>>> concentration1_a, concentration0_a)
>>> ans = b2.kl_loss('Beta', concentration1_b, concentration0_b, concentration1_a, concentration0_a)
>>> print(ans)
[0.12509346 0.13629508 0.26527953]
>>> # Examples of `sample`.


+ 9
- 0
mindspore/nn/probability/distribution/categorical.py View File

@@ -175,6 +175,7 @@ class Categorical(Distribution):
self.squeeze_last_axis = P.Squeeze(-1)
self.square = P.Square()
self.transpose = P.Transpose()
self.is_nan = P.IsNan()
self.index_type = mstype.int32
self.nan = np.nan
@@ -290,6 +291,10 @@ class Categorical(Distribution):
value = self.cast(value, self.dtypeop(probs))
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0)
value = self.select(self.is_nan(value),
neg_one,
value)
between_zero_neone = self.logicand(self.less(value, 0,),
self.greater(value, -1.))
value = self.select(between_zero_neone,
@@ -354,6 +359,10 @@ class Categorical(Distribution):
value = self.cast(value, self.dtypeop(probs))
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0)
value = self.select(self.is_nan(value),
neg_one,
value)
between_zero_neone = self.logicand(self.less(value, 0,),
self.greater(value, -1.))
value = self.select(between_zero_neone,


Loading…
Cancel
Save