Browse Source

!9050 Mix bugs in Categorical distribution

From: @shallydeng
Reviewed-by: @zichun_ye,@sunnybeike
Signed-off-by: @sunnybeike
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
67505ef198
3 changed files with 17 additions and 11 deletions
  1. +3
    -3
      mindspore/nn/probability/distribution/_utils/utils.py
  2. +13
    -7
      mindspore/nn/probability/distribution/categorical.py
  3. +1
    -1
      tests/st/probability/distribution/test_categorical.py

+ 3
- 3
mindspore/nn/probability/distribution/_utils/utils.py View File

@@ -285,11 +285,11 @@ class CheckTuple(PrimitiveWithInfer):
return out

def __call__(self, x, name):
if context.get_context("mode") == 0:
return x["value"]
# Pynative mode
# The op is not used in a cell
if isinstance(x, tuple):
return x
if context.get_context("mode") == 0:
return x["value"]
raise TypeError(f"For {name}, input type should be a tuple.")




+ 13
- 7
mindspore/nn/probability/distribution/categorical.py View File

@@ -108,7 +108,7 @@ class Categorical(Distribution):
name="Categorical"):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type
valid_dtype = mstype.int_type + mstype.float_type
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Categorical, self).__init__(seed, dtype, name, param)
@@ -134,8 +134,8 @@ class Categorical(Distribution):
self.exp = exp_generic
self.expand_dim = P.ExpandDims()
self.fill = P.Fill()
self.floor = P.Floor()
self.gather = P.GatherNd()
self.issubclass = P.IsSubClass()
self.less = P.Less()
self.log = log_generic
self.log_softmax = P.LogSoftmax()
@@ -153,6 +153,7 @@ class Categorical(Distribution):
self.transpose = P.Transpose()
self.index_type = mstype.int32
self.nan = np.nan
def extend_repr(self):
@@ -255,7 +256,10 @@ class Categorical(Distribution):
probs (Tensor): Event probabilities. Default: self.probs.
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.parameter_type)
if self.issubclass(self.dtype, mstype.float_):
value = self.cast(value, self.index_type)
else:
value = self.cast(value, self.dtype)
probs = self._check_param_type(probs)
logits = self.log(probs)
@@ -294,8 +298,8 @@ class Categorical(Distribution):
# index into logit_pmf, fill in out_of_bound places with -inf
# reshape into label shape N
logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)), index)
neg_inf = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), -np.inf)
logits_pmf = self.select(out_of_bound, neg_inf, logits_pmf)
nan = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), self.nan)
logits_pmf = self.select(out_of_bound, nan, logits_pmf)
ans = self.reshape(logits_pmf, label_shape)
if drop_dim:
return self.squeeze(ans)
@@ -310,8 +314,10 @@ class Categorical(Distribution):
probs (Tensor): Event probabilities. Default: self.probs.
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.parameter_type)
value = self.floor(value)
if self.issubclass(self.dtype, mstype.float_):
value = self.cast(value, self.index_type)
else:
value = self.cast(value, self.dtype)
probs = self._check_param_type(probs)
# handle the case when value is of shape () and probs is a scalar batch


+ 1
- 1
tests/st/probability/distribution/test_categorical.py View File

@@ -219,7 +219,7 @@ def test_log_survival():
Test log survival funciton.
"""
expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3])
x_ = Tensor(np.array([-0.1, 0, 0, 0.5, 0.5]).astype(np.float32), dtype=dtype.float32)
x_ = Tensor(np.array([-2, 0, 0, 0.5, 0.5]).astype(np.float32), dtype=dtype.float32)
log_sf = LogSF()
output = log_sf(x_)
tol = 1e-6


Loading…
Cancel
Save