You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. The configuration module provides various functions to set and get the supported
  17. configuration parameters, and read a configuration file.
  18. """
  19. import random
  20. import numpy
  21. import mindspore._c_dataengine as cde
  22. __all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
  23. 'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load']
  24. INT32_MAX = 2147483647
  25. UINT32_MAX = 4294967295
  26. _config = cde.GlobalContext.config_manager()
  27. def set_seed(seed):
  28. """
  29. Set the seed to be used in any random generator. This is used to produce deterministic results.
  30. Note:
  31. This set_seed function sets the seed in the Python random library and numpy.random library
  32. for deterministic Python augmentations using randomness. This set_seed function should
  33. be called with every iterator created to reset the random seed. In the pipeline, this
  34. does not guarantee deterministic results with num_parallel_workers > 1.
  35. Args:
  36. seed(int): Seed to be set.
  37. Raises:
  38. ValueError: If seed is invalid (< 0 or > MAX_UINT_32).
  39. Examples:
  40. >>> import mindspore.dataset as ds
  41. >>>
  42. >>> # Set a new global configuration value for the seed value.
  43. >>> # Operations with randomness will use the seed value to generate random values.
  44. >>> ds.config.set_seed(1000)
  45. """
  46. if seed < 0 or seed > UINT32_MAX:
  47. raise ValueError("Seed given is not within the required range.")
  48. _config.set_seed(seed)
  49. random.seed(seed)
  50. # numpy.random isn't thread safe
  51. numpy.random.seed(seed)
  52. def get_seed():
  53. """
  54. Get the seed.
  55. Returns:
  56. Int, seed.
  57. """
  58. return _config.get_seed()
  59. def set_prefetch_size(size):
  60. """
  61. Set the number of rows to be prefetched.
  62. Args:
  63. size (int): Total number of rows to be prefetched.
  64. Raises:
  65. ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32).
  66. Examples:
  67. >>> import mindspore.dataset as ds
  68. >>>
  69. >>> # Set a new global configuration value for the prefetch size.
  70. >>> ds.config.set_prefetch_size(1000)
  71. """
  72. if size <= 0 or size > INT32_MAX:
  73. raise ValueError("Prefetch size given is not within the required range.")
  74. _config.set_op_connector_size(size)
  75. def get_prefetch_size():
  76. """
  77. Get the prefetch size in number of rows.
  78. Returns:
  79. Size, total number of rows to be prefetched.
  80. """
  81. return _config.get_op_connector_size()
  82. def set_num_parallel_workers(num):
  83. """
  84. Set the default number of parallel workers.
  85. Args:
  86. num (int): Number of parallel workers to be used as a default for each operation.
  87. Raises:
  88. ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32).
  89. Examples:
  90. >>> import mindspore.dataset as ds
  91. >>>
  92. >>> # Set a new global configuration value for the number of parallel workers.
  93. >>> # Now parallel dataset operators will run with 8 workers.
  94. >>> ds.config.set_num_parallel_workers(8)
  95. """
  96. if num <= 0 or num > INT32_MAX:
  97. raise ValueError("Number of parallel workers given is not within the required range.")
  98. _config.set_num_parallel_workers(num)
  99. def get_num_parallel_workers():
  100. """
  101. Get the default number of parallel workers.
  102. Returns:
  103. Int, number of parallel workers to be used as a default for each operation
  104. """
  105. return _config.get_num_parallel_workers()
  106. def set_monitor_sampling_interval(interval):
  107. """
  108. Set the default interval (in milliseconds) for monitor sampling.
  109. Args:
  110. interval (int): Interval (in milliseconds) to be used for performance monitor sampling.
  111. Raises:
  112. ValueError: If interval is invalid (<= 0 or > MAX_INT_32).
  113. Examples:
  114. >>> import mindspore.dataset as ds
  115. >>>
  116. >>> # Set a new global configuration value for the monitor sampling interval.
  117. >>> ds.config.set_monitor_sampling_interval(100)
  118. """
  119. if interval <= 0 or interval > INT32_MAX:
  120. raise ValueError("Interval given is not within the required range.")
  121. _config.set_monitor_sampling_interval(interval)
  122. def get_monitor_sampling_interval():
  123. """
  124. Get the default interval of performance monitor sampling.
  125. Returns:
  126. Int, interval (in milliseconds) for performance monitor sampling.
  127. """
  128. return _config.get_monitor_sampling_interval()
  129. def set_callback_timeout(timeout):
  130. """
  131. Set the default timeout (in seconds) for DSWaitedCallback.
  132. In case of a deadlock, the wait function will exit after the timeout period.
  133. Args:
  134. timeout (int): Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
  135. Raises:
  136. ValueError: If timeout is invalid (<= 0 or > MAX_INT_32).
  137. Examples:
  138. >>> import mindspore.dataset as ds
  139. >>>
  140. >>> # Set a new global configuration value for the timeout value.
  141. >>> ds.config.set_callback_timeout(100)
  142. """
  143. if timeout <= 0 or timeout > INT32_MAX:
  144. raise ValueError("Timeout given is not within the required range.")
  145. _config.set_callback_timeout(timeout)
  146. def get_callback_timeout():
  147. """
  148. Get the default timeout for DSWaitedCallback.
  149. In case of a deadlock, the wait function will exit after the timeout period.
  150. Returns:
  151. Int, the duration in seconds
  152. """
  153. return _config.get_callback_timeout()
  154. def __str__():
  155. """
  156. String representation of the configurations.
  157. Returns:
  158. Str, configurations.
  159. """
  160. return str(_config)
  161. def load(file):
  162. """
  163. Load configurations from a file.
  164. Args:
  165. file (str): Path of the configuration file to be loaded.
  166. Raises:
  167. RuntimeError: If file is invalid and parsing fails.
  168. Examples:
  169. >>> import mindspore.dataset as ds
  170. >>>
  171. >>> # Set new default configuration values according to values in the configuration file.
  172. >>> ds.config.load("path/to/config/file")
  173. >>> # example config file:
  174. >>> # {
  175. >>> # "logFilePath": "/tmp",
  176. >>> # "numParallelWorkers": 4,
  177. >>> # "seed": 5489,
  178. >>> # "monitorSamplingInterval": 30
  179. >>> # }
  180. """
  181. _config.load(file)