You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.cc 12 kB

4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/parallel/context.h"
  17. #include <algorithm>
  18. #include <cstdint>
  19. #include <functional>
  20. #include <map>
  21. #include <memory>
  22. #include <utility>
  23. #include "frontend/parallel/device_manager.h"
  24. namespace mindspore {
  25. namespace parallel {
  26. std::map<std::string, Shape> param_shapes;
  27. std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL,
  28. AUTO_PARALLEL};
  29. std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING, SHARDING_PROPAGATION};
  30. std::vector<std::string> COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL,
  31. NO_GROUP_PARALLEL};
  32. std::vector<std::string> FUSION_MODE_LIST = {FUSION_AUTO, FUSION_SIZE, FUSION_INDEX};
  33. std::shared_ptr<ParallelContext> ParallelContext::inst_context_ = nullptr;
  34. std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
  35. if (inst_context_ == nullptr) {
  36. inst_context_.reset(new (std::nothrow) ParallelContext());
  37. }
  38. return inst_context_;
  39. }
  40. ParallelContext::ParallelContext() { Reset(); }
  41. void ParallelContext::Reset() {
  42. init_param_shape_ = true;
  43. gradients_mean_ = false;
  44. full_batch_ = false;
  45. gradient_fp32_sync_ = true;
  46. loss_repeated_mean_ = true;
  47. device_num_ = 1;
  48. global_rank_ = 0;
  49. device_num_is_set_ = false;
  50. global_rank_is_set_ = false;
  51. parallel_mode_ = STAND_ALONE;
  52. parameter_broadcast_ = false;
  53. parameter_broadcast_is_set_ = false;
  54. enable_all_reduce_fusion_ = true;
  55. strategy_ckpt_load_file_ = "";
  56. strategy_ckpt_save_file_ = "";
  57. enable_parallel_optimizer_ = false;
  58. all_reduce_fusion_split_indices_.clear();
  59. all_reduce_fusion_split_sizes_.clear();
  60. strategy_search_mode_ = DYNAMIC_PROGRAMMING;
  61. pipeline_stage_split_num_ = 1;
  62. grad_accumulation_step_ = 1;
  63. communi_parallel_mode_ = ALL_GROUP_PARALLEL;
  64. optimizer_weight_shard_size_ = -1;
  65. optimizer_weight_shard_aggregated_save_ = false;
  66. enable_all2all_ = false;
  67. grad_accumulation_shard_ = true;
  68. sharding_propagation_ = false;
  69. dataset_strategy_.clear();
  70. fusion_threshold_mb_ = FUSUION_THRESHOLD;
  71. fusion_threshold_is_set_ = true;
  72. fusion_mode_ = FUSION_AUTO;
  73. }
  74. void ParallelContext::set_device_num(int64_t device_num) {
  75. device_num_ = device_num;
  76. device_num_is_set_ = true;
  77. }
  78. void ParallelContext::set_fusion_threshold_mb(int64_t fusion_threshold) {
  79. fusion_threshold_mb_ = fusion_threshold;
  80. fusion_threshold_is_set_ = true;
  81. enable_all_reduce_fusion_ = true;
  82. }
  83. bool ParallelContext::set_fusion_mode(const std::string &fusion_mode) {
  84. auto iter = std::find(FUSION_MODE_LIST.begin(), FUSION_MODE_LIST.end(), fusion_mode);
  85. if (iter == FUSION_MODE_LIST.end()) {
  86. MS_LOG(INFO) << "Invalid fusion mode:" << fusion_mode;
  87. return false;
  88. }
  89. fusion_mode_ = fusion_mode;
  90. return true;
  91. }
  92. void ParallelContext::set_global_rank(int64_t global_rank) {
  93. global_rank_ = global_rank;
  94. global_rank_is_set_ = true;
  95. }
  96. void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_ = gradients_mean; }
  97. void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
  98. void ParallelContext::set_dataset_strategy(const std::vector<std::vector<int64_t>> &dataset_strategy) {
  99. dataset_strategy_ = dataset_strategy;
  100. }
  101. void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) {
  102. grad_accumulation_step_ = grad_accumulation_step;
  103. }
  104. void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
  105. void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
  106. void ParallelContext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_num; }
  107. bool ParallelContext::set_parallel_mode(const std::string &parallel_mode) {
  108. auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
  109. if (iter == PARALLEL_MODE_LIST.end()) {
  110. MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode;
  111. return false;
  112. }
  113. parallel_mode_ = parallel_mode;
  114. return true;
  115. }
  116. bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) {
  117. auto iter = std::find(STRATEGY_SEARCH_MODE_LIST.begin(), STRATEGY_SEARCH_MODE_LIST.end(), strategy_search_mode);
  118. if (iter == STRATEGY_SEARCH_MODE_LIST.end()) {
  119. MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode;
  120. return false;
  121. }
  122. strategy_search_mode_ = strategy_search_mode;
  123. return true;
  124. }
  125. void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) {
  126. parameter_broadcast_ = parameter_broadcast;
  127. parameter_broadcast_is_set_ = true;
  128. }
  129. void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) {
  130. strategy_ckpt_load_file_ = strategy_ckpt_load_file;
  131. }
  132. void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) {
  133. strategy_ckpt_save_file_ = strategy_ckpt_save_file;
  134. }
  135. void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_save_file) {
  136. group_ckpt_save_file_ = group_ckpt_save_file;
  137. }
  138. void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) {
  139. optimizer_weight_shard_size_ = optimizer_weight_shard_size;
  140. }
  141. void ParallelContext::set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save) {
  142. optimizer_weight_shard_aggregated_save_ = optimizer_weight_shard_aggregated_save;
  143. }
  144. void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group) {
  145. if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
  146. group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
  147. group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
  148. all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
  149. all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
  150. all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
  151. }
  152. all_reduce_fusion_split_indices_[group] = indices;
  153. }
  154. std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
  155. auto iter = all_reduce_fusion_split_indices_.find(group);
  156. if (iter != all_reduce_fusion_split_indices_.end()) {
  157. return iter->second;
  158. }
  159. return {};
  160. }
  161. void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group) {
  162. if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
  163. group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
  164. group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
  165. all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
  166. all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
  167. all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
  168. }
  169. all_reduce_fusion_split_sizes_[group] = sizes;
  170. }
  171. std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
  172. auto iter = all_reduce_fusion_split_sizes_.find(group);
  173. if (iter != all_reduce_fusion_split_sizes_.end()) {
  174. return iter->second;
  175. }
  176. return {};
  177. }
  178. bool ParallelContext::set_communi_parallel_mode(const std::string &communi_parallel_mode) {
  179. auto iter = std::find(COMMUNI_PARALLEL_MODE_LIST.begin(), COMMUNI_PARALLEL_MODE_LIST.end(), communi_parallel_mode);
  180. if (iter == COMMUNI_PARALLEL_MODE_LIST.end()) {
  181. MS_LOG(INFO) << "Invalid communication parallel mode:" << communi_parallel_mode;
  182. return false;
  183. }
  184. communi_parallel_mode_ = communi_parallel_mode;
  185. return true;
  186. }
  187. // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
  188. void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
  189. MS_EXCEPTION_IF_NULL(func_graph);
  190. if (!func_graph->has_flag(AUTO_PARALLEL)) {
  191. return;
  192. }
  193. if (func_graph->has_flag(IS_FIRST_ITERATION)) {
  194. param_shapes.clear();
  195. init_param_shape_ = true;
  196. MS_LOG(INFO) << "Init the parameter shape dict in increment predict with two graph";
  197. return;
  198. }
  199. if (!func_graph->has_flag(TRAINING)) {
  200. init_param_shape_ = false;
  201. MS_LOG(INFO) << "In parallel evaluation or prediction, may be need to restore the parameter shape";
  202. return;
  203. }
  204. if ((ParallelContext::GetInstance()->grad_accumulation_step() > 1) && !func_graph->has_flag(ACCUMULATION)) {
  205. init_param_shape_ = false;
  206. MS_LOG(INFO) << "In parallel grad accumulation second graph, need to restore the parameter shape";
  207. } else {
  208. param_shapes.clear();
  209. init_param_shape_ = true;
  210. MS_LOG(INFO) << "Init the parameter shape dict";
  211. }
  212. }
  213. // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode
  214. void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph,
  215. const ParameterPtr &param_node, const AbstractBasePtr &ptr) {
  216. MS_EXCEPTION_IF_NULL(func_graph);
  217. MS_EXCEPTION_IF_NULL(param_node);
  218. MS_EXCEPTION_IF_NULL(ptr);
  219. if (!func_graph->has_flag(AUTO_PARALLEL)) {
  220. return;
  221. }
  222. if (init_param_shape_) {
  223. return;
  224. }
  225. auto iter = param_shapes.find(param_node->name());
  226. if (iter == param_shapes.end()) {
  227. MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name();
  228. return;
  229. }
  230. Shape shape = iter->second;
  231. std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
  232. ptr->set_shape(base_shape);
  233. MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
  234. }
  235. // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
  236. // Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode
  237. void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
  238. const AbstractBasePtr &ptr) {
  239. MS_EXCEPTION_IF_NULL(func_graph);
  240. MS_EXCEPTION_IF_NULL(param_node);
  241. MS_EXCEPTION_IF_NULL(ptr);
  242. if (!func_graph->has_flag(AUTO_PARALLEL)) {
  243. return;
  244. }
  245. if (!init_param_shape_) {
  246. return;
  247. }
  248. std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
  249. auto ret = param_shapes.try_emplace(param_node->name(), shape);
  250. if (!ret.second) {
  251. MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed";
  252. return;
  253. }
  254. MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
  255. }
  256. void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
  257. void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
  258. } // namespace parallel
  259. } // namespace mindspore