|
|
|
@@ -3619,6 +3619,12 @@ class EditDistance(PrimitiveWithInfer): |
|
|
|
Tensor, a dense tensor with rank `R-1` and float32 data type. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> import numpy as np |
|
|
|
>>> from mindspore import context |
|
|
|
>>> from mindspore import Tensor |
|
|
|
>>> import mindspore.nn as nn |
|
|
|
>>> import mindspore.ops.operations as P |
|
|
|
>>> context.set_context(mode=context.GRAPH_MODE) |
|
|
|
>>> class EditDistance(nn.Cell): |
|
|
|
>>> def __init__(self, hypothesis_shape, truth_shape, normalize=True): |
|
|
|
>>> super(EditDistance, self).__init__() |
|
|
|
@@ -3645,6 +3651,7 @@ class EditDistance(PrimitiveWithInfer): |
|
|
|
def __init__(self, normalize=True): |
|
|
|
"""Initialize EditDistance""" |
|
|
|
self.normalize = validator.check_value_type("normalize", normalize, [bool], self.name) |
|
|
|
self.set_const_input_indexes([2, 5]) |
|
|
|
|
|
|
|
def __infer__(self, h_indices, h_values, h_shape, truth_indices, truth_values, truth_shape): |
|
|
|
validator.check_const_input('hypothesis_shape', h_shape['value'], self.name) |
|
|
|
|