Browse Source

!787 Fix dtype judge sentence in infer_dtype function of hcom operations

Merge pull request !787 from zhouyuanshen/r0.2
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
f4e8bca783
3 changed files with 13 additions and 11 deletions
  1. +9
    -8
      mindspore/ops/operations/comm_ops.py
  2. +1
    -1
      tests/ut/python/communication/test_comm.py
  3. +3
    -2
      tests/ut/python/parallel/test_bool_grad.py

+ 9
- 8
mindspore/ops/operations/comm_ops.py View File

@@ -45,7 +45,6 @@ class AllReduce(PrimitiveWithInfer):

Note:
The operation of AllReduce does not support "prod" currently.
The input of AllReduce does not support dtype "Bool".
Tensor must have same shape and format in all processes participating in the collective.

Args:
@@ -103,7 +102,7 @@ class AllReduce(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
if x_dtype.element_type() == mstype.bool_:
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!")
return x_dtype

@@ -161,7 +160,7 @@ class AllGather(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
if x_dtype.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype

@@ -218,7 +217,7 @@ class ReduceScatter(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
if x_dtype.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype

@@ -275,11 +274,13 @@ class Broadcast(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
if not isinstance(x_dtype, tuple):
raise TypeError(f"{self.name}'s input should be a tuple!")
for _ele in x_dtype:
if _ele.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype


class _AlltoAll(PrimitiveWithInfer):
"""
AlltoAll is a collective operation.
@@ -318,7 +319,7 @@ class _AlltoAll(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
if x_dtype.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype



+ 1
- 1
tests/ut/python/communication/test_comm.py View File

@@ -55,7 +55,7 @@ class BroadCastNet(nn.Cell):
self.broadcast = Broadcast(0)

def construct(self, x):
x = self.broadcast((x))
x, = self.broadcast((x,))
x = self.dense(x)
return x



+ 3
- 2
tests/ut/python/parallel/test_bool_grad.py View File

@@ -52,7 +52,7 @@ class CommonNet(nn.Cell):
def __init__(self):
super(CommonNet, self).__init__()
self.weight = Parameter(Tensor(np.ones([256, 64]), dtype=ms.float32), name="mul_weight")
self.logicalnot = P.LogicalNot().set_strategy(((4,1),))
self.logicalnot = P.LogicalNot().set_strategy(((4,2),))
self.equal = P.Equal().set_strategy(((4,2),(4,2)))

def construct(self, x, label):
@@ -78,4 +78,5 @@ def common_net():


def test_bool_grad():
common_net()
common_net()


Loading…
Cancel
Save