You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module.cpp 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. /**
  2. * \file imperative/python/src/module.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <pybind11/eval.h>
  12. #define DO_IMPORT_ARRAY
  13. #include "./helper.h"
  14. #include "./numpy_dtypes.h"
  15. #include "megdnn/handle.h"
  16. #include "./common.h"
  17. #include "./graph_rt.h"
  18. #include "./imperative_rt.h"
  19. #include "./ops.h"
  20. #include "./utils.h"
  21. #include "./tensor.h"
  22. namespace py = pybind11;
  23. using namespace mgb::imperative::python;
  24. #ifndef MODULE_NAME
  25. #define MODULE_NAME imperative_rt
  26. #endif
  27. namespace megdnn {
  28. extern const std::shared_ptr<Handle>& inplace_cpu_handle(int debug_level = 0);
  29. }
  30. PYBIND11_MODULE(MODULE_NAME, m) {
  31. // initialize numpy
  32. if ([]() {
  33. import_array1(1);
  34. return 0;
  35. }()) {
  36. throw py::error_already_set();
  37. }
  38. megdnn::inplace_cpu_handle();
  39. py::module::import("sys").attr("modules")[m.attr("__name__")] = m;
  40. m.attr("__package__") = m.attr("__name__");
  41. m.attr("__builtins__") = py::module::import("builtins");
  42. auto atexit = py::module::import("atexit");
  43. atexit.attr("register")(py::cpp_function([]() {
  44. py::gil_scoped_release _;
  45. py_task_q.wait_all_task_finish();
  46. }));
  47. auto common = submodule(m, "common");
  48. auto utils = submodule(m, "utils");
  49. auto imperative = submodule(m, "imperative");
  50. auto graph = submodule(m, "graph");
  51. auto ops = submodule(m, "ops");
  52. init_common(common);
  53. init_utils(utils);
  54. init_imperative_rt(imperative);
  55. init_graph_rt(graph);
  56. init_ops(ops);
  57. py::exec(
  58. R"(
  59. from .common import *
  60. from .utils import *
  61. from .imperative import *
  62. from .graph import *
  63. from .ops import OpDef
  64. )",
  65. py::getattr(m, "__dict__"));
  66. init_tensor(submodule(m, "core2"));
  67. }