You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_context.h 10 kB

4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. /**
  2. * Copyright 2019-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PARALLEL_CONTEXT_H_
  17. #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PARALLEL_CONTEXT_H_
  18. #include <cstdint>
  19. #include <map>
  20. #include <memory>
  21. #include <string>
  22. #include <vector>
  23. #include "abstract/abstract_value.h"
  24. #include "ir/anf.h"
  25. #include "ir/func_graph.h"
  26. #include "include/common/utils/convert_utils.h"
  27. #include "utils/info.h"
  28. #include "include/common/visible.h"
  29. namespace mindspore::parallel {
  30. constexpr char kStandalone[] = "stand_alone";
  31. constexpr char kDataParallel[] = "data_parallel";
  32. constexpr char kHybridParallel[] = "hybrid_parallel";
  33. constexpr char kAutoParallel[] = "auto_parallel";
  34. constexpr char kSemiAutoParallel[] = "semi_auto_parallel";
  35. constexpr char kDynamicProgramming[] = "dynamic_programming";
  36. constexpr char kRecursiveProgramming[] = "recursive_programming";
  37. constexpr char kShardingPropagation[] = "sharding_propagation";
  38. constexpr char kTraining[] = "training";
  39. constexpr char kAccumulation[] = "accumulation";
  40. constexpr char kAllGroupParallel[] = "all_group_parallel";
  41. constexpr char kSameServerGroupParallel[] = "same_server_group_parallel";
  42. constexpr char kNoGroupParallel[] = "no_group_parallel";
  43. constexpr char kIsFirstIteration[] = "is_first_iteration";
  44. constexpr char kFusionAuto[] = "auto";
  45. constexpr char kFusionSize[] = "size";
  46. constexpr char kFusionIndex[] = "index";
  47. constexpr int64_t kFusionThreshold = 64;
  48. class COMMON_EXPORT ParallelContext {
  49. public:
  50. static std::shared_ptr<ParallelContext> GetInstance();
  51. ~ParallelContext() = default;
  52. ParallelContext(const ParallelContext &) = delete;
  53. ParallelContext &operator=(const ParallelContext &) = delete;
  54. void set_gradients_mean(bool gradients_mean);
  55. bool gradients_mean() const { return gradients_mean_; }
  56. void set_full_batch(bool full_batch);
  57. bool full_batch() const { return full_batch_; }
  58. void set_dataset_strategy(const std::vector<std::vector<int64_t>> &dataset_strategy);
  59. std::vector<std::vector<int64_t>> dataset_strategy() const { return dataset_strategy_; }
  60. void set_gradient_fp32_sync(bool gradient_fp32_sync);
  61. bool gradient_fp32_sync() const { return gradient_fp32_sync_; }
  62. void set_loss_repeated_mean(bool loss_repeated_mean);
  63. bool loss_repeated_mean() const { return loss_repeated_mean_; }
  64. void set_device_num(int64_t device_num);
  65. int64_t device_num() const { return device_num_; }
  66. void set_fusion_threshold_mb(int64_t fusion_threshold);
  67. int64_t fusion_threshold_mb() const { return fusion_threshold_mb_; }
  68. void set_allgather_fusion_threshold_mb(int64_t allgather_fusion_threshold);
  69. int64_t allgather_fusion_threshold_mb() const { return allgather_fusion_threshold_mb_; }
  70. void set_reducescatter_fusion_threshold_mb(int64_t rs_fusion_threshold);
  71. int64_t reducescatter_fusion_threshold_mb() const { return reducescatter_fusion_threshold_mb_; }
  72. bool set_fusion_mode(const std::string &fusion_mode);
  73. std::string get_fusion_mode() const { return fusion_mode_; }
  74. void set_pipeline_stage_split_num(const int64_t stages);
  75. int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; }
  76. void set_global_rank(int64_t global_rank);
  77. int64_t global_rank() const { return global_rank_; }
  78. void set_grad_accumulation_step(int64_t grad_accumulation_step);
  79. int64_t grad_accumulation_step() const { return grad_accumulation_step_; }
  80. bool set_parallel_mode(const std::string &parallel_mode);
  81. std::string parallel_mode() const { return parallel_mode_; }
  82. bool set_strategy_search_mode(const std::string &strategy_search_mode);
  83. std::string strategy_search_mode() const { return strategy_search_mode_; }
  84. void set_parameter_broadcast(bool parameter_broadcast);
  85. bool parameter_broadcast() const { return parameter_broadcast_; }
  86. bool device_num_is_set() const { return device_num_is_set_; }
  87. bool global_rank_is_set() const { return global_rank_is_set_; }
  88. bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; }
  89. void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size);
  90. int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; }
  91. void set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save);
  92. bool optimizer_weight_shard_aggregated_save() const { return optimizer_weight_shard_aggregated_save_; }
  93. void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group);
  94. std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
  95. void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group);
  96. std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
  97. void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
  98. enable_all_reduce_fusion_ = enable_all_reduce_fusion;
  99. }
  100. bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
  101. void set_enable_all_gather_fusion(bool enable_all_gather_fusion) {
  102. enable_all_gather_fusion_ = enable_all_gather_fusion;
  103. }
  104. bool enable_all_gather_fusion() const { return enable_all_gather_fusion_; }
  105. void set_enable_reduce_scatter_fusion(bool enable_reduce_scatter_fusion) {
  106. enable_reduce_scatter_fusion_ = enable_reduce_scatter_fusion;
  107. }
  108. bool enable_reduce_scatter_fusion() const { return enable_reduce_scatter_fusion_; }
  109. void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file);
  110. std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
  111. void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
  112. std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
  113. void set_group_ckpt_save_file(const std::string &group_ckpt_save_file);
  114. std::string group_ckpt_save_file() const { return group_ckpt_save_file_; }
  115. void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
  116. enable_parallel_optimizer_ = enable_parallel_optimizer;
  117. }
  118. bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
  119. void set_hccl_test_available(bool hccl_test_available) { hccl_test_available_ = hccl_test_available; }
  120. bool hccl_test_available() const { return hccl_test_available_; }
  121. void set_grad_accumulation_shard(const bool grad_accumulation_shard) {
  122. grad_accumulation_shard_ = grad_accumulation_shard;
  123. }
  124. bool grad_accumulation_shard() const { return grad_accumulation_shard_; }
  125. void set_parallel_optimizer_threshold(const int64_t parallel_optimizer_threshold) {
  126. parallel_optimizer_threshold_ = parallel_optimizer_threshold;
  127. }
  128. int64_t get_parallel_optimizer_threshold() const { return parallel_optimizer_threshold_; }
  129. bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
  130. std::string communi_parallel_mode() const { return communi_parallel_mode_; }
  131. void set_enable_all2all(const bool);
  132. bool enable_all2all() const { return enable_all2all_; }
  133. void set_dataset_repeat_dim_right(const bool dataset_repeat_dim_right) {
  134. dataset_repeat_dim_right_ = dataset_repeat_dim_right;
  135. }
  136. bool dataset_repeat_dim_right() const { return dataset_repeat_dim_right_; }
  137. void Reset();
  138. void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
  139. void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
  140. const AbstractBasePtr &ptr);
  141. void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
  142. const AbstractBasePtr &ptr);
  143. void set_sharding_propagation(const bool);
  144. bool sharding_propagation() const { return sharding_propagation_; }
  145. private:
  146. ParallelContext();
  147. bool gradients_mean_;
  148. bool full_batch_;
  149. bool gradient_fp32_sync_;
  150. bool loss_repeated_mean_;
  151. int64_t device_num_;
  152. int64_t fusion_threshold_mb_;
  153. int64_t allgather_fusion_threshold_mb_;
  154. int64_t reducescatter_fusion_threshold_mb_; // reducescatter
  155. int64_t global_rank_;
  156. int64_t grad_accumulation_step_;
  157. std::string parallel_mode_;
  158. std::string strategy_search_mode_;
  159. int64_t pipeline_stage_split_num_;
  160. bool parameter_broadcast_;
  161. bool device_num_is_set_;
  162. bool fusion_threshold_is_set_;
  163. bool global_rank_is_set_;
  164. bool parameter_broadcast_is_set_;
  165. bool enable_all_reduce_fusion_;
  166. bool enable_all_gather_fusion_;
  167. bool enable_reduce_scatter_fusion_;
  168. std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_;
  169. std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
  170. std::string strategy_ckpt_load_file_;
  171. std::string strategy_ckpt_save_file_;
  172. std::string group_ckpt_save_file_;
  173. bool enable_parallel_optimizer_;
  174. bool init_param_shape_;
  175. std::string communi_parallel_mode_;
  176. int64_t optimizer_weight_shard_size_;
  177. bool optimizer_weight_shard_aggregated_save_;
  178. bool grad_accumulation_shard_;
  179. int64_t parallel_optimizer_threshold_;
  180. // Enable AllToAll or not. If false, use AllGather and Split.
  181. bool enable_all2all_;
  182. std::vector<std::vector<int64_t>> dataset_strategy_;
  183. bool dataset_repeat_dim_right_ = false;
  184. bool hccl_test_available_ = false;
  185. bool sharding_propagation_;
  186. std::string fusion_mode_;
  187. };
  188. } // namespace mindspore::parallel
  189. #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PARALLEL_CONTEXT_H_