You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config_manager.h 4.1 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_UTILS_CONFIG_MANAGER_H_
  17. #define MINDSPORE_CCSRC_UTILS_CONFIG_MANAGER_H_
  18. #include <string>
  19. #include <memory>
  20. #include <vector>
  21. #include <map>
  22. #include <utility>
  23. #include <sstream>
  24. #include "utils/overload.h"
  25. namespace mindspore {
  26. enum ParallelStrategy {
  27. ONE_DEVICE = 0,
  28. DISTRIBUTION,
  29. };
  30. enum DatasetMode { DS_NORMAL_MODE = 0, DS_SINK_MODE };
  31. class DatasetGraphParam {
  32. public:
  33. DatasetGraphParam(const std::string &name, int64_t size, int64_t batch_size, const std::vector<int64_t> &ge_types,
  34. const std::vector<std::vector<int64_t>> &shapes, const std::vector<int64_t> &input_indexes)
  35. : queue_name_(name),
  36. loop_size_(size),
  37. batch_size_(batch_size),
  38. ge_types_(ge_types),
  39. shapes_(shapes),
  40. input_indexes_(input_indexes) {}
  41. ~DatasetGraphParam() = default;
  42. std::string ToString() const {
  43. std::ostringstream buffer;
  44. buffer << "DatasetGraphParam: queue_name=" << queue_name_ << " size=" << loop_size_ << " batch_size=" << batch_size_
  45. << " ge_types=" << ge_types_ << " shapes=" << shapes_ << " input_indexes=" << input_indexes_;
  46. return buffer.str();
  47. }
  48. std::string queue_name() const { return queue_name_; }
  49. int64_t loop_size() const { return loop_size_; }
  50. int64_t batch_size() const { return batch_size_; }
  51. std::vector<int64_t> ge_types() const { return ge_types_; }
  52. std::vector<std::vector<int64_t>> shapes() const { return shapes_; }
  53. std::vector<int64_t> input_indexes() const { return input_indexes_; }
  54. private:
  55. std::string queue_name_;
  56. int64_t loop_size_;
  57. int64_t batch_size_;
  58. std::vector<int64_t> ge_types_;
  59. std::vector<std::vector<int64_t>> shapes_;
  60. std::vector<int64_t> input_indexes_;
  61. };
  62. class ConfigManager {
  63. public:
  64. ConfigManager(const ConfigManager &) = delete;
  65. ConfigManager &operator=(const ConfigManager &) = delete;
  66. static ConfigManager &GetInstance() noexcept;
  67. ParallelStrategy parallel_strategy() const { return parallel_strategy_; }
  68. void set_parallel_strategy(ParallelStrategy strategy) { parallel_strategy_ = strategy; }
  69. const std::map<std::string, std::string> &ge_initialize_options() const { return ge_initialize_options_; }
  70. void set_ge_initialize_options(const std::map<std::string, std::string> &options) {
  71. ge_initialize_options_ = options;
  72. }
  73. DatasetMode dataset_mode() const { return dataset_mode_; }
  74. void set_dataset_mode(DatasetMode mode) { dataset_mode_ = mode; }
  75. int64_t iter_num() const {
  76. if (dataset_mode_ == DS_NORMAL_MODE) return 1;
  77. return iter_num_;
  78. }
  79. void set_iter_num(const int64_t num) { iter_num_ = num; }
  80. std::string dataset_phase() const { return dataset_phase_; }
  81. void set_dataset_phase(const std::string &phase) { dataset_phase_ = phase; }
  82. DatasetGraphParam dataset_param() const { return dataset_param_; }
  83. void set_dataset_param(const DatasetGraphParam &param) { dataset_param_ = param; }
  84. static void SetDatasetModeConfig(const std::string &mode);
  85. void ResetConfig() noexcept;
  86. std::map<std::string, std::string> ge_initialize_options_;
  87. private:
  88. ConfigManager() = default;
  89. ~ConfigManager() = default;
  90. ParallelStrategy parallel_strategy_{ONE_DEVICE};
  91. DatasetMode dataset_mode_{DS_NORMAL_MODE};
  92. DatasetGraphParam dataset_param_{"", 0, 0, {}, {}, {}};
  93. int64_t iter_num_{1};
  94. std::string dataset_phase_{""};
  95. };
  96. } // namespace mindspore
  97. #endif // MINDSPORE_CCSRC_UTILS_CONFIG_MANAGER_H_