|
|
|
@@ -175,6 +175,7 @@ class Categorical(Distribution): |
|
|
|
self.squeeze_last_axis = P.Squeeze(-1)
|
|
|
|
self.square = P.Square()
|
|
|
|
self.transpose = P.Transpose()
|
|
|
|
self.is_nan = P.IsNan()
|
|
|
|
|
|
|
|
self.index_type = mstype.int32
|
|
|
|
self.nan = np.nan
|
|
|
|
@@ -290,6 +291,10 @@ class Categorical(Distribution): |
|
|
|
value = self.cast(value, self.dtypeop(probs))
|
|
|
|
|
|
|
|
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
|
|
neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0)
|
|
|
|
value = self.select(self.is_nan(value),
|
|
|
|
neg_one,
|
|
|
|
value)
|
|
|
|
between_zero_neone = self.logicand(self.less(value, 0,),
|
|
|
|
self.greater(value, -1.))
|
|
|
|
value = self.select(between_zero_neone,
|
|
|
|
@@ -354,6 +359,10 @@ class Categorical(Distribution): |
|
|
|
value = self.cast(value, self.dtypeop(probs))
|
|
|
|
|
|
|
|
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
|
|
neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0)
|
|
|
|
value = self.select(self.is_nan(value),
|
|
|
|
neg_one,
|
|
|
|
value)
|
|
|
|
between_zero_neone = self.logicand(self.less(value, 0,),
|
|
|
|
self.greater(value, -1.))
|
|
|
|
value = self.select(between_zero_neone,
|
|
|
|
|