Browse Source

!10969 fix dropout pynative unify_ir pattern

From: @yuchaojie
Reviewed-by: @zhoufeng54,@kisnwang
Signed-off-by: @kisnwang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
f54939aa99
1 changed files with 5 additions and 23 deletions
  1. +5
    -23
      mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc

+ 5
- 23
mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc View File

@@ -314,28 +314,14 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,

const BaseRef DropoutUnifyMindIRPynative::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
VarPtr Z = std::make_shared<Var>();
auto dropout = VectorRef({prim::kPrimDropout, X});
auto getitem0 = VectorRef({prim::kPrimTupleGetItem, dropout, Y});
auto getitem1 = VectorRef({prim::kPrimTupleGetItem, dropout, Z});
auto maketuple = VectorRef({prim::kPrimMakeTuple, getitem0, getitem1});
return maketuple;
return VectorRef({prim::kPrimDropout, X});
}

const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto maketuple_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maketuple_cnode);
auto getitem0_node = maketuple_cnode->input(1);
MS_EXCEPTION_IF_NULL(getitem0_node);
auto getitem1_node = maketuple_cnode->input(2);
MS_EXCEPTION_IF_NULL(getitem1_node);
auto getitem1_cnode = getitem1_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem1_cnode);
auto dropout_node = getitem1_cnode->input(1);
auto dropout_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_node);

auto inputx_type_id = GetInputXDataType(dropout_node);
@@ -368,13 +354,9 @@ const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_gr
dropout_do_mask->set_abstract(do_mask_abstract);
dropout_do_mask->set_scope(node->scope());

// replace genmask and domask
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
(void)manager->Replace(getitem0_node, dropout_do_mask);
(void)manager->Replace(getitem1_node, dropout_gen_mask);

return node;
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), dropout_do_mask, dropout_gen_mask};
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple;
}

const BaseRef DropoutGradUnifyMindIRPynative::DefinePattern() const {


Loading…
Cancel
Save