Browse Source

!231 add bool type check in communication operator

Merge pull request !231 from chentingting/add_bool_type_check_in_comm_op
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
54481c30c8
1 changed files with 8 additions and 0 deletions
  1. +8
    -0
      mindspore/ops/operations/comm_ops.py

+ 8
- 0
mindspore/ops/operations/comm_ops.py View File

@@ -162,6 +162,8 @@ class AllGather(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("AllGather does not support 'Bool' as the dtype of input!")
return x_dtype

def __call__(self, tensor):
@@ -219,6 +221,8 @@ class ReduceScatter(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!")
return x_dtype

def __call__(self, tensor):
@@ -276,6 +280,8 @@ class Broadcast(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("Broadcast does not support 'Bool' as the dtype of input!")
return x_dtype


@@ -318,6 +324,8 @@ class _AlltoAll(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!")
return x_dtype

def __call__(self, tensor):


Loading…
Cancel
Save