|
|
@@ -45,7 +45,6 @@ class AllReduce(PrimitiveWithInfer): |
|
|
|
|
|
|
|
|
Note: |
|
|
Note: |
|
|
The operation of AllReduce does not support "prod" currently. |
|
|
The operation of AllReduce does not support "prod" currently. |
|
|
The input of AllReduce does not support dtype "Bool". |
|
|
|
|
|
Tensor must have same shape and format in all processes participating in the collective. |
|
|
Tensor must have same shape and format in all processes participating in the collective. |
|
|
|
|
|
|
|
|
Args: |
|
|
Args: |
|
|
@@ -103,7 +102,7 @@ class AllReduce(PrimitiveWithInfer): |
|
|
return x_shape |
|
|
return x_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
def infer_dtype(self, x_dtype): |
|
|
if x_dtype == mstype.bool_: |
|
|
|
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!") |
|
|
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!") |
|
|
return x_dtype |
|
|
return x_dtype |
|
|
|
|
|
|
|
|
@@ -161,7 +160,7 @@ class AllGather(PrimitiveWithInfer): |
|
|
return x_shape |
|
|
return x_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
def infer_dtype(self, x_dtype): |
|
|
if x_dtype == mstype.bool_: |
|
|
|
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
return x_dtype |
|
|
return x_dtype |
|
|
|
|
|
|
|
|
@@ -176,6 +175,7 @@ class ReduceScatter(PrimitiveWithInfer): |
|
|
Note: |
|
|
Note: |
|
|
The back propagation of the op is not surported yet. Stay tuned for more. |
|
|
The back propagation of the op is not surported yet. Stay tuned for more. |
|
|
Tensor must have the same shape and format in all processes participating in the collective. |
|
|
Tensor must have the same shape and format in all processes participating in the collective. |
|
|
|
|
|
|
|
|
Args: |
|
|
Args: |
|
|
op (str): Specifies an operation used for element-wise reductions, |
|
|
op (str): Specifies an operation used for element-wise reductions, |
|
|
like sum, max, avg. Default: ReduceOp.SUM. |
|
|
like sum, max, avg. Default: ReduceOp.SUM. |
|
|
@@ -218,7 +218,7 @@ class ReduceScatter(PrimitiveWithInfer): |
|
|
return x_shape |
|
|
return x_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
def infer_dtype(self, x_dtype): |
|
|
if x_dtype == mstype.bool_: |
|
|
|
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
return x_dtype |
|
|
return x_dtype |
|
|
|
|
|
|
|
|
@@ -275,8 +275,11 @@ class Broadcast(PrimitiveWithInfer): |
|
|
return x_shape |
|
|
return x_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
def infer_dtype(self, x_dtype): |
|
|
if x_dtype == mstype.bool_: |
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
|
|
|
|
|
|
if not isinstance(x_dtype, tuple): |
|
|
|
|
|
raise TypeError(f"{self.name}'s input should be a tuple!") |
|
|
|
|
|
for _ele in x_dtype: |
|
|
|
|
|
if _ele.element_type() == mstype.bool_: |
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
return x_dtype |
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -318,7 +321,7 @@ class _AlltoAll(PrimitiveWithInfer): |
|
|
return x_shape |
|
|
return x_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
def infer_dtype(self, x_dtype): |
|
|
if x_dtype == mstype.bool_: |
|
|
|
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
return x_dtype |
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|