From a058881b9097553def6e6314d00b1ad71c044ab0 Mon Sep 17 00:00:00 2001 From: Xun Deng Date: Wed, 25 Nov 2020 22:39:26 -0500 Subject: [PATCH] fix minor bug in catgorical distribution --- .../probability/distribution/_utils/utils.py | 6 +++--- .../probability/distribution/categorical.py | 20 ++++++++++++------- .../distribution/test_categorical.py | 2 +- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 1336e9fd1d..7bef5ff347 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -285,11 +285,11 @@ class CheckTuple(PrimitiveWithInfer): return out def __call__(self, x, name): - if context.get_context("mode") == 0: - return x["value"] - # Pynative mode + # The op is not used in a cell if isinstance(x, tuple): return x + if context.get_context("mode") == 0: + return x["value"] raise TypeError(f"For {name}, input type should be a tuple.") diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index 40ac3b7ef2..44b53ff4d0 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -108,7 +108,7 @@ class Categorical(Distribution): name="Categorical"): param = dict(locals()) param['param_dict'] = {'probs': probs} - valid_dtype = mstype.int_type + valid_dtype = mstype.int_type + mstype.float_type Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) super(Categorical, self).__init__(seed, dtype, name, param) @@ -134,8 +134,8 @@ class Categorical(Distribution): self.exp = exp_generic self.expand_dim = P.ExpandDims() self.fill = P.Fill() - self.floor = P.Floor() self.gather = P.GatherNd() + self.issubclass = P.IsSubClass() self.less = P.Less() self.log = log_generic self.log_softmax = P.LogSoftmax() @@ -153,6 +153,7 @@ class Categorical(Distribution): self.transpose = P.Transpose() self.index_type = mstype.int32 + self.nan = np.nan def extend_repr(self): @@ -255,7 +256,10 @@ class Categorical(Distribution): probs (Tensor): Event probabilities. Default: self.probs. """ value = self._check_value(value, 'value') - value = self.cast(value, self.parameter_type) + if self.issubclass(self.dtype, mstype.float_): + value = self.cast(value, self.index_type) + else: + value = self.cast(value, self.dtype) probs = self._check_param_type(probs) logits = self.log(probs) @@ -294,8 +298,8 @@ class Categorical(Distribution): # index into logit_pmf, fill in out_of_bound places with -inf # reshape into label shape N logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)), index) - neg_inf = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), -np.inf) - logits_pmf = self.select(out_of_bound, neg_inf, logits_pmf) + nan = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), self.nan) + logits_pmf = self.select(out_of_bound, nan, logits_pmf) ans = self.reshape(logits_pmf, label_shape) if drop_dim: return self.squeeze(ans) @@ -310,8 +314,10 @@ class Categorical(Distribution): probs (Tensor): Event probabilities. Default: self.probs. """ value = self._check_value(value, 'value') - value = self.cast(value, self.parameter_type) - value = self.floor(value) + if self.issubclass(self.dtype, mstype.float_): + value = self.cast(value, self.index_type) + else: + value = self.cast(value, self.dtype) probs = self._check_param_type(probs) # handle the case when value is of shape () and probs is a scalar batch diff --git a/tests/st/probability/distribution/test_categorical.py b/tests/st/probability/distribution/test_categorical.py index 7dd972749a..b8a0f5d8b3 100644 --- a/tests/st/probability/distribution/test_categorical.py +++ b/tests/st/probability/distribution/test_categorical.py @@ -219,7 +219,7 @@ def test_log_survival(): Test log survival funciton. """ expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3]) - x_ = Tensor(np.array([-0.1, 0, 0, 0.5, 0.5]).astype(np.float32), dtype=dtype.float32) + x_ = Tensor(np.array([-2, 0, 0, 0.5, 0.5]).astype(np.float32), dtype=dtype.float32) log_sf = LogSF() output = log_sf(x_) tol = 1e-6