|
|
|
@@ -314,28 +314,14 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, |
|
|
|
|
|
|
|
const BaseRef DropoutUnifyMindIRPynative::DefinePattern() const { |
|
|
|
VarPtr X = std::make_shared<Var>(); |
|
|
|
VarPtr Y = std::make_shared<Var>(); |
|
|
|
VarPtr Z = std::make_shared<Var>(); |
|
|
|
auto dropout = VectorRef({prim::kPrimDropout, X}); |
|
|
|
auto getitem0 = VectorRef({prim::kPrimTupleGetItem, dropout, Y}); |
|
|
|
auto getitem1 = VectorRef({prim::kPrimTupleGetItem, dropout, Z}); |
|
|
|
auto maketuple = VectorRef({prim::kPrimMakeTuple, getitem0, getitem1}); |
|
|
|
return maketuple; |
|
|
|
return VectorRef({prim::kPrimDropout, X}); |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto maketuple_cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(maketuple_cnode); |
|
|
|
auto getitem0_node = maketuple_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem0_node); |
|
|
|
auto getitem1_node = maketuple_cnode->input(2); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem1_node); |
|
|
|
auto getitem1_cnode = getitem1_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem1_cnode); |
|
|
|
auto dropout_node = getitem1_cnode->input(1); |
|
|
|
auto dropout_node = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_node); |
|
|
|
|
|
|
|
auto inputx_type_id = GetInputXDataType(dropout_node); |
|
|
|
@@ -368,13 +354,9 @@ const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_gr |
|
|
|
dropout_do_mask->set_abstract(do_mask_abstract); |
|
|
|
dropout_do_mask->set_scope(node->scope()); |
|
|
|
|
|
|
|
// replace genmask and domask |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
(void)manager->Replace(getitem0_node, dropout_do_mask); |
|
|
|
(void)manager->Replace(getitem1_node, dropout_gen_mask); |
|
|
|
|
|
|
|
return node; |
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), dropout_do_mask, dropout_gen_mask}; |
|
|
|
auto make_tuple = func_graph->NewCNode(make_tuple_inputs); |
|
|
|
return make_tuple; |
|
|
|
} |
|
|
|
|
|
|
|
const BaseRef DropoutGradUnifyMindIRPynative::DefinePattern() const { |
|
|
|
|