You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

configuration.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. The configuration manager.
  17. """
  18. import random
  19. import numpy
  20. import mindspore._c_dataengine as cde
  21. INT32_MAX = 2147483647
  22. UINT32_MAX = 4294967295
  23. class ConfigurationManager:
  24. """The configuration manager"""
  25. def __init__(self):
  26. self.config = cde.GlobalContext.config_manager()
  27. def set_seed(self, seed):
  28. """
  29. Set the seed to be used in any random generator. This is used to produce deterministic results.
  30. Note:
  31. This set_seed function sets the seed in the python random library and numpy.random library
  32. for deterministic python augmentations using randomness. This set_seed function should
  33. be called with every iterator created to reset the random seed. In our pipeline this
  34. does not guarantee deterministic results with num_parallel_workers > 1.
  35. Args:
  36. seed(int): seed to be set
  37. Raises:
  38. ValueError: If seed is invalid (< 0 or > MAX_UINT_32).
  39. Examples:
  40. >>> import mindspore.dataset as ds
  41. >>> con = ds.engine.ConfigurationManager()
  42. >>> # sets the new seed value, now operators with a random seed will use new seed value.
  43. >>> con.set_seed(1000)
  44. """
  45. if seed < 0 or seed > UINT32_MAX:
  46. raise ValueError("Seed given is not within the required range")
  47. self.config.set_seed(seed)
  48. random.seed(seed)
  49. # numpy.random isn't thread safe
  50. numpy.random.seed(seed)
  51. def get_seed(self):
  52. """
  53. Get the seed
  54. Returns:
  55. Int, seed.
  56. """
  57. return self.config.get_seed()
  58. def set_prefetch_size(self, size):
  59. """
  60. Set the number of rows to be prefetched.
  61. Args:
  62. size: total number of rows to be prefetched.
  63. Raises:
  64. ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32).
  65. Examples:
  66. >>> import mindspore.dataset as ds
  67. >>> con = ds.engine.ConfigurationManager()
  68. >>> # sets the new prefetch value.
  69. >>> con.set_prefetch_size(1000)
  70. """
  71. if size <= 0 or size > INT32_MAX:
  72. raise ValueError("Prefetch size given is not within the required range")
  73. self.config.set_op_connector_size(size)
  74. def get_prefetch_size(self):
  75. """
  76. Get the prefetch size in number of rows.
  77. Returns:
  78. Size, total number of rows to be prefetched.
  79. """
  80. return self.config.get_op_connector_size()
  81. def set_num_parallel_workers(self, num):
  82. """
  83. Set the default number of parallel workers
  84. Args:
  85. num: number of parallel workers to be used as a default for each operation
  86. Raises:
  87. ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32).
  88. Examples:
  89. >>> import mindspore.dataset as ds
  90. >>> con = ds.engine.ConfigurationManager()
  91. >>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers.
  92. >>> con.set_num_parallel_workers(8)
  93. """
  94. if num <= 0 or num > INT32_MAX:
  95. raise ValueError("Num workers given is not within the required range")
  96. self.config.set_num_parallel_workers(num)
  97. def get_num_parallel_workers(self):
  98. """
  99. Get the default number of parallel workers.
  100. Returns:
  101. Int, number of parallel workers to be used as a default for each operation
  102. """
  103. return self.config.get_num_parallel_workers()
  104. def set_monitor_sampling_interval(self, interval):
  105. """
  106. Set the default interval(ms) of monitor sampling.
  107. Args:
  108. interval: interval(ms) to be used to performance monitor sampling.
  109. Raises:
  110. ValueError: If interval is invalid (<= 0 or > MAX_INT_32).
  111. Examples:
  112. >>> import mindspore.dataset as ds
  113. >>> con = ds.engine.ConfigurationManager()
  114. >>> # sets the new interval value.
  115. >>> con.set_monitor_sampling_interval(100)
  116. """
  117. if interval <= 0 or interval > INT32_MAX:
  118. raise ValueError("Interval given is not within the required range")
  119. self.config.set_monitor_sampling_interval(interval)
  120. def get_monitor_sampling_interval(self):
  121. """
  122. Get the default interval of performance monitor sampling.
  123. Returns:
  124. Interval: interval(ms) of performance monitor sampling.
  125. """
  126. return self.config.get_monitor_sampling_interval()
  127. def __str__(self):
  128. """
  129. String representation of the configurations.
  130. Returns:
  131. Str, configurations.
  132. """
  133. return str(self.config)
  134. def load(self, file):
  135. """
  136. Load configuration from a file.
  137. Args:
  138. file: path the config file to be loaded
  139. Raises:
  140. RuntimeError: If file is invalid and parsing fails.
  141. Examples:
  142. >>> import mindspore.dataset as ds
  143. >>> con = ds.engine.ConfigurationManager()
  144. >>> # sets the default value according to values in configuration file.
  145. >>> con.load("path/to/config/file")
  146. >>> # example config file:
  147. >>> # {
  148. >>> # "logFilePath": "/tmp",
  149. >>> # "rowsPerBuffer": 32,
  150. >>> # "numParallelWorkers": 4,
  151. >>> # "workerConnectorSize": 16,
  152. >>> # "opConnectorSize": 16,
  153. >>> # "seed": 5489,
  154. >>> # "monitorSamplingInterval": 30
  155. >>> # }
  156. """
  157. self.config.load(file)
  158. config = ConfigurationManager()