|
|
|
@@ -1415,7 +1415,7 @@ class EqualCount(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
Computes the number of the same elements of two tensors. |
|
|
|
|
|
|
|
The two input tensors should have same data type. |
|
|
|
The two input tensors should have same data type and shape. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Tensor) - The first input tensor. |
|
|
|
@@ -1438,6 +1438,7 @@ class EqualCount(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, y_shape): |
|
|
|
validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name) |
|
|
|
output_shape = (1,) |
|
|
|
return output_shape |
|
|
|
|
|
|
|
|