You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

uniform.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Uniform Distribution"""
  16. import numpy as np
  17. from mindspore.ops import operations as P
  18. from mindspore.ops import composite as C
  19. from mindspore._checkparam import Validator
  20. from mindspore.common import dtype as mstype
  21. from .distribution import Distribution
  22. from ._utils.utils import check_greater, check_distribution_name
  23. from ._utils.custom_ops import exp_generic, log_generic
  24. class Uniform(Distribution):
  25. """
  26. Example class: Uniform Distribution.
  27. Args:
  28. low (int, float, list, numpy.ndarray, Tensor): The lower bound of the distribution.
  29. high (int, float, list, numpy.ndarray, Tensor): The upper bound of the distribution.
  30. seed (int): The seed uses in sampling. The global seed is used if it is None. Default: None.
  31. dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
  32. name (str): The name of the distribution. Default: 'Uniform'.
  33. Supported Platforms:
  34. ``Ascend`` ``GPU``
  35. Note:
  36. `low` must be stricly less than `high`.
  37. `dist_spec_args` are `high` and `low`.
  38. `dtype` must be float type because Uniform distributions are continuous.
  39. Examples:
  40. >>> import mindspore
  41. >>> import mindspore.context as context
  42. >>> import mindspore.nn as nn
  43. >>> import mindspore.nn.probability.distribution as msd
  44. >>> from mindspore import Tensor
  45. >>> context.set_context(mode=context.GRAPH_MODE)
  46. >>> # To initialize a Uniform distribution of the lower bound 0.0 and the higher bound 1.0.
  47. >>> u1 = msd.Uniform(0.0, 1.0, dtype=mindspore.float32)
  48. >>> # A Uniform distribution can be initialized without arguments.
  49. >>> # In this case, `high` and `low` must be passed in through arguments during function calls.
  50. >>> u2 = msd.Uniform(dtype=mindspore.float32)
  51. >>>
  52. >>> # Here are some tensors used below for testing
  53. >>> value = Tensor([0.5, 0.8], dtype=mindspore.float32)
  54. >>> low_a = Tensor([0., 0.], dtype=mindspore.float32)
  55. >>> high_a = Tensor([2.0, 4.0], dtype=mindspore.float32)
  56. >>> low_b = Tensor([-1.5], dtype=mindspore.float32)
  57. >>> high_b = Tensor([2.5, 5.], dtype=mindspore.float32)
  58. >>> # Private interfaces of probability functions corresponding to public interfaces, including
  59. >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same arguments.
  60. >>> # Args:
  61. >>> # value (Tensor): the value to be evaluated.
  62. >>> # low (Tensor): the lower bound of the distribution. Default: self.low.
  63. >>> # high (Tensor): the higher bound of the distribution. Default: self.high.
  64. >>> # Examples of `prob`.
  65. >>> # Similar calls can be made to other probability functions
  66. >>> # by replacing 'prob' by the name of the function.
  67. >>> ans = u1.prob(value)
  68. >>> print(ans)
  69. [1. 1.]
  70. >>> # Evaluate with respect to distribution b.
  71. >>> ans = u1.prob(value, low_b, high_b)
  72. >>> print(ans)
  73. [0.25 0.15384614]
  74. >>> # `high` and `low` must be passed in during function calls.
  75. >>> ans = u2.prob(value, low_a, high_a)
  76. >>> print(ans)
  77. [0.5 0.25]
  78. >>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
  79. >>> # Args:
  80. >>> # low (Tensor): the lower bound of the distribution. Default: self.low.
  81. >>> # high (Tensor): the higher bound of the distribution. Default: self.high.
  82. >>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar.
  83. >>> ans = u1.mean() # return 0.5
  84. >>> print(ans)
  85. 0.5
  86. >>> ans = u1.mean(low_b, high_b) # return (low_b + high_b) / 2
  87. >>> print(ans)
  88. [0.5 1.75]
  89. >>> # `high` and `low` must be passed in during function calls.
  90. >>> ans = u2.mean(low_a, high_a)
  91. >>> print(ans)
  92. [1. 2.]
  93. >>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same.
  94. >>> # Args:
  95. >>> # dist (str): the type of the distributions. Should be "Uniform" in this case.
  96. >>> # low_b (Tensor): the lower bound of distribution b.
  97. >>> # high_b (Tensor): the upper bound of distribution b.
  98. >>> # low_a (Tensor): the lower bound of distribution a. Default: self.low.
  99. >>> # high_a (Tensor): the upper bound of distribution a. Default: self.high.
  100. >>> # Examples of `kl_loss`. `cross_entropy` is similar.
  101. >>> ans = u1.kl_loss('Uniform', low_b, high_b)
  102. >>> print(ans)
  103. [1.3862944 1.8718022]
  104. >>> ans = u1.kl_loss('Uniform', low_b, high_b, low_a, high_a)
  105. >>> print(ans)
  106. [0.6931472 0.48550785]
  107. >>> # Additional `high` and `low` must be passed in.
  108. >>> ans = u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
  109. >>> print(ans)
  110. [0.6931472 0.48550785]
  111. >>> # Examples of `sample`.
  112. >>> # Args:
  113. >>> # shape (tuple): the shape of the sample. Default: ()
  114. >>> # low (Tensor): the lower bound of the distribution. Default: self.low.
  115. >>> # high (Tensor): the upper bound of the distribution. Default: self.high.
  116. >>> ans = u1.sample()
  117. >>> print(ans.shape)
  118. ()
  119. >>> ans = u1.sample((2,3))
  120. >>> print(ans.shape)
  121. (2, 3)
  122. >>> ans = u1.sample((2,3), low_b, high_b)
  123. >>> print(ans.shape)
  124. (2, 3, 2)
  125. >>> ans = u2.sample((2,3), low_a, high_a)
  126. >>> print(ans.shape)
  127. (2, 3, 2)
  128. """
  129. def __init__(self,
  130. low=None,
  131. high=None,
  132. seed=None,
  133. dtype=mstype.float32,
  134. name="Uniform"):
  135. """
  136. Constructor of Uniform distribution.
  137. """
  138. param = dict(locals())
  139. param['param_dict'] = {'low': low, 'high': high}
  140. valid_dtype = mstype.float_type
  141. Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
  142. super(Uniform, self).__init__(seed, dtype, name, param)
  143. self._low = self._add_parameter(low, 'low')
  144. self._high = self._add_parameter(high, 'high')
  145. if self.low is not None and self.high is not None:
  146. check_greater(self.low, self.high, 'low', 'high')
  147. # ops needed for the class
  148. self.exp = exp_generic
  149. self.log = log_generic
  150. self.squeeze = P.Squeeze(0)
  151. self.cast = P.Cast()
  152. self.const = P.ScalarToArray()
  153. self.dtypeop = P.DType()
  154. self.fill = P.Fill()
  155. self.less = P.Less()
  156. self.lessequal = P.LessEqual()
  157. self.logicaland = P.LogicalAnd()
  158. self.select = P.Select()
  159. self.shape = P.Shape()
  160. self.sq = P.Square()
  161. self.zeroslike = P.ZerosLike()
  162. self.uniform = C.uniform
  163. def extend_repr(self):
  164. if self.is_scalar_batch:
  165. s = f'low = {self.low}, high = {self.high}'
  166. else:
  167. s = f'batch_shape = {self._broadcast_shape}'
  168. return s
  169. @property
  170. def low(self):
  171. """
  172. Return the lower bound of the distribution.
  173. """
  174. return self._low
  175. @property
  176. def high(self):
  177. """
  178. Return the upper bound of the distribution.
  179. """
  180. return self._high
  181. def _get_dist_type(self):
  182. return "Uniform"
  183. def _get_dist_args(self, low=None, high=None):
  184. if low is not None:
  185. self.checktensor(low, 'low')
  186. else:
  187. low = self.low
  188. if high is not None:
  189. self.checktensor(high, 'high')
  190. else:
  191. high = self.high
  192. return low, high
  193. def _range(self, low=None, high=None):
  194. r"""
  195. Return the range of the distribution.
  196. .. math::
  197. range(U) = high -low
  198. """
  199. low, high = self._check_param_type(low, high)
  200. return high - low
  201. def _mean(self, low=None, high=None):
  202. r"""
  203. .. math::
  204. MEAN(U) = \frac{low + high}{2}.
  205. """
  206. low, high = self._check_param_type(low, high)
  207. return (low + high) / 2.
  208. def _var(self, low=None, high=None):
  209. r"""
  210. .. math::
  211. VAR(U) = \frac{(high -low) ^ 2}{12}.
  212. """
  213. low, high = self._check_param_type(low, high)
  214. return self.sq(high - low) / 12.0
  215. def _entropy(self, low=None, high=None):
  216. r"""
  217. .. math::
  218. H(U) = \log(high - low).
  219. """
  220. low, high = self._check_param_type(low, high)
  221. return self.log(high - low)
  222. def _cross_entropy(self, dist, low_b, high_b, low=None, high=None):
  223. """
  224. Evaluate cross entropy between Uniform distributoins.
  225. Args:
  226. dist (str): The type of the distributions. Should be "Uniform" in this case.
  227. low_b (Tensor): The lower bound of distribution b.
  228. high_b (Tensor): The upper bound of distribution b.
  229. low_a (Tensor): The lower bound of distribution a. Default: self.low.
  230. high_a (Tensor): The upper bound of distribution a. Default: self.high.
  231. """
  232. check_distribution_name(dist, 'Uniform')
  233. return self._entropy(low, high) + self._kl_loss(dist, low_b, high_b, low, high)
  234. def _prob(self, value, low=None, high=None):
  235. r"""
  236. pdf of Uniform distribution.
  237. Args:
  238. value (Tensor): The value to be evaluated.
  239. low (Tensor): The lower bound of the distribution. Default: self.low.
  240. high (Tensor): The upper bound of the distribution. Default: self.high.
  241. .. math::
  242. pdf(x) = 0 if x < low;
  243. pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
  244. pdf(x) = 0 if x > high;
  245. """
  246. value = self._check_value(value, 'value')
  247. value = self.cast(value, self.dtype)
  248. low, high = self._check_param_type(low, high)
  249. neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
  250. prob = self.exp(neg_ones * self.log(high - low))
  251. broadcast_shape = self.shape(prob)
  252. zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
  253. comp_lo = self.less(value, low)
  254. comp_hi = self.lessequal(value, high)
  255. less_than_low = self.select(comp_lo, zeros, prob)
  256. return self.select(comp_hi, less_than_low, zeros)
  257. def _kl_loss(self, dist, low_b, high_b, low=None, high=None):
  258. """
  259. Evaluate uniform-uniform KL divergence, i.e. KL(a||b).
  260. Args:
  261. dist (str): The type of the distributions. Should be "Uniform" in this case.
  262. low_b (Tensor): The lower bound of distribution b.
  263. high_b (Tensor): The upper bound of distribution b.
  264. low_a (Tensor): The lower bound of distribution a. Default: self.low.
  265. high_a (Tensor): The upper bound of distribution a. Default: self.high.
  266. """
  267. check_distribution_name(dist, 'Uniform')
  268. low_b = self._check_value(low_b, 'low_b')
  269. low_b = self.cast(low_b, self.parameter_type)
  270. high_b = self._check_value(high_b, 'high_b')
  271. high_b = self.cast(high_b, self.parameter_type)
  272. low_a, high_a = self._check_param_type(low, high)
  273. kl = self.log(high_b - low_b) - self.log(high_a - low_a)
  274. comp = self.logicaland(self.lessequal(
  275. low_b, low_a), self.lessequal(high_a, high_b))
  276. inf = self.fill(self.dtypeop(kl), self.shape(kl), np.inf)
  277. return self.select(comp, kl, inf)
  278. def _cdf(self, value, low=None, high=None):
  279. r"""
  280. The cumulative distribution function of Uniform distribution.
  281. Args:
  282. value (Tensor): The value to be evaluated.
  283. low (Tensor): The lower bound of the distribution. Default: self.low.
  284. high (Tensor): The upper bound of the distribution. Default: self.high.
  285. .. math::
  286. cdf(x) = 0 if x < low;
  287. cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
  288. cdf(x) = 1 if x > high;
  289. """
  290. value = self._check_value(value, 'value')
  291. value = self.cast(value, self.dtype)
  292. low, high = self._check_param_type(low, high)
  293. prob = (value - low) / (high - low)
  294. broadcast_shape = self.shape(prob)
  295. zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
  296. ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0)
  297. comp_lo = self.less(value, low)
  298. comp_hi = self.less(value, high)
  299. less_than_low = self.select(comp_lo, zeros, prob)
  300. return self.select(comp_hi, less_than_low, ones)
  301. def _sample(self, shape=(), low=None, high=None):
  302. """
  303. Sampling.
  304. Args:
  305. shape (tuple): The shape of the sample. Default: ().
  306. low (Tensor): The lower bound of the distribution. Default: self.low.
  307. high (Tensor): The upper bound of the distribution. Default: self.high.
  308. Returns:
  309. Tensor, with the shape being shape + batch_shape.
  310. """
  311. shape = self.checktuple(shape, 'shape')
  312. low, high = self._check_param_type(low, high)
  313. broadcast_shape = self.shape(low + high)
  314. origin_shape = shape + broadcast_shape
  315. if origin_shape == ():
  316. sample_shape = (1,)
  317. else:
  318. sample_shape = origin_shape
  319. l_zero = self.const(0.0)
  320. h_one = self.const(1.0)
  321. sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed)
  322. sample = (high - low) * sample_uniform + low
  323. value = self.cast(sample, self.dtype)
  324. if origin_shape == ():
  325. value = self.squeeze(value)
  326. return value