You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.cc 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/core/config_manager.h"
  17. #include "minddata/dataset/core/global_context.h"
  18. #include "minddata/dataset/include/config.h"
  19. #include "minddata/dataset/util/status.h"
  20. namespace mindspore {
  21. namespace dataset {
  22. // Config operations for setting and getting the configuration.
  23. namespace config {
  24. std::shared_ptr<ConfigManager> _config = GlobalContext::config_manager();
  25. // Function to set the seed to be used in any random generator
  26. bool set_seed(int32_t seed) {
  27. if (seed < 0 || seed > INT32_MAX) {
  28. MS_LOG(ERROR) << "Seed given is not within the required range: " << seed;
  29. return false;
  30. }
  31. _config->set_seed((uint32_t)seed);
  32. return true;
  33. }
  34. // Function to get the seed
  35. uint32_t get_seed() { return _config->seed(); }
  36. // Function to set the number of rows to be prefetched
  37. bool set_prefetch_size(int32_t prefetch_size) {
  38. if (prefetch_size <= 0 || prefetch_size > INT32_MAX) {
  39. MS_LOG(ERROR) << "Prefetch size given is not within the required range: " << prefetch_size;
  40. return false;
  41. }
  42. _config->set_op_connector_size(prefetch_size);
  43. return true;
  44. }
  45. // Function to get prefetch size in number of rows
  46. int32_t get_prefetch_size() { return _config->op_connector_size(); }
  47. // Function to set the default number of parallel workers
  48. bool set_num_parallel_workers(int32_t num_parallel_workers) {
  49. if (num_parallel_workers <= 0 || num_parallel_workers > INT32_MAX) {
  50. MS_LOG(ERROR) << "Number of parallel workers given is not within the required range: " << num_parallel_workers;
  51. return false;
  52. }
  53. _config->set_num_parallel_workers(num_parallel_workers);
  54. return true;
  55. }
  56. // Function to get the default number of parallel workers
  57. int32_t get_num_parallel_workers() { return _config->num_parallel_workers(); }
  58. // Function to set the default interval (in milliseconds) for monitor sampling
  59. bool set_monitor_sampling_interval(int32_t interval) {
  60. if (interval <= 0 || interval > INT32_MAX) {
  61. MS_LOG(ERROR) << "Interval given is not within the required range: " << interval;
  62. return false;
  63. }
  64. _config->set_monitor_sampling_interval((uint32_t)interval);
  65. return true;
  66. }
  67. // Function to get the default interval of performance monitor sampling
  68. int32_t get_monitor_sampling_interval() { return _config->monitor_sampling_interval(); }
  69. // Function to set the default timeout (in seconds) for DSWaitedCallback
  70. bool set_callback_timeback(int32_t timeout) {
  71. if (timeout <= 0 || timeout > INT32_MAX) {
  72. MS_LOG(ERROR) << "Timeout given is not within the required range: " << timeout;
  73. return false;
  74. }
  75. _config->set_callback_timeout((uint32_t)timeout);
  76. return true;
  77. }
  78. // Function to get the default timeout for DSWaitedCallback
  79. int32_t get_callback_timeout() { return _config->callback_timeout(); }
  80. // Function to load configurations from a file
  81. bool load(std::string file) {
  82. Status rc = _config->LoadFile(file);
  83. if (rc.IsError()) {
  84. MS_LOG(ERROR) << rc << file;
  85. return false;
  86. }
  87. return true;
  88. }
  89. } // namespace config
  90. } // namespace dataset
  91. } // namespace mindspore