|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -13,40 +13,51 @@ |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <set> |
|
|
|
#include <map> |
|
|
|
#include "ops/gelu.h" |
|
|
|
#include <string> |
|
|
|
#include <memory> |
|
|
|
#include <algorithm> |
|
|
|
#include <map> |
|
|
|
#include <set> |
|
|
|
#include <vector> |
|
|
|
|
|
|
|
#include "ops/gelu.h" |
|
|
|
#include "ops/op_utils.h" |
|
|
|
#include "utils/check_convert_utils.h" |
|
|
|
#include "abstract/primitive_infer_map.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace ops { |
|
|
|
namespace { |
|
|
|
abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; |
|
|
|
return std::make_shared<abstract::Shape>(input_shape); |
|
|
|
} |
|
|
|
|
|
|
|
TypePtr GeLUInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
for (const auto &item : input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(item); |
|
|
|
auto op_name = primitive->name(); |
|
|
|
CheckAndConvertUtils::CheckInteger("gelu infer", input_args.size(), kEqual, 1, op_name); |
|
|
|
MS_EXCEPTION_IF_NULL(input_args[0]); |
|
|
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); |
|
|
|
auto in_shape = shape_map[kShape]; |
|
|
|
auto min_shape = shape_map[kMinShape]; |
|
|
|
auto max_shape = shape_map[kMaxShape]; |
|
|
|
if (min_shape.size() != 0 && max_shape.size() != 0) { |
|
|
|
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape); |
|
|
|
} |
|
|
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; |
|
|
|
return std::make_shared<abstract::Shape>(in_shape); |
|
|
|
} |
|
|
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
auto op_name = prim->name(); |
|
|
|
CheckAndConvertUtils::CheckInteger("gelu infer", input_args.size(), kEqual, 1, op_name); |
|
|
|
std::map<std::string, TypePtr> types; |
|
|
|
types.emplace("input_x", input_args[0]->BuildType()); |
|
|
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; |
|
|
|
MS_EXCEPTION_IF_NULL(input_args[0]); |
|
|
|
types.emplace("x", input_args[0]->BuildType()); |
|
|
|
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
return std::make_shared<abstract::AbstractTensor>(GeLUInferType(primitive, input_args), |
|
|
|
GeLUInferShape(primitive, input_args)->shape()); |
|
|
|
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), |
|
|
|
InferShape(primitive, input_args)->shape()); |
|
|
|
} |
|
|
|
REGISTER_PRIMITIVE_C(kNameGeLU, GeLU); |
|
|
|
REGISTER_PRIMITIVE_EVAL_IMPL(GeLU, prim::kPrimGeLU, GeLUInfer, nullptr, true); |
|
|
|
|
|
|
|
} // namespace ops |
|
|
|
} // namespace mindspore |