| @@ -584,6 +584,7 @@ inline const PrimitivePtr kPrimErfinv = std::make_shared<Primitive>("Erfinv"); | |||
| inline const PrimitivePtr kPrimIsNan = std::make_shared<Primitive>("IsNan"); | |||
| inline const PrimitivePtr kPrimIsInf = std::make_shared<Primitive>("IsInf"); | |||
| inline const PrimitivePtr kPrimIsFinite = std::make_shared<Primitive>("IsFinite"); | |||
| inline const PrimitivePtr kPrimIsClose = std::make_shared<Primitive>("IsClose"); | |||
| inline const PrimitivePtr kPrimLerp = std::make_shared<Primitive>("Lerp"); | |||
| inline const PrimitivePtr kPrimSquareSumAll = std::make_shared<Primitive>("SquareSumAll"); | |||
| inline const PrimitivePtr kPrimComplex = std::make_shared<Primitive>("Complex"); | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/is_close.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr IsCloseInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| const int MAX = 0x3fffffff; | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto op_name = primitive->name(); | |||
| const int input_num = 2; | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, input_num, op_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; | |||
| auto other_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; | |||
| auto input_rank = SizeToLong(input_shape.size()); | |||
| auto other_rank = SizeToLong(other_shape.size()); | |||
| CheckAndConvertUtils::Check("input rank", input_rank, kEqual, "other rank", other_rank, op_name); | |||
| int64_t input_size = 1, other_size = 1; | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size *= input_shape[i]; | |||
| other_size *= other_shape[i]; | |||
| if (input_shape[i] != other_shape[i] && (input_shape[i] != 1 || other_shape[i] != 1)) { | |||
| MS_EXCEPTION(ValueError) << "The size of tensor input must match the size of tensor other at the " << i | |||
| << " dimension!"; | |||
| } | |||
| } | |||
| if (input_size > MAX) | |||
| MS_EXCEPTION(ValueError) << "The size of tensor input must should be less than [2147483648], actual is " | |||
| << input_size; | |||
| if (other_size > MAX) | |||
| MS_EXCEPTION(ValueError) << "The size of tensor other must should be less than [2147483648], actual is " | |||
| << other_size; | |||
| return BroadCastInferShape(op_name, input_args); | |||
| } | |||
| TypePtr IsCloseInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto op_name = prim->name(); | |||
| const int input_num = 2; | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, input_num, op_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32}; | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("input", input_args[0]->BuildType()); | |||
| types.emplace("other", input_args[1]->BuildType()); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("input", input_args[0]->BuildType(), valid_types, op_name); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("other", input_args[1]->BuildType(), valid_types, op_name); | |||
| return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr IsCloseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| (void)IsCloseInferType(primitive, input_args); | |||
| return abstract::MakeAbstract(IsCloseInferShape(primitive, input_args), kBool); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(IsClose, prim::kPrimIsClose, IsCloseInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_IsClose_H_ | |||
| #define MINDSPORE_CORE_OPS_IsClose_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameIsClose = "IsClose"; | |||
| class IsClose : public PrimitiveC { | |||
| public: | |||
| IsClose() : PrimitiveC(kNameIsClose) { InitIOName({"x1", "x2"}, {"y"}); } | |||
| ~IsClose() = default; | |||
| MS_DECLARE_PARENT(IsClose, PrimitiveC); | |||
| }; | |||
| AbstractBasePtr IsCloseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimIsClosePtr = std::shared_ptr<IsClose>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_IsClose_H_ | |||
| @@ -496,3 +496,4 @@ from .non_zero_ds import _non_zero_ds_tbe | |||
| from .trunc import _trunc_tbe | |||
| from .extract_volume_patches import _extract_volume_patches_tbe | |||
| from .round_ds import _round_ds_tbe | |||
| from .is_close import _is_close_tbe | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """IsClose op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| is_close_op_info = TBERegOp("IsClose") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("is_close.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("is_close") \ | |||
| .partial_flag(True) \ | |||
| .attr("rtol", "optional", "float", "all", "1e-05")\ | |||
| .attr("atol", "optional", "float", "all", "1e-08")\ | |||
| .attr("equal_nan", "optional", "bool", "true,false", "False")\ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(is_close_op_info) | |||
| def _is_close_tbe(): | |||
| """IsClose TBE register""" | |||
| return | |||
| @@ -59,7 +59,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||
| Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | |||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | |||
| Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, | |||
| MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc) | |||
| MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose) | |||
| from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | |||
| RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler, | |||
| @@ -5569,3 +5569,60 @@ class Trunc(Primitive): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize Trunc""" | |||
| class IsClose(Primitive): | |||
| r""" | |||
| Returns a boolean tensor where two tensors are element-wise equal within a tolerance. | |||
| Note: | |||
| Returns a new tensor with boolean elements representing if each element of input | |||
| is “close” to the corresponding element of other. Closeness is defined as: | |||
| ∣input−other∣ ≤ atol + rtol × ∣other∣ | |||
| .. warning:: | |||
| When the input is nan or inf, the result is uncertain. | |||
| Args: | |||
| rtol(float): Relative tolerance. Default: 1e-05. | |||
| atol(float): Absolute tolerance. Default: 1e-08. | |||
| equal_nan(bool): If True, then two NaNs will be considered equal. At present, `equal_nan` must be True, | |||
| we will support False in future version. Default: True. | |||
| Inputs: | |||
| -**input**(Tensor) – First tensor to compare, with data type belongs to float32, float16, int32. | |||
| -**other**(Tensor) – Second tensor to compare, with data type belongs to float32, float16, int32. | |||
| Outputs: | |||
| Tensor, with same shape as input and other. When the input is close to the other, it is true, | |||
| otherwise it is false. | |||
| Raises: | |||
| TypeError: If either of `input` and `other` is not tensor. | |||
| TypeError: If either of `input` and `other` is not float16, float32 or int32. | |||
| TypeError: If either of `atol` and `rtol` is not float. | |||
| TypeError: If `equal_nan` is not bool. | |||
| TypeError: If the dtype of `input` is not same as the `other`. | |||
| ValueError: If shape of `input` is not same as the `other`. | |||
| ValueError: If either of `atol` and `rtol` is less than zero. | |||
| ValueError: If `equal_nan` is False. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> input = Tensor(np.array([1.3, 2.1, 3.2, 4.1, 5.1]), mindspore.float16) | |||
| >>> other = Tensor(np.array([1.3, 3.3, 2.3, 3.1, 5.1]), mindspore.float16) | |||
| >>> output = ops.IsClose()(input, other) | |||
| >>> print(output) | |||
| [true false false false true] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, rtol=1e-05, atol=1e-08, equal_nan=True): | |||
| """Initialize IsClose""" | |||
| validator.check_value_type('rtol', rtol, [float], self.name) | |||
| validator.check_value_type('atol', atol, [float], self.name) | |||
| validator.check_value_type('equal_nan', equal_nan, [bool], self.name) | |||
| if equal_nan is not True: | |||
| raise ValueError("For IsClose, the `equal_nan` must be True, but got False.") | |||
| validator.check_non_negative_float(rtol, 'rtol', self.name) | |||
| validator.check_non_negative_float(atol, 'atol', self.name) | |||
| @@ -1195,7 +1195,10 @@ test_case_math_ops = [ | |||
| ('IsInf', { | |||
| 'block': P.IsInf(), | |||
| 'desc_inputs': [Tensor(np.array([np.log(-1), 1, np.log(0)]).astype(np.float32))], | |||
| 'desc_bprop': [], | |||
| 'desc_bprop': []}), | |||
| ('IsClose', { | |||
| 'block': P.IsClose(rtol=1e-05, atol=1e-08, equal_nan=True), | |||
| 'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32)], | |||
| 'skip': ['backward']}), | |||
| ('ACos', { | |||
| 'block': P.ACos(), | |||