diff --git a/mindspore/core/ops/gelu.cc b/mindspore/core/ops/gelu.cc index 8f3282b4c8..d8cba7a140 100644 --- a/mindspore/core/ops/gelu.cc +++ b/mindspore/core/ops/gelu.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,40 +13,51 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#include -#include +#include "ops/gelu.h" #include -#include +#include +#include +#include +#include -#include "ops/gelu.h" +#include "ops/op_utils.h" #include "utils/check_convert_utils.h" #include "abstract/primitive_infer_map.h" namespace mindspore { namespace ops { namespace { -abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { +abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - return std::make_shared(input_shape); -} - -TypePtr GeLUInferType(const PrimitivePtr &prim, const std::vector &input_args) { - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); + auto op_name = primitive->name(); + CheckAndConvertUtils::CheckInteger("gelu infer", input_args.size(), kEqual, 1, op_name); + MS_EXCEPTION_IF_NULL(input_args[0]); + auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + auto in_shape = shape_map[kShape]; + auto min_shape = shape_map[kMinShape]; + auto max_shape = shape_map[kMaxShape]; + if (min_shape.size() != 0 && max_shape.size() != 0) { + return std::make_shared(in_shape, min_shape, max_shape); } - const std::set valid_types = {kFloat16, kFloat32}; + return std::make_shared(in_shape); +} +TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto op_name = prim->name(); + CheckAndConvertUtils::CheckInteger("gelu infer", input_args.size(), kEqual, 1, op_name); std::map types; - types.emplace("input_x", input_args[0]->BuildType()); + const std::set valid_types = {kFloat16, kFloat32}; + MS_EXCEPTION_IF_NULL(input_args[0]); + types.emplace("x", input_args[0]->BuildType()); return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } } // namespace AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(GeLUInferType(primitive, input_args), - GeLUInferShape(primitive, input_args)->shape()); + return std::make_shared(InferType(primitive, input_args), + InferShape(primitive, input_args)->shape()); } -REGISTER_PRIMITIVE_C(kNameGeLU, GeLU); +REGISTER_PRIMITIVE_EVAL_IMPL(GeLU, prim::kPrimGeLU, GeLUInfer, nullptr, true); + } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/gelu.h b/mindspore/core/ops/gelu.h index b3ffa39a9e..17d83ac7e4 100644 --- a/mindspore/core/ops/gelu.h +++ b/mindspore/core/ops/gelu.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,30 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef MINDSPORE_CORE_OPS_GELU_H_ #define MINDSPORE_CORE_OPS_GELU_H_ #include #include - #include "ops/primitive_c.h" +#include "ops/op_utils.h" #include "abstract/abstract_value.h" #include "utils/check_convert_utils.h" namespace mindspore { namespace ops { -constexpr auto kNameGeLU = "GeLU"; +constexpr auto kNameGeLU = prim::kGeLU; class GeLU : public PrimitiveC { public: - GeLU() : PrimitiveC(kNameGeLU) {} + GeLU() : PrimitiveC(kNameGeLU) { InitIOName({"x"}, {"output"}); } ~GeLU() = default; MS_DECLARE_PARENT(GeLU, PrimitiveC); void Init() {} }; -AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); + using PrimGeLUPtr = std::shared_ptr; } // namespace ops } // namespace mindspore - #endif // MINDSPORE_CORE_OPS_GELU_H_ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a1b1475c48..793709a1b2 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3318,7 +3318,7 @@ class Gelu(PrimitiveWithInfer): return input_x -class GeLU(PrimitiveWithInfer): +class GeLU(Primitive): r""" Gaussian Error Linear Units activation function. @@ -3359,12 +3359,7 @@ class GeLU(PrimitiveWithInfer): """Initialize GeLU""" self.init_prim_io_names(inputs=['x'], outputs=['output']) - def infer_shape(self, input_x): - return input_x - def infer_dtype(self, input_x): - validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name) - return input_x class FastGelu(PrimitiveWithInfer):