Browse Source

!15239 gelu

From: @shen_jingxing
Reviewed-by: 
Signed-off-by:
pull/15239/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
27a04cb2e3
3 changed files with 36 additions and 33 deletions
  1. +30
    -19
      mindspore/core/ops/gelu.cc
  2. +5
    -8
      mindspore/core/ops/gelu.h
  3. +1
    -6
      mindspore/ops/operations/nn_ops.py

+ 30
- 19
mindspore/core/ops/gelu.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,40 +13,51 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <set>
#include <map>
#include "ops/gelu.h"
#include <string>
#include <memory>
#include <algorithm>
#include <map>
#include <set>
#include <vector>

#include "ops/gelu.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"

namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(input_shape);
}

TypePtr GeLUInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("gelu infer", input_args.size(), kEqual, 1, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto in_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape];
auto max_shape = shape_map[kMaxShape];
if (min_shape.size() != 0 && max_shape.size() != 0) {
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto op_name = prim->name();
CheckAndConvertUtils::CheckInteger("gelu infer", input_args.size(), kEqual, 1, op_name);
std::map<std::string, TypePtr> types;
types.emplace("input_x", input_args[0]->BuildType());
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
MS_EXCEPTION_IF_NULL(input_args[0]);
types.emplace("x", input_args[0]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(GeLUInferType(primitive, input_args),
GeLUInferShape(primitive, input_args)->shape());
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameGeLU, GeLU);
REGISTER_PRIMITIVE_EVAL_IMPL(GeLU, prim::kPrimGeLU, GeLUInfer, nullptr, true);

} // namespace ops
} // namespace mindspore

+ 5
- 8
mindspore/core/ops/gelu.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,30 +13,27 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_OPS_GELU_H_
#define MINDSPORE_CORE_OPS_GELU_H_
#include <vector>
#include <memory>

#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace ops {
constexpr auto kNameGeLU = "GeLU";
constexpr auto kNameGeLU = prim::kGeLU;
class GeLU : public PrimitiveC {
public:
GeLU() : PrimitiveC(kNameGeLU) {}
GeLU() : PrimitiveC(kNameGeLU) { InitIOName({"x"}, {"output"}); }
~GeLU() = default;
MS_DECLARE_PARENT(GeLU, PrimitiveC);
void Init() {}
};
AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);

using PrimGeLUPtr = std::shared_ptr<GeLU>;
} // namespace ops
} // namespace mindspore

#endif // MINDSPORE_CORE_OPS_GELU_H_

+ 1
- 6
mindspore/ops/operations/nn_ops.py View File

@@ -3318,7 +3318,7 @@ class Gelu(PrimitiveWithInfer):
return input_x


class GeLU(PrimitiveWithInfer):
class GeLU(Primitive):
r"""
Gaussian Error Linear Units activation function.

@@ -3359,12 +3359,7 @@ class GeLU(PrimitiveWithInfer):
"""Initialize GeLU"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])

def infer_shape(self, input_x):
return input_x

def infer_dtype(self, input_x):
validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
return input_x


class FastGelu(PrimitiveWithInfer):


Loading…
Cancel
Save