Merge pull request !4789 from liuxiao93/Add-EditDistance-op-for-GEtags/v0.7.0-beta
| @@ -190,6 +190,7 @@ constexpr const char kNameSquareSumAll[] = "SquareSumAll"; | |||||
| constexpr const char kNameAscendQuant[] = "Quant"; | constexpr const char kNameAscendQuant[] = "Quant"; | ||||
| constexpr const char kNameAscendDequant[] = "Dequant"; | constexpr const char kNameAscendDequant[] = "Dequant"; | ||||
| constexpr const char kNameReverseSequence[] = "ReverseSequence"; | constexpr const char kNameReverseSequence[] = "ReverseSequence"; | ||||
| constexpr const char kNameEditDistance[] = "EditDistance"; | |||||
| constexpr const char kNameCase[] = "Case"; | constexpr const char kNameCase[] = "Case"; | ||||
| class OpAdapterMap { | class OpAdapterMap { | ||||
| @@ -87,4 +87,12 @@ ATTR_MAP(ReverseSequence) = {{"seq_dim", ATTR_DESC(seq_dim, AnyTraits<int>())}, | |||||
| {"batch_dim", ATTR_DESC(batch_dim, AnyTraits<int>())}}; | {"batch_dim", ATTR_DESC(batch_dim, AnyTraits<int>())}}; | ||||
| OUTPUT_MAP(ReverseSequence) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(ReverseSequence) = {{0, OUTPUT_DESC(y)}}; | ||||
| REG_ADPT_DESC(ReverseSequence, kNameReverseSequence, ADPT_DESC(ReverseSequence)) | REG_ADPT_DESC(ReverseSequence, kNameReverseSequence, ADPT_DESC(ReverseSequence)) | ||||
| // EditDistance | |||||
| INPUT_MAP(EditDistance) = {{1, INPUT_DESC(hypothesis_indices)}, {2, INPUT_DESC(hypothesis_values)}, | |||||
| {3, INPUT_DESC(hypothesis_shape)}, {4, INPUT_DESC(truth_indices)}, | |||||
| {5, INPUT_DESC(truth_values)}, {6, INPUT_DESC(truth_shape)}}; | |||||
| ATTR_MAP(EditDistance) = {{"normalize", ATTR_DESC(normalize, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(EditDistance) = {{0, OUTPUT_DESC(output)}}; | |||||
| REG_ADPT_DESC(EditDistance, kNameEditDistance, ADPT_DESC(EditDistance)) | |||||
| } // namespace mindspore::transform | } // namespace mindspore::transform | ||||
| @@ -54,5 +54,8 @@ DECLARE_OP_ADAPTER(Data) | |||||
| DECLARE_OP_ADAPTER(ReverseSequence) | DECLARE_OP_ADAPTER(ReverseSequence) | ||||
| DECLARE_OP_USE_OUTPUT(ReverseSequence) | DECLARE_OP_USE_OUTPUT(ReverseSequence) | ||||
| DECLARE_OP_ADAPTER(EditDistance) | |||||
| DECLARE_OP_USE_OUTPUT(EditDistance) | |||||
| } // namespace mindspore::transform | } // namespace mindspore::transform | ||||
| #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_ | #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_ | ||||
| @@ -23,7 +23,6 @@ from .acos_grad import _acos_grad_tbe | |||||
| from .acosh import _acosh_tbe | from .acosh import _acosh_tbe | ||||
| from .acosh_grad import _acosh_grad_tbe | from .acosh_grad import _acosh_grad_tbe | ||||
| from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe | from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe | ||||
| from .add import _add_tbe | |||||
| from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe | from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe | ||||
| from .add_n import _add_n_tbe | from .add_n import _add_n_tbe | ||||
| from .accumulate_n_v2 import _accumulate_n_v2_tbe | from .accumulate_n_v2 import _accumulate_n_v2_tbe | ||||
| @@ -1,37 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Add op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| add_op_info = TBERegOp("Add") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("add.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("add") \ | |||||
| .partial_flag(True) \ | |||||
| .op_pattern("dynamicFormat") \ | |||||
| .input(0, "x1", False, "required", "all") \ | |||||
| .input(1, "x2", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ | |||||
| .get_op_info() | |||||
| @op_info_register(add_op_info) | |||||
| def _add_tbe(): | |||||
| """Add TBE register""" | |||||
| return | |||||
| @@ -29,7 +29,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ||||
| Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding, | Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding, | ||||
| ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, | ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, | ||||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, | |||||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, | |||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | ||||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | ||||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | ||||
| @@ -92,6 +92,7 @@ from .sparse_ops import SparseToDense | |||||
| __all__ = [ | __all__ = [ | ||||
| 'ReverseSequence', | 'ReverseSequence', | ||||
| 'EditDistance', | |||||
| 'CropAndResize', | 'CropAndResize', | ||||
| 'TensorAdd', | 'TensorAdd', | ||||
| 'Argmax', | 'Argmax', | ||||
| @@ -3470,6 +3470,93 @@ class ReverseSequence(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| class EditDistance(PrimitiveWithInfer): | |||||
| """ | |||||
| Computes the Levebshtein Edit Distance. It is used to measure the similarity of two sequences. | |||||
| Args: | |||||
| normalize (bool): If True, edit distances are normalized by length of truth. Default: True. | |||||
| Inputs: | |||||
| - **hypothesis_indices** (Tensor) - The indices of the hypothesis list SparseTensor. With int64 data type. | |||||
| The shape of tensor is :math:`(N, R)`. | |||||
| - **hypothesis_values** (Tensor) - The values of the hypothesis list SparseTensor. | |||||
| Must be 1-D vector with length of N. | |||||
| - **hypothesis_shape** (Tensor) - The values of the hypothesis list SparseTensor. | |||||
| Must be R-length vector with int64 data type. Only constant value is allowed. | |||||
| - **truth_indices** (Tensor) - The indices of the truth list SparseTensor. With int64 data type. | |||||
| The shape of tensor is :math:`(M, R)`. | |||||
| - **truth_values** (Tensor) - The values of the truth list SparseTensor. Must be 1-D vector with length of M. | |||||
| - **truth_shape** (Tensor) - The values of the truth list SparseTensor. | |||||
| Must be R-length vector with int64 data type. Only constant value is allowed. | |||||
| Outputs: | |||||
| Tensor, a dense tensor with rank `R-1` and float32 data type. | |||||
| Examples: | |||||
| >>> class EditDistance(nn.Cell): | |||||
| >>> def __init__(self, hypothesis_shape, truth_shape, normalize=True): | |||||
| >>> super(EditDistance, self).__init__() | |||||
| >>> self.edit_distance = P.EditDistance(normalize) | |||||
| >>> self.hypothesis_shape = hypothesis_shape | |||||
| >>> self.truth_shape = truth_shape | |||||
| >>> | |||||
| >>> def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values): | |||||
| >>> return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape, | |||||
| >>> truth_indices, truth_values, self.truth_shape) | |||||
| >>> | |||||
| >>> hypothesis_indices = Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64)) | |||||
| >>> hypothesis_values = Tensor(np.array([1, 2, 3]).astype(np.float32)) | |||||
| >>> hypothesis_shape = Tensor(np.array([1, 1, 2]).astype(np.int64)) | |||||
| >>> truth_indices = Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64)) | |||||
| >>> truth_values = Tensor(np.array([1, 3, 2, 1]).astype(np.float32)) | |||||
| >>> truth_shape = Tensor(np.array([2, 2, 2]).astype(np.int64)) | |||||
| >>> edit_distance = EditDistance(hypothesis_shape, truth_shape) | |||||
| >>> out = edit_distance(hypothesis_indices, hypothesis_values, truth_indices, truth_values) | |||||
| >>> [[1.0, 1.0], [1.0, 1.0]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, normalize=True): | |||||
| """init EditDistance""" | |||||
| self.normalize = validator.check_value_type("normalize", normalize, [bool], self.name) | |||||
| def __infer__(self, h_indices, h_values, h_shape, truth_indices, truth_values, truth_shape): | |||||
| validator.check_const_input('hypothesis_shape', h_shape['value'], self.name) | |||||
| validator.check_const_input('truth_shape', truth_shape['value'], self.name) | |||||
| args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'], | |||||
| "truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']} | |||||
| validator.check_tensor_type_same(args_int, [mstype.int64], self.name) | |||||
| args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']} | |||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape'] | |||||
| validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name) | |||||
| validator.check("truth_indices rank", len(truth_indices_shp), "expected", 2, Rel.EQ, self.name) | |||||
| validator.check("hypothesis_values rank", len(h_values['shape']), "expected", 1, Rel.EQ, self.name) | |||||
| validator.check("hypothesis_shape rank", len(h_shape['shape']), "expected", 1, Rel.EQ, self.name) | |||||
| validator.check("truth_values rank", len(truth_values['shape']), "expected", 1, Rel.EQ, self.name) | |||||
| validator.check("truth_shape rank", len(truth_shape['shape']), "expected", 1, Rel.EQ, self.name) | |||||
| validator.check("hypothesis_values shape", h_values['shape'][0], | |||||
| "hypothesis_indices shape[0]", hypothesis_indices_shp[0], Rel.EQ, self.name) | |||||
| validator.check("hypothesis_shape", h_shape['shape'][0], | |||||
| "hypothesis_indices shape[1]", hypothesis_indices_shp[1], Rel.EQ, self.name) | |||||
| validator.check("truth_values shape", truth_values['shape'][0], | |||||
| "truth_indices shape[0]", truth_indices_shp[0], Rel.EQ, self.name) | |||||
| validator.check("hypothesis_shape", h_shape['shape'][0], | |||||
| "truth_shape", truth_shape['shape'][0], Rel.EQ, self.name) | |||||
| hypothesis_shape_v = h_shape['value'].asnumpy() | |||||
| truth_shape_v = truth_shape['value'].asnumpy() | |||||
| out_shape_rank = len(hypothesis_shape_v) - 1 | |||||
| out_shape = [] | |||||
| for i in range(out_shape_rank): | |||||
| out_shape.append(max(hypothesis_shape_v[i], truth_shape_v[i])) | |||||
| return {'shape': tuple(out_shape), | |||||
| 'dtype': mstype.tensor_type(mstype.float32), | |||||
| 'value': None} | |||||
| class TransShape(PrimitiveWithInfer): | class TransShape(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Transform the shape of input tensor to target shape. | Transform the shape of input tensor to target shape. | ||||
| @@ -684,6 +684,18 @@ class ParallelConcatNet(nn.Cell): | |||||
| return self.parallel_concat((x1, x2)) | return self.parallel_concat((x1, x2)) | ||||
| class EditDistance(nn.Cell): | |||||
| def __init__(self, hypothesis_shape, truth_shape, normalize=True): | |||||
| super(EditDistance, self).__init__() | |||||
| self.edit_distance = P.EditDistance(normalize) | |||||
| self.hypothesis_shape = hypothesis_shape | |||||
| self.truth_shape =truth_shape | |||||
| def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values): | |||||
| return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape, | |||||
| truth_indices, truth_values, self.truth_shape) | |||||
| test_case_math_ops = [ | test_case_math_ops = [ | ||||
| ('BitwiseAnd', { | ('BitwiseAnd', { | ||||
| 'block': P.BitwiseAnd(), | 'block': P.BitwiseAnd(), | ||||
| @@ -1978,6 +1990,15 @@ test_case_array_ops = [ | |||||
| 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)), | 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)), | ||||
| Tensor(np.array([1, 2, 3]).astype(np.int32))], | Tensor(np.array([1, 2, 3]).astype(np.int32))], | ||||
| 'desc_bprop': [[3, 3]]}), | 'desc_bprop': [[3, 3]]}), | ||||
| ('EditDistance', { | |||||
| 'block': EditDistance(Tensor(np.array([1, 1, 2]).astype(np.int64)), | |||||
| Tensor(np.array([2, 2, 2]).astype(np.int64))), | |||||
| 'desc_inputs': [Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64)), | |||||
| Tensor(np.array([1, 2, 3]).astype(np.float32)), | |||||
| Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64)), | |||||
| Tensor(np.array([1, 3, 2, 1]).astype(np.float32))], | |||||
| 'skip': ['backward'], | |||||
| }), | |||||
| ('LinSpace', { | ('LinSpace', { | ||||
| 'block': inner.LinSpace(), | 'block': inner.LinSpace(), | ||||
| 'desc_inputs': [Tensor([5, 5.5], mstype.float32), | 'desc_inputs': [Tensor([5, 5.5], mstype.float32), | ||||