| @@ -1532,7 +1532,8 @@ class ParallelConcat(PrimitiveWithInfer): | |||||
| The input tensors are all required to have size 1 in the first dimension. | The input tensors are all required to have size 1 in the first dimension. | ||||
| Inputs: | Inputs: | ||||
| - **values** (tuple, list) - Tuple or list of input tensors. | |||||
| - **values** (tuple, list) - Tuple or list of input tensors. The data type and shape of these | |||||
| tensors must be same. | |||||
| Outputs: | Outputs: | ||||
| Tensor, data type same as `values`. | Tensor, data type same as `values`. | ||||
| @@ -1542,6 +1543,7 @@ class ParallelConcat(PrimitiveWithInfer): | |||||
| >>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32)) | >>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32)) | ||||
| >>> op = P.ParallelConcat() | >>> op = P.ParallelConcat() | ||||
| >>> output = op((data1, data2)) | >>> output = op((data1, data2)) | ||||
| [[0, 1], [2, 1]] | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -1553,14 +1555,15 @@ class ParallelConcat(PrimitiveWithInfer): | |||||
| x_type = values['dtype'] | x_type = values['dtype'] | ||||
| validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name) | validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name) | ||||
| args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} | |||||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||||
| first_elem = x_shp[0] | first_elem = x_shp[0] | ||||
| args = {} | |||||
| for i, elem in enumerate(x_shp[1:]): | for i, elem in enumerate(x_shp[1:]): | ||||
| j = i + 1 | j = i + 1 | ||||
| args[f'x_type[{j}]'] = x_type[j] | |||||
| validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name) | validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name) | ||||
| validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name) | validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name) | ||||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||||
| ret_shp = x_shp[0].copy() | ret_shp = x_shp[0].copy() | ||||
| ret_shp[0] = len(x_shp) | ret_shp[0] = len(x_shp) | ||||