/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "frontend/parallel/step_parallel_utils.h" #include #include #include #include #include #include #include #include #include #include "utils/hash_map.h" #include "base/core_ops.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/optimizer.h" #include "include/common/utils/parallel_context.h" #include "frontend/parallel/device_manager.h" #include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/graph_util/graph_info.h" #include "frontend/parallel/graph_util/node_info.h" #include "frontend/parallel/graph_util/pipeline_split_utils.h" #include "frontend/parallel/node_check.h" #include "frontend/parallel/parameter_manager.h" #include "ir/param_info.h" #include "ir/tensor.h" #include "utils/trace_base.h" #include "include/common/utils/comm_manager.h" #include "utils/ms_context.h" #include "utils/symbolic.h" #include "mindspore/core/utils/parallel_node_check.h" namespace mindspore { namespace parallel { bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { if (!cnode) return false; ValueNodePtr anf_node = cnode->input(0)->cast(); if (!anf_node) return false; PrimitivePtr prim = anf_node->value()->cast(); if (!prim) return false; return (prim->name() == name); } bool IsSomePrimitiveList(const CNodePtr &cnode, const std::set &check_list) { return std::any_of(check_list.begin(), check_list.end(), [cnode](const string &in) { return IsSomePrimitive(cnode, in); }); } std::string GetPrimName(const CNodePtr &node) { auto prim = GetCNodePrimitive(node); MS_EXCEPTION_IF_NULL(prim); return prim->name(); } TensorInfo GetInputsTensorInfo(const std::pair ¶m_info) { auto user_cnode = param_info.first->cast(); MS_EXCEPTION_IF_NULL(user_cnode); auto user_input_index = param_info.second; OperatorInfoPtr op_info = user_cnode->user_data(); MS_EXCEPTION_IF_NULL(op_info); TensorInfo tensor_info; if (IsPrimitiveCNode(user_cnode, prim::kPrimSend)) { auto param_index = IntToSize(GetValue(user_cnode->GetPrimalAttr(PARAM_INDEX))); tensor_info = op_info->inputs_tensor_info()[param_index]; } else { size_t input_tensor_info_size = op_info->inputs_tensor_info().size(); if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) { MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size << ", but the index is " << (user_input_index - 1); } tensor_info = op_info->inputs_tensor_info()[LongToSize(user_input_index - 1)]; } return tensor_info; } AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { auto node_users = manager->node_users()[node]; bool is_first_tensor_info = true; TensorInfo first_tensor_info; AnfNodePtr first_node; for (auto &node_user : node_users) { auto user_node = node_user.first->cast(); if (!user_node->has_user_data()) { continue; } auto tensor_info = GetInputsTensorInfo(node_user); if (is_first_tensor_info) { is_first_tensor_info = false; first_tensor_info = tensor_info; first_node = node_user.first; continue; } if (first_tensor_info == tensor_info) { continue; } else { MS_LOG(EXCEPTION) << "The node: " << node->DebugString() << " has multiple users, but the TensorInfo are different"; } } return first_node; } bool IsParallelCareNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); ValueNodePtr prim_node = cnode->input(0)->cast(); if (prim_node == nullptr) { return false; } PrimitivePtr prim = prim_node->value()->cast(); if (prim == nullptr) { return false; } if (IsInParallelBlackList(prim)) { MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name(); return false; } // get_next is not in the forward graph, we need mark the get_next as the forward node if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) { return true; } if ((prim->name() == CAST) && !cnode->has_user_data()) { return false; } return cnode->in_forward_flag(); } Shapes GetValueListShape(const AnfNodePtr &node) { Shapes shapes; std::vector inputs_seq; if (IsValueNode(node)) { inputs_seq = node->cast()->value()->cast()->value(); } else if (IsValueNode(node)) { inputs_seq = node->cast()->value()->cast()->value(); } else { MS_LOG(EXCEPTION) << "node is eigther ValueList or ValueTuple"; } for (auto &ele : inputs_seq) { auto tensor = ele->cast(); if (tensor == nullptr) { MS_LOG(WARNING) << "The value node is not a tensor"; break; } auto one_shape = tensor->shape(); shapes.push_back(one_shape); } return shapes; } Shapes GetNodeShape(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); Shapes shapes; if (IsValueNode(node) || IsValueNode(node)) { return GetValueListShape(node); } BaseShapePtr base_shape_ptr = node->Shape(); if (node->isa()) { auto cnode = node->cast(); if (IsValueNode(cnode->input(0))) { PrimitivePtr prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(prim); if (prim->name() == MAKEREF) { AnfNodePtr ref_node = cnode->input(1); auto func_graph = cnode->func_graph(); MS_EXCEPTION_IF_NULL(ref_node); MS_EXCEPTION_IF_NULL(func_graph); return GetRefKeyNodeShape(ref_node, func_graph); } } if (cnode->input(0)->isa()) { if (cnode->inputs().size() < 2) { MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2"; } base_shape_ptr = cnode->input(1)->Shape(); } } if (base_shape_ptr == nullptr) { MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is " << node->fullname_with_scope(); } auto tuple_shape_ptr = dyn_cast(base_shape_ptr); if (tuple_shape_ptr != nullptr) { auto tuple_shape = tuple_shape_ptr->shape(); for (auto &shape : tuple_shape) { auto each_shape = dyn_cast(shape); MS_EXCEPTION_IF_NULL(each_shape); shapes.push_back(each_shape->shape()); } } else { auto shape_ptr = dyn_cast(base_shape_ptr); MS_EXCEPTION_IF_NULL(shape_ptr); shapes.push_back(shape_ptr->shape()); } return shapes; } RankList FindCommonMirrorGroup(const FuncGraphPtr &root) { auto parameters = root->parameters(); for (auto ¶meter : parameters) { auto param_ptr = parameter->cast(); MS_EXCEPTION_IF_NULL(param_ptr); if (!(param_ptr->has_default() && ParameterRequireGrad(param_ptr))) { continue; } size_t allow_repeat_num = 1; if (ParallelContext::GetInstance()->enable_parallel_optimizer() && (!param_ptr->param_info() || param_ptr->param_info()->parallel_optimizer())) { if (ParallelContext::GetInstance()->optimizer_weight_shard_size() == -1) { MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope() << " is fully shard by optimizer parallel," " thus cannot find common data parallel group for this rank"; return {g_device_manager->global_rank()}; } allow_repeat_num = size_t(ParallelContext::GetInstance()->optimizer_weight_shard_size()); } if (IsFullySplitParameter(param_ptr, allow_repeat_num)) { MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope() << " is fully shard, thus cannot find common data parallel group for this rank"; return {g_device_manager->global_rank()}; } } AnfNodePtr ret = root->get_return(); MS_EXCEPTION_IF_NULL(ret); std::vector common_group_list; std::vector all_nodes = DeepScopedGraphSearch(ret); bool is_first_group = true; for (auto &node : all_nodes) { if (!IsPrimitiveCNode(node, prim::kPrimMirror) && !IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) && !IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep)) { continue; } auto prim = GetCNodePrimitive(node); if (!prim->HasAttr(GROUP)) { MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString(); } std::string group_name = GetValue(prim->GetAttr(GROUP)); std::vector group_list = g_device_manager->FindRankListByHashName(group_name); if (is_first_group) { common_group_list = group_list; is_first_group = false; } else { std::vector new_comm_group_list; (void)std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(), std::back_inserter(new_comm_group_list)); common_group_list = new_comm_group_list; } } MS_LOG(INFO) << "The common mirror group is:" << common_group_list; return common_group_list; } std::string CreateInstanceName(const CNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); if (!IsValueNode(node->input(0))) { MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; } std::string name_base = node->fullname_with_scope(); std::string name = name_base + "_" + std::to_string(index); std::string instance_name = HashInstanceName(name); return instance_name; } void SetCommunicationOpGroupLabel(std::vector new_node_input) { if (new_node_input.empty()) { return; } auto prim_anf_node = new_node_input[0]->cast(); auto prim = GetValueNode(prim_anf_node); MS_EXCEPTION_IF_NULL(prim); auto attrs = prim->attrs(); auto iter = attrs.find(GROUP); if (iter != attrs.end()) { auto value = iter->second; MS_EXCEPTION_IF_NULL(value); if (value->isa()) { std::string hash_name = value->cast()->value(); MS_EXCEPTION_IF_NULL(g_device_manager); std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name); (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name)); } } } std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, const CNodePtr &node) { OperatorArgs arg_replace_op = replace_op.second; ValuePtr pyop_instance = CreateOpInstance(arg_replace_op.first, replace_op.first, instance_name); if (pyop_instance == nullptr) { MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreateOpInstance failed"; } OperatorParams params = arg_replace_op.second; if (node->inputs().size() < 2) { // GetNext operator dose not has input if (node->inputs().size() == 1) { return {NewValueNode(pyop_instance)}; } MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; } std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; if (replace_op.first == EMBEDDING_LOOKUP) { replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; } if (!params.empty()) { Param param_first = *(params.begin()); int64_t first_position = param_first.second; if (first_position == 1) { replace_input.pop_back(); } for (auto ¶m : params) { AnfNodePtr val = NewValueNode(param.first.second); if (val == nullptr) { MS_LOG(EXCEPTION) << "Failure:val is nullptr"; } int64_t position = param.second; (void)replace_input.insert(replace_input.begin() + position, val); } } else if (replace_op.first == SYNC_BATCH_NORM) { for (size_t i = 2; i < node->inputs().size(); ++i) { replace_input.push_back(node->input(i)); } } SetCommunicationOpGroupLabel(replace_input); return replace_input; } void SetStridedSliceSplitStrategy(const std::vector &all_nodes) { for (auto &node : all_nodes) { if (!node->isa()) { continue; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) { continue; } auto slice_prim = GetCNodePrimitive(cnode); MS_EXCEPTION_IF_NULL(slice_prim); if (slice_prim->HasAttr(FUNC_GRAPH_FLAG_STRIDED_SLICE)) { SetStridedSliceStrategy(cnode); } } } // Check the given tensor, return nullptr if the given type is not an TensorType bool CheckTensorType(const TypePtr &node_type) { MS_EXCEPTION_IF_NULL(node_type); if (!node_type->isa()) { return false; } return true; } // For the weight used by cast and matmul at the same time, like the followings // weight1->mirror->cast1-> matmul1; // weight1->add // we will not insert the cast(FP32->FP16), as it will cause the input of the operator add to be changed to fp16. AnfNodePtr GetChildCastNode(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map) { std::queue visited; AnfNodePtr queue_node = nullptr; CNodePtr cnode = nullptr; AnfNodePtr node = nullptr; if (!node_ptr) { return nullptr; } auto users = node_users_map.at(node_ptr); for (auto &node_user : users) { cnode = node_user.first->cast(); if (!cnode || !cnode->in_forward_flag()) { continue; } if (node_user.first) { visited.push(node_user.first); } } while (!visited.empty()) { queue_node = visited.front(); visited.pop(); cnode = queue_node->cast(); // MAKE_TUPLE will not appear after the load in the forward graph if (IsSomePrimitive(cnode, MAKE_TUPLE)) { continue; } else if (IsInAllGatherNodeList(cnode) || IsSomePrimitiveList(cnode, {LOAD, RESHAPE})) { auto node_set = node_users_map.at(queue_node); for (auto &node_user : node_set) { visited.push(node_user.first); } } else if (!IsSomePrimitive(cnode, CAST)) { MS_LOG(INFO) << "The weight's users including the non cast node So " << "will not insert cast for this parameter " << node_ptr->DebugString(); return nullptr; } else if (!node) { node = queue_node; } } return node; } // Given the cnode ptr, find its users until we find the computation node, then return the type of the // computation node. This function is used to find the target type for CreateFP16Cast. Only returns the target type if // it is float16, and the source node is float32. If the situation is not matched, then return the nullptr. TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map) { auto node_ptr = cnode_ptr->cast(); if (!node_ptr) { return nullptr; } auto cnode_inputs = cnode_ptr->inputs(); if (cnode_inputs.size() < TWO_INPUT_SIZE) { return nullptr; } // As we execute the function IsWeightValidUsed when we start to insert the mirror, so the second parameter // is always the parameter. auto weight = cnode_inputs[1]; if (!weight->isa()) { return nullptr; } MS_LOG(INFO) << "Start to search the weight params:" << weight->DebugString(); AnfNodePtr node = GetChildCastNode(weight, node_users_map); if (!node) { return nullptr; } // get the output dtype of the operator auto node_type = node->Type(); if (!CheckTensorType(node_type)) { return nullptr; } auto input_element_type = node_type->cast()->element(); MS_EXCEPTION_IF_NULL(input_element_type); auto source_node_type = node_ptr->Type(); if (!CheckTensorType(source_node_type)) { return nullptr; } auto source_element_type = source_node_type->cast()->element(); MS_EXCEPTION_IF_NULL(input_element_type); // We only add cast operation when the source is fp32 type, and the users is fp16 type. if (source_element_type->type_id() == kNumberTypeFloat32 && input_element_type->type_id() == kNumberTypeFloat16) { return input_element_type; } return nullptr; } // Create a cast node given the current node and the previous node. The target type of the the cast is from the // compute_node_type. // Return the new cast node with pre_node as the inputs. AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const TypePtr &compute_node_type) { const char kOpsFunctionModelName[] = "mindspore.ops.functional"; static py::object cast_prim = python_adapter::GetPyFn(kOpsFunctionModelName, "cast"); const auto &adapter = py::cast(cast_prim); MS_EXCEPTION_IF_NULL(adapter); MS_EXCEPTION_IF_NULL(compute_node_type); auto prim = adapter->attached_primitive(); if (prim == nullptr) { prim = std::make_shared(cast_prim, adapter); } // Insert cast. auto type_node = NewValueNode(compute_node_type); type_node->set_abstract(compute_node_type->ToAbstract()); auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), pre_node, type_node}); new_node->set_abstract(node->abstract()); new_node->set_in_forward_flag(true); return new_node; } void LabelGenMaskMicro(const FuncGraphPtr &root) { AnfNodePtr ret = root->get_return(); MS_EXCEPTION_IF_NULL(ret); std::vector all_nodes = DeepScopedGraphSearch(ret); for (auto &node : all_nodes) { if (IsPrimitiveCNode(node, prim::kPrimDropoutDoMask)) { auto gen_mask_node = RealInputNode(node->cast(), 2); if (gen_mask_node->isa()) { gen_mask_node->cast()->set_primal_attrs(node->cast()->primal_attrs()); } } } } void SetCastForParamNotRecompute(const std::vector &all_nodes) { for (const auto &node : all_nodes) { if (!IsPrimitiveCNode(node, prim::kPrimCast)) { continue; } auto cnode = node->cast(); auto cast_input = RealInputNode(cnode, 1); if (cast_input->isa()) { MS_LOG(INFO) << "Cast for parameter no needs recompute to avoid redundant trans_data operator"; PrimitivePtr prim = GetValueNode(cnode->input(0)->cast()); (void)prim->AddAttr("recompute", MakeValue(false)); } } } std::shared_ptr GetAttrsFromAnfNode(const std::shared_ptr &node, const string &key) { if (!node) return nullptr; auto cnode = node->cast(); auto prim = GetCNodePrimitive(cnode); if (prim && prim->HasAttr(key)) { return prim->GetAttr(key); } return nullptr; } AnfNodePtr MatchPattern(const AnfNodePtr &node, const NodeUsersMap &user_map, const std::vector> &match_pattern) { AnfNodePtr start_node = node; bool find = false; for (uint32_t i = 0; i < match_pattern.size(); ++i) { find = false; if (!IsSomePrimitive(start_node->cast(), {match_pattern[i].first})) { break; } else if (i == match_pattern.size() - 1) { find = true; break; } auto next_node_users = user_map.at(start_node); for (auto &next_node : next_node_users) { if (i + 1 < match_pattern.size() && IsSomePrimitive(next_node.first->cast(), {match_pattern[i + 1].first}) && next_node.second == match_pattern[i + 1].second) { start_node = next_node.first; break; } } } if (!find) { start_node = nullptr; } return start_node; } } // namespace parallel } // namespace mindspore