/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "pipeline/jit/pipeline_split.h" #include "utils/ms_context.h" #include "utils/comm_manager.h" #include "frontend/parallel/context.h" #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" #include "frontend/parallel/step_parallel.h" namespace mindspore { namespace pipeline { static int64_t GetRank() { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); std::string world_group; std::string backend = ms_context->get_param(MS_CTX_DEVICE_TARGET); if (backend == kAscendDevice) { world_group = parallel::HCCL_WORLD_GROUP; } else if (backend == kGPUDevice) { world_group = parallel::NCCL_WORLD_GROUP; } else { MS_LOG(EXCEPTION) << "Invalid backend: " << backend; } int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank(); uint32_t rank_id = 0; if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) { if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { MS_LOG(EXCEPTION) << "Get rank id failed."; } global_rank = UintToInt(rank_id); } return global_rank; } static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) { if (stage_num == 0) { MS_LOG(EXCEPTION) << "stage_num is zero"; } if (device_num % stage_num != 0) { MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num << "stage_num: " << stage_num; } auto per_stage_rank_num = device_num / stage_num; return rank_id / per_stage_rank_num; } // Only auto_parallel and semi_auto_parallel support PipelineSplit bool PipelineSplit(const ResourcePtr &res) { auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); if (parallel_mode != parallel::SEMI_AUTO_PARALLEL && parallel_mode != parallel::AUTO_PARALLEL) { MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split."; return true; } auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num(); if (stage_num <= 1) { MS_LOG(INFO) << "stage num is: " << stage_num << ". No need Pipeline split."; return true; } auto manager = res->manager(); auto root = res->func_graph(); auto global_rank = GetRank(); auto device_num = parallel::ParallelContext::GetInstance()->device_num(); if (device_num < 1) { MS_LOG(EXCEPTION) << "Invalid device num: " << device_num; } if (global_rank < 0) { MS_LOG(EXCEPTION) << "Invalid global rank: " << global_rank; } auto stage = InferStage(global_rank, stage_num, device_num); auto per_stage_rank_num = device_num / stage_num; if (parallel::ParallelInit() != parallel::SUCCESS) { MS_LOG(EXCEPTION) << "parallel init failed."; } auto transformer = std::make_shared(manager, stage, root, global_rank, per_stage_rank_num); // step1: Do color graph transformer->LabelRequiredGradCNode(); transformer->Coloring(); // step2: Do color broadcast transformer->BroadCastColoring(); // step3: Handle shared parameters transformer->ParameterColoring(); transformer->HandleSharedParameter(); // step4: Cut Graph transformer->CutGraph(); // step5: Handle Sens transformer->CoverSensShape(); // step6: Elim Graph stages and no used parameter transformer->ElimGraphStage(); transformer->ElimParameter(); return true; } } // namespace pipeline } // namespace mindspore