Browse Source

[feat][assistant][I40FG5] add new Ascend operator IsClose

tags/v1.6.0
韩峥嵘 4 years ago
parent
commit
43a495f3fd
8 changed files with 231 additions and 2 deletions
  1. +1
    -0
      mindspore/core/base/core_ops.h
  2. +85
    -0
      mindspore/core/ops/is_close.cc
  3. +42
    -0
      mindspore/core/ops/is_close.h
  4. +1
    -0
      mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py
  5. +40
    -0
      mindspore/python/mindspore/ops/_op_impl/tbe/is_close.py
  6. +1
    -1
      mindspore/python/mindspore/ops/operations/__init__.py
  7. +57
    -0
      mindspore/python/mindspore/ops/operations/math_ops.py
  8. +4
    -1
      tests/ut/python/ops/test_ops.py

+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -584,6 +584,7 @@ inline const PrimitivePtr kPrimErfinv = std::make_shared<Primitive>("Erfinv");
inline const PrimitivePtr kPrimIsNan = std::make_shared<Primitive>("IsNan");
inline const PrimitivePtr kPrimIsInf = std::make_shared<Primitive>("IsInf");
inline const PrimitivePtr kPrimIsFinite = std::make_shared<Primitive>("IsFinite");
inline const PrimitivePtr kPrimIsClose = std::make_shared<Primitive>("IsClose");
inline const PrimitivePtr kPrimLerp = std::make_shared<Primitive>("Lerp");
inline const PrimitivePtr kPrimSquareSumAll = std::make_shared<Primitive>("SquareSumAll");
inline const PrimitivePtr kPrimComplex = std::make_shared<Primitive>("Complex");


+ 85
- 0
mindspore/core/ops/is_close.cc View File

@@ -0,0 +1,85 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ops/is_close.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"

namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr IsCloseInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
const int MAX = 0x3fffffff;
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
const int input_num = 2;
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, input_num, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto other_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto input_rank = SizeToLong(input_shape.size());
auto other_rank = SizeToLong(other_shape.size());
CheckAndConvertUtils::Check("input rank", input_rank, kEqual, "other rank", other_rank, op_name);
int64_t input_size = 1, other_size = 1;
for (size_t i = 0; i < input_shape.size(); i++) {
input_size *= input_shape[i];
other_size *= other_shape[i];
if (input_shape[i] != other_shape[i] && (input_shape[i] != 1 || other_shape[i] != 1)) {
MS_EXCEPTION(ValueError) << "The size of tensor input must match the size of tensor other at the " << i
<< " dimension!";
}
}
if (input_size > MAX)
MS_EXCEPTION(ValueError) << "The size of tensor input must should be less than [2147483648], actual is "
<< input_size;
if (other_size > MAX)
MS_EXCEPTION(ValueError) << "The size of tensor other must should be less than [2147483648], actual is "
<< other_size;
return BroadCastInferShape(op_name, input_args);
}
TypePtr IsCloseInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto op_name = prim->name();
const int input_num = 2;
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, input_num, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32};
std::map<std::string, TypePtr> types;
types.emplace("input", input_args[0]->BuildType());
types.emplace("other", input_args[1]->BuildType());
CheckAndConvertUtils::CheckTensorTypeValid("input", input_args[0]->BuildType(), valid_types, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("other", input_args[1]->BuildType(), valid_types, op_name);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
}
} // namespace
AbstractBasePtr IsCloseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
(void)IsCloseInferType(primitive, input_args);
return abstract::MakeAbstract(IsCloseInferShape(primitive, input_args), kBool);
}
REGISTER_PRIMITIVE_EVAL_IMPL(IsClose, prim::kPrimIsClose, IsCloseInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

+ 42
- 0
mindspore/core/ops/is_close.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_OPS_IsClose_H_
#define MINDSPORE_CORE_OPS_IsClose_H_

#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace ops {
constexpr auto kNameIsClose = "IsClose";
class IsClose : public PrimitiveC {
public:
IsClose() : PrimitiveC(kNameIsClose) { InitIOName({"x1", "x2"}, {"y"}); }
~IsClose() = default;
MS_DECLARE_PARENT(IsClose, PrimitiveC);
};
AbstractBasePtr IsCloseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimIsClosePtr = std::shared_ptr<IsClose>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_IsClose_H_

+ 1
- 0
mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -496,3 +496,4 @@ from .non_zero_ds import _non_zero_ds_tbe
from .trunc import _trunc_tbe
from .extract_volume_patches import _extract_volume_patches_tbe
from .round_ds import _round_ds_tbe
from .is_close import _is_close_tbe

+ 40
- 0
mindspore/python/mindspore/ops/_op_impl/tbe/is_close.py View File

@@ -0,0 +1,40 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""IsClose op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

is_close_op_info = TBERegOp("IsClose") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("is_close.so") \
.compute_cost(10) \
.kernel_name("is_close") \
.partial_flag(True) \
.attr("rtol", "optional", "float", "all", "1e-05")\
.attr("atol", "optional", "float", "all", "1e-08")\
.attr("equal_nan", "optional", "bool", "true,false", "False")\
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.get_op_info()

@op_info_register(is_close_op_info)
def _is_close_tbe():
"""IsClose TBE register"""
return

+ 1
- 1
mindspore/python/mindspore/ops/operations/__init__.py View File

@@ -59,7 +59,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc)
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose)

from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,


+ 57
- 0
mindspore/python/mindspore/ops/operations/math_ops.py View File

@@ -5569,3 +5569,60 @@ class Trunc(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize Trunc"""

class IsClose(Primitive):
r"""
Returns a boolean tensor where two tensors are element-wise equal within a tolerance.

Note:
Returns a new tensor with boolean elements representing if each element of input
is “close” to the corresponding element of other. Closeness is defined as:
∣input−other∣ ≤ atol + rtol × ∣other∣

.. warning::
When the input is nan or inf, the result is uncertain.

Args:
rtol(float): Relative tolerance. Default: 1e-05.
atol(float): Absolute tolerance. Default: 1e-08.
equal_nan(bool): If True, then two NaNs will be considered equal. At present, `equal_nan` must be True,
we will support False in future version. Default: True.

Inputs:
-**input**(Tensor) – First tensor to compare, with data type belongs to float32, float16, int32.
-**other**(Tensor) – Second tensor to compare, with data type belongs to float32, float16, int32.

Outputs:
Tensor, with same shape as input and other. When the input is close to the other, it is true,
otherwise it is false.

Raises:
TypeError: If either of `input` and `other` is not tensor.
TypeError: If either of `input` and `other` is not float16, float32 or int32.
TypeError: If either of `atol` and `rtol` is not float.
TypeError: If `equal_nan` is not bool.
TypeError: If the dtype of `input` is not same as the `other`.
ValueError: If shape of `input` is not same as the `other`.
ValueError: If either of `atol` and `rtol` is less than zero.
ValueError: If `equal_nan` is False.

Supported Platforms:
``Ascend``

Examples:
>>> input = Tensor(np.array([1.3, 2.1, 3.2, 4.1, 5.1]), mindspore.float16)
>>> other = Tensor(np.array([1.3, 3.3, 2.3, 3.1, 5.1]), mindspore.float16)
>>> output = ops.IsClose()(input, other)
>>> print(output)
[true false false false true]
"""
@prim_attr_register
def __init__(self, rtol=1e-05, atol=1e-08, equal_nan=True):
"""Initialize IsClose"""
validator.check_value_type('rtol', rtol, [float], self.name)
validator.check_value_type('atol', atol, [float], self.name)
validator.check_value_type('equal_nan', equal_nan, [bool], self.name)
if equal_nan is not True:
raise ValueError("For IsClose, the `equal_nan` must be True, but got False.")
validator.check_non_negative_float(rtol, 'rtol', self.name)
validator.check_non_negative_float(atol, 'atol', self.name)

+ 4
- 1
tests/ut/python/ops/test_ops.py View File

@@ -1195,7 +1195,10 @@ test_case_math_ops = [
('IsInf', {
'block': P.IsInf(),
'desc_inputs': [Tensor(np.array([np.log(-1), 1, np.log(0)]).astype(np.float32))],
'desc_bprop': [],
'desc_bprop': []}),
('IsClose', {
'block': P.IsClose(rtol=1e-05, atol=1e-08, equal_nan=True),
'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32)],
'skip': ['backward']}),
('ACos', {
'block': P.ACos(),


Loading…
Cancel
Save