Merge pull request !4789 from liuxiao93/Add-EditDistance-op-for-GEtags/v0.7.0-beta
| @@ -190,6 +190,7 @@ constexpr const char kNameSquareSumAll[] = "SquareSumAll"; | |||
| constexpr const char kNameAscendQuant[] = "Quant"; | |||
| constexpr const char kNameAscendDequant[] = "Dequant"; | |||
| constexpr const char kNameReverseSequence[] = "ReverseSequence"; | |||
| constexpr const char kNameEditDistance[] = "EditDistance"; | |||
| constexpr const char kNameCase[] = "Case"; | |||
| class OpAdapterMap { | |||
| @@ -87,4 +87,12 @@ ATTR_MAP(ReverseSequence) = {{"seq_dim", ATTR_DESC(seq_dim, AnyTraits<int>())}, | |||
| {"batch_dim", ATTR_DESC(batch_dim, AnyTraits<int>())}}; | |||
| OUTPUT_MAP(ReverseSequence) = {{0, OUTPUT_DESC(y)}}; | |||
| REG_ADPT_DESC(ReverseSequence, kNameReverseSequence, ADPT_DESC(ReverseSequence)) | |||
| // EditDistance | |||
| INPUT_MAP(EditDistance) = {{1, INPUT_DESC(hypothesis_indices)}, {2, INPUT_DESC(hypothesis_values)}, | |||
| {3, INPUT_DESC(hypothesis_shape)}, {4, INPUT_DESC(truth_indices)}, | |||
| {5, INPUT_DESC(truth_values)}, {6, INPUT_DESC(truth_shape)}}; | |||
| ATTR_MAP(EditDistance) = {{"normalize", ATTR_DESC(normalize, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(EditDistance) = {{0, OUTPUT_DESC(output)}}; | |||
| REG_ADPT_DESC(EditDistance, kNameEditDistance, ADPT_DESC(EditDistance)) | |||
| } // namespace mindspore::transform | |||
| @@ -54,5 +54,8 @@ DECLARE_OP_ADAPTER(Data) | |||
| DECLARE_OP_ADAPTER(ReverseSequence) | |||
| DECLARE_OP_USE_OUTPUT(ReverseSequence) | |||
| DECLARE_OP_ADAPTER(EditDistance) | |||
| DECLARE_OP_USE_OUTPUT(EditDistance) | |||
| } // namespace mindspore::transform | |||
| #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_ | |||
| @@ -23,7 +23,6 @@ from .acos_grad import _acos_grad_tbe | |||
| from .acosh import _acosh_tbe | |||
| from .acosh_grad import _acosh_grad_tbe | |||
| from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe | |||
| from .add import _add_tbe | |||
| from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe | |||
| from .add_n import _add_n_tbe | |||
| from .accumulate_n_v2 import _accumulate_n_v2_tbe | |||
| @@ -1,37 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Add op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| add_op_info = TBERegOp("Add") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("add.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("add") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ | |||
| .get_op_info() | |||
| @op_info_register(add_op_info) | |||
| def _add_tbe(): | |||
| """Add TBE register""" | |||
| return | |||
| @@ -29,7 +29,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | |||
| Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding, | |||
| ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, | |||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, | |||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, | |||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | |||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | |||
| @@ -92,6 +92,7 @@ from .sparse_ops import SparseToDense | |||
| __all__ = [ | |||
| 'ReverseSequence', | |||
| 'EditDistance', | |||
| 'CropAndResize', | |||
| 'TensorAdd', | |||
| 'Argmax', | |||
| @@ -3470,6 +3470,93 @@ class ReverseSequence(PrimitiveWithInfer): | |||
| return x | |||
| class EditDistance(PrimitiveWithInfer): | |||
| """ | |||
| Computes the Levebshtein Edit Distance. It is used to measure the similarity of two sequences. | |||
| Args: | |||
| normalize (bool): If True, edit distances are normalized by length of truth. Default: True. | |||
| Inputs: | |||
| - **hypothesis_indices** (Tensor) - The indices of the hypothesis list SparseTensor. With int64 data type. | |||
| The shape of tensor is :math:`(N, R)`. | |||
| - **hypothesis_values** (Tensor) - The values of the hypothesis list SparseTensor. | |||
| Must be 1-D vector with length of N. | |||
| - **hypothesis_shape** (Tensor) - The values of the hypothesis list SparseTensor. | |||
| Must be R-length vector with int64 data type. Only constant value is allowed. | |||
| - **truth_indices** (Tensor) - The indices of the truth list SparseTensor. With int64 data type. | |||
| The shape of tensor is :math:`(M, R)`. | |||
| - **truth_values** (Tensor) - The values of the truth list SparseTensor. Must be 1-D vector with length of M. | |||
| - **truth_shape** (Tensor) - The values of the truth list SparseTensor. | |||
| Must be R-length vector with int64 data type. Only constant value is allowed. | |||
| Outputs: | |||
| Tensor, a dense tensor with rank `R-1` and float32 data type. | |||
| Examples: | |||
| >>> class EditDistance(nn.Cell): | |||
| >>> def __init__(self, hypothesis_shape, truth_shape, normalize=True): | |||
| >>> super(EditDistance, self).__init__() | |||
| >>> self.edit_distance = P.EditDistance(normalize) | |||
| >>> self.hypothesis_shape = hypothesis_shape | |||
| >>> self.truth_shape = truth_shape | |||
| >>> | |||
| >>> def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values): | |||
| >>> return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape, | |||
| >>> truth_indices, truth_values, self.truth_shape) | |||
| >>> | |||
| >>> hypothesis_indices = Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64)) | |||
| >>> hypothesis_values = Tensor(np.array([1, 2, 3]).astype(np.float32)) | |||
| >>> hypothesis_shape = Tensor(np.array([1, 1, 2]).astype(np.int64)) | |||
| >>> truth_indices = Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64)) | |||
| >>> truth_values = Tensor(np.array([1, 3, 2, 1]).astype(np.float32)) | |||
| >>> truth_shape = Tensor(np.array([2, 2, 2]).astype(np.int64)) | |||
| >>> edit_distance = EditDistance(hypothesis_shape, truth_shape) | |||
| >>> out = edit_distance(hypothesis_indices, hypothesis_values, truth_indices, truth_values) | |||
| >>> [[1.0, 1.0], [1.0, 1.0]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, normalize=True): | |||
| """init EditDistance""" | |||
| self.normalize = validator.check_value_type("normalize", normalize, [bool], self.name) | |||
| def __infer__(self, h_indices, h_values, h_shape, truth_indices, truth_values, truth_shape): | |||
| validator.check_const_input('hypothesis_shape', h_shape['value'], self.name) | |||
| validator.check_const_input('truth_shape', truth_shape['value'], self.name) | |||
| args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'], | |||
| "truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']} | |||
| validator.check_tensor_type_same(args_int, [mstype.int64], self.name) | |||
| args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape'] | |||
| validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name) | |||
| validator.check("truth_indices rank", len(truth_indices_shp), "expected", 2, Rel.EQ, self.name) | |||
| validator.check("hypothesis_values rank", len(h_values['shape']), "expected", 1, Rel.EQ, self.name) | |||
| validator.check("hypothesis_shape rank", len(h_shape['shape']), "expected", 1, Rel.EQ, self.name) | |||
| validator.check("truth_values rank", len(truth_values['shape']), "expected", 1, Rel.EQ, self.name) | |||
| validator.check("truth_shape rank", len(truth_shape['shape']), "expected", 1, Rel.EQ, self.name) | |||
| validator.check("hypothesis_values shape", h_values['shape'][0], | |||
| "hypothesis_indices shape[0]", hypothesis_indices_shp[0], Rel.EQ, self.name) | |||
| validator.check("hypothesis_shape", h_shape['shape'][0], | |||
| "hypothesis_indices shape[1]", hypothesis_indices_shp[1], Rel.EQ, self.name) | |||
| validator.check("truth_values shape", truth_values['shape'][0], | |||
| "truth_indices shape[0]", truth_indices_shp[0], Rel.EQ, self.name) | |||
| validator.check("hypothesis_shape", h_shape['shape'][0], | |||
| "truth_shape", truth_shape['shape'][0], Rel.EQ, self.name) | |||
| hypothesis_shape_v = h_shape['value'].asnumpy() | |||
| truth_shape_v = truth_shape['value'].asnumpy() | |||
| out_shape_rank = len(hypothesis_shape_v) - 1 | |||
| out_shape = [] | |||
| for i in range(out_shape_rank): | |||
| out_shape.append(max(hypothesis_shape_v[i], truth_shape_v[i])) | |||
| return {'shape': tuple(out_shape), | |||
| 'dtype': mstype.tensor_type(mstype.float32), | |||
| 'value': None} | |||
| class TransShape(PrimitiveWithInfer): | |||
| """ | |||
| Transform the shape of input tensor to target shape. | |||
| @@ -684,6 +684,18 @@ class ParallelConcatNet(nn.Cell): | |||
| return self.parallel_concat((x1, x2)) | |||
| class EditDistance(nn.Cell): | |||
| def __init__(self, hypothesis_shape, truth_shape, normalize=True): | |||
| super(EditDistance, self).__init__() | |||
| self.edit_distance = P.EditDistance(normalize) | |||
| self.hypothesis_shape = hypothesis_shape | |||
| self.truth_shape =truth_shape | |||
| def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values): | |||
| return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape, | |||
| truth_indices, truth_values, self.truth_shape) | |||
| test_case_math_ops = [ | |||
| ('BitwiseAnd', { | |||
| 'block': P.BitwiseAnd(), | |||
| @@ -1978,6 +1990,15 @@ test_case_array_ops = [ | |||
| 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)), | |||
| Tensor(np.array([1, 2, 3]).astype(np.int32))], | |||
| 'desc_bprop': [[3, 3]]}), | |||
| ('EditDistance', { | |||
| 'block': EditDistance(Tensor(np.array([1, 1, 2]).astype(np.int64)), | |||
| Tensor(np.array([2, 2, 2]).astype(np.int64))), | |||
| 'desc_inputs': [Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64)), | |||
| Tensor(np.array([1, 2, 3]).astype(np.float32)), | |||
| Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64)), | |||
| Tensor(np.array([1, 3, 2, 1]).astype(np.float32))], | |||
| 'skip': ['backward'], | |||
| }), | |||
| ('LinSpace', { | |||
| 'block': inner.LinSpace(), | |||
| 'desc_inputs': [Tensor([5, 5.5], mstype.float32), | |||