You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cpp 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. /**
  2. * \file imperative/src/test/helper.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "helper.h"
  12. #include "megbrain/graph.h"
  13. #include "megbrain/opr/io.h"
  14. #include <pybind11/embed.h>
  15. #include <pybind11/numpy.h>
  16. #include <memory>
  17. namespace py = pybind11;
  18. namespace mgb {
  19. namespace imperative {
  20. namespace {
  21. #define XSTR(s) STR(s)
  22. #define STR(s) #s
  23. #define CONCAT(a, b) a##b
  24. #define PYINIT(name) CONCAT(PyInit_, name)
  25. #define pyinit PYINIT(MODULE_NAME)
  26. #define UNUSED __attribute__((unused))
  27. extern "C" PyObject* pyinit();
  28. class PyEnv {
  29. static std::unique_ptr<PyEnv> m_instance;
  30. std::unique_ptr<py::scoped_interpreter> m_interpreter;
  31. PyEnv();
  32. public:
  33. static PyEnv& instance();
  34. static py::module get();
  35. };
  36. std::unique_ptr<PyEnv> PyEnv::m_instance = nullptr;
  37. PyEnv::PyEnv() {
  38. mgb_assert(!m_instance);
  39. auto err = PyImport_AppendInittab(XSTR(MODULE_NAME), &pyinit);
  40. mgb_assert(!err);
  41. m_interpreter.reset(new py::scoped_interpreter());
  42. }
  43. PyEnv& PyEnv::instance() {
  44. if (!m_instance) {
  45. m_instance.reset(new PyEnv());
  46. }
  47. return *m_instance;
  48. }
  49. py::module PyEnv::get() {
  50. instance();
  51. return py::module::import(XSTR(MODULE_NAME));
  52. }
  53. py::array array(const Tensor& x) {
  54. PyEnv::get();
  55. return py::cast(x).attr("numpy")();
  56. }
  57. py::array array(const HostTensorND& x) {
  58. return array(*Tensor::make(x));
  59. }
  60. py::array array(const DeviceTensorND& x) {
  61. return array(*Tensor::make(x));
  62. }
  63. UNUSED void print(const Tensor& x) {
  64. return print(array(x));
  65. }
  66. UNUSED void print(const HostTensorND& x) {
  67. return print(array(x));
  68. }
  69. UNUSED void print(const DeviceTensorND& x) {
  70. return print(array(x));
  71. }
  72. UNUSED void print(const char* s) {
  73. PyEnv::instance();
  74. py::print(s);
  75. }
  76. } // anonymous namespace
  77. OprChecker::OprChecker(std::shared_ptr<OpDef> opdef) : m_op(opdef) {}
  78. void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) {
  79. HostTensorGenerator<> gen;
  80. size_t nr_inps = inp_keys.size();
  81. SmallVector<HostTensorND> host_inp(nr_inps);
  82. VarNodeArray sym_inp(nr_inps);
  83. auto graph = ComputingGraph::make();
  84. graph->options().graph_opt_level = 0;
  85. for (size_t i = 0; i < nr_inps; ++i) {
  86. // TODO: remove std::visit for support osx 10.12
  87. host_inp[i] = std::visit(
  88. [&gen](auto&& arg) -> HostTensorND {
  89. using T = std::decay_t<decltype(arg)>;
  90. if constexpr (std::is_same_v<TensorShape, T>) {
  91. return *gen(arg);
  92. } else {
  93. static_assert(std::is_same_v<HostTensorND, T>);
  94. return arg;
  95. }
  96. },
  97. inp_keys[i]);
  98. sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node();
  99. }
  100. auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp);
  101. size_t nr_oups = sym_oup.size();
  102. ComputingGraph::OutputSpec oup_spec(nr_oups);
  103. SmallVector<HostTensorND> host_sym_oup(nr_oups);
  104. for (size_t i = 0; i < nr_oups; ++i) {
  105. oup_spec[i] = make_callback_copy(sym_oup[i], host_sym_oup[i]);
  106. }
  107. auto func = graph->compile(oup_spec);
  108. SmallVector<TensorPtr> imp_physical_inp(nr_inps);
  109. for (size_t i = 0; i < nr_inps; ++i) {
  110. imp_physical_inp[i] = Tensor::make(host_inp[i]);
  111. }
  112. SmallVector<LogicalTensorDesc> output_descs;
  113. auto imp_oup = OpDef::apply_on_physical_tensor(
  114. *m_op, imp_physical_inp, output_descs, false);
  115. mgb_assert(imp_oup.size() == nr_oups);
  116. // check input not modified
  117. for (size_t i = 0; i < imp_physical_inp.size(); ++i) {
  118. HostTensorND hv;
  119. hv.copy_from(imp_physical_inp[i]->dev_tensor()).sync();
  120. MGB_ASSERT_TENSOR_EQ(hv, host_inp[i]);
  121. }
  122. SmallVector<HostTensorND> host_imp_oup(nr_oups);
  123. for (size_t i = 0; i < nr_oups; ++i) {
  124. host_imp_oup[i].copy_from(imp_oup[i]->dev_tensor()).sync();
  125. }
  126. func->execute().wait(); // run last because it may contain inplace operations
  127. for (size_t i = 0; i < nr_oups; ++i) {
  128. if (bypass.find(i) != bypass.end())
  129. continue;
  130. MGB_ASSERT_TENSOR_EQ(host_sym_oup[i], host_imp_oup[i]);
  131. }
  132. }
  133. TEST(TestHelper, PyModule) {
  134. py::module m = PyEnv::get();
  135. py::print(m);
  136. py::print(py::cast(DeviceTensorND()));
  137. }
  138. } // namespace imperative
  139. } // namespace mgb
  140. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}