/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "frontend/parallel/step_parallel_utils.h" #include #include #include #include #include #include #include #include "utils/hash_map.h" #include "base/core_ops.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/optimizer.h" #include "frontend/parallel/context.h" #include "frontend/parallel/device_manager.h" #include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/graph_util/graph_info.h" #include "frontend/parallel/graph_util/node_info.h" #include "frontend/parallel/graph_util/pipeline_split_utils.h" #include "frontend/parallel/node_check.h" #include "frontend/parallel/parameter_manager.h" #include "ir/param_info.h" #include "ir/tensor.h" #include "utils/trace_base.h" #include "utils/comm_manager.h" #include "utils/ms_context.h" #include "utils/symbolic.h" #include "mindspore/core/utils/parallel_node_check.h" namespace mindspore { namespace parallel { bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { if (!cnode) { return false; } ValueNodePtr anf_node = cnode->input(0)->cast(); MS_EXCEPTION_IF_NULL(anf_node); PrimitivePtr prim = anf_node->value()->cast(); return (prim->name() == name); } bool IsParallelCareNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); ValueNodePtr prim_node = cnode->input(0)->cast(); if (prim_node == nullptr) { return false; } PrimitivePtr prim = prim_node->value()->cast(); if (prim == nullptr) { return false; } if (IsInParallelBlackList(prim)) { MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name(); return false; } // get_next is not in the forward graph, we need mark the get_next as the forward node if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) { return true; } if ((prim->name() == CAST) && !cnode->has_user_data()) { return false; } return cnode->in_forward_flag(); } Shapes GetValueListShape(const AnfNodePtr &node) { Shapes shapes; std::vector inputs_seq; if (IsValueNode(node)) { inputs_seq = node->cast()->value()->cast()->value(); } else if (IsValueNode(node)) { inputs_seq = node->cast()->value()->cast()->value(); } else { MS_LOG(EXCEPTION) << "node is eigther ValueList or ValueTuple"; } for (auto &ele : inputs_seq) { auto tensor = ele->cast(); if (tensor == nullptr) { MS_LOG(WARNING) << "The value node is not a tensor"; break; } auto one_shape = tensor->shape(); shapes.push_back(one_shape); } return shapes; } Shapes GetNodeShape(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); Shapes shapes; if (IsValueNode(node) || IsValueNode(node)) { return GetValueListShape(node); } BaseShapePtr base_shape_ptr = node->Shape(); if (node->isa()) { auto cnode = node->cast(); if (IsValueNode(cnode->input(0))) { PrimitivePtr prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(prim); if (prim->name() == MAKEREF) { AnfNodePtr ref_node = cnode->input(1); auto func_graph = cnode->func_graph(); MS_EXCEPTION_IF_NULL(ref_node); MS_EXCEPTION_IF_NULL(func_graph); return GetRefKeyNodeShape(ref_node, func_graph); } } if (cnode->input(0)->isa()) { if (cnode->inputs().size() < 2) { MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2"; } base_shape_ptr = cnode->input(1)->Shape(); } } if (base_shape_ptr == nullptr) { MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is " << node->fullname_with_scope(); } auto tuple_shape_ptr = dyn_cast(base_shape_ptr); if (tuple_shape_ptr != nullptr) { auto tuple_shape = tuple_shape_ptr->shape(); for (auto &shape : tuple_shape) { auto each_shape = dyn_cast(shape); MS_EXCEPTION_IF_NULL(each_shape); shapes.push_back(each_shape->shape()); } } else { auto shape_ptr = dyn_cast(base_shape_ptr); MS_EXCEPTION_IF_NULL(shape_ptr); shapes.push_back(shape_ptr->shape()); } return shapes; } RankList FindCommonMirrorGroup(const FuncGraphPtr &root) { auto parameters = root->parameters(); for (auto ¶meter : parameters) { auto param_ptr = parameter->cast(); MS_EXCEPTION_IF_NULL(param_ptr); if (!(param_ptr->has_default() && ParameterRequireGrad(param_ptr))) { continue; } size_t allow_repeat_num = 1; if (ParallelContext::GetInstance()->enable_parallel_optimizer() && (!param_ptr->param_info() || param_ptr->param_info()->parallel_optimizer())) { if (ParallelContext::GetInstance()->optimizer_weight_shard_size() == -1) { MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope() << " is fully shard by optimizer parallel," " thus cannot find common data parallel group for this rank"; return {g_device_manager->global_rank()}; } allow_repeat_num = size_t(ParallelContext::GetInstance()->optimizer_weight_shard_size()); } if (IsFullySplitParameter(param_ptr, allow_repeat_num)) { MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope() << " is fully shard, thus cannot find common data parallel group for this rank"; return {g_device_manager->global_rank()}; } } AnfNodePtr ret = root->get_return(); MS_EXCEPTION_IF_NULL(ret); std::vector common_group_list; std::vector all_nodes = DeepScopedGraphSearch(ret); bool is_first_group = true; for (auto &node : all_nodes) { if (!IsPrimitiveCNode(node, prim::kPrimMirror) && !IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) && !IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep)) { continue; } auto prim = GetCNodePrimitive(node); if (!prim->HasAttr(GROUP)) { MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString(); } std::string group_name = GetValue(prim->GetAttr(GROUP)); std::vector group_list = g_device_manager->FindRankListByHashName(group_name); if (is_first_group) { common_group_list = group_list; is_first_group = false; } else { std::vector new_comm_group_list; std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(), std::back_inserter(new_comm_group_list)); common_group_list = new_comm_group_list; } } MS_LOG(INFO) << "The common mirror group is:" << common_group_list; return common_group_list; } std::string CreateInstanceName(const CNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); if (!IsValueNode(node->input(0))) { MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; } std::string name_base = node->fullname_with_scope(); std::string name = name_base + "_" + std::to_string(index); std::string instance_name = HashInstanceName(name); return instance_name; } void SetCommunicationOpGroupLabel(std::vector new_node_input) { if (new_node_input.empty()) { return; } auto prim_anf_node = new_node_input[0]->cast(); auto prim = GetValueNode(prim_anf_node); MS_EXCEPTION_IF_NULL(prim); auto attrs = prim->attrs(); auto iter = attrs.find(GROUP); if (iter != attrs.end()) { auto value = iter->second; MS_EXCEPTION_IF_NULL(value); if (value->isa()) { std::string hash_name = value->cast()->value(); MS_EXCEPTION_IF_NULL(g_device_manager); std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name); (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name)); } } } std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, const CNodePtr &node) { OperatorArgs arg_replace_op = replace_op.second; ValuePtr pyop_instance = CreateOpInstance(arg_replace_op.first, replace_op.first, instance_name); if (pyop_instance == nullptr) { MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreateOpInstance failed"; } OperatorParams params = arg_replace_op.second; if (node->inputs().size() < 2) { // GetNext operator dose not has input if (node->inputs().size() == 1) { return {NewValueNode(pyop_instance)}; } MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; } std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; if (replace_op.first == EMBEDDING_LOOKUP) { replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; } if (!params.empty()) { Param param_first = *(params.begin()); int64_t first_position = param_first.second; if (first_position == 1) { replace_input.pop_back(); } for (auto ¶m : params) { AnfNodePtr val = NewValueNode(param.first.second); if (val == nullptr) { MS_LOG(EXCEPTION) << "Failure:val is nullptr"; } int64_t position = param.second; (void)replace_input.insert(replace_input.begin() + position, val); } } else if (replace_op.first == SYNC_BATCH_NORM) { for (size_t i = 2; i < node->inputs().size(); ++i) { replace_input.push_back(node->input(i)); } } SetCommunicationOpGroupLabel(replace_input); return replace_input; } void SetStridedSliceSplitStrategy(const std::vector &all_nodes) { for (auto &node : all_nodes) { if (!node->isa()) { continue; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) { continue; } auto slice_prim = GetCNodePrimitive(cnode); MS_EXCEPTION_IF_NULL(slice_prim); if (slice_prim->HasAttr(FUNC_GRAPH_FLAG_STRIDED_SLICE)) { SetStridedSliceStrategy(cnode); } } } } // namespace parallel } // namespace mindspore