diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc index 89a8d5fbdb..4145ff63a8 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc @@ -314,28 +314,14 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const BaseRef DropoutUnifyMindIRPynative::DefinePattern() const { VarPtr X = std::make_shared(); - VarPtr Y = std::make_shared(); - VarPtr Z = std::make_shared(); - auto dropout = VectorRef({prim::kPrimDropout, X}); - auto getitem0 = VectorRef({prim::kPrimTupleGetItem, dropout, Y}); - auto getitem1 = VectorRef({prim::kPrimTupleGetItem, dropout, Z}); - auto maketuple = VectorRef({prim::kPrimMakeTuple, getitem0, getitem1}); - return maketuple; + return VectorRef({prim::kPrimDropout, X}); } const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); - auto maketuple_cnode = node->cast(); - MS_EXCEPTION_IF_NULL(maketuple_cnode); - auto getitem0_node = maketuple_cnode->input(1); - MS_EXCEPTION_IF_NULL(getitem0_node); - auto getitem1_node = maketuple_cnode->input(2); - MS_EXCEPTION_IF_NULL(getitem1_node); - auto getitem1_cnode = getitem1_node->cast(); - MS_EXCEPTION_IF_NULL(getitem1_cnode); - auto dropout_node = getitem1_cnode->input(1); + auto dropout_node = node->cast(); MS_EXCEPTION_IF_NULL(dropout_node); auto inputx_type_id = GetInputXDataType(dropout_node); @@ -368,13 +354,9 @@ const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_gr dropout_do_mask->set_abstract(do_mask_abstract); dropout_do_mask->set_scope(node->scope()); - // replace genmask and domask - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(getitem0_node, dropout_do_mask); - (void)manager->Replace(getitem1_node, dropout_gen_mask); - - return node; + std::vector make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), dropout_do_mask, dropout_gen_mask}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; } const BaseRef DropoutGradUnifyMindIRPynative::DefinePattern() const {