You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config_manager.h 4.7 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONFIG_MANAGER_H_
  17. #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONFIG_MANAGER_H_
  18. #include <string>
  19. #include <memory>
  20. #include <vector>
  21. #include <map>
  22. #include <utility>
  23. #include <sstream>
  24. #include "utils/overload.h"
  25. #include "include/common/visible.h"
  26. namespace mindspore {
  27. enum ParallelStrategy {
  28. ONE_DEVICE = 0,
  29. DISTRIBUTION,
  30. };
  31. enum DatasetMode { DS_NORMAL_MODE = 0, DS_SINK_MODE };
  32. class DatasetGraphParam {
  33. public:
  34. DatasetGraphParam(const std::string &name, int64_t size, int64_t batch_size, const std::vector<int64_t> &ge_types,
  35. const std::vector<std::vector<int64_t>> &shapes, const std::vector<int64_t> &input_indexes)
  36. : queue_name_(name),
  37. loop_size_(size),
  38. batch_size_(batch_size),
  39. ge_types_(ge_types),
  40. shapes_(shapes),
  41. input_indexes_(input_indexes) {}
  42. ~DatasetGraphParam() = default;
  43. std::string ToString() const {
  44. std::ostringstream buffer;
  45. buffer << "DatasetGraphParam: queue_name=" << queue_name_ << " size=" << loop_size_ << " batch_size=" << batch_size_
  46. << " ge_types=" << ge_types_ << " shapes=" << shapes_ << " input_indexes=" << input_indexes_;
  47. return buffer.str();
  48. }
  49. std::string queue_name() const { return queue_name_; }
  50. int64_t loop_size() const { return loop_size_; }
  51. int64_t batch_size() const { return batch_size_; }
  52. std::vector<int64_t> ge_types() const { return ge_types_; }
  53. std::vector<std::vector<int64_t>> shapes() const { return shapes_; }
  54. std::vector<int64_t> input_indexes() const { return input_indexes_; }
  55. private:
  56. std::string queue_name_;
  57. int64_t loop_size_;
  58. int64_t batch_size_;
  59. std::vector<int64_t> ge_types_;
  60. std::vector<std::vector<int64_t>> shapes_;
  61. std::vector<int64_t> input_indexes_;
  62. };
  63. class COMMON_EXPORT ConfigManager {
  64. public:
  65. ConfigManager(const ConfigManager &) = delete;
  66. ConfigManager &operator=(const ConfigManager &) = delete;
  67. static ConfigManager &GetInstance() noexcept;
  68. ParallelStrategy parallel_strategy() const { return parallel_strategy_; }
  69. void set_parallel_strategy(ParallelStrategy strategy) { parallel_strategy_ = strategy; }
  70. const std::map<std::string, std::string> &ge_initialize_options() const { return ge_initialize_options_; }
  71. void set_ge_initialize_options(const std::map<std::string, std::string> &options) {
  72. ge_initialize_options_ = options;
  73. }
  74. DatasetMode dataset_mode() const { return dataset_mode_; }
  75. void set_dataset_mode(DatasetMode mode) { dataset_mode_ = mode; }
  76. int64_t iter_num() const {
  77. if (dataset_mode_ == DS_NORMAL_MODE) return 1;
  78. return iter_num_;
  79. }
  80. void set_iter_num(const std::string &queue_name, const int64_t num) {
  81. queue_name_ = queue_name;
  82. iter_num_ = num;
  83. queue_info_map[queue_name_] = static_cast<int16_t>(num);
  84. }
  85. std::string dataset_phase() const { return dataset_phase_; }
  86. void set_dataset_phase(const std::string &phase) { dataset_phase_ = phase; }
  87. DatasetGraphParam dataset_param() const { return dataset_param_; }
  88. void set_dataset_param(const DatasetGraphParam &param) { dataset_param_ = param; }
  89. static void SetDatasetModeConfig(const std::string &mode);
  90. void ResetConfig() noexcept;
  91. void ResetIterNum() noexcept;
  92. void ResetQueue(const std::string &queue_name) noexcept;
  93. std::map<std::string, std::string> ge_initialize_options_;
  94. int64_t gpu_loopsink_size() const { return gpu_loopsink_size_; }
  95. void set_gpu_loopsink_size(const int64_t size) { gpu_loopsink_size_ = size; }
  96. private:
  97. ConfigManager() = default;
  98. ~ConfigManager() = default;
  99. ParallelStrategy parallel_strategy_{ONE_DEVICE};
  100. DatasetMode dataset_mode_{DS_NORMAL_MODE};
  101. DatasetGraphParam dataset_param_{"", 0, 0, {}, {}, {}};
  102. int64_t iter_num_{1};
  103. std::string queue_name_{""};
  104. // now only save iter_num_ in the map
  105. std::map<std::string, int16_t> queue_info_map;
  106. std::string dataset_phase_{""};
  107. int64_t gpu_loopsink_size_{1};
  108. };
  109. } // namespace mindspore
  110. #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONFIG_MANAGER_H_