You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

uniform.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Uniform Distribution"""
  16. from mindspore.ops import operations as P
  17. from mindspore.ops import composite as C
  18. from mindspore.common import dtype as mstype
  19. from .distribution import Distribution
  20. from ._utils.utils import convert_to_batch, check_greater, check_type
  21. class Uniform(Distribution):
  22. """
  23. Example class: Uniform Distribution.
  24. Args:
  25. low (int, float, list, numpy.ndarray, Tensor, Parameter): lower bound of the distribution.
  26. high (int, float, list, numpy.ndarray, Tensor, Parameter): upper bound of the distribution.
  27. seed (int): seed to use in sampling. Default: 0.
  28. dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
  29. name (str): name of the distribution. Default: Uniform.
  30. Note:
  31. low should be stricly less than high.
  32. Dist_spec_args are high and low.
  33. Examples:
  34. >>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0
  35. >>> import mindspore.nn.probability.distribution as msd
  36. >>> u = msd.Uniform(0.0, 1.0, dtype=mstype.float32)
  37. >>>
  38. >>> # The following creates two independent Uniform distributions
  39. >>> u = msd.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32)
  40. >>>
  41. >>> # A Uniform distribution can be initilized without arguments
  42. >>> # In this case, high and low must be passed in through args during function calls.
  43. >>> u = msd.Uniform(dtype=mstype.float32)
  44. >>>
  45. >>> # To use Uniform in a network
  46. >>> class net(Cell):
  47. >>> def __init__(self)
  48. >>> super(net, self).__init__():
  49. >>> self.u1 = msd.Uniform(0.0, 1.0, dtype=mstype.float32)
  50. >>> self.u2 = msd.Uniform(dtype=mstype.float32)
  51. >>>
  52. >>> # All the following calls in construct are valid
  53. >>> def construct(self, value, low_b, high_b, low_a, high_a):
  54. >>>
  55. >>> # Similar calls can be made to other probability functions
  56. >>> # by replacing 'prob' with the name of the function
  57. >>> ans = self.u1.prob(value)
  58. >>> # Evaluate with the respect to distribution b
  59. >>> ans = self.u1.prob(value, low_b, high_b)
  60. >>>
  61. >>> # High and low must be passed in during function calls
  62. >>> ans = self.u2.prob(value, low_a, high_a)
  63. >>>
  64. >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
  65. >>> # Will return 0.5
  66. >>> ans = self.u1.mean()
  67. >>> # Will return (low_b + high_b) / 2
  68. >>> ans = self.u1.mean(low_b, high_b)
  69. >>>
  70. >>> # High and low must be passed in during function calls
  71. >>> ans = self.u2.mean(low_a, high_a)
  72. >>>
  73. >>> # Usage of 'kl_loss' and 'cross_entropy' are similar
  74. >>> ans = self.u1.kl_loss('Uniform', low_b, high_b)
  75. >>> ans = self.u1.kl_loss('Uniform', low_b, high_b, low_a, high_a)
  76. >>>
  77. >>> # Additional high and low must be passed
  78. >>> ans = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
  79. >>>
  80. >>> # Sample
  81. >>> ans = self.u1.sample()
  82. >>> ans = self.u1.sample((2,3))
  83. >>> ans = self.u1.sample((2,3), low_b, high_b)
  84. >>> ans = self.u2.sample((2,3), low_a, high_a)
  85. """
  86. def __init__(self,
  87. low=None,
  88. high=None,
  89. seed=0,
  90. dtype=mstype.float32,
  91. name="Uniform"):
  92. """
  93. Constructor of Uniform distribution.
  94. """
  95. param = dict(locals())
  96. valid_dtype = mstype.float_type
  97. check_type(dtype, valid_dtype, "Uniform")
  98. super(Uniform, self).__init__(seed, dtype, name, param)
  99. if low is not None and high is not None:
  100. self._low = convert_to_batch(low, self.broadcast_shape, dtype)
  101. self._high = convert_to_batch(high, self.broadcast_shape, dtype)
  102. check_greater(self.low, self.high, "low value", "high value")
  103. else:
  104. self._low = low
  105. self._high = high
  106. # ops needed for the class
  107. self.cast = P.Cast()
  108. self.const = P.ScalarToArray()
  109. self.dtypeop = P.DType()
  110. self.exp = P.Exp()
  111. self.fill = P.Fill()
  112. self.less = P.Less()
  113. self.lessequal = P.LessEqual()
  114. self.log = P.Log()
  115. self.logicaland = P.LogicalAnd()
  116. self.select = P.Select()
  117. self.shape = P.Shape()
  118. self.sq = P.Square()
  119. self.sqrt = P.Sqrt()
  120. self.zeroslike = P.ZerosLike()
  121. self.uniform = C.uniform
  122. def extend_repr(self):
  123. if self.is_scalar_batch:
  124. str_info = f'low = {self.low}, high = {self.high}'
  125. else:
  126. str_info = f'batch_shape = {self._broadcast_shape}'
  127. return str_info
  128. @property
  129. def low(self):
  130. """
  131. Return lower bound of the distribution.
  132. """
  133. return self._low
  134. @property
  135. def high(self):
  136. """
  137. Return upper bound of the distribution.
  138. """
  139. return self._high
  140. def _range(self, low=None, high=None):
  141. r"""
  142. Return the range of the distribution.
  143. .. math::
  144. range(U) = high -low
  145. """
  146. low = self.low if low is None else low
  147. high = self.high if high is None else high
  148. return high - low
  149. def _mean(self, low=None, high=None):
  150. r"""
  151. .. math::
  152. MEAN(U) = \frac{low + high}{2}.
  153. """
  154. low = self.low if low is None else low
  155. high = self.high if high is None else high
  156. return (low + high) / 2.
  157. def _var(self, low=None, high=None):
  158. r"""
  159. .. math::
  160. VAR(U) = \frac{(high -low) ^ 2}{12}.
  161. """
  162. low = self.low if low is None else low
  163. high = self.high if high is None else high
  164. return self.sq(high - low) / 12.0
  165. def _entropy(self, low=None, high=None):
  166. r"""
  167. .. math::
  168. H(U) = \log(high - low).
  169. """
  170. low = self.low if low is None else low
  171. high = self.high if high is None else high
  172. return self.log(high - low)
  173. def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None):
  174. """
  175. Evaluate cross_entropy between Uniform distributoins.
  176. Args:
  177. dist (str): type of the distributions. Should be "Uniform" in this case.
  178. low_b (Tensor): lower bound of distribution b.
  179. high_b (Tensor): upper bound of distribution b.
  180. low_a (Tensor): lower bound of distribution a. Default: self.low.
  181. high_a (Tensor): upper bound of distribution a. Default: self.high.
  182. """
  183. if dist == 'Uniform':
  184. return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a)
  185. return None
  186. def _prob(self, value, low=None, high=None):
  187. r"""
  188. pdf of Uniform distribution.
  189. Args:
  190. value (Tensor): value to be evaluated.
  191. low (Tensor): lower bound of the distribution. Default: self.low.
  192. high (Tensor): upper bound of the distribution. Default: self.high.
  193. .. math::
  194. pdf(x) = 0 if x < low;
  195. pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
  196. pdf(x) = 0 if x > high;
  197. """
  198. low = self.low if low is None else low
  199. high = self.high if high is None else high
  200. ones = self.fill(self.dtype, self.shape(value), 1.0)
  201. prob = ones / (high - low)
  202. broadcast_shape = self.shape(prob)
  203. zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
  204. comp_lo = self.less(value, low)
  205. comp_hi = self.lessequal(value, high)
  206. less_than_low = self.select(comp_lo, zeros, prob)
  207. return self.select(comp_hi, less_than_low, zeros)
  208. def _kl_loss(self, dist, low_b, high_b, low_a=None, high_a=None):
  209. """
  210. Evaluate uniform-uniform kl divergence, i.e. KL(a||b).
  211. Args:
  212. dist (str): type of the distributions. Should be "Uniform" in this case.
  213. low_b (Tensor): lower bound of distribution b.
  214. high_b (Tensor): upper bound of distribution b.
  215. low_a (Tensor): lower bound of distribution a. Default: self.low.
  216. high_a (Tensor): upper bound of distribution a. Default: self.high.
  217. """
  218. if dist == 'Uniform':
  219. low_a = self.low if low_a is None else low_a
  220. high_a = self.high if high_a is None else high_a
  221. kl = self.log(high_b - low_b) / self.log(high_a - low_a)
  222. comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
  223. return self.select(comp, kl, self.log(self.zeroslike(kl)))
  224. return None
  225. def _cdf(self, value, low=None, high=None):
  226. r"""
  227. cdf of Uniform distribution.
  228. Args:
  229. value (Tensor): value to be evaluated.
  230. low (Tensor): lower bound of the distribution. Default: self.low.
  231. high (Tensor): upper bound of the distribution. Default: self.high.
  232. .. math::
  233. cdf(x) = 0 if x < low;
  234. cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
  235. cdf(x) = 1 if x > high;
  236. """
  237. low = self.low if low is None else low
  238. high = self.high if high is None else high
  239. prob = (value - low) / (high - low)
  240. broadcast_shape = self.shape(prob)
  241. zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
  242. ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0)
  243. comp_lo = self.less(value, low)
  244. comp_hi = self.less(value, high)
  245. less_than_low = self.select(comp_lo, zeros, prob)
  246. return self.select(comp_hi, less_than_low, ones)
  247. def _sample(self, shape=(), low=None, high=None):
  248. """
  249. Sampling.
  250. Args:
  251. shape (tuple): shape of the sample. Default: ().
  252. low (Tensor): lower bound of the distribution. Default: self.low.
  253. high (Tensor): upper bound of the distribution. Default: self.high.
  254. Returns:
  255. Tensor, shape is shape + batch_shape.
  256. """
  257. low = self.low if low is None else low
  258. high = self.high if high is None else high
  259. broadcast_shape = self.shape(low + high)
  260. l_zero = self.const(0.0)
  261. h_one = self.const(1.0)
  262. sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one, self.seed)
  263. sample = (high - low) * sample_uniform + low
  264. return self.cast(sample, self.dtype)