You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.h 6.9 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_
  18. #include <cstdint>
  19. #include <map>
  20. #include <memory>
  21. #include <string>
  22. #include <vector>
  23. #include "abstract/abstract_value.h"
  24. #include "frontend/parallel/ops_info/ops_utils.h"
  25. #include "frontend/parallel/status.h"
  26. #include "ir/anf.h"
  27. #include "ir/func_graph.h"
  28. #include "utils/convert_utils.h"
  29. #include "utils/info.h"
  30. #include "pipeline/jit/pipeline.h"
  31. namespace mindspore {
  32. namespace parallel {
  33. constexpr char STAND_ALONE[] = "stand_alone";
  34. constexpr char DATA_PARALLEL[] = "data_parallel";
  35. constexpr char HYBRID_PARALLEL[] = "hybrid_parallel";
  36. constexpr char AUTO_PARALLEL[] = "auto_parallel";
  37. constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel";
  38. constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming";
  39. constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
  40. constexpr char TRAINING[] = "training";
  41. constexpr char ACCUMULATION[] = "accumulation";
  42. constexpr char ALL_GROUP_PARALLEL[] = "all_group_parallel";
  43. constexpr char SAME_SERVER_GROUP_PARALLEL[] = "same_server_group_parallel";
  44. constexpr char NO_GROUP_PARALLEL[] = "no_group_parallel";
  45. class ParallelContext {
  46. public:
  47. ~ParallelContext() = default;
  48. ParallelContext(const ParallelContext &) = delete;
  49. ParallelContext &operator=(const ParallelContext &) = delete;
  50. static std::shared_ptr<ParallelContext> GetInstance();
  51. void set_gradients_mean(bool gradients_mean);
  52. bool gradients_mean() const { return gradients_mean_; }
  53. void set_full_batch(bool full_batch);
  54. bool full_batch() const { return full_batch_; }
  55. void set_gradient_fp32_sync(bool gradient_fp32_sync);
  56. bool gradient_fp32_sync() const { return gradient_fp32_sync_; }
  57. void set_loss_repeated_mean(bool loss_repeated_mean);
  58. bool loss_repeated_mean() const { return loss_repeated_mean_; }
  59. void set_device_num(int64_t device_num);
  60. int64_t device_num() const { return device_num_; }
  61. void set_pipeline_stage_split_num(const int64_t stages);
  62. int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; }
  63. void set_global_rank(int64_t global_rank);
  64. int64_t global_rank() const { return global_rank_; }
  65. void set_grad_accumulation_step(int64_t grad_accumulation_step);
  66. int64_t grad_accumulation_step() const { return grad_accumulation_step_; }
  67. bool set_parallel_mode(const std::string &parallel_mode);
  68. std::string parallel_mode() const { return parallel_mode_; }
  69. bool set_strategy_search_mode(const std::string &strategy_search_mode);
  70. std::string strategy_search_mode() const { return strategy_search_mode_; }
  71. void set_parameter_broadcast(bool parameter_broadcast);
  72. bool parameter_broadcast() const { return parameter_broadcast_; }
  73. bool device_num_is_set() const { return device_num_is_set_; }
  74. bool global_rank_is_set() const { return global_rank_is_set_; }
  75. bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; }
  76. void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size);
  77. int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; }
  78. void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save);
  79. bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; }
  80. void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group);
  81. const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
  82. void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group);
  83. const std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
  84. void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
  85. enable_all_reduce_fusion_ = enable_all_reduce_fusion;
  86. }
  87. bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
  88. void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file);
  89. std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
  90. void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
  91. std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
  92. void set_group_ckpt_save_file(const std::string &group_ckpt_save_file);
  93. std::string group_ckpt_save_file() const { return group_ckpt_save_file_; }
  94. void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
  95. enable_parallel_optimizer_ = enable_parallel_optimizer;
  96. }
  97. bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
  98. bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
  99. std::string communi_parallel_mode() const { return communi_parallel_mode_; }
  100. void Reset();
  101. void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
  102. void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
  103. AbstractBasePtr ptr);
  104. void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
  105. const AbstractBasePtr &ptr);
  106. private:
  107. ParallelContext();
  108. static std::shared_ptr<ParallelContext> inst_context_;
  109. bool gradients_mean_;
  110. bool full_batch_;
  111. bool gradient_fp32_sync_;
  112. bool loss_repeated_mean_;
  113. int64_t device_num_;
  114. int64_t global_rank_;
  115. int64_t grad_accumulation_step_;
  116. std::string parallel_mode_;
  117. std::string strategy_search_mode_;
  118. int64_t pipeline_stage_split_num_;
  119. bool parameter_broadcast_;
  120. bool device_num_is_set_;
  121. bool global_rank_is_set_;
  122. bool parameter_broadcast_is_set_;
  123. bool enable_all_reduce_fusion_;
  124. std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_;
  125. std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
  126. std::string strategy_ckpt_load_file_;
  127. std::string strategy_ckpt_save_file_;
  128. std::string group_ckpt_save_file_;
  129. bool enable_parallel_optimizer_;
  130. bool init_param_shape_;
  131. std::string communi_parallel_mode_;
  132. int64_t optimizer_weight_shard_size_;
  133. bool optimizer_weight_shard_integrated_save_;
  134. };
  135. } // namespace parallel
  136. } // namespace mindspore
  137. #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_