| @@ -155,6 +155,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(), | virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(), | ||||
| "virtual_dataset_eliminate", prim::kPrimVirtualDataset); | "virtual_dataset_eliminate", prim::kPrimVirtualDataset); | ||||
| // Receive | |||||
| receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive); | |||||
| // Convert | // Convert | ||||
| print_tuple_wrapper_ = | print_tuple_wrapper_ = | ||||
| MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint); | MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint); | ||||
| @@ -99,6 +99,9 @@ class OptimizeIRPassLib { | |||||
| // virtual dataset | // virtual dataset | ||||
| SubstitutionPtr virtual_dataset_eliminate_; | SubstitutionPtr virtual_dataset_eliminate_; | ||||
| // Receive | |||||
| SubstitutionPtr receive_eliminate_; | |||||
| // Convert | // Convert | ||||
| SubstitutionPtr print_tuple_wrapper_; | SubstitutionPtr print_tuple_wrapper_; | ||||
| @@ -102,6 +102,25 @@ class VirtualDatasetEliminater : public AnfVisitor { | |||||
| void Visit(const AnfNodePtr &) override {} | void Visit(const AnfNodePtr &) override {} | ||||
| }; | }; | ||||
| // {prim::kPrimReceive, X} -> prim::kPrimReceive | |||||
| class ReceiveEliminater : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimReceive) || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (cnode->inputs().size() == 1) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> args = {cnode->input(0)}; | |||||
| return node->func_graph()->NewCNode(args); | |||||
| } | |||||
| void Visit(const AnfNodePtr &) override {} | |||||
| }; | |||||
| // {prim::kPrimSameTypeShape, X, Y} -> X | // {prim::kPrimSameTypeShape, X, Y} -> X | ||||
| class SameEliminater : public AnfVisitor { | class SameEliminater : public AnfVisitor { | ||||
| public: | public: | ||||
| @@ -201,7 +201,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, | {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, | ||||
| irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | ||||
| irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | ||||
| irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_}); | |||||
| irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); | |||||
| opt::OptPassConfig b_2 = opt::OptPassConfig({ | opt::OptPassConfig b_2 = opt::OptPassConfig({ | ||||
| irpass.replace_refkey_by_param_, | irpass.replace_refkey_by_param_, | ||||
| irpass.make_ref_eliminate_, | irpass.make_ref_eliminate_, | ||||
| @@ -185,6 +185,7 @@ inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD"); | |||||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | ||||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | ||||
| inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("_Receive"); | |||||
| inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | ||||
| inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap"); | inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap"); | ||||
| inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast"); | inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast"); | ||||