You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module.cpp 1.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #include <pybind11/eval.h>
  2. #define DO_IMPORT_ARRAY
  3. #include "./helper.h"
  4. #include "./numpy_dtypes.h"
  5. #include "megdnn/handle.h"
  6. #include "./common.h"
  7. #include "./graph_rt.h"
  8. #include "./imperative_rt.h"
  9. #include "./ops.h"
  10. #include "./utils.h"
  11. #include "./tensor.h"
  12. namespace py = pybind11;
  13. using namespace mgb::imperative::python;
  14. #ifndef MODULE_NAME
  15. #define MODULE_NAME imperative_rt
  16. #endif
  17. namespace megdnn {
  18. extern const std::shared_ptr<Handle>& inplace_cpu_handle(int debug_level = 0);
  19. }
  20. PYBIND11_MODULE(MODULE_NAME, m) {
  21. // initialize numpy
  22. if ([]() {
  23. import_array1(1);
  24. return 0;
  25. }()) {
  26. throw py::error_already_set();
  27. }
  28. megdnn::inplace_cpu_handle();
  29. py::module::import("sys").attr("modules")[m.attr("__name__")] = m;
  30. m.attr("__package__") = m.attr("__name__");
  31. m.attr("__builtins__") = py::module::import("builtins");
  32. auto atexit = py::module::import("atexit");
  33. atexit.attr("register")(py::cpp_function([]() {
  34. py::gil_scoped_release _;
  35. py_task_q.wait_all_task_finish();
  36. }));
  37. auto common = submodule(m, "common");
  38. auto utils = submodule(m, "utils");
  39. auto imperative = submodule(m, "imperative");
  40. auto graph = submodule(m, "graph");
  41. auto ops = submodule(m, "ops");
  42. init_common(common);
  43. init_utils(utils);
  44. init_imperative_rt(imperative);
  45. init_graph_rt(graph);
  46. init_ops(ops);
  47. py::exec(
  48. R"(
  49. from .common import *
  50. from .utils import *
  51. from .imperative import *
  52. from .graph import *
  53. from .ops import OpDef
  54. )",
  55. py::getattr(m, "__dict__"));
  56. init_tensor(submodule(m, "core2"));
  57. }