Merge pull request !29819 from huangxinjing/add_global_norm_searchfeature/build-system-rewrite
| @@ -228,6 +228,22 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, | |||
| RankList DeviceManager::GetDeviceListInThisStage() const { return GetDeviceListByStageId(stage_id_); } | |||
| RankList DeviceManager::GetDeviceListBetweenStage() const { | |||
| std::vector<int64_t> rank_list; | |||
| auto rank_id = g_device_manager->global_rank(); | |||
| auto stage_id = g_device_manager->stage_id(); | |||
| auto stage_num = g_device_manager->stage_num(); | |||
| if (stage_num < 1) { | |||
| MS_LOG(EXCEPTION) << "Stage num got " << stage_num << ", expected a positive integer."; | |||
| } | |||
| auto device_num = parallel::ParallelContext::GetInstance()->device_num(); | |||
| auto per_stage_rank_num = device_num / stage_num; | |||
| for (int64_t i = 0; i < stage_num; ++i) { | |||
| rank_list.push_back(rank_id + per_stage_rank_num * (i - stage_id)); | |||
| } | |||
| return rank_list; | |||
| } | |||
| RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const { | |||
| if (LongToSize(stage_id) >= stage_devices_.size()) | |||
| MS_LOG(ERROR) << "the 'stage_id': " << stage_id | |||
| @@ -66,6 +66,7 @@ class DeviceManager { | |||
| static DeviceManager &GetInstance(); | |||
| RankList GetDeviceListByStageId(int64_t stage_id) const; | |||
| RankList GetDeviceListInThisStage() const; | |||
| RankList GetDeviceListBetweenStage() const; | |||
| Device CreateNewDeviceByRank(int64_t rank) const; | |||
| std::vector<Device> CreateDeviceListByRankList(RankList ranks); | |||
| @@ -319,6 +319,20 @@ Operator CreateVirtualDivOp(int64_t div_num) { | |||
| return op; | |||
| } | |||
| Operator CreateDivOp(float scale) { | |||
| OperatorName operator_name = REAL_DIV; | |||
| OperatorAttrs operator_attrs; | |||
| OperatorParams operator_param; | |||
| size_t parameter_pos = 2; | |||
| mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(scale); | |||
| ValuePtr scale_value = MakeValue(tensor_ptr); | |||
| operator_param.push_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos)); | |||
| OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); | |||
| Operator op = std::make_pair(operator_name, operator_arg); | |||
| return op; | |||
| } | |||
| static OperatorArgs CreateReduceCommunicationOpArgs(const std::string &reduce_op, const std::string &group) { | |||
| ValuePtr attr0_value = MakeValue(reduce_op); | |||
| ValuePtr attr1_value = MakeValue(group); | |||
| @@ -324,6 +324,7 @@ Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &grou | |||
| Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); | |||
| Operator CreateAllGatherOp(const std::string &group); | |||
| Operator CreateCastOp(TypePtr type); | |||
| Operator CreateDivOp(float scale); | |||
| Operator CreateMiniStepAllGatherOp(const std::string &group); | |||
| void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node); | |||
| void AddCommOpMirrorFlag(const CNodePtr &comm_node, bool do_mirror); | |||
| @@ -71,6 +71,7 @@ constexpr size_t SoftmaxCrossEntropyWithLogitsOutputsSize = 2; | |||
| constexpr size_t UNIQUE_INPUTS_SIZE = 1; | |||
| constexpr size_t UNIQUE_INPUT_SIZE = 1; | |||
| constexpr size_t UNIQUE_OUTPUTS_SIZE = 2; | |||
| constexpr size_t RESHAPE_INPUT_SIZE = 3; | |||
| constexpr size_t TRANSFER_PERMUTE_ARGS_SIZE = 5; | |||
| constexpr size_t TRANSFER_PERMUTE_SPLIT_COUNT_INDEX = 0; | |||
| constexpr size_t TRANSFER_PERMUTE_SPLIT_DIM_INDEX = 1; | |||
| @@ -82,6 +83,7 @@ constexpr size_t TRANSFER_CONCAT_TENSOR_DIM_INDEX = 0; | |||
| constexpr size_t TRANSFER_CONCAT_DEV_DIM_INDEX = 1; | |||
| constexpr size_t TRANSFER_CONCAT_SPLIT_COUNT_INDEX = 2; | |||
| constexpr size_t TRANSFER_SPLIT_ARGS_SIZE = 3; | |||
| constexpr size_t TUPLE_GETITEM_INDEX_POS = 2; | |||
| constexpr size_t MATMUL_DDS_INPUTS_SIZE = 4; | |||
| constexpr size_t MATMUL_DDS_OUTPUTS_SIZE = 2; | |||
| constexpr size_t MATMUL_DDS_STRATEGY_SIZE = 4; | |||
| @@ -216,11 +218,16 @@ constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; | |||
| constexpr char DARA_PARALLEL[] = "data_parallel"; | |||
| constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; | |||
| constexpr char FIELD_SIZE[] = "field_size"; | |||
| constexpr char Y[] = "Y"; | |||
| constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; | |||
| constexpr char DEVICE[] = "Device"; | |||
| constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather"; | |||
| constexpr char PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE[] = "parallel_optimizer_allgather_not_recompute"; | |||
| constexpr char PARALLEL_OPTIMIZER_COMM_OP[] = "parallel_optimizer_comm_op"; | |||
| constexpr char PARALLEL_GLOBALNORM[] = "PARALLEL_GLOBALNORM_IN_STAGES"; | |||
| constexpr char PARALLEL_GLOBALNORM_BETWEEN[] = "PARALLEL_GLOBALNORM_BETWEEN_STAGES"; | |||
| constexpr char PARALLEL_GLOBALNORM_DIV[] = "PARALLEL_GLOBALNORM_DIV"; | |||
| constexpr char GRAD_SCALE[] = "grad_scale"; | |||
| constexpr char CELLLIST_KEYWORD_PATTERN[] = "-CellList/(\\d+)-"; | |||
| constexpr char OUT_CHANNEL[] = "out_channel"; | |||
| @@ -296,6 +303,7 @@ constexpr char L2_NORMALIZE[] = "L2Normalize"; | |||
| constexpr char TRANSPOSE[] = "Transpose"; | |||
| constexpr char RESHAPE[] = "Reshape"; | |||
| constexpr char ADD[] = "Add"; | |||
| constexpr char ADDN[] = "AddN"; | |||
| constexpr char BIAS_ADD[] = "BiasAdd"; | |||
| constexpr char SUB[] = "Sub"; | |||
| constexpr char MUL[] = "Mul"; | |||
| @@ -93,6 +93,7 @@ bool PipelineTransformer::NeedGrad(const CNodePtr &cnode, const CNodePtr &graph_ | |||
| auto temp = input; | |||
| while (IsPrimitiveCNode(temp, prim::kPrimLoad) || IsPrimitiveCNode(temp, prim::kPrimCast)) { | |||
| auto input_cnode = temp->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||
| temp = input_cnode->input(1); | |||
| } | |||
| if (temp->isa<Parameter>()) { | |||
| @@ -177,13 +178,7 @@ void PipelineTransformer::LabelMicroBatch() { | |||
| } | |||
| void PipelineTransformer::CreateForwardGroup() { | |||
| std::vector<int64_t> rank_list; | |||
| auto rank_id = g_device_manager->global_rank(); | |||
| auto stage_id = g_device_manager->stage_id(); | |||
| auto stage_num = g_device_manager->stage_num(); | |||
| for (int64_t i = 0; i < stage_num; ++i) { | |||
| rank_list.push_back(rank_id + per_stage_rank_num_ * (i - stage_id)); | |||
| } | |||
| std::vector<int64_t> rank_list = g_device_manager->GetDeviceListBetweenStage(); | |||
| auto dev_list = g_device_manager->CreateDeviceListByRankList(rank_list); | |||
| auto g = g_device_manager->CreateGroup(rank_list); | |||
| auto g_back_name = g.name() + BACKWARD; | |||
| @@ -48,7 +48,6 @@ | |||
| #include "utils/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| #include "mindspore/core/utils/parallel_node_check.h" | |||
| #include "mindspore/ccsrc/pybind_api/ir/primitive_py.h" | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| @@ -61,9 +60,12 @@ namespace parallel { | |||
| static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; | |||
| static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE}; | |||
| static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL}; | |||
| static const std::vector<std::pair<const std::string, int64_t>> REDUCE_SUM_MATCH_PATTERN = { | |||
| std::make_pair(MAKE_TUPLE, 1), std::make_pair(ADDN, 1), std::make_pair(SQRT, 1)}; | |||
| // g_RefMap, for CNode B input i is a RefKey[Parameter C], | |||
| // it will be one item in map with key: C, and value: (B, i) | |||
| std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap; | |||
| const uint32_t MAX_BFS_DEPTH = 7; | |||
| void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool do_mirror, bool accu_flag) { | |||
| if (new_node_input.empty()) { | |||
| @@ -421,12 +423,6 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr & | |||
| return tensorinfo_in.tensor_layout(); | |||
| } | |||
| std::string GetPrimName(const CNodePtr &node) { | |||
| auto prim = GetCNodePrimitive(node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return prim->name(); | |||
| } | |||
| OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!IsParallelCareNode(node)) { | |||
| @@ -850,11 +846,11 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode) { | |||
| MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is not 3"; | |||
| } | |||
| if (!cnode->input(2)->isa<ValueNode>()) { | |||
| if (!cnode->input(TUPLE_GETITEM_INDEX_POS)->isa<ValueNode>()) { | |||
| MS_LOG(EXCEPTION) << "The index of tuple getitem is not a value node"; | |||
| } | |||
| ValuePtr tuple_index_value = GetValueNode(cnode->input(2)); | |||
| ValuePtr tuple_index_value = GetValueNode(cnode->input(TUPLE_GETITEM_INDEX_POS)); | |||
| MS_EXCEPTION_IF_NULL(tuple_index_value); | |||
| if (!tuple_index_value->isa<Int64Imm>()) { | |||
| MS_LOG(EXCEPTION) << "The index of tuple getitem is not int32"; | |||
| @@ -892,6 +888,51 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node | |||
| } | |||
| } | |||
| void InsertRealDivOpToNodeInput(const CNodePtr &node, int64_t scale, const string &instance_name) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (scale == 0) { | |||
| MS_LOG(EXCEPTION) << "Find the scale value is 0, you should check the mirror operators's group size."; | |||
| } | |||
| size_t node_size = node->inputs().size(); | |||
| FuncGraphPtr func_graph = node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| // instance the real div operator | |||
| Operator div_op = CreateDivOp(scale); | |||
| // Insert it as the input of the node | |||
| for (size_t index = 1; index < node_size; ++index) { | |||
| AnfNodePtr input = node->input(index); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| // if it is not a tensor, continue | |||
| if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) { | |||
| continue; | |||
| } | |||
| InsertNode(div_op, node, index, node->input(index), func_graph, instance_name); | |||
| } | |||
| } | |||
| void InsertAllReduceToNodeInput(const CNodePtr &node, const std::string &group, const std::string &instance_name) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| size_t node_size = node->inputs().size(); | |||
| FuncGraphPtr func_graph = node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| // instance the real div operator | |||
| CheckGlobalDeviceManager(); | |||
| Operator allreduce_op = CreateAllReduceOp(REDUCE_OP_SUM, group); | |||
| // Insert it as the input of the node | |||
| for (size_t index = 1; index < node_size; ++index) { | |||
| AnfNodePtr input = node->input(index); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| // if it is not a tensor, continue | |||
| if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) { | |||
| continue; | |||
| } | |||
| InsertNode(allreduce_op, node, index, node->input(index), func_graph, instance_name); | |||
| } | |||
| } | |||
| void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) { | |||
| vector<std::string> last_forward_node_ids; | |||
| vector<size_t> last_indexs; | |||
| @@ -2192,7 +2233,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) { | |||
| if (StrategyFound(attrs)) { | |||
| MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; | |||
| } | |||
| MS_ASSERT(cnode->inputs().size() == 3); | |||
| MS_ASSERT(cnode->inputs().size() == RESHAPE_INPUT_SIZE); | |||
| auto prev_layout_ptr = FindPrevLayout(cnode->input(1)); | |||
| if (prev_layout_ptr) { | |||
| auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | |||
| @@ -3092,6 +3133,148 @@ static void PipelinePostProcess(const FuncGraphPtr &root, const std::vector<AnfN | |||
| } | |||
| } | |||
| static void InsertAllReduceForNormValue(const AnfNodePtr &res_node) { | |||
| auto cnode = res_node->cast<CNodePtr>(); | |||
| auto graphs = res_node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(graphs); | |||
| auto manager = graphs->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto node_user_map = manager->node_users(); | |||
| if (!IsSomePrimitive(cnode, EXPAND_DIMS)) { | |||
| MS_LOG(ERROR) << "Expected the operator expand_dims, but found the " << GetPrimName(cnode) | |||
| << "This may cause the calculation of the global norm incorrect"; | |||
| return; | |||
| } | |||
| auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num(); | |||
| auto expand_dims_node = node_user_map.at(res_node).front().first; | |||
| auto sqrt_node = MatchPattern(expand_dims_node, node_user_map, REDUCE_SUM_MATCH_PATTERN); | |||
| if (!sqrt_node) return; | |||
| auto cur_stage_rank_list = g_device_manager->GetDeviceListInThisStage(); | |||
| Group cur_stage_device_list = g_device_manager->CreateGroup(cur_stage_rank_list); | |||
| InsertAllReduceToNodeInput(sqrt_node->cast<CNodePtr>(), cur_stage_device_list.name(), PARALLEL_GLOBALNORM); | |||
| MS_LOG(INFO) << "Insert the AllReduce for global norm value in stages succeed."; | |||
| if (pipeline_stages > 1) { | |||
| MS_LOG(INFO) << "Insert the AllReduce for global norm value between stages succeed."; | |||
| auto ranks_between_stages = g_device_manager->GetDeviceListBetweenStage(); | |||
| Group group_between_stages = g_device_manager->CreateGroup(ranks_between_stages); | |||
| InsertAllReduceToNodeInput(sqrt_node->cast<CNodePtr>(), group_between_stages.name(), PARALLEL_GLOBALNORM_BETWEEN); | |||
| } | |||
| } | |||
| AnfNodePtr FindPrimitiveWithAtrribute(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map, uint32_t limits) { | |||
| std::queue<AnfNodePtr> visited; | |||
| AnfNodePtr queue_node = nullptr; | |||
| CNodePtr cnode = nullptr; | |||
| AnfNodePtr last_node = nullptr; | |||
| uint32_t depth = 0; | |||
| if (!node_ptr) { | |||
| return nullptr; | |||
| } | |||
| visited.push(node_ptr); | |||
| while (!visited.empty()) { | |||
| queue_node = visited.front(); | |||
| visited.pop(); | |||
| cnode = queue_node->cast<CNodePtr>(); | |||
| // MAKE_TUPLE will not appear after the load in the forward graph | |||
| if (IsSomePrimitive(cnode, EXPAND_DIMS)) { | |||
| auto value = GetAttrsFromAnfNode(queue_node, GRAD_SCALE); | |||
| if (!value || !GetValue<bool>(value)) { | |||
| continue; | |||
| } | |||
| return queue_node; | |||
| } | |||
| if (!IsSomePrimitiveList(cnode, {ENVIRONGET, MUL, SQUARE, REDUCE_SUM, EXPAND_DIMS, DEPEND, CAST, REF_TO_EMBED})) { | |||
| continue; | |||
| } | |||
| auto node_set = node_users_map.at(queue_node); | |||
| for (auto &node_user : node_set) { | |||
| visited.push(node_user.first); | |||
| } | |||
| if (!last_node || last_node == queue_node) { | |||
| if (++depth == limits) { | |||
| break; | |||
| } | |||
| last_node = visited.back(); | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| static void InsertDivAndAllReduceForNorm(const NodeUsersMap &node_user_map, const AnfNodePtr ¶meter, | |||
| uint32_t dev_num) { | |||
| AnfNodePtr expand_dims_node = nullptr; | |||
| AnfNodePtr prefix_node = nullptr; | |||
| auto params_user_set = node_user_map.at(parameter); | |||
| for (auto ¶m_pair : params_user_set) { | |||
| expand_dims_node = nullptr; | |||
| auto cnode = param_pair.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->in_forward_flag()) { | |||
| continue; | |||
| } | |||
| expand_dims_node = FindPrimitiveWithAtrribute(cnode, node_user_map, MAX_BFS_DEPTH); | |||
| if (!expand_dims_node) { | |||
| continue; | |||
| } | |||
| auto value = GetAttrsFromAnfNode(expand_dims_node, GRAD_SCALE); | |||
| if (!value || !GetValue<bool>(value)) { | |||
| continue; | |||
| } | |||
| InsertRealDivOpToNodeInput(expand_dims_node->cast<CNodePtr>(), dev_num, PARALLEL_GLOBALNORM_DIV); | |||
| MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->DebugString() | |||
| << "succeed!"; | |||
| // If already inserted allreduce, the pattern will not be matched and thus no allreduce will be inserted. | |||
| InsertAllReduceForNormValue(expand_dims_node); | |||
| } | |||
| } | |||
| static AnfNodePtr GetMirrorOp(const NodeUsersMap &node_user_map, const AnfNodePtr ¶meter) { | |||
| auto params_user_set = node_user_map.at(parameter); | |||
| for (auto ¶m_pair : params_user_set) { | |||
| auto cnode = param_pair.first->cast<CNodePtr>(); | |||
| std::vector<AnfNodePtr> candidate = {cnode}; | |||
| if (!cnode->in_forward_flag()) { | |||
| continue; | |||
| } | |||
| if (IsInTrivialNodeList(cnode) || IsSomePrimitive(cnode, LOAD)) { | |||
| auto load_users = node_user_map.at(param_pair.first); | |||
| std::transform(load_users.begin(), load_users.end(), std::back_inserter(candidate), | |||
| [](const auto &v) { return v.first; }); | |||
| } | |||
| for (auto &node : candidate) { | |||
| auto local_cnode = node->cast<CNodePtr>(); | |||
| if (!IsPrimitiveCNode(local_cnode, prim::kPrimMirror) && | |||
| !IsPrimitiveCNode(local_cnode, prim::kPrimMirrorMicroStep) && | |||
| !IsPrimitiveCNode(local_cnode, prim::kPrimMirrorMiniStep)) { | |||
| continue; | |||
| } | |||
| return node; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| static void HandlGlobalNormScale(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | |||
| const FuncGraphManagerPtr &manager) { | |||
| auto parameters = root->parameters(); | |||
| auto node_user_map = manager->node_users(); | |||
| MS_LOG(INFO) << "Start to process the global norm"; | |||
| for (auto ¶meter : parameters) { | |||
| if (!ParameterRequireGrad(parameter)) continue; | |||
| auto mirror_node = GetMirrorOp(node_user_map, parameter); | |||
| if (!mirror_node) continue; | |||
| auto device_num_ptr = GetAttrsFromAnfNode(mirror_node, DEV_NUM); | |||
| if (!device_num_ptr) { | |||
| MS_LOG(ERROR) << "The mirror operator is excepted to have device number attribute, but found none. This " | |||
| "will cause the global norm calculation with wrong precision."; | |||
| continue; | |||
| } | |||
| auto dev_num = device_num_ptr->cast<Int64ImmPtr>()->value(); | |||
| if (dev_num == 0) continue; | |||
| InsertDivAndAllReduceForNorm(node_user_map, parameter, dev_num); | |||
| } | |||
| } | |||
| bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__)) | |||
| if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { | |||
| @@ -3201,6 +3384,8 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| // handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter | |||
| HandleFullySplitParameters(root); | |||
| HandlGlobalNormScale(root, all_nodes, manager); | |||
| DumpGraph(root, std::string(STEP_PARALLEL_END)); | |||
| // step parallel only run once | |||
| @@ -58,6 +58,17 @@ bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { | |||
| return (prim->name() == name); | |||
| } | |||
| bool IsSomePrimitiveList(const CNodePtr &cnode, const std::set<string> &check_list) { | |||
| return std::any_of(check_list.begin(), check_list.end(), | |||
| [cnode](const string &in) { return IsSomePrimitive(cnode, in); }); | |||
| } | |||
| std::string GetPrimName(const CNodePtr &node) { | |||
| auto prim = GetCNodePrimitive(node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return prim->name(); | |||
| } | |||
| TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info) { | |||
| auto user_cnode = param_info.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(user_cnode); | |||
| @@ -108,11 +119,6 @@ AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr | |||
| return first_node; | |||
| } | |||
| bool IsInNodeList(const CNodePtr &cnode, const std::set<string> &check_list) { | |||
| return std::any_of(check_list.begin(), check_list.end(), | |||
| [cnode](const string &in) { return IsSomePrimitive(cnode, in); }); | |||
| } | |||
| bool IsParallelCareNode(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| @@ -394,9 +400,9 @@ AnfNodePtr GetChildCastNode(const AnfNodePtr &node_ptr, const NodeUsersMap &node | |||
| visited.pop(); | |||
| cnode = queue_node->cast<CNodePtr>(); | |||
| // MAKE_TUPLE will not appear after the load in the forward graph | |||
| if (IsInNodeList(cnode, {MAKE_TUPLE})) { | |||
| if (IsSomePrimitive(cnode, MAKE_TUPLE)) { | |||
| continue; | |||
| } else if (IsInAllGatherNodeList(cnode) || IsInNodeList(cnode, {LOAD, RESHAPE})) { | |||
| } else if (IsInAllGatherNodeList(cnode) || IsSomePrimitiveList(cnode, {LOAD, RESHAPE})) { | |||
| auto node_set = node_users_map.at(queue_node); | |||
| for (auto &node_user : node_set) { | |||
| visited.push(node_user.first); | |||
| @@ -473,6 +479,7 @@ AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, cons | |||
| type_node->set_abstract(compute_node_type->ToAbstract()); | |||
| auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), pre_node, type_node}); | |||
| new_node->set_abstract(node->abstract()); | |||
| new_node->set_in_forward_flag(true); | |||
| return new_node; | |||
| } | |||
| @@ -504,5 +511,44 @@ void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<Value> GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> &node, const string &key) { | |||
| if (!node) return nullptr; | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto prim = GetCNodePrimitive(cnode); | |||
| if (prim && prim->HasAttr(key)) { | |||
| return prim->GetAttr(key); | |||
| } | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr MatchPattern(const AnfNodePtr &node, const NodeUsersMap &user_map, | |||
| const std::vector<std::pair<const std::string, int64_t>> &match_pattern) { | |||
| AnfNodePtr start_node = node; | |||
| bool find = false; | |||
| for (uint32_t i = 0; i < match_pattern.size(); ++i) { | |||
| find = false; | |||
| if (!IsSomePrimitive(start_node->cast<CNodePtr>(), {match_pattern[i].first})) { | |||
| break; | |||
| } else if (i == match_pattern.size() - 1) { | |||
| find = true; | |||
| break; | |||
| } | |||
| auto next_node_users = user_map.at(start_node); | |||
| for (auto &next_node : next_node_users) { | |||
| if (i + 1 < match_pattern.size() && | |||
| IsSomePrimitive(next_node.first->cast<CNodePtr>(), {match_pattern[i + 1].first}) && | |||
| next_node.second == match_pattern[i + 1].second) { | |||
| start_node = next_node.first; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (!find) { | |||
| start_node = nullptr; | |||
| } | |||
| return start_node; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -20,6 +20,8 @@ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <set> | |||
| #include <memory> | |||
| #include "base/base.h" | |||
| #include "frontend/parallel/device_manager.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | |||
| @@ -30,20 +32,24 @@ const int64_t TWO_INPUT_SIZE = 2; | |||
| // common method | |||
| bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name); | |||
| bool IsSomePrimitiveList(const CNodePtr &cnode, const std::set<string> &check_list); | |||
| bool IsParallelCareNode(const CNodePtr &cnode); | |||
| Shapes GetNodeShape(const AnfNodePtr &node); | |||
| std::string GetPrimName(const CNodePtr &node); | |||
| std::shared_ptr<Value> GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> &node, const string &key); | |||
| std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, | |||
| const CNodePtr &node); | |||
| std::string CreateInstanceName(const CNodePtr &node, size_t index); | |||
| TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info); | |||
| AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager); | |||
| AnfNodePtr MatchPattern(const AnfNodePtr &node, const NodeUsersMap &user_map, | |||
| const std::vector<std::pair<const std::string, int64_t>> &match_pattern); | |||
| // for specific scenarios | |||
| RankList FindCommonMirrorGroup(const FuncGraphPtr &root); | |||
| void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input); | |||
| void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes); | |||
| AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const TypePtr &compute_node_type); | |||
| AnfNodePtr GetChildCastNode(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map); | |||
| TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map); | |||
| void LabelGenMaskMicro(const FuncGraphPtr &root); | |||
| void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes); | |||
| @@ -128,11 +128,16 @@ def clip_by_value(x, clip_value_min=None, clip_value_max=None): | |||
| return x_max | |||
| # The attribute grad_scale is needed for enabling the parallel mode | |||
| # If this is removed, c.clip_by_global_norm will have precision error in semi/auto parallel mode. | |||
| expand_dims = P.ExpandDims().add_prim_attr("grad_scale", True) | |||
| get_square_sum = C.MultitypeFuncGraph("get_square_sum") | |||
| @get_square_sum.register("Tensor") | |||
| def _get_square_sum(x): | |||
| norm = P.ReduceSum(False)(F.square(x), ()) | |||
| norm = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||
| norm = expand_dims(F.cast(norm, mstype.float32), 0) | |||
| return norm | |||
| @@ -0,0 +1,195 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test global norm test """ | |||
| import re | |||
| import os | |||
| import shutil | |||
| import glob | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| import mindspore.dataset as ds | |||
| from mindspore import Tensor, Parameter, Model | |||
| from mindspore.train import DynamicLossScaleManager | |||
| from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell, MicroBatchInterleaved, PipelineCell | |||
| from mindspore.nn.optim import AdamWeightDecay | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore import context | |||
| class Net(nn.Cell): | |||
| """Net definition""" | |||
| def __init__(self, param_type, strategy1, strategy2): | |||
| super(Net, self).__init__() | |||
| self.fc1 = P.MatMul().shard(strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy2) | |||
| self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(param_type)), name="weight1") | |||
| self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(param_type)), name="weight2", parallel_optimizer=False) | |||
| self.sub = P.Sub() | |||
| def construct(self, x, y): | |||
| x = P.Cast()(x, ms.float16) | |||
| p1 = P.Cast()(self.p1, ms.float16) | |||
| p2 = P.Cast()(self.p2, ms.float16) | |||
| x = self.fc1(x, p1) | |||
| x = self.fc2(x, p2) | |||
| return self.sub(x, y) | |||
| class Net2(nn.Cell): | |||
| """Net definition""" | |||
| def __init__(self, param_type, strategy1, strategy2): | |||
| super(Net2, self).__init__() | |||
| self.net1 = Net(param_type, strategy1, strategy2) | |||
| self.net2 = Net(param_type, strategy1, strategy2) | |||
| self.net1.pipeline_stage = 0 | |||
| self.net2.pipeline_stage = 1 | |||
| self.sub = P.Sub() | |||
| def construct(self, x, y): | |||
| out1 = self.net1(x, y) | |||
| out2 = self.net2(x, y) | |||
| return self.sub(out1, out2) | |||
| def get_dataset(): | |||
| inputs = np.ones([64, 48]).astype(np.float32) | |||
| label = np.zeros([64, 16]).astype(np.float32) | |||
| def dataset_generator(): | |||
| for _ in range(10): | |||
| yield inputs, label | |||
| dataset = ds.GeneratorDataset(dataset_generator, column_names=["inputs", "label"]) | |||
| return dataset | |||
| class CustomOptimizer(AdamWeightDecay): | |||
| def __init__(self, params): | |||
| super(CustomOptimizer, self).__init__(params) | |||
| self.optimizer = super(CustomOptimizer, self).construct | |||
| def construct(self, gradients): | |||
| grads = C.clip_by_global_norm(gradients) | |||
| return self.optimizer(grads) | |||
| def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None, | |||
| interleaved_batch=2, stages=1, micro_size=1, param_type=np.float32, | |||
| loss_scale_manager=None): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True, | |||
| pipeline_stages=stages) | |||
| net = MicroBatchInterleaved(net(param_type, strategy1, strategy2), interleaved_batch) | |||
| if stages > 1: | |||
| net = PipelineCell(net, micro_size=micro_size) | |||
| net = _VirtualDatasetCell(net).set_comm_fusion(4) | |||
| parameters = net.trainable_params() if stages == 1 else net.infer_param_pipeline_stage() | |||
| optimizer = CustomOptimizer(parameters) | |||
| if loss_scale_manager: | |||
| model = Model(net, optimizer=optimizer, loss_scale_manager=loss_scale_manager) | |||
| else: | |||
| model = Model(net, optimizer=optimizer) | |||
| dataset = get_dataset() | |||
| model.train(1, dataset) | |||
| class TestGlobalNormInserted: | |||
| def setup_method(self): | |||
| self.output_path = './graphs' + self.__str__() | |||
| context.set_context(save_graphs=True, | |||
| save_graphs_path=self.output_path) | |||
| def teardown_method(self): | |||
| shutil.rmtree(self.output_path) | |||
| def run_count_check(self, target_count, pattern): | |||
| """ | |||
| This function will check the target_key counts with the golden one. | |||
| :param target_count: The gold float16 count in the Ir files. | |||
| :param pattern: The generated keyword in the Ir files. | |||
| """ | |||
| # Find the step_parallel_end | |||
| ir_files = glob.glob(os.path.join(self.output_path, 'rank_0', 'step_parallel_end*.ir')) | |||
| assert len(ir_files) == 1 | |||
| appear_count = 0 | |||
| with open(ir_files[0], 'r') as fp: | |||
| for line in fp: | |||
| res = re.findall(pattern, line) | |||
| if len(res) >= 1: | |||
| appear_count += 1 | |||
| assert appear_count == target_count | |||
| def test_nonpipeline_global_norm(self): | |||
| """ | |||
| Feature: Parallel ClipByGlobalNorm | |||
| Description: Test the global norm when running in semi auto parallel mode, scale for data parallel should be 8 | |||
| Expectation:When there is no real div inserted or AllReduce inserted | |||
| """ | |||
| auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)), | |||
| interleaved_batch=1, param_type=np.float32) | |||
| self.run_count_check(target_count=1, pattern=r"=8.*PARALLEL_GLOBALNORM_DIV") | |||
| self.run_count_check(target_count=2, pattern=r"PARALLEL_GLOBALNORM") | |||
| def test_pipeline_global_norm(self): | |||
| """ | |||
| Feature: Parallel ClipByGlobalNorm | |||
| Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8 | |||
| Expectation: When there is no real div inserted or AllReduce inserted | |||
| """ | |||
| auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)), | |||
| interleaved_batch=1, stages=2, micro_size=2, param_type=np.float32) | |||
| self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV") | |||
| self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM") | |||
| def test_pipeline_global_norm_loss_scale(self): | |||
| """ | |||
| Feature: Parallel ClipByGlobalNorm | |||
| Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8 | |||
| Expectation: When there is no real div inserted or AllReduce inserted | |||
| """ | |||
| auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)), | |||
| interleaved_batch=1, stages=2, micro_size=2, param_type=np.float32, | |||
| loss_scale_manager=DynamicLossScaleManager()) | |||
| self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV") | |||
| self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM") | |||
| def test_pipeline_global_norm_fp16(self): | |||
| """ | |||
| Feature: Parallel ClipByGlobalNorm | |||
| Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8 | |||
| Expectation: When there is no real div inserted or AllReduce inserted | |||
| """ | |||
| auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)), | |||
| interleaved_batch=1, stages=2, micro_size=2, param_type=np.float16) | |||
| self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV") | |||
| self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM") | |||
| def test_pipeline_global_norm_loss_scale_fp16(self): | |||
| """ | |||
| Feature: Parallel ClipByGlobalNorm | |||
| Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8 | |||
| Expectation: When there is no real div inserted or AllReduce inserted | |||
| """ | |||
| auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)), | |||
| interleaved_batch=1, stages=2, micro_size=2, param_type=np.float16, | |||
| loss_scale_manager=DynamicLossScaleManager()) | |||
| self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV") | |||
| self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM") | |||