|
|
|
@@ -39,6 +39,8 @@ class ReduceOp: |
|
|
|
PROD = "prod" |
|
|
|
|
|
|
|
|
|
|
|
target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32) |
|
|
|
|
|
|
|
class AllReduce(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
Reduces the tensor data across all devices in such a way that all devices will get the same final result. |
|
|
|
@@ -102,8 +104,7 @@ class AllReduce(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
|
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!") |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -161,8 +162,7 @@ class AllGather(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
def __call__(self, tensor): |
|
|
|
@@ -219,8 +219,7 @@ class ReduceScatter(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
def __call__(self, tensor): |
|
|
|
@@ -279,8 +278,7 @@ class Broadcast(PrimitiveWithInfer): |
|
|
|
if not isinstance(x_dtype, tuple): |
|
|
|
raise TypeError(f"{self.name}'s input should be a tuple!") |
|
|
|
for _ele in x_dtype: |
|
|
|
if _ele.element_type() == mstype.bool_: |
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
|
validator.check_tensor_type_same({'x': _ele}, target_dtypes, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -322,8 +320,7 @@ class _AlltoAll(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype.element_type() == mstype.bool_: |
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
def __call__(self, tensor): |
|
|
|
|