You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exponential.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Exponential Distribution"""
  16. import numpy as np
  17. from mindspore.ops import operations as P
  18. from mindspore.ops import composite as C
  19. from mindspore._checkparam import Validator
  20. from mindspore.common import dtype as mstype
  21. from .distribution import Distribution
  22. from ._utils.utils import check_greater_zero, check_distribution_name
  23. from ._utils.custom_ops import exp_generic, log_generic
  24. class Exponential(Distribution):
  25. """
  26. Example class: Exponential Distribution.
  27. Args:
  28. rate (float, list, numpy.ndarray, Tensor, Parameter): The inverse scale.
  29. seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
  30. dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
  31. name (str): The name of the distribution. Default: 'Exponential'.
  32. Note:
  33. `rate` must be strictly greater than 0.
  34. `dist_spec_args` is `rate`.
  35. `dtype` must be a float type because Exponential distributions are continuous.
  36. Examples:
  37. >>> # To initialize an Exponential distribution of the rate 0.5.
  38. >>> import mindspore.nn.probability.distribution as msd
  39. >>> e = msd.Exponential(0.5, dtype=mstype.float32)
  40. >>>
  41. >>> # The following creates two independent Exponential distributions.
  42. >>> e = msd.Exponential([0.5, 0.5], dtype=mstype.float32)
  43. >>>
  44. >>> # An Exponential distribution can be initilized without arguments.
  45. >>> # In this case, `rate` must be passed in through `args` during function calls.
  46. >>> e = msd.Exponential(dtype=mstype.float32)
  47. >>>
  48. >>> # To use an Exponential distribution in a network.
  49. >>> class net(Cell):
  50. >>> def __init__(self):
  51. >>> super(net, self).__init__():
  52. >>> self.e1 = msd.Exponential(0.5, dtype=mstype.float32)
  53. >>> self.e2 = msd.Exponential(dtype=mstype.float32)
  54. >>>
  55. >>> # All the following calls in construct are valid.
  56. >>> def construct(self, value, rate_b, rate_a):
  57. >>>
  58. >>> # Private interfaces of probability functions corresponding to public interfaces, including
  59. >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows.
  60. >>> # Args:
  61. >>> # value (Tensor): the value to be evaluated.
  62. >>> # rate (Tensor): the rate of the distribution. Default: self.rate.
  63. >>>
  64. >>> # Examples of `prob`.
  65. >>> # Similar calls can be made to other probability functions
  66. >>> # by replacing `prob` by the name of the function.
  67. >>> ans = self.e1.prob(value)
  68. >>> # Evaluate with respect to distribution b.
  69. >>> ans = self.e1.prob(value, rate_b)
  70. >>> # `rate` must be passed in during function calls.
  71. >>> ans = self.e2.prob(value, rate_a)
  72. >>>
  73. >>>
  74. >>> # Functions `mean`, `sd`, 'var', and 'entropy' have the same arguments as follows.
  75. >>> # Args:
  76. >>> # rate (Tensor): the rate of the distribution. Default: self.rate.
  77. >>>
  78. >>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar.
  79. >>> ans = self.e1.mean() # return 2
  80. >>> ans = self.e1.mean(rate_b) # return 1 / rate_b
  81. >>> # `rate` must be passed in during function calls.
  82. >>> ans = self.e2.mean(rate_a)
  83. >>>
  84. >>>
  85. >>> # Interfaces of `kl_loss` and `cross_entropy` are the same.
  86. >>> # Args:
  87. >>> # dist (str): The name of the distribution. Only 'Exponential' is supported.
  88. >>> # rate_b (Tensor): the rate of distribution b.
  89. >>> # rate_a (Tensor): the rate of distribution a. Default: self.rate.
  90. >>>
  91. >>> # Examples of `kl_loss`. `cross_entropy` is similar.
  92. >>> ans = self.e1.kl_loss('Exponential', rate_b)
  93. >>> ans = self.e1.kl_loss('Exponential', rate_b, rate_a)
  94. >>> # An additional `rate` must be passed in.
  95. >>> ans = self.e2.kl_loss('Exponential', rate_b, rate_a)
  96. >>>
  97. >>>
  98. >>> # Examples of `sample`.
  99. >>> # Args:
  100. >>> # shape (tuple): the shape of the sample. Default: ()
  101. >>> # probs1 (Tensor): the rate of the distribution. Default: self.rate.
  102. >>> ans = self.e1.sample()
  103. >>> ans = self.e1.sample((2,3))
  104. >>> ans = self.e1.sample((2,3), rate_b)
  105. >>> ans = self.e2.sample((2,3), rate_a)
  106. """
  107. def __init__(self,
  108. rate=None,
  109. seed=None,
  110. dtype=mstype.float32,
  111. name="Exponential"):
  112. """
  113. Constructor of Exponential.
  114. """
  115. param = dict(locals())
  116. param['param_dict'] = {'rate': rate}
  117. valid_dtype = mstype.float_type
  118. Validator.check_type(type(self).__name__, dtype, valid_dtype)
  119. super(Exponential, self).__init__(seed, dtype, name, param)
  120. self._rate = self._add_parameter(rate, 'rate')
  121. if self.rate is not None:
  122. check_greater_zero(self.rate, 'rate')
  123. self.minval = np.finfo(np.float).tiny
  124. # ops needed for the class
  125. self.exp = exp_generic
  126. self.log = log_generic
  127. self.squeeze = P.Squeeze(0)
  128. self.cast = P.Cast()
  129. self.const = P.ScalarToArray()
  130. self.dtypeop = P.DType()
  131. self.fill = P.Fill()
  132. self.less = P.Less()
  133. self.select = P.Select()
  134. self.shape = P.Shape()
  135. self.uniform = C.uniform
  136. def extend_repr(self):
  137. if self.is_scalar_batch:
  138. str_info = f'rate = {self.rate}'
  139. else:
  140. str_info = f'batch_shape = {self._broadcast_shape}'
  141. return str_info
  142. @property
  143. def rate(self):
  144. """
  145. Return `rate` of the distribution.
  146. """
  147. return self._rate
  148. def _mean(self, rate=None):
  149. r"""
  150. .. math::
  151. MEAN(EXP) = \frac{1.0}{\lambda}.
  152. """
  153. rate = self._check_param_type(rate)
  154. return 1.0 / rate
  155. def _mode(self, rate=None):
  156. r"""
  157. .. math::
  158. MODE(EXP) = 0.
  159. """
  160. rate = self._check_param_type(rate)
  161. return self.fill(self.dtype, self.shape(rate), 0.)
  162. def _sd(self, rate=None):
  163. r"""
  164. .. math::
  165. sd(EXP) = \frac{1.0}{\lambda}.
  166. """
  167. rate = self._check_param_type(rate)
  168. return 1.0 / rate
  169. def _entropy(self, rate=None):
  170. r"""
  171. .. math::
  172. H(Exp) = 1 - \log(\lambda).
  173. """
  174. rate = self._check_param_type(rate)
  175. return 1.0 - self.log(rate)
  176. def _cross_entropy(self, dist, rate_b, rate=None):
  177. """
  178. Evaluate cross entropy between Exponential distributions.
  179. Args:
  180. dist (str): The type of the distributions. Should be "Exponential" in this case.
  181. rate_b (Tensor): The rate of distribution b.
  182. rate_a (Tensor): The rate of distribution a. Default: self.rate.
  183. """
  184. check_distribution_name(dist, 'Exponential')
  185. return self._entropy(rate) + self._kl_loss(dist, rate_b, rate)
  186. def _log_prob(self, value, rate=None):
  187. r"""
  188. Log probability density function of Exponential distributions.
  189. Args:
  190. Args:
  191. value (Tensor): The value to be evaluated.
  192. rate (Tensor): The rate of the distribution. Default: self.rate.
  193. Note:
  194. `value` must be greater or equal to zero.
  195. .. math::
  196. log_pdf(x) = \log(rate) - rate * x if x >= 0 else 0
  197. """
  198. value = self._check_value(value, "value")
  199. value = self.cast(value, self.dtype)
  200. rate = self._check_param_type(rate)
  201. prob = self.log(rate) - rate * value
  202. zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
  203. neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
  204. comp = self.less(value, zeros)
  205. return self.select(comp, neginf, prob)
  206. def _cdf(self, value, rate=None):
  207. r"""
  208. Cumulative distribution function (cdf) of Exponential distributions.
  209. Args:
  210. value (Tensor): The value to be evaluated.
  211. rate (Tensor): The rate of the distribution. Default: self.rate.
  212. Note:
  213. `value` must be greater or equal to zero.
  214. .. math::
  215. cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
  216. """
  217. value = self._check_value(value, 'value')
  218. value = self.cast(value, self.dtype)
  219. rate = self._check_param_type(rate)
  220. cdf = 1.0 - self.exp(-1. * rate * value)
  221. zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
  222. comp = self.less(value, zeros)
  223. return self.select(comp, zeros, cdf)
  224. def _log_survival(self, value, rate=None):
  225. r"""
  226. Log survival_function of Exponential distributions.
  227. Args:
  228. value (Tensor): The value to be evaluated.
  229. rate (Tensor): The rate of the distribution. Default: self.rate.
  230. Note:
  231. `value` must be greater or equal to zero.
  232. .. math::
  233. log_survival_function(x) = -1 * \lambda * x if x >= 0 else 0
  234. """
  235. value = self._check_value(value, 'value')
  236. value = self.cast(value, self.dtype)
  237. rate = self._check_param_type(rate)
  238. sf = -1. * rate * value
  239. zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
  240. comp = self.less(value, zeros)
  241. return self.select(comp, zeros, sf)
  242. def _kl_loss(self, dist, rate_b, rate=None):
  243. """
  244. Evaluate exp-exp kl divergence, i.e. KL(a||b).
  245. Args:
  246. dist (str): The type of the distributions. Should be "Exponential" in this case.
  247. rate_b (Tensor): The rate of distribution b.
  248. rate_a (Tensor): The rate of distribution a. Default: self.rate.
  249. """
  250. check_distribution_name(dist, 'Exponential')
  251. rate_b = self._check_value(rate_b, 'rate_b')
  252. rate_b = self.cast(rate_b, self.parameter_type)
  253. rate_a = self._check_param_type(rate)
  254. return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
  255. def _sample(self, shape=(), rate=None):
  256. """
  257. Sampling.
  258. Args:
  259. shape (tuple): The shape of the sample. Default: ().
  260. rate (Tensor): The rate of the distribution. Default: self.rate.
  261. Returns:
  262. Tensor, shape is shape + batch_shape.
  263. """
  264. shape = self.checktuple(shape, 'shape')
  265. rate = self._check_param_type(rate)
  266. origin_shape = shape + self.shape(rate)
  267. if origin_shape == ():
  268. sample_shape = (1,)
  269. else:
  270. sample_shape = origin_shape
  271. minval = self.const(self.minval)
  272. maxval = self.const(1.0)
  273. sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
  274. sample = self.log(sample_uniform) / rate
  275. value = self.cast(-sample, self.dtype)
  276. if origin_shape == ():
  277. value = self.squeeze(value)
  278. return value