| @@ -1,53 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/add.h" | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto add_prim = primitive->cast<PrimAddPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add_prim); | |||||
| auto op_name = add_prim->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | |||||
| } | |||||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { | |||||
| MS_LOG(EXCEPTION) << "nullptr"; | |||||
| } | |||||
| std::map<std::string, TypePtr> types; | |||||
| types.emplace("x", input_args[0]->BuildType()); | |||||
| types.emplace("y", input_args[1]->BuildType()); | |||||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||||
| return TypeIdToType(infer_type); | |||||
| } | |||||
| } // namespace | |||||
| AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||||
| InferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameAdd, Add); | |||||
| } // namespace mindspore | |||||