|
|
|
@@ -45,7 +45,6 @@ class AllReduce(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Note: |
|
|
|
The operation of AllReduce does not support "prod" currently. |
|
|
|
The input of AllReduce does not support dtype "Bool". |
|
|
|
Tensor must have same shape and format in all processes participating in the collective. |
|
|
|
|
|
|
|
Args: |
|
|
|
@@ -103,7 +102,7 @@ class AllReduce(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype == mstype.bool_: |
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
|
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!") |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
@@ -161,7 +160,7 @@ class AllGather(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype == mstype.bool_: |
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
@@ -218,7 +217,7 @@ class ReduceScatter(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype == mstype.bool_: |
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
@@ -275,11 +274,13 @@ class Broadcast(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype == mstype.bool_: |
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
|
if not isinstance(x_dtype, tuple): |
|
|
|
raise TypeError(f"{self.name}'s input should be a tuple!") |
|
|
|
for _ele in x_dtype: |
|
|
|
if _ele.element_type() == mstype.bool_: |
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
class _AlltoAll(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
AlltoAll is a collective operation. |
|
|
|
@@ -318,7 +319,7 @@ class _AlltoAll(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype == mstype.bool_: |
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|