You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

configuration.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. The configuration manager.
  17. """
  18. import mindspore._c_dataengine as cde
  19. INT32_MAX = 2147483647
  20. UINT32_MAX = 4294967295
  21. class ConfigurationManager:
  22. """The configuration manager"""
  23. def __init__(self):
  24. self.config = cde.GlobalContext.config_manager()
  25. def set_seed(self, seed):
  26. """
  27. Set the seed to be used in any random generator. This is used to produce deterministic results.
  28. Args:
  29. seed(int): seed to be set
  30. Raises:
  31. ValueError: If seed is invalid (< 0 or > MAX_UINT_32).
  32. Examples:
  33. >>> import mindspore.dataset as ds
  34. >>> con = ds.engine.ConfigurationManager()
  35. >>> # sets the new seed value, now operators with a random seed will use new seed value.
  36. >>> con.set_seed(1000)
  37. """
  38. if seed < 0 or seed > UINT32_MAX:
  39. raise ValueError("Seed given is not within the required range")
  40. self.config.set_seed(seed)
  41. def get_seed(self):
  42. """
  43. Get the seed
  44. Returns:
  45. Int, seed.
  46. """
  47. return self.config.get_seed()
  48. def set_prefetch_size(self, size):
  49. """
  50. Set the number of rows to be prefetched.
  51. Args:
  52. size: total number of rows to be prefetched.
  53. Raises:
  54. ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32).
  55. Examples:
  56. >>> import mindspore.dataset as ds
  57. >>> con = ds.engine.ConfigurationManager()
  58. >>> # sets the new prefetch value.
  59. >>> con.set_prefetch_size(1000)
  60. """
  61. if size <= 0 or size > INT32_MAX:
  62. raise ValueError("Prefetch size given is not within the required range")
  63. self.config.set_op_connector_size(size)
  64. def get_prefetch_size(self):
  65. """
  66. Get the prefetch size in number of rows.
  67. Returns:
  68. Size, total number of rows to be prefetched.
  69. """
  70. return self.config.get_op_connector_size()
  71. def set_num_parallel_workers(self, num):
  72. """
  73. Set the default number of parallel workers
  74. Args:
  75. num: number of parallel workers to be used as a default for each operation
  76. Raises:
  77. ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32).
  78. Examples:
  79. >>> import mindspore.dataset as ds
  80. >>> con = ds.engine.ConfigurationManager()
  81. >>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers.
  82. >>> con.set_num_parallel_workers(8)
  83. """
  84. if num <= 0 or num > INT32_MAX:
  85. raise ValueError("Num workers given is not within the required range")
  86. self.config.set_num_parallel_workers(num)
  87. def get_num_parallel_workers(self):
  88. """
  89. Get the default number of parallel workers.
  90. Returns:
  91. Int, number of parallel workers to be used as a default for each operation
  92. """
  93. return self.config.get_num_parallel_workers()
  94. def __str__(self):
  95. """
  96. String representation of the configurations.
  97. Returns:
  98. Str, configurations.
  99. """
  100. return str(self.config)
  101. def load(self, file):
  102. """
  103. Load configuration from a file.
  104. Args:
  105. file: path the config file to be loaded
  106. Raises:
  107. RuntimeError: If file is invalid and parsing fails.
  108. Examples:
  109. >>> import mindspore.dataset as ds
  110. >>> con = ds.engine.ConfigurationManager()
  111. >>> # sets the default value according to values in configuration file.
  112. >>> con.load("path/to/config/file")
  113. >>> # example config file:
  114. >>> # {
  115. >>> # "logFilePath": "/tmp",
  116. >>> # "rowsPerBuffer": 32,
  117. >>> # "numParallelWorkers": 4,
  118. >>> # "workerConnectorSize": 16,
  119. >>> # "opConnectorSize": 16,
  120. >>> # "seed": 5489
  121. >>> # }
  122. """
  123. self.config.load(file)
  124. config = ConfigurationManager()