Browse Source

add_receive_eliminate_pass

tags/v1.1.0
lichenever 5 years ago
parent
commit
ee34ae9259
5 changed files with 27 additions and 1 deletions
  1. +3
    -0
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  2. +3
    -0
      mindspore/ccsrc/frontend/optimizer/irpass.h
  3. +19
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h
  4. +1
    -1
      mindspore/ccsrc/pipeline/jit/pass.cc
  5. +1
    -0
      mindspore/core/base/core_ops.h

+ 3
- 0
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -155,6 +155,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(), virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
"virtual_dataset_eliminate", prim::kPrimVirtualDataset); "virtual_dataset_eliminate", prim::kPrimVirtualDataset);


// Receive
receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive);

// Convert // Convert
print_tuple_wrapper_ = print_tuple_wrapper_ =
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint); MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);


+ 3
- 0
mindspore/ccsrc/frontend/optimizer/irpass.h View File

@@ -99,6 +99,9 @@ class OptimizeIRPassLib {
// virtual dataset // virtual dataset
SubstitutionPtr virtual_dataset_eliminate_; SubstitutionPtr virtual_dataset_eliminate_;


// Receive
SubstitutionPtr receive_eliminate_;

// Convert // Convert
SubstitutionPtr print_tuple_wrapper_; SubstitutionPtr print_tuple_wrapper_;




+ 19
- 0
mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h View File

@@ -102,6 +102,25 @@ class VirtualDatasetEliminater : public AnfVisitor {
void Visit(const AnfNodePtr &) override {} void Visit(const AnfNodePtr &) override {}
}; };


// {prim::kPrimReceive, X} -> prim::kPrimReceive
class ReceiveEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimReceive) || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() == 1) {
return nullptr;
}
std::vector<AnfNodePtr> args = {cnode->input(0)};
return node->func_graph()->NewCNode(args);
}

void Visit(const AnfNodePtr &) override {}
};

// {prim::kPrimSameTypeShape, X, Y} -> X // {prim::kPrimSameTypeShape, X, Y} -> X
class SameEliminater : public AnfVisitor { class SameEliminater : public AnfVisitor {
public: public:


+ 1
- 1
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -201,7 +201,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
{irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_});
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_});
opt::OptPassConfig b_2 = opt::OptPassConfig({ opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_, irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_, irpass.make_ref_eliminate_,


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -185,6 +185,7 @@ inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD");
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("_Receive");
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap"); inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap");
inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast"); inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast");


Loading…
Cancel
Save