You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

signature_py.cc 3.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ir/signature.h"
  17. #include "pybind11/operators.h"
  18. #include "pybind_api/api_register.h"
  19. #include "pipeline/jit/parse/data_converter.h"
  20. namespace py = pybind11;
  21. namespace mindspore {
  22. static ValuePtr PyArgToValue(const py::object &arg) {
  23. if (py::isinstance<SignatureEnumKind>(arg) &&
  24. py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
  25. return nullptr;
  26. }
  27. return parse::data_converter::PyDataToValue(arg);
  28. }
  29. // Bind SignatureEnumRW as a python class.
  30. REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
  31. (void)py::class_<Signature>(*m, "Signature")
  32. .def(py::init([](std::string name, SignatureEnumRW rw, SignatureEnumKind kind,
  33. py::object arg_default, SignatureEnumDType dtype) {
  34. auto default_value = PyArgToValue(arg_default);
  35. return Signature(name, rw, kind, default_value, dtype);
  36. }));
  37. (void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
  38. .value("RW_READ", SignatureEnumRW::kRWRead)
  39. .value("RW_WRITE", SignatureEnumRW::kRWWrite)
  40. .value("RW_REF", SignatureEnumRW::kRWRef)
  41. .value("RW_EMPTY_DEFAULT_VALUE", SignatureEnumRW::kRWEmptyDefaultValue);
  42. (void)py::enum_<SignatureEnumKind>(*m, "signature_kind", py::arithmetic())
  43. .value("KIND_POSITIONAL_KEYWORD", SignatureEnumKind::kKindPositionalKeyword)
  44. .value("KIND_VAR_POSITIONAL", SignatureEnumKind::kKindVarPositional)
  45. .value("KIND_KEYWORD_ONLY", SignatureEnumKind::kKindKeywordOnly)
  46. .value("KIND_VAR_KEYWARD", SignatureEnumKind::kKindVarKeyword)
  47. .value("KIND_EMPTY_DEFAULT_VALUE", SignatureEnumKind::kKindEmptyDefaultValue);
  48. (void)py::enum_<SignatureEnumDType>(*m, "signature_dtype", py::arithmetic())
  49. .value("T", SignatureEnumDType::kDType)
  50. .value("T1", SignatureEnumDType::kDType1)
  51. .value("T2", SignatureEnumDType::kDType2)
  52. .value("T3", SignatureEnumDType::kDType3)
  53. .value("T4", SignatureEnumDType::kDType4)
  54. .value("T5", SignatureEnumDType::kDType5)
  55. .value("T6", SignatureEnumDType::kDType6)
  56. .value("T7", SignatureEnumDType::kDType7)
  57. .value("T8", SignatureEnumDType::kDType8)
  58. .value("T9", SignatureEnumDType::kDType9)
  59. .value("T_EMPTY_DEFAULT_VALUE", SignatureEnumDType::kDTypeEmptyDefaultValue);
  60. }));
  61. } // namespace mindspore