You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imperative_rt.cpp 1.0 kB

12345678910111213141516171819202122232425262728293031
  1. #include "./imperative_rt.h"
  2. #include <pybind11/numpy.h>
  3. #include <pybind11/operators.h>
  4. #include <future>
  5. #include <unordered_map>
  6. #include <variant>
  7. #include "./common.h"
  8. #include "./helper.h"
  9. #include "megbrain/imperative.h"
  10. #include "megbrain/imperative/interpreter.h"
  11. #include "megbrain/imperative/ops/opr_attr.h"
  12. namespace py = pybind11;
  13. using namespace mgb;
  14. using namespace imperative;
  15. using namespace interpreter;
  16. void init_imperative_rt(py::module m) {
  17. auto make_backward_graph = [](const OpDef& def,
  18. const SmallVector<LogicalTensorDesc>& inputs,
  19. const SmallVector<bool>& input_requires_grad,
  20. const SmallVector<bool>& output_has_grad) {
  21. auto result = OpDef::make_backward_graph(
  22. def, inputs, input_requires_grad, output_has_grad);
  23. return std::make_tuple("backward_graph", result.input_mask, result.output_mask);
  24. };
  25. m.def("make_backward_graph", make_backward_graph);
  26. }