You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

seed.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Provide random seed api."""
  16. import numpy as np
  17. import mindspore.dataset as de
  18. from mindspore._checkparam import Validator
  19. # constants
  20. _MAXINT32 = 2**31 - 1
  21. keyConstant = [3528531795, 2654435769, 3449720151, 3144134277]
  22. # set global RNG seed
  23. _GLOBAL_SEED = None
  24. _KERNEL_SEED = {}
  25. def _reset_op_seed():
  26. """
  27. Reset op seeds in the kernel's dictionary.
  28. """
  29. for (kernel_name, op_seed) in _KERNEL_SEED:
  30. _KERNEL_SEED[(kernel_name, op_seed)] = op_seed
  31. def set_seed(seed):
  32. """
  33. Set global random seed.
  34. Note:
  35. The global seed is used by numpy.random, mindspore.common.Initializer, mindspore.ops.composite.random_ops and
  36. mindspore.nn.probability.distribution.
  37. If global seed is not set, these packages will use their own default seed independently, numpy.random and
  38. mindspore.common.Initializer will choose a random seed, mindspore.ops.composite.random_ops and
  39. mindspore.nn.probability.distribution will use zero.
  40. Seed set by numpy.random.seed() only used by numpy.random, while seed set by this API will also used by
  41. numpy.random, so just set all seed by this API is recommended.
  42. Args:
  43. seed (int): The seed to be set.
  44. Raises:
  45. ValueError: If seed is invalid (< 0).
  46. TypeError: If seed isn't a int.
  47. Examples:
  48. >>> # 1. If global seed is not set, numpy.random and initializer will choose a random seed:
  49. >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
  50. >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A2
  51. >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W1
  52. >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W2
  53. >>> # Rerun the program will get diferent results:
  54. >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A3
  55. >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A4
  56. >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W3
  57. >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W4
  58. >>>
  59. >>> 2. If global seed is set, numpy.random and initializer will use it:
  60. >>> set_seed(1234)
  61. >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
  62. >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A2
  63. >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W1
  64. >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W2
  65. >>> # Rerun the program will get the same results:
  66. >>> set_seed(1234)
  67. >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
  68. >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A2
  69. >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W1
  70. >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W2
  71. >>>
  72. >>> # 3. If neither global seed nor op seed is set, mindspore.ops.composite.random_ops and
  73. >>> # mindspore.nn.probability.distribution will choose a random seed:
  74. >>> c1 = C.uniform((1, 4)) # C1
  75. >>> c2 = C.uniform((1, 4)) # C2
  76. >>> Rerun the program will get different results:
  77. >>> c1 = C.uniform((1, 4)) # C3
  78. >>> c2 = C.uniform((1, 4)) # C4
  79. >>>
  80. >>> # 4. If global seed is set, but op seed is not set, mindspore.ops.composite.random_ops and
  81. >>> # mindspore.nn.probability.distribution will caculate a seed according to global seed and
  82. >>> # default op seed. Each call will change the default op seed, thus each call get different
  83. >>> # results.
  84. >>> set_seed(1234)
  85. >>> c1 = C.uniform((1, 4)) # C1
  86. >>> c2 = C.uniform((1, 4)) # C2
  87. >>> # Rerun the program will get the same results:
  88. >>> set_seed(1234)
  89. >>> c1 = C.uniform((1, 4)) # C1
  90. >>> c2 = C.uniform((1, 4)) # C2
  91. >>>
  92. >>> # 5. If both global seed and op seed are set, mindspore.ops.composite.random_ops and
  93. >>> # mindspore.nn.probability.distribution will caculate a seed according to global seed and
  94. >>> # op seed counter. Each call will change the op seed counter, thus each call get different
  95. >>> # results.
  96. >>> set_seed(1234)
  97. >>> c1 = C.uniform((1, 4), seed=2) # C1
  98. >>> c2 = C.uniform((1, 4), seed=2) # C2
  99. >>> Rerun the program will get the same results:
  100. >>> set_seed(1234)
  101. >>> c1 = C.uniform((1, 4), seed=2) # C1
  102. >>> c2 = C.uniform((1, 4), seed=2) # C2
  103. >>>
  104. >>> # 6. If op seed is set but global seed is not set, 0 will be used as global seed. Then
  105. >>> # mindspore.ops.composite.random_ops and mindspore.nn.probability.distribution act as in
  106. >>> # condition 5.
  107. >>> c1 = C.uniform((1, 4), seed=2) # C1
  108. >>> c2 = C.uniform((1, 4), seed=2) # C2
  109. >>> #Rerun the program will get the same results:
  110. >>> c1 = C.uniform((1, 4), seed=2) # C1
  111. >>> c2 = C.uniform((1, 4), seed=2) # C2
  112. >>>
  113. >>> # 7. Recall set_seed() in the program will reset numpy seed and op seed counter of
  114. >>> # mindspore.ops.composite.random_ops and mindspore.nn.probability.distribution.
  115. >>> set_seed(1234)
  116. >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
  117. >>> c1 = C.uniform((1, 4), seed=2) # C1
  118. >>> set_seed(1234)
  119. >>> np_2 = np.random.normal(0, 1, [1]).astype(np.float32) # still get A1
  120. >>> c2 = C.uniform((1, 4), seed=2) # still get C1
  121. """
  122. if not isinstance(seed, int):
  123. raise TypeError("The seed must be type of int.")
  124. Validator.check_non_negative_int(seed, "seed", "global_seed")
  125. np.random.seed(seed)
  126. de.config.set_seed(seed)
  127. _reset_op_seed()
  128. global _GLOBAL_SEED
  129. _GLOBAL_SEED = seed
  130. def get_seed():
  131. """
  132. Get global random seed.
  133. """
  134. return _GLOBAL_SEED
  135. def _truncate_seed(seed):
  136. """
  137. Truncate the seed with MAXINT32.
  138. Args:
  139. seed (int): The seed to be truncated.
  140. """
  141. return seed % _MAXINT32 # Truncate to fit into 32-bit integer
  142. def _update_seeds(op_seed, kernel_name):
  143. """
  144. Update the seed every time when a random op is called.
  145. Args:
  146. seed (int): The op-seed to be updated.
  147. kernel_name (string): The random op kernel.
  148. """
  149. global _KERNEL_SEED
  150. if op_seed is not None:
  151. _KERNEL_SEED[(kernel_name, op_seed)] = _KERNEL_SEED[(kernel_name, op_seed)] + (keyConstant[0] ^ keyConstant[2])
  152. def _get_op_seed(op_seed, kernel_name):
  153. """
  154. Get op seed which is relating to the specific kernel.
  155. If the seed does not exist, add it into the kernel's dictionary.
  156. Args:
  157. seed (int): The op-seed to be updated.
  158. kernel_name (string): The random op kernel.
  159. """
  160. if (kernel_name, op_seed) not in _KERNEL_SEED:
  161. _KERNEL_SEED[(kernel_name, op_seed)] = op_seed
  162. return _KERNEL_SEED[(kernel_name, op_seed)]
  163. def _get_graph_seed(op_seed, kernel_name):
  164. """
  165. Get the graph-level seed.
  166. Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
  167. If op-level seed is 0, use graph-level seed; if graph-level seed is also 0, the system would generate a
  168. random seed.
  169. Note:
  170. For each seed, either op-seed or graph-seed, a random sequence will be generated relating to this seed.
  171. So, the state of the seed regarding to this op should be recorded.
  172. A simple illustration should be:
  173. If a random op is called twice within one program, the two results should be different:
  174. print(C.uniform((1, 4), seed=1)) # generates 'A1'
  175. print(C.uniform((1, 4), seed=1)) # generates 'A2'
  176. If the same program runs again, it repeat the results:
  177. print(C.uniform((1, 4), seed=1)) # generates 'A1'
  178. print(C.uniform((1, 4), seed=1)) # generates 'A2'
  179. Returns:
  180. Interger. The current graph-level seed.
  181. Examples:
  182. >>> _get_graph_seed(seed, 'normal')
  183. """
  184. global_seed = get_seed()
  185. if global_seed is None:
  186. global_seed = 0
  187. if op_seed is None:
  188. op_seed = 0
  189. # neither global seed or op seed is set, return (0, 0) to let kernel choose random seed.
  190. if global_seed == 0 and op_seed == 0:
  191. seeds = 0, 0
  192. else:
  193. Validator.check_non_negative_int(op_seed, "seed", kernel_name)
  194. temp_seed = _get_op_seed(op_seed, kernel_name)
  195. seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
  196. _update_seeds(op_seed, kernel_name)
  197. return seeds