You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pipeline_split.cc 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include <memory>
  18. #include "pipeline/jit/pipeline_split.h"
  19. #include "utils/ms_context.h"
  20. #include "utils/comm_manager.h"
  21. #include "frontend/parallel/context.h"
  22. #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
  23. #include "frontend/parallel/step_parallel.h"
  24. namespace mindspore {
  25. namespace pipeline {
  26. static int64_t GetRank() {
  27. auto ms_context = MsContext::GetInstance();
  28. MS_EXCEPTION_IF_NULL(ms_context);
  29. std::string world_group;
  30. std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  31. if (backend == kAscendDevice) {
  32. world_group = parallel::HCCL_WORLD_GROUP;
  33. } else if (backend == kGPUDevice) {
  34. world_group = parallel::NCCL_WORLD_GROUP;
  35. } else {
  36. MS_LOG(EXCEPTION) << "Invalid backend: " << backend;
  37. }
  38. int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank();
  39. uint32_t rank_id = 0;
  40. if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) {
  41. if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
  42. MS_LOG(EXCEPTION) << "Get rank id failed.";
  43. }
  44. global_rank = UintToInt(rank_id);
  45. }
  46. return global_rank;
  47. }
  48. static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) {
  49. if (stage_num == 0) {
  50. MS_LOG(EXCEPTION) << "stage_num is zero";
  51. }
  52. if (device_num % stage_num != 0) {
  53. MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num
  54. << "stage_num: " << stage_num;
  55. }
  56. auto per_stage_rank_num = device_num / stage_num;
  57. return rank_id / per_stage_rank_num;
  58. }
  59. // Only auto_parallel and semi_auto_parallel support PipelineSplit
  60. bool PipelineSplit(const ResourcePtr &res) {
  61. auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
  62. if (parallel_mode != parallel::SEMI_AUTO_PARALLEL && parallel_mode != parallel::AUTO_PARALLEL) {
  63. MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
  64. return true;
  65. }
  66. auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
  67. if (stage_num <= 1) {
  68. MS_LOG(INFO) << "stage num is: " << stage_num << ". No need Pipeline split.";
  69. return true;
  70. }
  71. auto manager = res->manager();
  72. auto root = res->func_graph();
  73. auto global_rank = GetRank();
  74. auto device_num = parallel::ParallelContext::GetInstance()->device_num();
  75. if (device_num < 1) {
  76. MS_LOG(EXCEPTION) << "Invalid device num: " << device_num;
  77. }
  78. if (global_rank < 0) {
  79. MS_LOG(EXCEPTION) << "Invalid global rank: " << global_rank;
  80. }
  81. auto stage = InferStage(global_rank, stage_num, device_num);
  82. auto per_stage_rank_num = device_num / stage_num;
  83. if (parallel::ParallelInit() != parallel::SUCCESS) {
  84. MS_LOG(EXCEPTION) << "parallel init failed.";
  85. }
  86. auto transformer =
  87. std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
  88. // step1: Do color graph
  89. transformer->LabelRequiredGradCNode();
  90. transformer->Coloring();
  91. // step2: Do color broadcast
  92. transformer->BroadCastColoring();
  93. // step3: Handle shared parameters
  94. transformer->ParameterColoring();
  95. transformer->HandleSharedParameter();
  96. // step4: Cut Graph
  97. transformer->CutGraph();
  98. // step5: Handle Sens
  99. transformer->CoverSensShape();
  100. // step6: Elim Graph stages and no used parameter
  101. transformer->ElimGraphStage();
  102. transformer->ElimParameter();
  103. return true;
  104. }
  105. } // namespace pipeline
  106. } // namespace mindspore