|
|
|
@@ -62,6 +62,12 @@ class AllGatherFusion : public CommunicationOpFusion { |
|
|
|
explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {}
|
|
|
|
~AllGatherFusion() override = default;
|
|
|
|
};
|
|
|
|
|
|
|
|
class BroadcastFusion : public CommunicationOpFusion {
|
|
|
|
public:
|
|
|
|
explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {}
|
|
|
|
~BroadcastFusion() override = default;
|
|
|
|
};
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace mindspore
|
|
|
|
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
|