From 4426be2126ebab4236c533b85f37ab665cc051fe Mon Sep 17 00:00:00 2001 From: Xun Deng Date: Wed, 9 Dec 2020 16:05:03 -0500 Subject: [PATCH] fixed minor issues --- mindspore/nn/probability/distribution/beta.py | 6 ++---- mindspore/nn/probability/distribution/categorical.py | 9 +++++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/probability/distribution/beta.py b/mindspore/nn/probability/distribution/beta.py index 75fd7d65cc..2a07375bfe 100644 --- a/mindspore/nn/probability/distribution/beta.py +++ b/mindspore/nn/probability/distribution/beta.py @@ -109,13 +109,11 @@ class Beta(Distribution): >>> ans = b1.kl_loss('Beta', concentration1_b, concentration0_b) >>> print(ans) [0.34434414 0.24721336 0.26786423] - >>> ans = b1.kl_loss('Beta', concentration1_b, concentration0_b, - >>> concentration1_a, concentration0_a) + >>> ans = b1.kl_loss('Beta', concentration1_b, concentration0_b, concentration1_a, concentration0_a) >>> print(ans) [0.12509346 0.13629508 0.26527953] >>> # Additional `concentration1` and `concentration0` must be passed in. - >>> ans = b2.kl_loss('Beta', concentration1_b, concentration0_b, - >>> concentration1_a, concentration0_a) + >>> ans = b2.kl_loss('Beta', concentration1_b, concentration0_b, concentration1_a, concentration0_a) >>> print(ans) [0.12509346 0.13629508 0.26527953] >>> # Examples of `sample`. diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index 8229ae4028..242ef8680f 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -175,6 +175,7 @@ class Categorical(Distribution): self.squeeze_last_axis = P.Squeeze(-1) self.square = P.Square() self.transpose = P.Transpose() + self.is_nan = P.IsNan() self.index_type = mstype.int32 self.nan = np.nan @@ -290,6 +291,10 @@ class Categorical(Distribution): value = self.cast(value, self.dtypeop(probs)) zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) + neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0) + value = self.select(self.is_nan(value), + neg_one, + value) between_zero_neone = self.logicand(self.less(value, 0,), self.greater(value, -1.)) value = self.select(between_zero_neone, @@ -354,6 +359,10 @@ class Categorical(Distribution): value = self.cast(value, self.dtypeop(probs)) zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) + neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0) + value = self.select(self.is_nan(value), + neg_one, + value) between_zero_neone = self.logicand(self.less(value, 0,), self.greater(value, -1.)) value = self.select(between_zero_neone,