You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

acos.cc 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ops/acos.h"
  17. #include <string>
  18. #include <algorithm>
  19. #include <map>
  20. #include <set>
  21. #include <vector>
  22. #include "ops/op_utils.h"
  23. #include "utils/check_convert_utils.h"
  24. #include "abstract/primitive_infer_map.h"
  25. namespace mindspore {
  26. namespace ops {
  27. namespace {
  28. abstract::ShapePtr ACosInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
  29. MS_EXCEPTION_IF_NULL(primitive);
  30. auto prim_name = primitive->name();
  31. (void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
  32. auto x = input_args[0]->BuildShape();
  33. const int64_t max_dim = 8;
  34. auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
  35. (void)CheckAndConvertUtils::CheckInteger("The dimension of Acos input", SizeToLong(in_shape.size()), kLessThan,
  36. max_dim, prim_name);
  37. MS_EXCEPTION_IF_NULL(x);
  38. auto shape_element = x->cast<abstract::ShapePtr>();
  39. MS_EXCEPTION_IF_NULL(shape_element);
  40. return shape_element;
  41. }
  42. TypePtr ACosInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
  43. MS_EXCEPTION_IF_NULL(primitive);
  44. auto prim_name = primitive->name();
  45. MS_EXCEPTION_IF_NULL(input_args[0]);
  46. auto x_type = input_args[0]->BuildType();
  47. const std::set valid_types = {kFloat16, kFloat32, kFloat64};
  48. (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
  49. return x_type;
  50. }
  51. } // namespace
  52. AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
  53. const std::vector<AbstractBasePtr> &input_args) {
  54. MS_EXCEPTION_IF_NULL(primitive);
  55. const int64_t input_num = 1;
  56. CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
  57. auto infer_type = ACosInferType(primitive, input_args);
  58. auto infer_shape = ACosInferShape(primitive, input_args);
  59. return abstract::MakeAbstract(infer_shape, infer_type);
  60. }
  61. REGISTER_PRIMITIVE_EVAL_IMPL(ACos, prim::kPrimACos, ACosInfer, nullptr, true);
  62. } // namespace ops
  63. } // namespace mindspore