| @@ -285,3 +285,4 @@ from .mod import _mod_tbe | |||||
| from .max_pool_grad_grad import _max_pool_grad_grad_tbe | from .max_pool_grad_grad import _max_pool_grad_grad_tbe | ||||
| from .max_pool_grad_grad_with_argmax import _max_pool_grad_grad_with_argmax_tbe | from .max_pool_grad_grad_with_argmax import _max_pool_grad_grad_with_argmax_tbe | ||||
| from .population_count import _population_count_tbe | from .population_count import _population_count_tbe | ||||
| from .parallel_concat import _parallel_concat_tbe | |||||
| @@ -0,0 +1,80 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ParallelConcat op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| parallel_concat_op_info = TBERegOp("ParallelConcat") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("parallel_concat.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("parallel_concat") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("shape", "required", "listInt", "all") \ | |||||
| .attr("N", "required", "int", "all") \ | |||||
| .input(0, "values", False, "dynamic", "all") \ | |||||
| .output(0, "output_data", False, "required", "all") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I16_5HD, DataType.I16_5HD) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U16_5HD, DataType.U16_5HD) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U32_5HD, DataType.U32_5HD) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.I64_5HD, DataType.I64_5HD) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.U64_5HD, DataType.U64_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \ | |||||
| .dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \ | |||||
| .dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \ | |||||
| .dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \ | |||||
| .dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \ | |||||
| .dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \ | |||||
| .dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \ | |||||
| .dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \ | |||||
| .dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \ | |||||
| .dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \ | |||||
| .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \ | |||||
| .dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \ | |||||
| .dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \ | |||||
| .dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \ | |||||
| .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \ | |||||
| .dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \ | |||||
| .dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \ | |||||
| .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ | |||||
| .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .get_op_info() | |||||
| @op_info_register(parallel_concat_op_info) | |||||
| def _parallel_concat_tbe(): | |||||
| """ParallelConcat TBE register""" | |||||
| return | |||||
| @@ -28,6 +28,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, | SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, | ||||
| ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ||||
| Shape, Size, Slice, Split, TransShape, | Shape, Size, Slice, Split, TransShape, | ||||
| ParallelConcat, | |||||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, | Squeeze, StridedSlice, Tile, TensorScatterUpdate, | ||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | ||||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | ||||
| @@ -329,7 +330,8 @@ __all__ = [ | |||||
| "InTopK", | "InTopK", | ||||
| "LRN", | "LRN", | ||||
| "Mod", | "Mod", | ||||
| "PopulationCount" | |||||
| "PopulationCount", | |||||
| "ParallelConcat", | |||||
| ] | ] | ||||
| __all__.sort() | __all__.sort() | ||||
| @@ -1463,6 +1463,57 @@ class Concat(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class ParallelConcat(PrimitiveWithInfer): | |||||
| r""" | |||||
| Concat tensor in the first dimension. | |||||
| Concat input tensors along with the first dimension. | |||||
| Note: | |||||
| The input tensors are all required to have size 1 in the first dimension. | |||||
| Inputs: | |||||
| - **values** (tuple, list) - Tuple or list of input tensors. | |||||
| Outputs: | |||||
| Tensor, data type same as `values`. | |||||
| Examples: | |||||
| >>> data1 = Tensor(np.array([[0, 1]]).astype(np.int32)) | |||||
| >>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32)) | |||||
| >>> op = P.ParallelConcat() | |||||
| >>> output = op((data1, data2)) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init ParallelConcat""" | |||||
| def __infer__(self, values): | |||||
| x_shp = values['shape'] | |||||
| x_type = values['dtype'] | |||||
| validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name) | |||||
| first_elem = x_shp[0] | |||||
| args = {} | |||||
| for i, elem in enumerate(x_shp[1:]): | |||||
| j = i + 1 | |||||
| args[f'x_type[{j}]'] = x_type[j] | |||||
| validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name) | |||||
| validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name) | |||||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||||
| ret_shp = x_shp[0].copy() | |||||
| ret_shp[0] = len(x_shp) | |||||
| self.add_prim_attr('shape', ret_shp) | |||||
| self.add_prim_attr('N', len(x_shp)) | |||||
| out = {'shape': ret_shp, | |||||
| 'dtype': x_type[0], | |||||
| 'value': None} | |||||
| return out | |||||
| def _get_pack_shape(x_shape, x_type, axis, prim_name): | def _get_pack_shape(x_shape, x_type, axis, prim_name): | ||||
| """for pack output shape""" | """for pack output shape""" | ||||
| validator.check_value_type("shape", x_shape, [tuple, list], prim_name) | validator.check_value_type("shape", x_shape, [tuple, list], prim_name) | ||||
| @@ -596,6 +596,15 @@ def test_strided_slice_const(): | |||||
| assert (ret.asnumpy() == np.array([], np.float32).reshape([0, 1, 7, 8, 9, 3, 1])).all() | assert (ret.asnumpy() == np.array([], np.float32).reshape([0, 1, 7, 8, 9, 3, 1])).all() | ||||
| class ParallelConcatNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(ParallelConcatNet, self).__init__() | |||||
| self.parallel_concat = P.ParallelConcat() | |||||
| def construct(self, x1, x2): | |||||
| return self.parallel_concat((x1, x2)) | |||||
| test_case_math_ops = [ | test_case_math_ops = [ | ||||
| ('BitwiseAnd', { | ('BitwiseAnd', { | ||||
| 'block': P.BitwiseAnd(), | 'block': P.BitwiseAnd(), | ||||
| @@ -1875,6 +1884,12 @@ test_case_array_ops = [ | |||||
| 'desc_inputs': [[1, 3, 24, 24]], | 'desc_inputs': [[1, 3, 24, 24]], | ||||
| 'desc_bprop': [[1, 12, 24, 24]], | 'desc_bprop': [[1, 12, 24, 24]], | ||||
| }), | }), | ||||
| ('ParallelConcat', { | |||||
| 'block': ParallelConcatNet(), | |||||
| 'desc_inputs': [Tensor([[1, 2]], mstype.float32), | |||||
| Tensor([[5, 6]], mstype.float32)], | |||||
| 'skip': ['backward'], | |||||
| }), | |||||
| ] | ] | ||||
| test_case_other_ops = [ | test_case_other_ops = [ | ||||