You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_to_function.cc 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "operator/prim_to_function.h"
  17. #include <exception>
  18. #include <iostream>
  19. #include <string>
  20. namespace mindspore {
  21. // namespace to support prim related definition
  22. namespace prim {
  23. PrimToFunction::PrimToFunction()
  24. : prim_func_type_map_({// ONE_ARG prim
  25. {"bool_not", kPrimTypeOneArg},
  26. {"scalar_cos", kPrimTypeOneArg},
  27. {"scalar_exp", kPrimTypeOneArg},
  28. {"scalar_floor", kPrimTypeOneArg},
  29. {"scalar_log", kPrimTypeOneArg},
  30. {"scalar_sin", kPrimTypeOneArg},
  31. {"scalar_tan", kPrimTypeOneArg},
  32. {"scalar_trunc", kPrimTypeOneArg},
  33. {"typeof", kPrimTypeOneArg},
  34. {"scalar_uadd", kPrimTypeOneArg},
  35. {"scalar_usub", kPrimTypeOneArg},
  36. // TWO_ARGS prim
  37. {"scalar_add", kPrimTypeTwoArgs},
  38. {"bool_and", kPrimTypeTwoArgs},
  39. {"bool_eq", kPrimTypeTwoArgs},
  40. {"bool_or", kPrimTypeTwoArgs},
  41. {"scalar_div", kPrimTypeTwoArgs},
  42. {"scalar_eq", kPrimTypeTwoArgs},
  43. {"scalar_ge", kPrimTypeTwoArgs},
  44. {"scalar_gt", kPrimTypeTwoArgs},
  45. {"scalar_le", kPrimTypeTwoArgs},
  46. {"scalar_lt", kPrimTypeTwoArgs},
  47. {"scalar_ne", kPrimTypeTwoArgs},
  48. {"scalar_mod", kPrimTypeTwoArgs},
  49. {"scalar_mul", kPrimTypeTwoArgs},
  50. {"scalar_pow", kPrimTypeTwoArgs},
  51. {"scalar_sub", kPrimTypeTwoArgs},
  52. {"scalar_floordiv", kPrimTypeTwoArgs}}) {}
  53. bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const {
  54. bool result = false;
  55. if (func != nullptr) {
  56. int args_num = GetPrimType(prim);
  57. std::vector<TypePtr> one_arg{std::make_shared<Number>()};
  58. std::vector<TypePtr> two_args{std::make_shared<Number>(), std::make_shared<Number>()};
  59. TypePtr retval = std::make_shared<Number>();
  60. result = true;
  61. switch (args_num) {
  62. case kPrimTypeOneArg:
  63. *func = Function(one_arg, retval).DeepCopy()->cast<FunctionPtr>();
  64. break;
  65. case kPrimTypeTwoArgs:
  66. *func = Function(two_args, retval).DeepCopy()->cast<FunctionPtr>();
  67. break;
  68. default:
  69. result = false;
  70. break;
  71. }
  72. }
  73. return result;
  74. }
  75. int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const {
  76. MS_EXCEPTION_IF_NULL(prim);
  77. int prim_type = static_cast<int>(kPrimTypeUnknown);
  78. auto value = prim_func_type_map_.find(prim->name());
  79. if (value != prim_func_type_map_.end()) {
  80. prim_type = value->second;
  81. }
  82. return prim_type;
  83. }
  84. } // namespace prim
  85. } // namespace mindspore