You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

random_ops.py 9.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Operations for random number generators."""
  16. from .. import operations as P
  17. from .. import functional as F
  18. from ..primitive import constexpr
  19. from .multitype_ops import _constexpr_utils as const_utils
  20. from ...common import dtype as mstype
  21. # set graph-level RNG seed
  22. _GRAPH_SEED = 0
  23. @constexpr
  24. def set_seed(seed):
  25. """
  26. Set the graph-level seed.
  27. Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
  28. If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
  29. random seed.
  30. Args:
  31. seed(Int): the graph-level seed value that to be set. Must be non-negative.
  32. Examples:
  33. >>> C.set_seed(10)
  34. """
  35. const_utils.check_non_negative("seed", seed, "set_seed")
  36. global _GRAPH_SEED
  37. _GRAPH_SEED = seed
  38. @constexpr
  39. def get_seed():
  40. """
  41. Get the graph-level seed.
  42. Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
  43. If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
  44. random seed.
  45. Returns:
  46. Interger. The current graph-level seed.
  47. Examples:
  48. >>> C.get_seed()
  49. """
  50. return _GRAPH_SEED
  51. def normal(shape, mean, stddev, seed=0):
  52. """
  53. Generates random numbers according to the Normal (or Gaussian) random number distribution.
  54. Args:
  55. shape (tuple): The shape of random tensor to be generated.
  56. mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
  57. With float32 data type.
  58. stddev (Tensor): The deviation σ distribution parameter. With float32 data type.
  59. seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
  60. Must be non-negative. Default: 0.
  61. Returns:
  62. Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and stddev.
  63. The dtype is float32.
  64. Examples:
  65. >>> shape = (4, 16)
  66. >>> mean = Tensor(1.0, mstype.float32)
  67. >>> stddev = Tensor(1.0, mstype.float32)
  68. >>> output = C.normal(shape, mean, stddev, seed=5)
  69. """
  70. mean_dtype = F.dtype(mean)
  71. stddev_dtype = F.dtype(stddev)
  72. const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal")
  73. const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal")
  74. const_utils.check_non_negative("seed", seed, "normal")
  75. seed1 = get_seed()
  76. seed2 = seed
  77. stdnormal = P.StandardNormal(seed1, seed2)
  78. random_normal = stdnormal(shape)
  79. value = random_normal * stddev + mean
  80. return value
  81. def uniform(shape, a, b, seed=0, dtype=mstype.float32):
  82. """
  83. Generates random numbers according to the Uniform random number distribution.
  84. Note:
  85. The number in tensor a should be strictly less than b at any position after broadcasting.
  86. Args:
  87. shape (tuple): The shape of random tensor to be generated.
  88. a (Tensor): The a distribution parameter.
  89. It defines the minimum possibly generated value. With int32 or float32 data type.
  90. If dtype is int32, only one number is allowed.
  91. b (Tensor): The b distribution parameter.
  92. It defines the maximum possibly generated value. With int32 or float32 data type.
  93. If dtype is int32, only one number is allowed.
  94. seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
  95. Must be non-negative. Default: 0.
  96. Returns:
  97. Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b.
  98. The dtype is designated as the input `dtype`.
  99. Examples:
  100. >>> For discrete uniform distribution, only one number is allowed for both a and b:
  101. >>> shape = (4, 2)
  102. >>> a = Tensor(1, mstype.int32)
  103. >>> b = Tensor(2, mstype.int32)
  104. >>> output = C.uniform(shape, a, b, seed=5)
  105. >>>
  106. >>> For continuous uniform distribution, a and b can be multi-dimentional:
  107. >>> shape = (4, 2)
  108. >>> a = Tensor([1.0, 2.0], mstype.float32)
  109. >>> b = Tensor([4.0, 5.0], mstype.float32)
  110. >>> output = C.uniform(shape, a, b, seed=5)
  111. """
  112. a_dtype = F.dtype(a)
  113. b_dtype = F.dtype(b)
  114. const_utils.check_tensors_dtype_same(a_dtype, dtype, "uniform")
  115. const_utils.check_tensors_dtype_same(b_dtype, dtype, "uniform")
  116. const_utils.check_non_negative("seed", seed, "uniform")
  117. seed1 = get_seed()
  118. seed2 = seed
  119. if const_utils.is_same_type(dtype, mstype.int32):
  120. random_uniform = P.UniformInt(seed1, seed2)
  121. value = random_uniform(shape, a, b)
  122. else:
  123. uniform_real = P.UniformReal(seed1, seed2)
  124. random_uniform = uniform_real(shape)
  125. value = random_uniform * (b - a) + a
  126. return value
  127. def gamma(shape, alpha, beta, seed=0):
  128. """
  129. Generates random numbers according to the Gamma random number distribution.
  130. Args:
  131. shape (tuple): The shape of random tensor to be generated.
  132. alpha (Tensor): The alpha α distribution parameter. With float32 data type.
  133. beta (Tensor): The beta β distribution parameter. With float32 data type.
  134. seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
  135. Must be non-negative. Default: 0.
  136. Returns:
  137. Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of alpha and beta.
  138. The dtype is float32.
  139. Examples:
  140. >>> shape = (4, 16)
  141. >>> alpha = Tensor(1.0, mstype.float32)
  142. >>> beta = Tensor(1.0, mstype.float32)
  143. >>> output = C.gamma(shape, alpha, beta, seed=5)
  144. """
  145. const_utils.check_non_negative("seed", seed, "gamma")
  146. seed1 = get_seed()
  147. seed2 = seed
  148. random_gamma = P.Gamma(seed1, seed2)
  149. value = random_gamma(shape, alpha, beta)
  150. return value
  151. def poisson(shape, mean, seed=0):
  152. """
  153. Generates random numbers according to the Poisson random number distribution.
  154. Args:
  155. shape (tuple): The shape of random tensor to be generated.
  156. mean (Tensor): The mean μ distribution parameter. With float32 data type.
  157. seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
  158. Must be non-negative. Default: 0.
  159. Returns:
  160. Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean.
  161. The dtype is float32.
  162. Examples:
  163. >>> shape = (4, 16)
  164. >>> mean = Tensor(1.0, mstype.float32)
  165. >>> output = C.poisson(shape, mean, seed=5)
  166. """
  167. const_utils.check_non_negative("seed", seed, "poisson")
  168. seed1 = get_seed()
  169. seed2 = seed
  170. random_poisson = P.Poisson(seed1, seed2)
  171. value = random_poisson(shape, mean)
  172. return value
  173. def multinomial(inputs, num_sample, replacement=True, seed=0):
  174. r"""
  175. Returns a tensor sampled from the multinomial probability distribution located in the corresponding
  176. row of tensor input.
  177. Note:
  178. The rows of input do not need to sum to one (in which case we use the values as weights),
  179. but must be non-negative, finite and have a non-zero sum.
  180. Args:
  181. input (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims.
  182. num_samples (int) - number of samples to draw.
  183. replacement (bool, optional) - whether to draw with replacement or not, default True.
  184. seed (int, optional) - used as entropy source for Random number engines generating pseudo-random numbers.
  185. Must be non-negative. Default: 0.
  186. Outputs:
  187. Tensor. have the same rows with input, each row has num_samples sampled indices.
  188. Examples:
  189. >>> input = Tensor([0, 9, 4, 0], mstype.float32)
  190. >>> output = C.multinomial(input, 2, True)
  191. """
  192. shape = P.Shape()
  193. reshape = P.Reshape()
  194. if inputs.dim() != 1 and inputs.dim() != 2:
  195. raise ValueError("inputs dim must be 1d or 2d")
  196. if not replacement:
  197. P.Multinomial(replacement=replacement, seed=seed)(inputs, num_sample)
  198. if shape(inputs)[-1] < num_sample:
  199. raise ValueError("num_sample must be less than shape(input)[-1] without replacement")
  200. n_dist = 1
  201. if len(shape(inputs)) > 1:
  202. n_dist = shape(inputs)[-2]
  203. random_uniform = P.UniformReal(seed=seed)((n_dist * shape(inputs)[-1],))
  204. if n_dist != 1:
  205. random_uniform = reshape(random_uniform, (n_dist, shape(inputs)[-1]))
  206. vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6)
  207. _, indices = P.TopK()(vals, num_sample)
  208. return indices
  209. return P.Multinomial(replacement=replacement, seed=seed)(inputs, num_sample)