|
|
|
@@ -162,6 +162,8 @@ class AllGather(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype == mstype.bool_: |
|
|
|
raise TypeError("AllGather does not support 'Bool' as the dtype of input!") |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
def __call__(self, tensor): |
|
|
|
@@ -219,6 +221,8 @@ class ReduceScatter(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype == mstype.bool_: |
|
|
|
raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!") |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
def __call__(self, tensor): |
|
|
|
@@ -276,6 +280,8 @@ class Broadcast(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype == mstype.bool_: |
|
|
|
raise TypeError("Broadcast does not support 'Bool' as the dtype of input!") |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -318,6 +324,8 @@ class _AlltoAll(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
if x_dtype == mstype.bool_: |
|
|
|
raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!") |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
def __call__(self, tensor): |
|
|
|
|