You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

costmodel_context.h 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_COSTMODEL_CONTEXT_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_COSTMODEL_CONTEXT_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "utils/log_adapter.h"
  22. #include "utils/ms_context.h"
  23. namespace mindspore {
  24. namespace parallel {
  25. #define OPERATOR_TO_OPERATOR_CONNECTOR "-"
  26. #define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0)
  27. #define DEFAULT_COST_MODEL_ALPHA 1.0
  28. #define DEFAULT_COST_MODEL_BETA_ASCEND 400.0 // for 'device_target = Ascend'
  29. #define DEFAULT_COST_MODEL_BETA_GPU 50.0 // for 'device_target = GPU'
  30. #define DEFAULT_COST_MODEL_GAMMA 0.001
  31. #define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true
  32. #define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0
  33. #define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0
  34. #define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0
  35. #define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false
  36. #define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16
  37. #define DEFAULT_FULLY_USE_DEVICES true
  38. #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
  39. #define DEFAULT_IS_MULTI_SUBGRAPHS false
  40. #define DEFAULT_RUN_PHASE 0
  41. #define TRAINING_PHASE 0
  42. #define INFERENCE_PHASE 1
  43. #define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true;
  44. class CostModelContext {
  45. public:
  46. ~CostModelContext() = default;
  47. CostModelContext(const CostModelContext &) = delete;
  48. CostModelContext &operator=(const CostModelContext &) = delete;
  49. void ResetCostModel();
  50. void ResetAlgoParameters();
  51. static std::shared_ptr<CostModelContext> GetInstance();
  52. void set_costmodel_context_for_device(const std::string &);
  53. // DEVICE_MEMORY_CAPACITY
  54. void set_device_memory_capacity(double);
  55. double device_memory_capacity() const { return device_memory_capacity_; }
  56. // COST_MODEL_ALPHA
  57. void set_costmodel_alpha(double);
  58. double costmodel_alpha() const { return costmodel_alpha_; }
  59. // COST_MODEL_BETA
  60. void set_costmodel_beta(double);
  61. double costmodel_beta() const { return costmodel_beta_; }
  62. // COST_MODEL_GAMMA
  63. void set_costmodel_gamma(double);
  64. double costmodel_gamma() const { return costmodel_gamma_; }
  65. // COST_MODEL_SIMPLIFY_CALCULATION
  66. void set_costmodel_simplify_cal(bool);
  67. bool costmodel_simplify_cal() const { return costmodel_simplify_cal_; }
  68. // COST_MODEL_COMMUNI_THRESHOLD
  69. void set_costmodel_communi_threshold(double);
  70. double costmodel_communi_threshold() const { return costmodel_communi_threshold_; }
  71. // COST_MODEL_COMMUNI_CONST
  72. void set_costmodel_communi_const(double);
  73. double costmodel_communi_const() const { return costmodel_communi_const_; }
  74. // COST_MODEL_COMMUNI_BIAS
  75. void set_costmodel_communi_bias(double);
  76. double costmodel_communi_bias() const { return costmodel_communi_bias_; }
  77. void set_multi_subgraphs(bool);
  78. bool is_multi_subgraphs() const { return is_multi_subgraphs_; }
  79. void set_costmodel_allreduce_fusion_algorithm(int32_t);
  80. int32_t costmodel_allreduce_fusion_algorithm() const { return costmodel_allreduce_fusion_algorithm_; }
  81. void set_costmodel_allreduce_fusion_times(int32_t);
  82. int32_t costmodel_allreduce_fusion_times() const { return costmodel_allreduce_fusion_times_; }
  83. void set_costmodel_allreduce_fusion_tail_percent(double);
  84. double costmodel_allreduce_fusion_tail_percent() const { return costmodel_allreduce_fusion_tail_percent_; }
  85. void set_costmodel_allreduce_fusion_tail_time(double);
  86. double costmodel_allreduce_fusion_tail_time() const { return costmodel_allreduce_fusion_tail_time_; }
  87. void set_costmodel_allreduce_fusion_allreduce_inherent_time(double);
  88. double costmodel_allreduce_fusion_allreduce_inherent_time() const {
  89. return costmodel_allreduce_fusion_allreduce_inherent_time_;
  90. }
  91. void set_costmodel_allreduce_fusion_allreduce_bandwidth(double);
  92. double costmodel_allreduce_fusion_allreduce_bandwidth() const {
  93. return costmodel_allreduce_fusion_allreduce_bandwidth_;
  94. }
  95. void set_costmodel_allreduce_fusion_computation_time_parameter(double);
  96. double costmodel_allreduce_fusion_computation_time_parameter() const {
  97. return costmodel_allreduce_fusion_computation_time_parameter_;
  98. }
  99. // TENSOR_SLICE_ALIGNMENT_ENABLE
  100. void set_tensor_slice_alignment_enable(bool);
  101. bool tensor_slice_alignment_enable() const { return tensor_slice_alignment_enable_; }
  102. // TENSOR_SLICE_ALIGNMENT_SIZE
  103. void set_tensor_slice_alignment_size(size_t);
  104. size_t tensor_slice_alignment_size() const { return tensor_slice_alignment_size_; }
  105. // FULLY_USE_DEVICES
  106. void set_fully_use_device(bool);
  107. bool fully_use_device() const { return fully_use_device_; }
  108. // ELEMENTWISE_OP_STRA_FOLLOW
  109. void set_elementwise_stra_follow(bool);
  110. bool elementwise_stra_follow() const { return elementwise_stra_follow_; }
  111. void set_triangle_star_strategy_overwrite(bool);
  112. bool triangle_star_strategy_overwrite() const { return triangle_star_strategy_overwrite_; }
  113. void set_run_phase(int32_t);
  114. int32_t run_phase() const { return run_phase_; }
  115. private:
  116. CostModelContext();
  117. static std::shared_ptr<CostModelContext> cm_context_inst_;
  118. // DEVICE_MEMORY_CAPACITY
  119. double device_memory_capacity_;
  120. // COST_MODEL_ALPHA
  121. double costmodel_alpha_;
  122. // COST_MODEL_BETA
  123. double costmodel_beta_;
  124. // COST_MODEL_GAMMA
  125. double costmodel_gamma_;
  126. // COST_MODEL_SIMPLIFY_CALCULATION
  127. bool costmodel_simplify_cal_;
  128. // COST_MODEL_COMMUNI_THRESHOLD
  129. double costmodel_communi_threshold_;
  130. // COST_MODEL_COMMUNI_CONST
  131. double costmodel_communi_const_;
  132. // COST_MODEL_COMMUNI_BIAS
  133. double costmodel_communi_bias_;
  134. // MULTI_SUBGRAPHS
  135. bool is_multi_subgraphs_;
  136. // In the recovery phase of DP algorithm, when encountering triangle structure and star structure,
  137. // whether overwrite the right-node strategy
  138. bool triangle_star_strategy_overwrite_;
  139. int32_t run_phase_; // 0: 'training', 1: 'inference'
  140. int32_t costmodel_allreduce_fusion_algorithm_;
  141. int32_t costmodel_allreduce_fusion_times_;
  142. double costmodel_allreduce_fusion_tail_percent_;
  143. double costmodel_allreduce_fusion_tail_time_;
  144. double costmodel_allreduce_fusion_allreduce_inherent_time_;
  145. double costmodel_allreduce_fusion_allreduce_bandwidth_;
  146. double costmodel_allreduce_fusion_computation_time_parameter_;
  147. // TENSOR_SLICE_ALIGNMENT_ENABLE
  148. bool tensor_slice_alignment_enable_;
  149. // TENSOR_SLICE_ALIGNMENT_SIZE
  150. size_t tensor_slice_alignment_size_;
  151. // FULLY_USE_DEVICES
  152. bool fully_use_device_;
  153. // ELEMENTWISE_OP_STRA_FOLLOW
  154. bool elementwise_stra_follow_;
  155. };
  156. } // namespace parallel
  157. } // namespace mindspore
  158. #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_COSTMODEL_CONTEXT_H_