You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cell_py.cc 2.9 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pybind_api/ir/cell_py.h"
  17. #include <string>
  18. #include "pybind_api/api_register.h"
  19. #include "abstract/abstract_value.h"
  20. #include "pipeline/jit/parse/python_adapter.h"
  21. namespace mindspore {
  22. void CellPy::AddAttr(CellPtr cell, const std::string &name, const py::object &obj) {
  23. ValuePtr converted_ret = nullptr;
  24. MS_EXCEPTION_IF_NULL(cell);
  25. if (py::isinstance<py::module>(obj)) {
  26. MS_LOG(EXCEPTION) << "Cell set_attr failed, attr should not be py::module";
  27. }
  28. bool converted = parse::ConvertData(obj, &converted_ret, true);
  29. if (!converted) {
  30. MS_LOG(DEBUG) << "Attribute convert error with type: " << std::string(py::str(obj));
  31. } else {
  32. MS_LOG(DEBUG) << cell->ToString() << " add attr " << name << converted_ret->ToString();
  33. cell->AddAttr(name, converted_ret);
  34. }
  35. }
  36. // Define python 'Cell' class.
  37. REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) {
  38. (void)py::class_<Cell, std::shared_ptr<Cell>>(*m, "Cell_")
  39. .def(py::init<std::string &>())
  40. .def("__str__", &Cell::ToString)
  41. .def("_add_attr", &CellPy::AddAttr, "Add Cell attr.")
  42. .def("_del_attr", &Cell::DelAttr, "Delete Cell attr.")
  43. .def(
  44. "construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; },
  45. "construct")
  46. .def(py::pickle(
  47. [](const Cell &cell) { // __getstate__
  48. /* Return a tuple that fully encodes the state of the object */
  49. return py::make_tuple(py::str(cell.name()));
  50. },
  51. [](const py::tuple &tup) { // __setstate__
  52. if (tup.size() != 1) {
  53. throw std::runtime_error("Invalid state!");
  54. }
  55. /* Create a new C++ instance */
  56. Cell data(tup[0].cast<std::string>());
  57. return data;
  58. }));
  59. }));
  60. } // namespace mindspore