You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

categorical.py 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Categorical Distribution"""
  16. import numpy as np
  17. from mindspore.ops import operations as P
  18. from mindspore.ops import composite as C
  19. from mindspore.ops.functional import stop_gradient
  20. from mindspore._checkparam import Validator
  21. import mindspore.nn as nn
  22. from mindspore.common import dtype as mstype
  23. from .distribution import Distribution
  24. from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
  25. check_distribution_name, raise_not_implemented_util
  26. from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
  27. class Categorical(Distribution):
  28. """
  29. Create a categorical distribution parameterized by event probabilities.
  30. Args:
  31. probs (Tensor, list, numpy.ndarray, Parameter): Event probabilities.
  32. seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
  33. dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
  34. name (str): The name of the distribution. Default: Categorical.
  35. Note:
  36. `probs` must have rank at least 1, values are proper probabilities and sum to 1.
  37. Examples:
  38. >>> # To initialize a Categorical distribution of probs [0.5, 0.5]
  39. >>> import mindspore.nn.probability.distribution as msd
  40. >>> b = msd.Categorical(probs = [0.5, 0.5], dtype=mstype.int32)
  41. >>>
  42. >>> # To use a Categorical distribution in a network
  43. >>> class net(Cell):
  44. >>> def __init__(self, probs):
  45. >>> super(net, self).__init__():
  46. >>> self.ca = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32)
  47. >>> self.ca1 = msd.Categorical(dtype=mstype.int32)
  48. >>>
  49. >>> # All the following calls in construct are valid
  50. >>> def construct(self, value):
  51. >>>
  52. >>> # Private interfaces of probability functions corresponding to public interfaces, including
  53. >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows.
  54. >>> # Args:
  55. >>> # value (Tensor): the value to be evaluated.
  56. >>> # probs (Tensor): event probabilities. Default: self.probs.
  57. >>>
  58. >>> # Examples of `prob`.
  59. >>> # Similar calls can be made to other probability functions
  60. >>> # by replacing `prob` by the name of the function.
  61. >>> ans = self.ca.prob(value)
  62. >>> # Evaluate `prob` with respect to distribution b.
  63. >>> ans = self.ca.prob(value, probs_b)
  64. >>> # `probs` must be passed in during function calls.
  65. >>> ans = self.ca1.prob(value, probs_a)
  66. >>>
  67. >>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
  68. >>> # Args:
  69. >>> # probs (Tensor): event probabilities. Default: self.probs.
  70. >>>
  71. >>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar.
  72. >>> ans = self.ca.mean() # return 0.8
  73. >>> ans = self.ca.mean(probs_b)
  74. >>> # `probs` must be passed in during function calls.
  75. >>> ans = self.ca1.mean(probs_a)
  76. >>>
  77. >>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
  78. >>> # Args:
  79. >>> # dist (str): the name of the distribution. Only 'Categorical' is supported.
  80. >>> # probs_b (Tensor): event probabilities of distribution b.
  81. >>> # probs (Tensor): event probabilities of distribution a. Default: self.probs.
  82. >>>
  83. >>> # Examples of kl_loss. `cross_entropy` is similar.
  84. >>> ans = self.ca.kl_loss('Categorical', probs_b)
  85. >>> ans = self.ca.kl_loss('Categorical', probs_b, probs_a)
  86. >>> # An additional `probs` must be passed in.
  87. >>> ans = self.ca1.kl_loss('Categorical', probs_b, probs_a)
  88. >>>
  89. >>> # Examples of `sample`.
  90. >>> # Args:
  91. >>> # shape (tuple): the shape of the sample. Default: ().
  92. >>> # probs (Tensor): event probabilities. Default: self.probs.
  93. >>> ans = self.ca.sample()
  94. >>> ans = self.ca.sample((2,3))
  95. >>> ans = self.ca.sample((2,3), probs_b)
  96. >>> ans = self.ca1.sample((2,3), probs_a)
  97. """
  98. def __init__(self,
  99. probs=None,
  100. seed=None,
  101. dtype=mstype.int32,
  102. name="Categorical"):
  103. param = dict(locals())
  104. param['param_dict'] = {'probs': probs}
  105. valid_dtype = mstype.int_type
  106. Validator.check_type("Categorical", dtype, valid_dtype)
  107. super(Categorical, self).__init__(seed, dtype, name, param)
  108. self._probs = self._add_parameter(probs, 'probs')
  109. if self.probs is not None:
  110. check_rank(self.probs)
  111. check_prob(self.probs)
  112. check_sum_equal_one(self.probs)
  113. # update is_scalar_batch and broadcast_shape
  114. # drop one dimension
  115. if self.probs.shape[:-1] == ():
  116. self._is_scalar_batch = True
  117. self._broadcast_shape = self._broadcast_shape[:-1]
  118. self.argmax = P.Argmax()
  119. self.broadcast = broadcast_to
  120. self.cast = P.Cast()
  121. self.clip_by_value = C.clip_by_value
  122. self.concat = P.Concat(-1)
  123. self.cumsum = P.CumSum()
  124. self.dtypeop = P.DType()
  125. self.exp = exp_generic
  126. self.expand_dim = P.ExpandDims()
  127. self.fill = P.Fill()
  128. self.floor = P.Floor()
  129. self.gather = P.GatherNd()
  130. self.less = P.Less()
  131. self.log = log_generic
  132. self.log_softmax = P.LogSoftmax()
  133. self.logicor = P.LogicalOr()
  134. self.multinomial = P.Multinomial(seed=self.seed)
  135. self.reshape = P.Reshape()
  136. self.reduce_sum = P.ReduceSum(keep_dims=True)
  137. self.select = P.Select()
  138. self.shape = P.Shape()
  139. self.softmax = P.Softmax()
  140. self.squeeze = P.Squeeze()
  141. self.squeeze_first_axis = P.Squeeze(0)
  142. self.squeeze_last_axis = P.Squeeze(-1)
  143. self.square = P.Square()
  144. self.transpose = P.Transpose()
  145. self.index_type = mstype.int32
  146. def extend_repr(self):
  147. if self.is_scalar_batch:
  148. s = f'probs = {self.probs}'
  149. else:
  150. s = f'batch_shape = {self._broadcast_shape}'
  151. return s
  152. @property
  153. def probs(self):
  154. """
  155. Return the probability.
  156. """
  157. return self._probs
  158. def _get_dist_type(self):
  159. return "Categorical"
  160. def _get_dist_args(self, probs=None):
  161. if probs is not None:
  162. self.checktensor(probs, 'probs')
  163. else:
  164. probs = self.probs
  165. return (probs,)
  166. def _mean(self, probs=None):
  167. r"""
  168. .. math::
  169. E[X] = \sum_{i=0}^{num_classes-1} i*p_i
  170. """
  171. probs = self._check_param_type(probs)
  172. num_classes = self.shape(probs)[-1]
  173. index = nn.Range(0., num_classes, 1.)()
  174. return self.reduce_sum(index * probs, -1)
  175. def _mode(self, probs=None):
  176. probs = self._check_param_type(probs)
  177. mode = self.cast(self.argmax(probs), self.dtype)
  178. return self.squeeze(mode)
  179. def _var(self, probs=None):
  180. r"""
  181. .. math::
  182. VAR(X) = E[X^{2}] - (E[X])^{2}
  183. """
  184. probs = self._check_param_type(probs)
  185. num_classes = self.shape(probs)[-1]
  186. index = nn.Range(0., num_classes, 1.)()
  187. return self.reduce_sum(self.square(index) * probs, -1) -\
  188. self.square(self.reduce_sum(index * probs, -1))
  189. def _entropy(self, probs=None):
  190. r"""
  191. Evaluate entropy.
  192. .. math::
  193. H(X) = -\sum(logits * probs)
  194. """
  195. probs = self._check_param_type(probs)
  196. logits = self.log(probs)
  197. return self.squeeze(-self.reduce_sum(logits * probs, -1))
  198. def _kl_loss(self, dist, probs_b, probs=None):
  199. """
  200. Evaluate KL divergence between Categorical distributions.
  201. Args:
  202. dist (str): The type of the distributions. Should be "Categorical" in this case.
  203. probs_b (Tensor): Event probabilities of distribution b.
  204. probs (Tensor): Event probabilities of distribution a. Default: self.probs.
  205. """
  206. check_distribution_name(dist, 'Categorical')
  207. probs_b = self._check_value(probs_b, 'probs_b')
  208. probs_b = self.cast(probs_b, self.parameter_type)
  209. probs_a = self._check_param_type(probs)
  210. logits_a = self.log(probs_a)
  211. logits_b = self.log(probs_b)
  212. return self.squeeze(-self.reduce_sum(
  213. self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1))
  214. def _cross_entropy(self, dist, probs_b, probs=None):
  215. """
  216. Evaluate cross entropy between Categorical distributions.
  217. Args:
  218. dist (str): The type of the distributions. Should be "Categorical" in this case.
  219. probs_b (Tensor): Event probabilities of distribution b.
  220. probs (Tensor): Event probabilities of distribution a. Default: self.probs.
  221. """
  222. check_distribution_name(dist, 'Categorical')
  223. return self._entropy(probs) + self._kl_loss(dist, probs_b, probs)
  224. def _log_prob(self, value, probs=None):
  225. r"""
  226. Evaluate log probability.
  227. Args:
  228. value (Tensor): The value to be evaluated.
  229. probs (Tensor): Event probabilities. Default: self.probs.
  230. """
  231. value = self._check_value(value, 'value')
  232. value = self.cast(value, self.parameter_type)
  233. probs = self._check_param_type(probs)
  234. logits = self.log(probs)
  235. # handle the case when value is of shape () and probs is a scalar batch
  236. drop_dim = False
  237. if self.shape(value) == () and self.shape(probs)[:-1] == ():
  238. drop_dim = True
  239. # manually add one more dimension: () -> (1,)
  240. # drop this dimension before return
  241. value = self.expand_dim(value, -1)
  242. value = self.expand_dim(value, -1)
  243. broadcast_shape_tensor = logits * value
  244. broadcast_shape = self.shape(broadcast_shape_tensor)
  245. # broadcast_shape (N, C)
  246. num_classes = broadcast_shape[-1]
  247. label_shape = broadcast_shape[:-1]
  248. # broadcasting logits and value
  249. # logit_pmf shape (num of labels, C)
  250. logits = self.broadcast(logits, broadcast_shape_tensor)
  251. value = self.broadcast(value, broadcast_shape_tensor)[..., :1]
  252. # flatten value to shape (number of labels, 1)
  253. # clip value to be in range from 0 to num_classes -1 and cast into int32
  254. value = self.reshape(value, (-1, 1))
  255. out_of_bound = self.squeeze_last_axis(self.logicor(\
  256. self.less(value, 0.0), self.less(num_classes-1, value)))
  257. value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
  258. value_clipped = self.cast(value_clipped, self.index_type)
  259. # create index from 0 ... NumOfLabels
  260. index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
  261. index = self.concat((index, value_clipped))
  262. # index into logit_pmf, fill in out_of_bound places with -inf
  263. # reshape into label shape N
  264. logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)), index)
  265. neg_inf = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), -np.inf)
  266. logits_pmf = self.select(out_of_bound, neg_inf, logits_pmf)
  267. ans = self.reshape(logits_pmf, label_shape)
  268. if drop_dim:
  269. return self.squeeze(ans)
  270. return ans
  271. def _cdf(self, value, probs=None):
  272. r"""
  273. Cumulative distribution function (cdf) of Categorical distributions.
  274. Args:
  275. value (Tensor): The value to be evaluated.
  276. probs (Tensor): Event probabilities. Default: self.probs.
  277. """
  278. value = self._check_value(value, 'value')
  279. value = self.cast(value, self.parameter_type)
  280. value = self.floor(value)
  281. probs = self._check_param_type(probs)
  282. # handle the case when value is of shape () and probs is a scalar batch
  283. drop_dim = False
  284. if self.shape(value) == () and self.shape(probs)[:-1] == ():
  285. drop_dim = True
  286. # manually add one more dimension: () -> (1,)
  287. # drop this dimension before return
  288. value = self.expand_dim(value, -1)
  289. value = self.expand_dim(value, -1)
  290. broadcast_shape_tensor = probs * value
  291. broadcast_shape = self.shape(broadcast_shape_tensor)
  292. # broadcast_shape (N, C)
  293. num_classes = broadcast_shape[-1]
  294. label_shape = broadcast_shape[:-1]
  295. probs = self.broadcast(probs, broadcast_shape_tensor)
  296. value = self.broadcast(value, broadcast_shape_tensor)[..., :1]
  297. # flatten value to shape (number of labels, 1)
  298. value = self.reshape(value, (-1, 1))
  299. # drop one dimension to match cdf
  300. # clip value to be in range from 0 to num_classes -1 and cast into int32
  301. less_than_zero = self.squeeze_last_axis(self.less(value, 0.0))
  302. value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
  303. value_clipped = self.cast(value_clipped, self.index_type)
  304. index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
  305. index = self.concat((index, value_clipped))
  306. # reshape probs and fill less_than_zero places with 0
  307. probs = self.reshape(probs, (-1, num_classes))
  308. cdf = self.gather(self.cumsum(probs, 1), index)
  309. zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
  310. cdf = self.select(less_than_zero, zeros, cdf)
  311. cdf = self.reshape(cdf, label_shape)
  312. if drop_dim:
  313. return self.squeeze(cdf)
  314. return cdf
  315. def _sample(self, shape=(), probs=None):
  316. """
  317. Sampling.
  318. Args:
  319. shape (tuple): The shape of the sample. Default: ().
  320. probs (Tensor): Event probabilities. Default: self.probs.
  321. Returns:
  322. Tensor, shape is shape(probs)[:-1] + sample_shape
  323. """
  324. if self.device_target == 'Ascend':
  325. raise_not_implemented_util('On d backend, sample', self.name)
  326. shape = self.checktuple(shape, 'shape')
  327. probs = self._check_param_type(probs)
  328. num_classes = self.shape(probs)[-1]
  329. batch_shape = self.shape(probs)[:-1]
  330. sample_shape = shape + batch_shape
  331. drop_dim = False
  332. if sample_shape == ():
  333. drop_dim = True
  334. sample_shape = (1,)
  335. probs_2d = self.reshape(probs, (-1, num_classes))
  336. sample_tensor = self.fill(self.dtype, shape, 1.0)
  337. sample_tensor = self.reshape(sample_tensor, (-1, 1))
  338. num_sample = self.shape(sample_tensor)[0]
  339. samples = self.multinomial(probs_2d, num_sample)
  340. samples = self.squeeze(self.transpose(samples, (1, 0)))
  341. samples = self.cast(self.reshape(samples, sample_shape), self.dtype)
  342. if drop_dim:
  343. return self.squeeze_first_axis(samples)
  344. samples = stop_gradient(samples)
  345. return samples