| @@ -64,6 +64,7 @@ constexpr auto kTile = "Tile"; | |||
| constexpr auto kBiasAddGrad = "BiasAddGrad"; | |||
| constexpr auto kCos = "Cos"; | |||
| constexpr auto kAbs = "Abs"; | |||
| constexpr auto kTrunc = "Trunc"; | |||
| constexpr auto kSquare = "Square"; | |||
| // Arrays | |||
| @@ -123,6 +124,7 @@ inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_l | |||
| inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin"); | |||
| inline const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos"); | |||
| inline const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan"); | |||
| inline const PrimitivePtr kPrimTrunc = std::make_shared<Primitive>(kTrunc); | |||
| // Comparisons | |||
| inline const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq"); | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/trunc.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr TruncInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| auto x_shape = input_args[0]->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(x_shape); | |||
| auto output_shape = x_shape->cast<abstract::ShapePtr>(); | |||
| return output_shape; | |||
| } | |||
| TypePtr TruncInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name()); | |||
| std::set<TypePtr> check_list = {kFloat16, kFloat32, kInt8, kInt32, kUInt8}; | |||
| auto input_type = input_args[0]->BuildType(); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_type, check_list, prim->name()); | |||
| return input_type; | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr TruncInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return abstract::MakeAbstract(TruncInferShape(primitive, input_args), TruncInferType(primitive, input_args)); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Trunc, prim::kPrimTrunc, TruncInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_TRUNC_H_ | |||
| #define MINDSPORE_CORE_OPS_TRUNC_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameTrunc = "Trunc"; | |||
| class Trunc : public PrimitiveC { | |||
| public: | |||
| Trunc() : PrimitiveC(kNameTrunc) { InitIOName({"input_x"}, {"output_y"}); } | |||
| ~Trunc() = default; | |||
| MS_DECLARE_PARENT(Trunc, PrimitiveC); | |||
| }; | |||
| AbstractBasePtr TruncInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimTruncPtr = std::shared_ptr<Trunc>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_TRUNC_H_ | |||
| @@ -21,6 +21,7 @@ from .. import functional as F | |||
| from .. import operations as P | |||
| from .._grad.grad_base import bprop_getters | |||
| from .._grad.grad_math_ops import binop_grad_common | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..operations import _grad_ops as G | |||
| from ..primitive import constexpr | |||
| @@ -96,3 +97,14 @@ def get_bprop_erfinv(self): | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.Trunc) | |||
| def get_bprop_trunc(self): | |||
| """Grad definition for `Trunc` operation.""" | |||
| def bprop(input_x, output_y, dout): | |||
| bc_x = zeros_like(input_x) | |||
| return (bc_x,) | |||
| return bprop | |||
| @@ -418,3 +418,4 @@ from .hshrink import _hshrink_tbe | |||
| from .hshrink_grad import _hshrink_grad_tbe | |||
| from .new_im2col import _new_im2col_tbe | |||
| from .non_zero_ds import _non_zero_ds_tbe | |||
| from .trunc import _trunc_tbe | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Trunc op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| trunc_op_info = TBERegOp("Trunc") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("trunc.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("trunc") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "input_x", False, "required", "all") \ | |||
| .output(0, "output_y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(trunc_op_info) | |||
| def _trunc_tbe(): | |||
| """Trunc TBE register""" | |||
| return | |||
| @@ -59,7 +59,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||
| Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | |||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | |||
| Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, | |||
| MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex) | |||
| MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc,) | |||
| from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | |||
| RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler, | |||
| @@ -483,6 +483,7 @@ __all__ = [ | |||
| "Conj", | |||
| "Real", | |||
| "Imag", | |||
| "Trunc", | |||
| "Complex" | |||
| ] | |||
| @@ -5516,3 +5516,31 @@ class Imag(PrimitiveWithInfer): | |||
| elif input_dtype == mstype.tensor_type(mstype.complex128): | |||
| output_dtype = mstype.float64 | |||
| return output_dtype | |||
| class Trunc(Primitive): | |||
| """ | |||
| Returns a new tensor with the truncated integer values of the elements of input. | |||
| Inputs: | |||
| - **input_x** (Tensor) - Input_x is a tensor. | |||
| Outputs: | |||
| Tensor, the same shape and data type as the input. | |||
| Raises: | |||
| TypeError: If `input_x` is not a Tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> trunc = ops.Trunc() | |||
| >>> output = trunc(Tensor(np.array([3.4742, 0.5466, -0.8008, -3.9079]),mindspore.float32)) | |||
| >>> print(output) | |||
| [ 3. 0. 0. -3.] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize Trunc""" | |||
| @@ -469,6 +469,10 @@ raise_set = [ | |||
| ('AssignAdd_Error', { | |||
| 'block': (P.AssignAdd(), {'exception': ValueError}), | |||
| 'desc_inputs': [[1]]}), | |||
| ('Trunc', { | |||
| 'block': P.Trunc(), | |||
| 'desc_inputs': [Tensor(np.array([[1.1, 2.2, -4.1]], np.float32))], | |||
| 'skip': ['backward']}), | |||
| ] | |||