/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "parallel/auto_parallel/costmodel.h" #include #include #include #include "parallel/auto_parallel/graph_costmodel.h" namespace mindspore { namespace parallel { void Simplify(CostPtrList *clist_ptrs) { // Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method // excludes the cost with greater computation_cost_ and greater communication_cost. // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} if (!COST_MODEL_SIMPLIFY_CALCULATION) { return; } MS_EXCEPTION_IF_NULL(clist_ptrs); std::vector id(clist_ptrs->size()); std::iota(id.begin(), id.end(), size_t(0)); std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; }); CostPtrList ret; for (size_t i = 0; i < clist_ptrs->size(); ++i) { if ((ret.size() == size_t(0)) || (clist_ptrs->at(id[i])->communication_cost_ < ret.back()->communication_cost_)) { ret.emplace_back(std::move(clist_ptrs->at(id[i]))); } } *clist_ptrs = std::move(ret); } void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. if (!COST_MODEL_SIMPLIFY_CALCULATION) { return; } MS_EXCEPTION_IF_NULL(clist_ptrs); std::vector id(clist_ptrs->size()); std::iota(id.begin(), id.end(), size_t(0)); std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; }); CostPtrList ret; for (size_t i = 0; i < clist_ptrs->size(); ++i) { if ((ret.size() == size_t(0)) || (clist_ptrs->at(id[i])->communication_with_partial_para_ < ret.back()->communication_with_partial_para_)) { ret.emplace_back(std::move(clist_ptrs->at(id[i]))); } } *clist_ptrs = std::move(ret); } void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { MS_EXCEPTION_IF_NULL(origin_cost); if (is_redistribution) { // Redistribution cost if ((origin_cost->communication_redis_forward_ > EPS) && (origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST; } else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) { origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS; } if ((origin_cost->communication_redis_backward_ > EPS) && (origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST; } else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) { origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS; } origin_cost->communication_cost_ = origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_; origin_cost->communication_without_parameter_ = origin_cost->communication_cost_; origin_cost->communication_with_partial_para_ = origin_cost->communication_cost_; } else { // Operator cost double backward = 0.0; if (std::abs(origin_cost->communication_cost_ - origin_cost->communication_without_parameter_) > EPS) { backward = origin_cost->communication_cost_ - origin_cost->communication_without_parameter_; } // forward cost if ((origin_cost->communication_without_parameter_ > EPS) && (origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) { origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST; } else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) { origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS; } // total if (origin_cost->communication_cost_ > EPS) { origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward; } if (origin_cost->communication_with_partial_para_ > EPS) { origin_cost->communication_with_partial_para_ = origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward; } } } } // namespace parallel } // namespace mindspore