|
|
|
@@ -95,8 +95,6 @@ class AllReduce(PrimitiveWithInfer): |
|
|
|
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): |
|
|
|
if not isinstance(op, type(ReduceOp.SUM)): |
|
|
|
raise TypeError("The operation of AllReduce should be str.") |
|
|
|
if op == ReduceOp.PROD: |
|
|
|
raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.") |
|
|
|
if not isinstance(_get_group(group), str): |
|
|
|
raise TypeError("The group of AllReduce should be str.") |
|
|
|
self.op = op |
|
|
|
|