|
|
|
@@ -68,6 +68,13 @@ class BroadcastFusion : public CommunicationOpFusion { |
|
|
|
explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {}
|
|
|
|
~BroadcastFusion() override = default;
|
|
|
|
};
|
|
|
|
|
|
|
|
class ReduceScatterFusion : public CommunicationOpFusion {
|
|
|
|
public:
|
|
|
|
explicit ReduceScatterFusion(size_t groups = 1)
|
|
|
|
: CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {}
|
|
|
|
~ReduceScatterFusion() override = default;
|
|
|
|
};
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace mindspore
|
|
|
|
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
|