You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_context.cc 12 kB

4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. /**
  2. * Copyright 2019-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "include/common/utils/parallel_context.h"
  17. #include <algorithm>
  18. #include <cstdint>
  19. #include <functional>
  20. #include <map>
  21. #include <memory>
  22. namespace mindspore::parallel {
  23. namespace {
  24. std::map<std::string, std::vector<int64_t>> param_shapes;
  25. std::vector<std::string> kParallelModeList = {kStandalone, kDataParallel, kHybridParallel, kSemiAutoParallel,
  26. kAutoParallel};
  27. std::vector<std::string> kStrategySearchModeList = {kDynamicProgramming, kRecursiveProgramming, kShardingPropagation};
  28. std::vector<std::string> kCommuniParallelModeList = {kAllGroupParallel, kSameServerGroupParallel, kNoGroupParallel};
  29. std::vector<std::string> kFusionModeList = {kFusionAuto, kFusionSize, kFusionIndex};
  30. } // namespace
  31. std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
  32. static std::shared_ptr<ParallelContext> inst_context_ = std::shared_ptr<ParallelContext>(new ParallelContext());
  33. return inst_context_;
  34. }
  35. ParallelContext::ParallelContext() { Reset(); }
  36. void ParallelContext::Reset() {
  37. init_param_shape_ = true;
  38. gradients_mean_ = false;
  39. full_batch_ = false;
  40. gradient_fp32_sync_ = true;
  41. loss_repeated_mean_ = true;
  42. device_num_ = 1;
  43. global_rank_ = 0;
  44. device_num_is_set_ = false;
  45. global_rank_is_set_ = false;
  46. parallel_mode_ = kStandalone;
  47. parameter_broadcast_ = false;
  48. parameter_broadcast_is_set_ = false;
  49. enable_all_reduce_fusion_ = true;
  50. enable_all_gather_fusion_ = true;
  51. enable_reduce_scatter_fusion_ = true;
  52. strategy_ckpt_load_file_ = "";
  53. strategy_ckpt_save_file_ = "";
  54. enable_parallel_optimizer_ = false;
  55. all_reduce_fusion_split_indices_.clear();
  56. all_reduce_fusion_split_sizes_.clear();
  57. strategy_search_mode_ = kDynamicProgramming;
  58. pipeline_stage_split_num_ = 1;
  59. grad_accumulation_step_ = 1;
  60. communi_parallel_mode_ = kAllGroupParallel;
  61. optimizer_weight_shard_size_ = -1;
  62. optimizer_weight_shard_aggregated_save_ = false;
  63. enable_all2all_ = false;
  64. grad_accumulation_shard_ = true;
  65. parallel_optimizer_threshold_ = -1;
  66. sharding_propagation_ = false;
  67. dataset_strategy_.clear();
  68. fusion_threshold_mb_ = kFusionThreshold;
  69. allgather_fusion_threshold_mb_ = kFusionThreshold;
  70. reducescatter_fusion_threshold_mb_ = kFusionThreshold;
  71. fusion_threshold_is_set_ = true;
  72. fusion_mode_ = kFusionAuto;
  73. }
  74. void ParallelContext::set_device_num(int64_t device_num) {
  75. device_num_ = device_num;
  76. device_num_is_set_ = true;
  77. }
  78. void ParallelContext::set_fusion_threshold_mb(int64_t fusion_threshold) {
  79. fusion_threshold_mb_ = fusion_threshold;
  80. fusion_threshold_is_set_ = true;
  81. enable_all_reduce_fusion_ = true;
  82. }
  83. void ParallelContext::set_allgather_fusion_threshold_mb(int64_t fusion_threshold) {
  84. allgather_fusion_threshold_mb_ = fusion_threshold;
  85. enable_all_gather_fusion_ = true;
  86. }
  87. void ParallelContext::set_reducescatter_fusion_threshold_mb(int64_t fusion_threshold) {
  88. reducescatter_fusion_threshold_mb_ = fusion_threshold;
  89. enable_reduce_scatter_fusion_ = true;
  90. }
  91. bool ParallelContext::set_fusion_mode(const std::string &fusion_mode) {
  92. auto iter = std::find(kFusionModeList.begin(), kFusionModeList.end(), fusion_mode);
  93. if (iter == kFusionModeList.end()) {
  94. MS_LOG(INFO) << "Invalid fusion mode:" << fusion_mode;
  95. return false;
  96. }
  97. fusion_mode_ = fusion_mode;
  98. return true;
  99. }
  100. void ParallelContext::set_global_rank(int64_t global_rank) {
  101. global_rank_ = global_rank;
  102. global_rank_is_set_ = true;
  103. }
  104. void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_ = gradients_mean; }
  105. void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
  106. void ParallelContext::set_dataset_strategy(const std::vector<std::vector<int64_t>> &dataset_strategy) {
  107. dataset_strategy_ = dataset_strategy;
  108. }
  109. void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) {
  110. grad_accumulation_step_ = grad_accumulation_step;
  111. }
  112. void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
  113. void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
  114. void ParallelContext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_num; }
  115. bool ParallelContext::set_parallel_mode(const std::string &parallel_mode) {
  116. auto iter = std::find(kParallelModeList.begin(), kParallelModeList.end(), parallel_mode);
  117. if (iter == kParallelModeList.end()) {
  118. MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode;
  119. return false;
  120. }
  121. parallel_mode_ = parallel_mode;
  122. return true;
  123. }
  124. bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) {
  125. auto iter = std::find(kStrategySearchModeList.begin(), kStrategySearchModeList.end(), strategy_search_mode);
  126. if (iter == kStrategySearchModeList.end()) {
  127. MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode;
  128. return false;
  129. }
  130. strategy_search_mode_ = strategy_search_mode;
  131. return true;
  132. }
  133. void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) {
  134. parameter_broadcast_ = parameter_broadcast;
  135. parameter_broadcast_is_set_ = true;
  136. }
  137. void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) {
  138. strategy_ckpt_load_file_ = strategy_ckpt_load_file;
  139. }
  140. void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) {
  141. strategy_ckpt_save_file_ = strategy_ckpt_save_file;
  142. }
  143. void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_save_file) {
  144. group_ckpt_save_file_ = group_ckpt_save_file;
  145. }
  146. void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) {
  147. optimizer_weight_shard_size_ = optimizer_weight_shard_size;
  148. }
  149. void ParallelContext::set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save) {
  150. optimizer_weight_shard_aggregated_save_ = optimizer_weight_shard_aggregated_save;
  151. }
  152. void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group) {
  153. if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
  154. group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
  155. group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
  156. all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
  157. all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
  158. all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
  159. }
  160. all_reduce_fusion_split_indices_[group] = indices;
  161. }
  162. std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
  163. auto iter = all_reduce_fusion_split_indices_.find(group);
  164. if (iter != all_reduce_fusion_split_indices_.end()) {
  165. return iter->second;
  166. }
  167. return {};
  168. }
  169. void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group) {
  170. if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
  171. group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
  172. group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
  173. all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
  174. all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
  175. all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
  176. }
  177. all_reduce_fusion_split_sizes_[group] = sizes;
  178. }
  179. std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
  180. auto iter = all_reduce_fusion_split_sizes_.find(group);
  181. if (iter != all_reduce_fusion_split_sizes_.end()) {
  182. return iter->second;
  183. }
  184. return {};
  185. }
  186. bool ParallelContext::set_communi_parallel_mode(const std::string &communi_parallel_mode) {
  187. auto iter = std::find(kCommuniParallelModeList.begin(), kCommuniParallelModeList.end(), communi_parallel_mode);
  188. if (iter == kCommuniParallelModeList.end()) {
  189. MS_LOG(INFO) << "Invalid communication parallel mode:" << communi_parallel_mode;
  190. return false;
  191. }
  192. communi_parallel_mode_ = communi_parallel_mode;
  193. return true;
  194. }
  195. // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
  196. void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
  197. MS_EXCEPTION_IF_NULL(func_graph);
  198. if (!func_graph->has_flag(kAutoParallel)) {
  199. return;
  200. }
  201. if (func_graph->has_flag(kIsFirstIteration)) {
  202. param_shapes.clear();
  203. init_param_shape_ = true;
  204. MS_LOG(INFO) << "Init the parameter shape dict in increment predict with two graph";
  205. return;
  206. }
  207. if (!func_graph->has_flag(kTraining)) {
  208. init_param_shape_ = false;
  209. MS_LOG(INFO) << "In parallel evaluation or prediction, may be need to restore the parameter shape";
  210. return;
  211. }
  212. if ((ParallelContext::GetInstance()->grad_accumulation_step() > 1) && !func_graph->has_flag(kAccumulation)) {
  213. init_param_shape_ = false;
  214. MS_LOG(INFO) << "In parallel grad accumulation second graph, need to restore the parameter shape";
  215. } else {
  216. param_shapes.clear();
  217. init_param_shape_ = true;
  218. MS_LOG(INFO) << "Init the parameter shape dict";
  219. }
  220. }
  221. // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode
  222. void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph,
  223. const ParameterPtr &param_node, const AbstractBasePtr &ptr) {
  224. MS_EXCEPTION_IF_NULL(func_graph);
  225. MS_EXCEPTION_IF_NULL(param_node);
  226. MS_EXCEPTION_IF_NULL(ptr);
  227. if (!func_graph->has_flag(kAutoParallel)) {
  228. return;
  229. }
  230. if (init_param_shape_) {
  231. return;
  232. }
  233. auto iter = param_shapes.find(param_node->name());
  234. if (iter == param_shapes.end()) {
  235. MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name();
  236. return;
  237. }
  238. auto shape = iter->second;
  239. std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
  240. ptr->set_shape(base_shape);
  241. MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
  242. }
  243. // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
  244. // Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode
  245. void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
  246. const AbstractBasePtr &ptr) {
  247. MS_EXCEPTION_IF_NULL(func_graph);
  248. MS_EXCEPTION_IF_NULL(param_node);
  249. MS_EXCEPTION_IF_NULL(ptr);
  250. if (!func_graph->has_flag(kAutoParallel)) {
  251. return;
  252. }
  253. if (!init_param_shape_) {
  254. return;
  255. }
  256. std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
  257. auto ret = param_shapes.try_emplace(param_node->name(), shape);
  258. if (!ret.second) {
  259. MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed";
  260. return;
  261. }
  262. MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
  263. }
  264. void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
  265. void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
  266. } // namespace mindspore::parallel