Merge pull request !30118 from zhuyuxiao/I4S85Vfeature/build-system-rewrite
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -78,6 +78,7 @@ void ParallelContext::Reset() { | |||
| optimizer_weight_shard_aggregated_save_ = false; | |||
| enable_all2all_ = false; | |||
| grad_accumulation_shard_ = true; | |||
| parallel_optimizer_threshold_ = -1; | |||
| sharding_propagation_ = false; | |||
| dataset_strategy_.clear(); | |||
| fusion_threshold_mb_ = FUSUION_THRESHOLD; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -157,6 +157,10 @@ class ParallelContext { | |||
| grad_accumulation_shard_ = grad_accumulation_shard; | |||
| } | |||
| bool grad_accumulation_shard() const { return grad_accumulation_shard_; } | |||
| void set_parallel_optimizer_threshold(const int64_t parallel_optimizer_threshold) { | |||
| parallel_optimizer_threshold_ = parallel_optimizer_threshold; | |||
| } | |||
| int64_t get_parallel_optimizer_threshold() const { return parallel_optimizer_threshold_; } | |||
| bool set_communi_parallel_mode(const std::string &communi_parallel_mode); | |||
| std::string communi_parallel_mode() const { return communi_parallel_mode_; } | |||
| @@ -211,6 +215,7 @@ class ParallelContext { | |||
| int64_t optimizer_weight_shard_size_; | |||
| bool optimizer_weight_shard_aggregated_save_; | |||
| bool grad_accumulation_shard_; | |||
| int64_t parallel_optimizer_threshold_; | |||
| // Enable AllToAll or not. If false, use AllGather and Split. | |||
| bool enable_all2all_; | |||
| std::vector<std::vector<int64_t>> dataset_strategy_; | |||
| @@ -0,0 +1,161 @@ | |||
| /** | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/parallel/parallel_optimizer/opt_param_mgr.h" | |||
| #include <string> | |||
| #include <vector> | |||
| #include <functional> | |||
| #include <map> | |||
| #include <memory> | |||
| #include "frontend/parallel/ops_info/operator_info.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "ir/dtype/type_id.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| class OptParamMgrImpl : public OptParamMgr { | |||
| public: | |||
| explicit OptParamMgrImpl(const FuncGraphPtr &root) : root_(root) {} | |||
| virtual ~OptParamMgrImpl() = default; | |||
| std::string ShardOptGroup(const AnfNodePtr ¶meter, TensorLayout *const tensor_layout, | |||
| const OperatorInfoPtr &distribute_operator) const override { | |||
| if (!SplitParam(parameter)) { | |||
| return ""; | |||
| } | |||
| Status ret = tensor_layout->GenerateOptShardSliceShape(); | |||
| if (ret != Status::SUCCESS) { | |||
| MS_LOG(INFO) << parameter->ToString() << "'s distributed shape " << tensor_layout->slice_shape().ToString() | |||
| << " does not satisfy the conditions."; | |||
| return ""; | |||
| } | |||
| // get the shard tensor slice shape if the weight is repeated on devices | |||
| // and the shape of the first dimension could be divided | |||
| // apply parallel optimizer on parameters | |||
| // create communication group for allgather operator | |||
| std::string opt_shard_group; | |||
| std::vector<Group> dev_group; | |||
| MS_LOG(INFO) << "Creating shard group for param: " << parameter->ToString() | |||
| << ", shape: " << parameter->Shape()->ToString(); | |||
| if (distribute_operator->CreateGroupForOptShard(tensor_layout, &dev_group) == Status::SUCCESS && | |||
| !dev_group.empty()) { | |||
| opt_shard_group = dev_group[0].name(); | |||
| MS_LOG(INFO) << "create group success."; | |||
| } else { | |||
| MS_LOG(ERROR) << "create group failed."; | |||
| } | |||
| return opt_shard_group; | |||
| } | |||
| private: | |||
| int64_t ComputeShapeSize(const AnfNodePtr ¶meter) const { | |||
| ShapeVector shape(parameter->Shape()->cast<abstract::ShapePtr>()->shape()); | |||
| int64_t total_size = | |||
| std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()); | |||
| return total_size; | |||
| } | |||
| // unit: KB | |||
| float ComputeMemorySize(const AnfNodePtr ¶meter) const { | |||
| // key, value: typeid, bytes | |||
| const std::map<TypeId, size_t> dtype_size_map = { | |||
| {kNumberTypeBool, sizeof(bool)}, {kNumberTypeInt8, sizeof(int8_t)}, | |||
| {kNumberTypeInt16, sizeof(int16_t)}, {kNumberTypeInt32, sizeof(int32_t)}, | |||
| {kNumberTypeInt64, sizeof(int64_t)}, {kNumberTypeFloat16, sizeof(float16)}, | |||
| {kNumberTypeFloat32, sizeof(float)}, {kNumberTypeFloat64, sizeof(double)}, | |||
| {kNumberTypeUInt8, sizeof(uint8_t)}, {kNumberTypeUInt16, sizeof(uint16_t)}, | |||
| {kNumberTypeUInt32, sizeof(uint32_t)}, {kNumberTypeUInt64, sizeof(uint64_t)}}; | |||
| int64_t shape_size = ComputeShapeSize(parameter); | |||
| TypeId type_id = parameter->Type()->cast<mindspore::TensorTypePtr>()->element()->type_id(); | |||
| if (dtype_size_map.find(type_id) == dtype_size_map.end()) { | |||
| MS_LOG(EXCEPTION) << "unsupported type of parameter: " << parameter->DebugString(); | |||
| } | |||
| size_t type_size = dtype_size_map.find(type_id)->second; | |||
| return static_cast<float>(shape_size) * type_size / DIVISOR_K; | |||
| } | |||
| bool StageSharedParam(const AnfNodePtr ¶meter) const { | |||
| MS_EXCEPTION_IF_NULL(root_); | |||
| FuncGraphManagerPtr manager = root_->manager(); | |||
| auto user_set = manager->node_users()[parameter]; | |||
| for (auto ¶m_pair : user_set) { | |||
| CNodePtr cnode = param_pair.first->cast<CNodePtr>(); | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimSend) || IsPrimitiveCNode(cnode, prim::kPrimReceive)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| int64_t GetThresholdFromUsrInput() const { | |||
| return ParallelContext::GetInstance()->get_parallel_optimizer_threshold(); | |||
| } | |||
| bool SplitParam(const AnfNodePtr ¶meter) const { | |||
| if (!ParallelContext::GetInstance()->enable_parallel_optimizer()) { | |||
| MS_LOG(INFO) << "Parallel optimizer: feature is not enabled. Skipped."; | |||
| return false; | |||
| } | |||
| if (StageSharedParam(parameter)) { | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() | |||
| << " is stage-shared in pipeline parallel. Skipped."; | |||
| return false; | |||
| } | |||
| if (!ParameterRequireGrad(parameter)) { | |||
| // only trainable parameters need parallel optimizer | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; | |||
| return false; | |||
| } | |||
| if (parameter->cast<ParameterPtr>()->param_info() && | |||
| !parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) { | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is manually set skipped."; | |||
| return false; | |||
| } | |||
| int64_t param_split_threshold = DEFAULT_VAL; | |||
| int64_t user_define_threshold = GetThresholdFromUsrInput(); | |||
| if (user_define_threshold != -1) { | |||
| MS_LOG(INFO) << "Parallel optimizer: use user-define threshold = " << user_define_threshold << "KB."; | |||
| param_split_threshold = user_define_threshold; | |||
| } else { | |||
| MS_LOG(INFO) << "Parallel optimizer: use DEFAULT threshold = " << DEFAULT_VAL << "KB."; | |||
| } | |||
| float param_size = ComputeMemorySize(parameter); | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " size = " << param_size << "KB"; | |||
| if (param_size < param_split_threshold) { | |||
| MS_LOG(INFO) << "Parallel optimizer: the size of " << parameter->ToString() << "(" << param_size | |||
| << "KB) is smaller than the threshold(" << param_split_threshold << "KB). Skipped."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| private: | |||
| FuncGraphPtr root_; | |||
| int64_t DEFAULT_VAL = 64; // unit: KB | |||
| int64_t DIVISOR_K = 1024; | |||
| }; | |||
| std::unique_ptr<OptParamMgr> createOptParamMgr(const FuncGraphPtr &root) { | |||
| return std::make_unique<OptParamMgrImpl>(root); | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_OPTPARAMMGR_H | |||
| #define MINDSPORE_OPTPARAMMGR_H | |||
| #include <string> | |||
| #include <memory> | |||
| #include "frontend/parallel/tensor_layout/tensor_layout.h" | |||
| #include "frontend/parallel/graph_util/node_info.h" | |||
| #include "base/base.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| class OptParamMgr { | |||
| public: | |||
| virtual ~OptParamMgr() = default; | |||
| virtual std::string ShardOptGroup(const AnfNodePtr ¶meter, TensorLayout *const tensor_layout, | |||
| const OperatorInfoPtr &distribute_operator) const = 0; | |||
| }; | |||
| std::unique_ptr<OptParamMgr> createOptParamMgr(const FuncGraphPtr &root); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_OPTPARAMMGR_H | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -49,6 +49,7 @@ | |||
| #include "utils/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| #include "mindspore/core/utils/parallel_node_check.h" | |||
| #include "frontend/parallel/parallel_optimizer/opt_param_mgr.h" | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| @@ -1578,9 +1579,9 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||
| if (opt_shard_group.empty()) { | |||
| return; | |||
| } | |||
| FuncGraphManagerPtr manager = root->manager(); | |||
| // set all gather type | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||
| int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); | |||
| std::string op_name; | |||
| @@ -1591,6 +1592,10 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||
| } else { | |||
| op_name = ALL_GATHER; | |||
| } | |||
| // insert all gather | |||
| FuncGraphManagerPtr manager = root->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto param_sub_set = manager->node_users()[parameter]; | |||
| bool insert_flag = false; | |||
| for (auto ¶m_pair : param_sub_set) { | |||
| @@ -1605,6 +1610,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||
| MS_LOG(EXCEPTION) << "The index is out of range, index is " << (param_pair.second - 1) << ", vector size is " | |||
| << distribute_operator->inputs_tensor_info().size(); | |||
| } | |||
| if (insert_flag) { | |||
| // if there are multiple node users, they share one same allgather | |||
| auto next_cnode = FindCNode(parameter, op_name, cnode->func_graph(), 0); | |||
| @@ -1627,53 +1633,24 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||
| } | |||
| } | |||
| static std::string GetOptShardGroup(const AnfNodePtr ¶meter, TensorLayout *const tensor_layout, | |||
| const OperatorInfoPtr &distribute_operator) { | |||
| std::string opt_shard_group; | |||
| if (!ParameterRequireGrad(parameter)) { | |||
| // only trainable parameters need parallel optimizer | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; | |||
| } else if (parameter->cast<ParameterPtr>()->param_info() && | |||
| !parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) { | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard."; | |||
| } else if (tensor_layout->GenerateOptShardSliceShape() == Status::SUCCESS) { | |||
| // get the shard tensor slice shape if the weight is repeated on devices | |||
| // and the shape of the first dimension could be divided | |||
| // apply parallel optimizer on parameters | |||
| // create communication group for allgather operator | |||
| std::vector<Group> dev_group; | |||
| if (distribute_operator->CreateGroupForOptShard(tensor_layout, &dev_group) == Status::SUCCESS && | |||
| !dev_group.empty()) { | |||
| opt_shard_group = dev_group[0].name(); | |||
| MS_LOG(INFO) << "Parallel optimizer: create group for " << parameter->ToString() << " success."; | |||
| } else { | |||
| MS_LOG(ERROR) << "Parallel optimizer: create group for " << parameter->ToString() << " failed."; | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Parallel optimizer: " << parameter->ToString() << "'s distributed shape " | |||
| << tensor_layout->slice_shape().ToString() << " does not satisfy the conditions."; | |||
| } | |||
| return opt_shard_group; | |||
| } | |||
| void SetSharedParameterFlag(const FuncGraphPtr &root, const AnfNodePtr ¶meter) { | |||
| MS_EXCEPTION_IF_NULL(root); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| FuncGraphManagerPtr manager = root->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto parameter_ptr = parameter->cast<ParameterPtr>(); | |||
| if (!parameter_ptr) { | |||
| MS_LOG(INFO) << parameter->ToString() << " is not a parameter"; | |||
| ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | |||
| if (parameter_ptr == nullptr) { | |||
| MS_LOG(INFO) << parameter->ToString() << ": cast to ptr failed. it may not be a parameter"; | |||
| return; | |||
| } | |||
| auto param_sub_set = manager->node_users()[parameter]; | |||
| int32_t users_count = 0; | |||
| for (auto ¶m_pair : param_sub_set) { | |||
| auto cnode = param_pair.first->cast<CNodePtr>(); | |||
| auto user_set = manager->node_users()[parameter]; | |||
| int32_t user_count = 0; | |||
| for (auto ¶m_pair : user_set) { | |||
| CNodePtr cnode = param_pair.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->in_forward_flag()) users_count++; | |||
| if (cnode->in_forward_flag()) user_count++; | |||
| } | |||
| if (users_count > 1) { | |||
| if (user_count > 1) { | |||
| auto tensor_layout = parameter_ptr->user_data<TensorLayout>(); | |||
| tensor_layout->set_is_shared_param(true); | |||
| MS_LOG(WARNING) << "There are multiple users for " << parameter->ToString() | |||
| @@ -1682,41 +1659,57 @@ void SetSharedParameterFlag(const FuncGraphPtr &root, const AnfNodePtr ¶mete | |||
| } | |||
| // When this function returns non-empty string, that means parallel optimizer is applied on this parameter. | |||
| std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int64_t> &res) { | |||
| std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int64_t> &res, | |||
| const FuncGraphPtr &root) { | |||
| // check null for param and cnode | |||
| auto param_shape = parameter->Shape(); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| AbstractBasePtr abstract = parameter->abstract(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); | |||
| MS_EXCEPTION_IF_NULL(param_shape); | |||
| CNodePtr cnode = res.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // get slice_shape | |||
| OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>(); | |||
| if (distribute_operator == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; | |||
| MS_LOG(EXCEPTION) << "node " << cnode->ToString() << " 's distribute_operator is nullptr"; | |||
| } | |||
| if (LongToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) { | |||
| MS_LOG(EXCEPTION) << "The index is out of range, index is " << (res.second - 1) << ", vector size is " | |||
| << distribute_operator->inputs_tensor_info().size(); | |||
| MS_LOG(EXCEPTION) << "The parameter index is not in inputs_tensor_info. index = " << (res.second - 1) | |||
| << ", inputs_tensor_info size = " << distribute_operator->inputs_tensor_info().size(); | |||
| } | |||
| TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(res.second - 1)]; | |||
| TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | |||
| Shape slice_shape = tensor_layout.slice_shape().array(); | |||
| // generate shard group | |||
| std::string opt_shard_group; | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer(); | |||
| if (enable_parallel_optimizer) { | |||
| opt_shard_group = GetOptShardGroup(parameter, &tensor_layout, distribute_operator); | |||
| } | |||
| if (!opt_shard_group.empty()) { | |||
| slice_shape = tensor_layout.opt_shard_slice_shape(); | |||
| } | |||
| MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " | |||
| << MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name(); | |||
| std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | |||
| MS_EXCEPTION_IF_NULL(parallel_shape); | |||
| // Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis. | |||
| auto cloned_abstract = abstract->Clone(); | |||
| MS_EXCEPTION_IF_NULL(cloned_abstract); | |||
| cloned_abstract->set_shape(parallel_shape); | |||
| std::unique_ptr<OptParamMgr> apOptParamMgr = createOptParamMgr(root); | |||
| opt_shard_group = apOptParamMgr->ShardOptGroup(parameter, &tensor_layout, distribute_operator); | |||
| // set the shape of parameter to sliced shape | |||
| if (!opt_shard_group.empty()) { | |||
| slice_shape = tensor_layout.opt_shard_slice_shape(); | |||
| } | |||
| MS_LOG(INFO) << "the shape of " << parameter->ToString() << "(original: " << param_shape->ToString() << ")" | |||
| << " will be sliced into " << MakeValue(slice_shape)->ToString() << " in op " | |||
| << distribute_operator->name(); | |||
| } | |||
| AbstractBasePtr abstract = parameter->abstract(); | |||
| if (abstract == nullptr) { | |||
| MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << ": abstract is nullptr"; | |||
| } | |||
| AbstractBasePtr cloned_abstract = abstract->Clone(); | |||
| if (cloned_abstract == nullptr) { | |||
| MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << ": abstract clone failed"; | |||
| } | |||
| cloned_abstract->set_shape(std::make_shared<abstract::Shape>(slice_shape)); | |||
| parameter->set_abstract(cloned_abstract); | |||
| ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(parameter_ptr); | |||
| @@ -1729,19 +1722,21 @@ void CoverSliceShape(const FuncGraphPtr &root) { | |||
| auto parameters = root->parameters(); | |||
| for (auto ¶meter : parameters) { | |||
| MS_EXCEPTION_IF_NULL(parameter->Shape()); | |||
| auto iter = g_RefMap.find(parameter); | |||
| if (iter != g_RefMap.end()) { | |||
| std::string group = SetParallelShape(parameter, g_RefMap[parameter]); | |||
| std::string group = SetParallelShape(parameter, g_RefMap[parameter], root); | |||
| // find all forward nodes that use parameter in graphs and insert allgather if group is not empty | |||
| SetSharedParameterFlag(root, parameter); | |||
| ApplyParallelOptOnParam(root, parameter, group); | |||
| continue; | |||
| } | |||
| std::pair<AnfNodePtr, int64_t> res = FindSubGraph(root, parameter); | |||
| if (res.first == nullptr) { | |||
| MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape"; | |||
| MS_LOG(INFO) << "Parameter " << parameter->ToString() << " is not in graph, thus no need to set parallel shape"; | |||
| } else { | |||
| std::string group = SetParallelShape(parameter, res); | |||
| std::string group = SetParallelShape(parameter, res, root); | |||
| // find all forward nodes that use parameter in graphs and insert allgather if group is not empty | |||
| SetSharedParameterFlag(root, parameter); | |||
| ApplyParallelOptOnParam(root, parameter, group); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -161,6 +161,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.") | |||
| .def("get_grad_accumulation_shard", &ParallelContext::grad_accumulation_shard, "Get grad_accumulation_shard.") | |||
| .def("set_grad_accumulation_shard", &ParallelContext::set_grad_accumulation_shard, "Set grad_accumulation_shard.") | |||
| .def("get_parallel_optimizer_threshold", &ParallelContext::get_parallel_optimizer_threshold, "Get opt threshold.") | |||
| .def("set_parallel_optimizer_threshold", &ParallelContext::set_parallel_optimizer_threshold, "Set opt threshold.") | |||
| .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.") | |||
| .def("get_gradients_mean", &ParallelContext::gradients_mean, "Get mirror mean.") | |||
| .def("set_gradients_mean", &ParallelContext::set_gradients_mean, "Set mirror mean.") | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -496,7 +496,7 @@ def set_auto_parallel_context(**kwargs): | |||
| context.set_auto_parallel_context(enable_parallel_optimizer=True). | |||
| It supports the following keys. | |||
| - gradient_accumulation_shard: If true, the accumulation gradient parameters will be | |||
| - gradient_accumulation_shard(bool): If true, the accumulation gradient parameters will be | |||
| sharded across the data parallel devices. This will | |||
| introduce additional communication(ReduceScatter) at | |||
| each step when accumulate the gradients, but saves a | |||
| @@ -504,6 +504,11 @@ def set_auto_parallel_context(**kwargs): | |||
| with larger batch size. This configure is effective only | |||
| when the model runs on pipeline training or gradient | |||
| accumulation with data parallel. Default True. | |||
| - parallel_optimizer_threshold(int): Set the threshold of parallel optimizer. When parallel | |||
| optimizer is enabled, parameters with size smaller than this threshold will not be sharded | |||
| across the devices. Unit: KB. Default: 64. | |||
| comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each | |||
| communication fusion config has two keys: "mode" and "config". | |||
| It supports following communication fusion types and configurations: | |||
| @@ -767,8 +772,8 @@ def set_context(**kwargs): | |||
| Indicates whether to enable image-computing convergence to optimize network execution performance. | |||
| If enable_graph_kernel is set to True, acceleration can be enabled. | |||
| For details of graph kernel fusion, please check | |||
| `Enabling Graph Kernel Fusion <https://www.mindspore.cn/docs/programming_guide | |||
| /en/master/enable_graph_kernel_fusion.html>`_. | |||
| `Enabling Graph Kernel Fusion | |||
| <https://www.mindspore.cn/docs/programming_guide/en/master/enable_graph_kernel_fusion.html>`_. | |||
| graph_kernel_flags (str) – | |||
| Optimization options of graph kernel fusion, and the priority is higher when it conflicts | |||
| with enable_graph_kernel. Only for experienced users. | |||
| @@ -802,8 +807,8 @@ def set_context(**kwargs): | |||
| (Automatic selection). | |||
| For more information about the enable operator tuning tool settings, please check | |||
| `Enable the operator optimization tool <https://www.mindspore.cn/docs/programming_guide/en | |||
| /master/enable_auto_tune.html>`_. | |||
| `Enable the operator optimization tool | |||
| <https://www.mindspore.cn/docs/programming_guide/en/master/enable_auto_tune.html>`_. | |||
| check_bprop (bool): Whether to check back propagation nodes. The checking ensures that the shape and dtype | |||
| of back propagation node outputs is the same as input parameters. Default: False. | |||
| max_call_depth (int): Specify the maximum depth of function call. Must be positive integer. Default: 1000. | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -46,6 +46,7 @@ class _ParallelOptimizerConfig: | |||
| The key of the Parallel Optimizer. There are three | |||
| """ | |||
| GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard" | |||
| PARALLEL_OPTIMIZER_THRESHOLD = "parallel_optimizer_threshold" | |||
| class _AutoParallelContext: | |||
| @@ -771,37 +772,43 @@ class _AutoParallelContext: | |||
| parallel_optimizer_config(dict): A dict contains the keys and values for setting the parallel optimizer | |||
| configure. It supports the following keys: | |||
| - gradient_accumulation_shard: If true, the accumulation gradient parameters will be sharded | |||
| across the data parallel devices. This will introduce additional | |||
| communication(ReduceScatter) at each step when accumulate the | |||
| gradients, but saves a lot of device memories, | |||
| thus can make model be trained with larger batch size. | |||
| This configure is effective only when the model runs on pipeline | |||
| training or gradient accumulation with data parallel. | |||
| - gradient_accumulation_shard(bool): If true, the accumulation gradient parameters will be sharded | |||
| across the data parallel devices. This will introduce additional | |||
| communication cost(ReduceScatter) at each step when accumulate the | |||
| gradients, but saves a lot of device memories, | |||
| thus can make model be trained with larger batch size. | |||
| This configuration is effective only when the model runs on pipeline | |||
| training or gradient accumulation with data parallel. | |||
| - parallel_optimizer_threshold(int): Set the threshold of parallel optimizer. When parallel optimizer | |||
| is enabled, parameters with size smaller than this threshold will | |||
| not be sharded across the devices. Unit: KB. Default: 64. | |||
| """ | |||
| self.check_context_handle() | |||
| grad_shard_name = _ParallelOptimizerConfig.GRADIENT_ACCUMULATION_SHARD | |||
| if len(parallel_optimizer_config) > 1 and grad_shard_name in parallel_optimizer_config: | |||
| other_keys = list(parallel_optimizer_config.keys()) | |||
| other_keys.remove(grad_shard_name) | |||
| raise ValueError(f"Except {grad_shard_name}, there are useless keys in parallel_optimizer_config " | |||
| f"{other_keys}, please check your " | |||
| f"parallel_optimizer_config to remove the useless keys.") | |||
| if grad_shard_name not in parallel_optimizer_config: | |||
| raise ValueError(f"The parallel_optimizer_config does not support the keys " | |||
| f"{list(parallel_optimizer_config.keys())}, " | |||
| f"you should input the key {grad_shard_name} only, please check your " | |||
| f"parallel_optimizer_config.") | |||
| Validator.check_bool( | |||
| parallel_optimizer_config[grad_shard_name], grad_shard_name, grad_shard_name) | |||
| self._context_handle.set_grad_accumulation_shard( | |||
| parallel_optimizer_config[grad_shard_name]) | |||
| threshold_name = _ParallelOptimizerConfig.PARALLEL_OPTIMIZER_THRESHOLD | |||
| if grad_shard_name in parallel_optimizer_config: | |||
| Validator.check_bool( | |||
| parallel_optimizer_config[grad_shard_name], grad_shard_name, grad_shard_name) | |||
| self._context_handle.set_grad_accumulation_shard( | |||
| parallel_optimizer_config[grad_shard_name]) | |||
| if threshold_name in parallel_optimizer_config: | |||
| Validator.check_positive_int( | |||
| parallel_optimizer_config[threshold_name]) | |||
| self._context_handle.set_parallel_optimizer_threshold( | |||
| parallel_optimizer_config[threshold_name]) | |||
| def get_grad_accumulation_shard(self): | |||
| """Get grad accumulation shard.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_grad_accumulation_shard() | |||
| def get_parallel_optimizer_threshold(self): | |||
| """Get parallel optimizer threshold.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_parallel_optimizer_threshold() | |||
| def set_enable_alltoall(self, enable_a2a): | |||
| """ | |||
| Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll. | |||
| @@ -93,6 +93,7 @@ def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None | |||
| loss_scale_manager=None): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True, | |||
| parallel_optimizer_config={"parallel_optimizer_threshold": 1}, | |||
| pipeline_stages=stages) | |||
| net = MicroBatchInterleaved(net(param_type, strategy1, strategy2), interleaved_batch) | |||
| @@ -115,8 +115,9 @@ def test_mirror_group_parallel_optimizer(): | |||
| Expectation: group info list match expectation value. | |||
| """ | |||
| os.environ['GROUP_INFO_FILE'] = "./test_mirror_group_parallel_optimizer.pb" | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", | |||
| device_num=32, enable_parallel_optimizer=True) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, | |||
| parallel_optimizer_config={"parallel_optimizer_threshold": 1}, | |||
| enable_parallel_optimizer=True) | |||
| auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1))) | |||
| group_info_list = restore_group_info_list("./test_mirror_group_parallel_optimizer.pb") | |||
| assert group_info_list == [0] | |||
| @@ -130,8 +131,9 @@ def test_mirror_group_parallel_optimizer_not_full_shard(): | |||
| Expectation: group info list match expectation value. | |||
| """ | |||
| os.environ['GROUP_INFO_FILE'] = "./test_mirror_group_parallel_optimizer_not_full_shard.pb" | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", | |||
| device_num=32, enable_parallel_optimizer=True, optimizer_weight_shard_size=2) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, | |||
| parallel_optimizer_config={"parallel_optimizer_threshold": 2}, | |||
| enable_parallel_optimizer=True, optimizer_weight_shard_size=2) | |||
| auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1))) | |||
| group_info_list = restore_group_info_list("./test_mirror_group_parallel_optimizer_not_full_shard.pb") | |||
| assert group_info_list == [0, 8, 16, 24] | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -81,6 +81,21 @@ class Net3(nn.Cell): | |||
| return x - y | |||
| class Net4(nn.Cell): | |||
| """Net definition""" | |||
| def __init__(self, strategy1, strategy2): | |||
| super(Net4, self).__init__() | |||
| self.fc1 = P.MatMul().shard(strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy2) | |||
| self.p1 = Parameter(Tensor(np.ones([48, 1152]).astype(np.float32)), name="weight1") | |||
| self.p2 = Parameter(Tensor(np.ones([1152, 16]).astype(np.float32)), name="weight2") | |||
| def construct(self, x, y): | |||
| x = self.fc1(x, self.p1) | |||
| x = self.fc2(x, self.p2) | |||
| return x - y | |||
| def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True) | |||
| @@ -109,11 +124,13 @@ def test_auto_parallel_momentum_2(): | |||
| def test_auto_parallel_momentum_3(): | |||
| # hybrid parallel case | |||
| # weight1 could not be shard and weight2 is repeated | |||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2))) | |||
| dp = 4 | |||
| context.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 1}) | |||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((dp, 8), (8, 1)), ((dp, 4), (4, 2))) | |||
| param_dict = train_network.parameter_layout_dict | |||
| # validate opt_shard_group | |||
| assert not param_dict["weight1"][5] | |||
| assert param_dict["weight2"][5].startswith("4") | |||
| assert param_dict["weight2"][5].startswith(str(dp)) | |||
| def test_auto_parallel_momentum_4(): | |||
| @@ -124,6 +141,7 @@ def test_auto_parallel_momentum_4(): | |||
| def test_auto_parallel_momentum_5(): | |||
| # test parallel optimizer filter | |||
| context.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 1}) | |||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net3, ((4, 8), (8, 1)), ((4, 4), (4, 2))) | |||
| param_dict = train_network.parameter_layout_dict | |||
| # validate opt_shard_group | |||
| @@ -134,17 +152,45 @@ def test_auto_parallel_momentum_5(): | |||
| def test_auto_parallel_momentum_6(): | |||
| # test not fully use parallel optimizer with optimizer_weight_shard_size | |||
| # weight1 could not be shard and weight2 is repeated | |||
| context.set_auto_parallel_context(optimizer_weight_shard_size=2) | |||
| param_shard_group_size = 2 | |||
| context.set_auto_parallel_context(optimizer_weight_shard_size=param_shard_group_size) | |||
| context.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 1}) | |||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2))) | |||
| param_dict = train_network.parameter_layout_dict | |||
| # validate opt_shard_group | |||
| assert param_dict["weight1"][5].startswith("2") | |||
| assert param_dict["weight2"][5].startswith("2") | |||
| assert param_dict["weight1"][5].startswith(str(param_shard_group_size)) | |||
| assert param_dict["weight2"][5].startswith(str(param_shard_group_size)) | |||
| def test_default_threshold(): | |||
| """ | |||
| Feature: auto-parallel-optimizer(I4S85V) | |||
| Description: the memory size of weight2(72KB) is higher than the threshold(64KB). | |||
| Expectation: weight2 being sharded with sharding group size equal to dp. | |||
| """ | |||
| dp = 4 | |||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net4, ((dp, 8), (8, 1)), ((dp, 4), (4, 2))) | |||
| param_dict = train_network.parameter_layout_dict | |||
| # validate opt_shard_group | |||
| assert param_dict["weight2"][5] | |||
| def test_user_define_threshold(): | |||
| """ | |||
| Feature: auto-parallel-optimizer(I4S85V) | |||
| Description: the memory size of weight2(72KB) is lower than the threshold(100KB). | |||
| Expectation: weight2 being not sharded. | |||
| """ | |||
| dp = 4 | |||
| context.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 100}) | |||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net4, ((dp, 8), (8, 1)), ((dp, 4), (4, 2))) | |||
| param_dict = train_network.parameter_layout_dict | |||
| # validate opt_shard_group | |||
| assert not param_dict["weight2"][5] | |||
| def test_AdamWeightDecay(): | |||
| """ test_AdamWeightDecay """ | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True, | |||
| parallel_optimizer_config={"parallel_optimizer_threshold": 1}) | |||
| inputs = Tensor(np.ones([32, 128]).astype(np.float32)) | |||
| label = Tensor(np.zeros([32, 768]).astype(np.float32)) | |||
| net = Net() | |||
| @@ -160,7 +206,8 @@ def test_AdamWeightDecay(): | |||
| def test_lamb_compile(): | |||
| """ test_Lamb_compile """ | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True, | |||
| parallel_optimizer_config={"parallel_optimizer_threshold": 2}) | |||
| inputs = Tensor(np.ones([32, 128]).astype(np.float32)) | |||
| label = Tensor(np.zeros([32, 768]).astype(np.float32)) | |||
| net = Net() | |||
| @@ -177,7 +224,8 @@ def test_lamb_compile(): | |||
| def test_lamb_split_fusion(): | |||
| """ test_Lamb_split_fusion """ | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True, | |||
| all_reduce_fusion_config=[2, 4, 6, 8]) | |||
| all_reduce_fusion_config=[2, 4, 6, 8], | |||
| parallel_optimizer_config={"parallel_optimizer_threshold": 1}) | |||
| inputs = Tensor(np.ones([32, 128]).astype(np.float32)) | |||
| label = Tensor(np.zeros([32, 768]).astype(np.float32)) | |||
| net = Net() | |||
| @@ -209,4 +257,7 @@ def test_edge_case(): | |||
| with pytest.raises(RuntimeError): | |||
| context.set_auto_parallel_context(device_num=16) | |||
| Lamb(net.trainable_params(), learning_rate=0.1) | |||
| with pytest.raises(ValueError): | |||
| context.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": -1}) | |||
| Lamb(net.trainable_params(), learning_rate=0.1) | |||
| context.reset_auto_parallel_context() | |||
| @@ -66,8 +66,9 @@ class Net2(nn.Cell): | |||
| def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None, | |||
| interleaved_batch=2, stages=1, micro_size=1): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True, | |||
| pipeline_stages=stages) | |||
| context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, pipeline_stages=stages, | |||
| enable_parallel_optimizer=True, | |||
| parallel_optimizer_config={"parallel_optimizer_threshold": 1}) | |||
| inputs = Tensor(np.ones([64, 48]).astype(np.float32)) | |||
| label = Tensor(np.zeros([64, 16]).astype(np.float32)) | |||
| net = MicroBatchInterleaved(net(strategy1, strategy2), interleaved_batch) | |||