diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 8a1b3736b5..8142b305df 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -155,6 +155,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset); + // Receive + receive_eliminate_ = MakeSubstitution(std::make_shared(), "receive_eliminate", prim::kPrimReceive); + // Convert print_tuple_wrapper_ = MakeSubstitution(std::make_shared(), "print_tuple_wrapper", prim::kPrimPrint); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 890352c406..e5b2371f92 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -99,6 +99,9 @@ class OptimizeIRPassLib { // virtual dataset SubstitutionPtr virtual_dataset_eliminate_; + // Receive + SubstitutionPtr receive_eliminate_; + // Convert SubstitutionPtr print_tuple_wrapper_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index 3cd09d80c3..b8699f32fa 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -102,6 +102,25 @@ class VirtualDatasetEliminater : public AnfVisitor { void Visit(const AnfNodePtr &) override {} }; +// {prim::kPrimReceive, X} -> prim::kPrimReceive +class ReceiveEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimReceive) || node->func_graph() == nullptr) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() == 1) { + return nullptr; + } + std::vector args = {cnode->input(0)}; + return node->func_graph()->NewCNode(args); + } + + void Visit(const AnfNodePtr &) override {} +}; + // {prim::kPrimSameTypeShape, X, Y} -> X class SameEliminater : public AnfVisitor { public: diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 59be55f43a..fe3d7844e4 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -201,7 +201,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, - irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_}); + irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, irpass.make_ref_eliminate_, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 7ac2553bc2..42318173c4 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -185,6 +185,7 @@ inline const PrimitivePtr kPrimSGD = std::make_shared("SGD"); inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); inline const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); +inline const PrimitivePtr kPrimReceive = std::make_shared("_Receive"); inline const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); inline const PrimitivePtr kPrimAllSwap = std::make_shared("AllSwap"); inline const PrimitivePtr kPrimBroadcast = std::make_shared("Broadcast");