Merge pull request !6933 from wuxuejian/new_optags/v1.1.0
| @@ -43,8 +43,10 @@ constexpr auto kSeed2 = "seed2"; | |||||
| constexpr auto kTopK = "TopK"; | constexpr auto kTopK = "TopK"; | ||||
| constexpr auto kTopKV2 = "TopKV2"; | constexpr auto kTopKV2 = "TopKV2"; | ||||
| constexpr auto kEditDistance = "EditDistance"; | constexpr auto kEditDistance = "EditDistance"; | ||||
| constexpr auto kGatherD = "GatherD"; | |||||
| constexpr auto kIdentity = "Identity"; | |||||
| constexpr auto kCustRunApi = "RunCpuKernel"; | constexpr auto kCustRunApi = "RunCpuKernel"; | ||||
| const std::set<std::string> kCustAiCpuKernelOps{kTopK, kEditDistance}; | |||||
| const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kGatherD, kIdentity}; | |||||
| struct AicpuParamHead { | struct AicpuParamHead { | ||||
| uint32_t length; // Total length: include cunstom message | uint32_t length; // Total length: include cunstom message | ||||
| @@ -16,6 +16,8 @@ | |||||
| from .init_data_set_queue import _init_data_set_queue_aicpu | from .init_data_set_queue import _init_data_set_queue_aicpu | ||||
| from .embedding_lookup import _embedding_lookup_aicpu | from .embedding_lookup import _embedding_lookup_aicpu | ||||
| from .padding import _padding_aicpu | from .padding import _padding_aicpu | ||||
| from .gather import _gather_aicpu | |||||
| from .identity import _identity_aicpu | |||||
| from .dropout_genmask import _dropout_genmask_aicpu | from .dropout_genmask import _dropout_genmask_aicpu | ||||
| from .get_next import _get_next_aicpu | from .get_next import _get_next_aicpu | ||||
| from .print_tensor import _print_aicpu | from .print_tensor import _print_aicpu | ||||
| @@ -0,0 +1,78 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """GatherD op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| gather_op_info = AiCPURegOp("GatherD") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "input", "required") \ | |||||
| .input(1, "dim", "required") \ | |||||
| .input(2, "index", "required") \ | |||||
| .output(0, "output", "required") \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, \ | |||||
| DataType.I64_Default, DataType.BOOL_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(gather_op_info) | |||||
| def _gather_aicpu(): | |||||
| """GatherD AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Identity op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| identity_op_info = AiCPURegOp("Identity") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x", "required") \ | |||||
| .output(0, "y", "required") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.F64_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(identity_op_info) | |||||
| def _identity_aicpu(): | |||||
| """Identity AiCPU register""" | |||||
| return | |||||
| @@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | ||||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | ||||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | ||||
| Unique) | |||||
| Unique, GatherD, Identity) | |||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, ReduceOp, _VirtualDataset, | _MirrorOperator, ReduceOp, _VirtualDataset, | ||||
| _VirtualDiv, _GetTensorSlice, | _VirtualDiv, _GetTensorSlice, | ||||
| @@ -153,6 +153,8 @@ __all__ = [ | |||||
| 'SparseGatherV2', | 'SparseGatherV2', | ||||
| 'EmbeddingLookup', | 'EmbeddingLookup', | ||||
| 'Padding', | 'Padding', | ||||
| 'GatherD', | |||||
| 'Identity', | |||||
| 'Concat', | 'Concat', | ||||
| 'Pack', | 'Pack', | ||||
| 'Unpack', | 'Unpack', | ||||
| @@ -3797,3 +3797,76 @@ class EmbeddingLookup(PrimitiveWithInfer): | |||||
| 'dtype': params['dtype'], | 'dtype': params['dtype'], | ||||
| 'value': None} | 'value': None} | ||||
| return out | return out | ||||
| class GatherD(PrimitiveWithInfer): | |||||
| """ | |||||
| Gathers values along an axis specified by dim. | |||||
| Inputs: | |||||
| - **x** (Tensor) - The source tensor. | |||||
| - **dim** (int) - The axis along which to index. It must be int32. Only constant value is allowed. | |||||
| - **index** (Tensor) - The indices of elements to gather. It can be one of the following data types: | |||||
| int32, int64. | |||||
| Outputs: | |||||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | |||||
| Examples: | |||||
| >>> x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32) | |||||
| >>> index = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32) | |||||
| >>> dim = 1 | |||||
| >>> out = P.GatherD()(x, dim, index) | |||||
| [[1, 1], [4, 3]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """Initialize GatherD""" | |||||
| def __infer__(self, x, dim, index): | |||||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({"index": index['dtype']}, [mstype.int32, mstype.int64], self.name) | |||||
| validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name) | |||||
| x_shp = x['shape'] | |||||
| idx_shp = index['shape'] | |||||
| x_rank = len(x_shp) | |||||
| idx_rank = len(idx_shp) | |||||
| validator.check("x_rank, idx_rank", x_rank, "expected", idx_rank, Rel.EQ, self.name) | |||||
| dim_v = dim['value'] | |||||
| validator.check("dim value", dim_v, "expected", 0, Rel.GE, self.name) | |||||
| validator.check("dim value", dim_v, "expected", x_rank, Rel.LT, self.name) | |||||
| for i in range(x_rank): | |||||
| if i == dim_v: | |||||
| continue | |||||
| validator.check("x_shp[{0}], idx_shp[{0}]".format(i), x_shp[i], "expected", idx_shp[i], Rel.EQ, self.name) | |||||
| out = {'shape': index['shape'], | |||||
| 'dtype': x['dtype'], | |||||
| 'value': None} | |||||
| return out | |||||
| class Identity(PrimitiveWithInfer): | |||||
| """ | |||||
| Returns a Tensor with the same shape and contents as input. | |||||
| Inputs: | |||||
| - **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| Outputs: | |||||
| Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`. | |||||
| Examples: | |||||
| >>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64) | |||||
| >>> y = P.Identity()(x) | |||||
| [1, 2, 3, 4] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """Initialize identity""" | |||||
| def __infer__(self, x): | |||||
| out = {'shape': x['shape'], | |||||
| 'dtype': x['dtype'], | |||||
| 'value': None} | |||||
| return out | |||||