Browse Source

!1124 Add broadcast fusion pass

Merge pull request !1124 from YuJianfeng/master
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
8b98f921cc
3 changed files with 7 additions and 0 deletions
  1. +1
    -0
      mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
  2. +6
    -0
      mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h
  3. +0
    -0
      tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc

+ 1
- 0
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc View File

@@ -276,6 +276,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto other_pm = std::make_shared<PassManager>("other_pm");
other_pm->AddPass(std::make_shared<AllReduceFusion>());
other_pm->AddPass(std::make_shared<AllGatherFusion>());
other_pm->AddPass(std::make_shared<BroadcastFusion>());
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
other_pm->AddPass(std::make_shared<BufferFusion>());


+ 6
- 0
mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h View File

@@ -62,6 +62,12 @@ class AllGatherFusion : public CommunicationOpFusion {
explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {}
~AllGatherFusion() override = default;
};
class BroadcastFusion : public CommunicationOpFusion {
public:
explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {}
~BroadcastFusion() override = default;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_

tests/ut/cpp/pre_activate/common/ir_fusion/allreduce_fusion_test.cc → tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc View File


Loading…
Cancel
Save