diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 1a6f87df97..a5d28584d5 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -55,6 +55,10 @@ static const std::set INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; // it will be one item in map with key: C, and value: (B, i) static std::map> g_RefMap; static void HandleNoUsedParameter(const FuncGraphPtr &root); +static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, + const std::string &instance_name); +static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, + const std::string &opt_shard_group); void SetCommunicationOpGroupLabel(std::vector new_node_input) { if (new_node_input.empty()) { @@ -125,6 +129,30 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An MS_LOG(INFO) << "Insert " << instance_name << " success"; } +// Replace pre_node with pre_node->op +static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, + const std::string &instance_name) { + // insert new node before the node + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + ScopePtr scope = pre_node->scope(); + MS_EXCEPTION_IF_NULL(scope); + std::vector node_input = CreateInput(op, pre_node, instance_name); + CNodePtr new_node = func_graph->NewCNode(node_input); + MS_EXCEPTION_IF_NULL(new_node); + if (instance_name.find(SPLIT_SENS) == std::string::npos) { + new_node->set_in_forward_flag(true); // mark forward flag + } + auto new_node_prim = GetValueNode(node_input[0]); + new_node_prim->set_instance_name(instance_name); + new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); + new_node->set_scope(scope); + node_input[0]->set_scope(scope); + manager->Replace(pre_node, new_node); + MS_LOG(INFO) << "Insert " << instance_name << " success"; + return new_node; +} + std::string CreateInstanceName(const CNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); if (!IsValueNode(node->input(0))) { @@ -1380,18 +1408,26 @@ void InsertAllGatherOp(const std::string &group, const std::paircast(); auto graph = cnode->func_graph(); MS_EXCEPTION_IF_NULL(graph); - InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); + auto cnode_prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(cnode_prim); + CNodePtr allgather; + if (cnode_prim->name() == CAST) { + allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER); + } else { + InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); + allgather = cnode->input(res.second)->cast(); + } // add fusion flag - auto allgather = cnode->input(res.second)->cast(); + MS_EXCEPTION_IF_NULL(allgather); auto prim = GetValueNode(allgather->input(0)); auto attrs = prim->attrs(); // enable fusion flag later when it's supported in backend - attrs["fusion"] = MakeValue(0); + attrs["fusion"] = MakeValue(0); prim->SetAttrs(attrs); } -void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, - const std::string &opt_shard_group) { +static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, + const std::string &opt_shard_group) { if (opt_shard_group.empty()) { return; } diff --git a/tests/ut/python/parallel/test_loss_and_o2_level.py b/tests/ut/python/parallel/test_loss_and_o2_level.py index 358b11e8d3..05b112d89f 100755 --- a/tests/ut/python/parallel/test_loss_and_o2_level.py +++ b/tests/ut/python/parallel/test_loss_and_o2_level.py @@ -119,3 +119,12 @@ def test_neg_repeat_calc2(): strategy2 = ((4, 4),) net = Net(_w1, strategy1, strategy2) compile_net(net) + + +def test_parallel_optimizer_with_mix_precision(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0, + enable_parallel_optimizer=True) + strategy1 = ((8, 1), (8, 1)) + strategy2 = ((8, 1),) + net = Net(_w1, strategy1, strategy2) + compile_net(net)