Browse Source

fix ParallelConcat

tags/v0.6.0-beta
jiangjinsheng 5 years ago
parent
commit
bdcc607b1a
1 changed files with 7 additions and 4 deletions
  1. +7
    -4
      mindspore/ops/operations/array_ops.py

+ 7
- 4
mindspore/ops/operations/array_ops.py View File

@@ -1532,7 +1532,8 @@ class ParallelConcat(PrimitiveWithInfer):
The input tensors are all required to have size 1 in the first dimension. The input tensors are all required to have size 1 in the first dimension.


Inputs: Inputs:
- **values** (tuple, list) - Tuple or list of input tensors.
- **values** (tuple, list) - Tuple or list of input tensors. The data type and shape of these
tensors must be same.


Outputs: Outputs:
Tensor, data type same as `values`. Tensor, data type same as `values`.
@@ -1542,6 +1543,7 @@ class ParallelConcat(PrimitiveWithInfer):
>>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32)) >>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32))
>>> op = P.ParallelConcat() >>> op = P.ParallelConcat()
>>> output = op((data1, data2)) >>> output = op((data1, data2))
[[0, 1], [2, 1]]
""" """


@prim_attr_register @prim_attr_register
@@ -1553,14 +1555,15 @@ class ParallelConcat(PrimitiveWithInfer):
x_type = values['dtype'] x_type = values['dtype']


validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name) validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name)

args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)

first_elem = x_shp[0] first_elem = x_shp[0]
args = {}
for i, elem in enumerate(x_shp[1:]): for i, elem in enumerate(x_shp[1:]):
j = i + 1 j = i + 1
args[f'x_type[{j}]'] = x_type[j]
validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name) validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name)
validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name) validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name)
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)


ret_shp = x_shp[0].copy() ret_shp = x_shp[0].copy()
ret_shp[0] = len(x_shp) ret_shp[0] = len(x_shp)


Loading…
Cancel
Save