You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

random_ops.py 19 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Operators for random."""
  16. from ..._checkparam import Validator, Rel
  17. from ...common import dtype as mstype
  18. from ..primitive import PrimitiveWithInfer, prim_attr_register
  19. from .._utils import get_broadcast_shape
  20. class StandardNormal(PrimitiveWithInfer):
  21. r"""
  22. Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
  23. Args:
  24. seed (int): Random seed, must be non-negative. Default: 0.
  25. seed2 (int): Random seed2, must be non-negative. Default: 0.
  26. Inputs:
  27. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
  28. Outputs:
  29. Tensor. The shape is the same as the input `shape`. The dtype is float32.
  30. Examples:
  31. >>> shape = (4, 16)
  32. >>> stdnormal = P.StandardNormal(seed=2)
  33. >>> output = stdnormal(shape)
  34. """
  35. @prim_attr_register
  36. def __init__(self, seed=0, seed2=0):
  37. """Initialize StandardNormal"""
  38. self.init_prim_io_names(inputs=['shape'], outputs=['output'])
  39. Validator.check_non_negative_int(seed, "seed", self.name)
  40. Validator.check_non_negative_int(seed2, "seed2", self.name)
  41. def __infer__(self, shape):
  42. shape_v = shape["value"]
  43. if shape_v is None:
  44. raise ValueError(f"For {self.name}, shape must be const.")
  45. Validator.check_value_type("shape", shape_v, [tuple], self.name)
  46. for i, shape_i in enumerate(shape_v):
  47. Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
  48. out = {
  49. 'shape': shape_v,
  50. 'dtype': mstype.float32,
  51. 'value': None}
  52. return out
  53. class StandardLaplace(PrimitiveWithInfer):
  54. r"""
  55. Generates random numbers according to the Laplace random number distribution (mean=0, lambda=1).
  56. It is defined as:
  57. .. math::
  58. \text{f}(x;0,1) = \frac{1}{2}\exp(-|x|),
  59. Args:
  60. seed (int): Random seed. Default: 0.
  61. seed2 (int): Random seed2. Default: 0.
  62. Inputs:
  63. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
  64. Outputs:
  65. Tensor. The shape that the input 'shape' denotes. The dtype is float32.
  66. Examples:
  67. >>> shape = (4, 16)
  68. >>> stdlaplace = P.StandardLaplace(seed=2)
  69. >>> output = stdlaplace(shape)
  70. """
  71. @prim_attr_register
  72. def __init__(self, seed=0, seed2=0):
  73. """Initialize StandardLaplace"""
  74. self.init_prim_io_names(inputs=['shape'], outputs=['output'])
  75. Validator.check_value_type('seed', seed, [int], self.name)
  76. Validator.check_value_type('seed2', seed2, [int], self.name)
  77. def __infer__(self, shape):
  78. shape_v = shape["value"]
  79. if shape_v is None:
  80. raise ValueError(f"For {self.name}, shape must be const.")
  81. Validator.check_value_type("shape", shape_v, [tuple], self.name)
  82. for i, shape_i in enumerate(shape_v):
  83. Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
  84. out = {
  85. 'shape': shape_v,
  86. 'dtype': mstype.float32,
  87. 'value': None}
  88. return out
  89. class Gamma(PrimitiveWithInfer):
  90. r"""
  91. Produces random positive floating-point values x, distributed according to probability density function:
  92. .. math::
  93. \text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}},
  94. Args:
  95. seed (int): Random seed, must be non-negative. Default: 0.
  96. seed2 (int): Random seed2, must be non-negative. Default: 0.
  97. Inputs:
  98. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
  99. - **alpha** (Tensor) - The α distribution parameter. It must be greater than 0.
  100. It is also known as the shape parameter with float32 data type.
  101. - **beta** (Tensor) - The β distribution parameter. It must be greater than 0.
  102. It is also known as the scale parameter with float32 data type.
  103. Outputs:
  104. Tensor. The shape must be the broadcasted shape of Input "shape" and shapes of alpha and beta.
  105. The dtype is float32.
  106. Examples:
  107. >>> shape = (4, 16)
  108. >>> alpha = Tensor(1.0, mstype.float32)
  109. >>> beta = Tensor(1.0, mstype.float32)
  110. >>> gamma = P.Gamma(seed=3)
  111. >>> output = Gamma(shape, alpha, beta)
  112. """
  113. @prim_attr_register
  114. def __init__(self, seed=0, seed2=0):
  115. """Initialize Gamma"""
  116. self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
  117. Validator.check_non_negative_int(seed, "seed", self.name)
  118. Validator.check_non_negative_int(seed2, "seed2", self.name)
  119. def __infer__(self, shape, alpha, beta):
  120. shape_v = shape["value"]
  121. if shape_v is None:
  122. raise ValueError(f"For {self.name}, shape must be const.")
  123. Validator.check_value_type("shape", shape_v, [tuple], self.name)
  124. for i, shape_i in enumerate(shape_v):
  125. Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
  126. Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name)
  127. Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name)
  128. broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
  129. broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
  130. out = {
  131. 'shape': broadcast_shape,
  132. 'dtype': mstype.float32,
  133. 'value': None}
  134. return out
  135. class Poisson(PrimitiveWithInfer):
  136. r"""
  137. Produces random non-negative integer values i, distributed according to discrete probability function:
  138. .. math::
  139. \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!},
  140. Args:
  141. seed (int): Random seed, must be non-negative. Default: 0.
  142. seed2 (int): Random seed2, must be non-negative. Default: 0.
  143. Inputs:
  144. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
  145. - **mean** (Tensor) - μ parameter the distribution was constructed with. The parameter defines mean number
  146. of occurrences of the event. It must be greater than 0. With float32 data type.
  147. Outputs:
  148. Tensor. Its shape must be the broadcasted shape of `shape` and the shape of `mean`.
  149. The dtype is int32.
  150. Examples:
  151. >>> shape = (4, 16)
  152. >>> mean = Tensor(5.0, mstype.float32)
  153. >>> poisson = P.Poisson(seed=5)
  154. >>> output = poisson(shape, mean)
  155. """
  156. @prim_attr_register
  157. def __init__(self, seed=0, seed2=0):
  158. """Initialize Poisson"""
  159. self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
  160. Validator.check_non_negative_int(seed, "seed", self.name)
  161. Validator.check_non_negative_int(seed2, "seed2", self.name)
  162. def __infer__(self, shape, mean):
  163. shape_v = shape["value"]
  164. if shape_v is None:
  165. raise ValueError(f"For {self.name}, shape must be const.")
  166. Validator.check_value_type("shape", shape_v, [tuple], self.name)
  167. for i, shape_i in enumerate(shape_v):
  168. Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
  169. Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
  170. broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
  171. out = {
  172. 'shape': broadcast_shape,
  173. 'dtype': mstype.int32,
  174. 'value': None}
  175. return out
  176. class UniformInt(PrimitiveWithInfer):
  177. r"""
  178. Produces random integer values i, uniformly distributed on the closed interval [minval, maxval), that is,
  179. distributed according to the discrete probability function:
  180. .. math::
  181. \text{P}(i|a,b) = \frac{1}{b-a+1},
  182. Note:
  183. The number in tensor minval must be strictly less than maxval at any position after broadcasting.
  184. Args:
  185. seed (int): Random seed, must be non-negative. Default: 0.
  186. seed2 (int): Random seed2, must be non-negative. Default: 0.
  187. Inputs:
  188. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
  189. - **minval** (Tensor) - The distribution parameter, a.
  190. It defines the minimum possibly generated value, with int32 data type. Only one number is supported.
  191. - **maxval** (Tensor) - The distribution parameter, b.
  192. It defines the maximum possibly generated value, with int32 data type. Only one number is supported.
  193. Outputs:
  194. Tensor. The shape is the same as the input 'shape', and the data type is int32.
  195. Examples:
  196. >>> shape = (4, 16)
  197. >>> minval = Tensor(1, mstype.int32)
  198. >>> maxval = Tensor(5, mstype.int32)
  199. >>> uniform_int = P.UniformInt(seed=10)
  200. >>> output = uniform_int(shape, minval, maxval)
  201. """
  202. @prim_attr_register
  203. def __init__(self, seed=0, seed2=0):
  204. """Initialize UniformInt"""
  205. self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output'])
  206. Validator.check_non_negative_int(seed, "seed", self.name)
  207. Validator.check_non_negative_int(seed2, "seed2", self.name)
  208. def __infer__(self, shape, minval, maxval):
  209. shape_v = shape["value"]
  210. if shape_v is None:
  211. raise ValueError(f"For {self.name}, shape must be const.")
  212. Validator.check_value_type("shape", shape_v, [tuple], self.name)
  213. for i, shape_i in enumerate(shape_v):
  214. Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
  215. Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name)
  216. Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name)
  217. minval_shape = minval['shape']
  218. maxval_shape = maxval['shape']
  219. Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
  220. Validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name)
  221. out = {
  222. 'shape': shape_v,
  223. 'dtype': mstype.int32,
  224. 'value': None}
  225. return out
  226. class UniformReal(PrimitiveWithInfer):
  227. r"""
  228. Produces random floating-point values i, uniformly distributed to the interval [0, 1).
  229. Args:
  230. seed (int): Random seed, must be non-negative. Default: 0.
  231. seed2 (int): Random seed2, must be non-negative. Default: 0.
  232. Inputs:
  233. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
  234. Outputs:
  235. Tensor. The shape that the input 'shape' denotes. The dtype is float32.
  236. Examples:
  237. >>> shape = (4, 16)
  238. >>> uniformreal = P.UniformReal(seed=2)
  239. >>> output = uniformreal(shape)
  240. """
  241. @prim_attr_register
  242. def __init__(self, seed=0, seed2=0):
  243. """Initialize UniformReal"""
  244. self.init_prim_io_names(inputs=['shape'], outputs=['output'])
  245. Validator.check_non_negative_int(seed, "seed", self.name)
  246. Validator.check_non_negative_int(seed2, "seed2", self.name)
  247. def __infer__(self, shape):
  248. shape_v = shape["value"]
  249. if shape_v is None:
  250. raise ValueError(f"For {self.name}, shape must be const.")
  251. Validator.check_value_type("shape", shape_v, [tuple], self.name)
  252. for i, shape_i in enumerate(shape_v):
  253. Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
  254. out = {
  255. 'shape': shape_v,
  256. 'dtype': mstype.float32,
  257. 'value': None}
  258. return out
  259. class RandomChoiceWithMask(PrimitiveWithInfer):
  260. """
  261. Generates a random sample as index tensor with a mask tensor from a given tensor.
  262. The input must be a tensor of rank not less than 1. If its rank is greater than or equal to 2,
  263. the first dimension specifies the number of samples.
  264. The index tensor and the mask tensor have the fixed shapes. The index tensor denotes the index of the nonzero
  265. sample, while the mask tensor denotes which elements in the index tensor are valid.
  266. Args:
  267. count (int): Number of items expected to get and the number must be greater than 0. Default: 256.
  268. seed (int): Random seed. Default: 0.
  269. seed2 (int): Random seed2. Default: 0.
  270. Inputs:
  271. - **input_x** (Tensor[bool]) - The input tensor.
  272. The input tensor rank must be greater than or equal to 1 and less than or equal to 5.
  273. Outputs:
  274. Two tensors, the first one is the index tensor and the other one is the mask tensor.
  275. - **index** (Tensor) - The output shape is 2-D.
  276. - **mask** (Tensor) - The output shape is 1-D.
  277. Examples:
  278. >>> rnd_choice_mask = P.RandomChoiceWithMask()
  279. >>> input_x = Tensor(np.ones(shape=[240000, 4]).astype(np.bool))
  280. >>> output_y, output_mask = rnd_choice_mask(input_x)
  281. >>> output_y.shape
  282. (256, 2)
  283. >>> output_mask.shape
  284. (256,)
  285. """
  286. @prim_attr_register
  287. def __init__(self, count=256, seed=0, seed2=0):
  288. """Initialize RandomChoiceWithMask"""
  289. Validator.check_value_type("count", count, [int], self.name)
  290. Validator.check_positive_int(count, "count", self.name)
  291. Validator.check_value_type('seed', seed, [int], self.name)
  292. Validator.check_value_type('seed2', seed2, [int], self.name)
  293. def infer_shape(self, x_shape):
  294. Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name)
  295. Validator.check_int(len(x_shape), 5, Rel.LE, "input_x rank", self.name)
  296. return ([self.count, len(x_shape)], [self.count])
  297. def infer_dtype(self, x_dtype):
  298. Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
  299. return (mstype.int32, mstype.bool_)
  300. class RandomCategorical(PrimitiveWithInfer):
  301. """
  302. Generates random samples from a given categorical distribution tensor.
  303. Args:
  304. dtype (mindspore.dtype): The type of output. Its value must be one of mindspore.int16,
  305. mindspore.int32 and mindspore.int64. Default: mindspore.int64.
  306. Inputs:
  307. - **logits** (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes].
  308. - **num_sample** (int) - Number of sample to be drawn. Only constant values is allowed.
  309. - **seed** (int) - Random seed. Default: 0. Only constant values is allowed.
  310. Outputs:
  311. - **output** (Tensor) - The output Tensor with shape [batch_size, num_samples].
  312. Examples:
  313. >>> class Net(nn.Cell):
  314. >>> def __init__(self, num_sample):
  315. >>> super(Net, self).__init__()
  316. >>> self.random_categorical = P.RandomCategorical(mindspore.int64)
  317. >>> self.num_sample = num_sample
  318. >>> def construct(self, logits, seed=0):
  319. >>> return self.random_categorical(logits, self.num_sample, seed)
  320. >>>
  321. >>> x = np.random.random((10, 5)).astype(np.float32)
  322. >>> net = Net(8)
  323. >>> output = net(Tensor(x))
  324. """
  325. @prim_attr_register
  326. def __init__(self, dtype=mstype.int64):
  327. """Initialize RandomCategorical"""
  328. self.dtype = dtype
  329. valid_values = (mstype.int32, mstype.int16, mstype.int64)
  330. Validator.check_type_name("dtype", dtype, valid_values, self.name)
  331. self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'],
  332. outputs=['output'])
  333. def __infer__(self, logits, num_samples, seed):
  334. logits_dtype = logits['dtype']
  335. valid_types = (mstype.float32, mstype.float16, mstype.float64)
  336. Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
  337. num_samples_v = num_samples['value']
  338. seed_v = seed['value']
  339. Validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
  340. Validator.check_value_type('seed', seed_v, (int,), self.name)
  341. Validator.check_positive_int(num_samples_v, "num_samples", self.name)
  342. x_shape = list(logits['shape'])
  343. if len(x_shape) != 2:
  344. raise ValueError("RandomCategorical shape should be 2-dimension.")
  345. ndim = len(x_shape) - 1
  346. x_shape[ndim] = num_samples_v
  347. return {'shape': (x_shape),
  348. 'dtype': (self.dtype),
  349. 'value': None}
  350. class Multinomial(PrimitiveWithInfer):
  351. r"""
  352. Returns a tensor sampled from the multinomial probability distribution located in the corresponding
  353. row of tensor input.
  354. Note:
  355. The rows of input do not need to sum to one (in which case we use the values as weights),
  356. but must be non-negative, finite and have a non-zero sum.
  357. Args:
  358. seed (int): Seed data is used as entropy source for Random number engines to generate pseudo-random numbers.
  359. Must be non-negative. Default: 0.
  360. Inputs:
  361. - **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2
  362. dimensions.
  363. - **num_samples** (int32) - number of samples to draw.
  364. Outputs:
  365. Tensor with the same rows as input, each row has num_samples sampled indices.
  366. Examples:
  367. >>> input = Tensor([0., 9., 4., 0.], mstype.float32)
  368. >>> multinomial = P.Multinomial(seed=10)
  369. >>> output = multinomial(input, 2)
  370. """
  371. @prim_attr_register
  372. def __init__(self, seed=0):
  373. """init"""
  374. Validator.check_value_type("seed", seed, [int], self.name)
  375. Validator.check_non_negative_int(seed, "seed", self.name)
  376. self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
  377. def __infer__(self, inputs, num_samples):
  378. input_shape = inputs["shape"]
  379. if len(input_shape) != 1 and len(input_shape) != 2:
  380. raise ValueError("input dim must be 1 or 2")
  381. Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name)
  382. num_samples_value = num_samples["value"]
  383. if num_samples_value is None:
  384. raise ValueError(f"For {self.name}, shape nust be const")
  385. Validator.check_value_type("num_samples", num_samples_value, (int,), self.name)
  386. Validator.check_positive_int(num_samples_value, "num_samples")
  387. y_shape = (num_samples_value,)
  388. if len(input_shape) == 2:
  389. y_shape = (input_shape[0], num_samples_value)
  390. out = {
  391. "shape": y_shape,
  392. "dtype": mstype.int32,
  393. "value": None}
  394. return out