You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.h 1.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #pragma once
  2. #include "./tensor.h"
  3. #include "megbrain/imperative/ops/utility.h"
  4. #include "megbrain/imperative/transformations/grad.h"
  5. #include "megbrain/utils/small_vector.h"
  6. #include <memory>
  7. #include <optional>
  8. namespace mgb::imperative::python {
  9. struct GradKeyWrapper : NonCopyableObj {
  10. using wrap_t = pyext17::wrap<GradKeyWrapper>;
  11. static constexpr auto tp_name = pybind11::detail::_("GradKey");
  12. std::string m_name;
  13. std::shared_ptr<GradKey> m_key;
  14. std::shared_ptr<GradTransformation> m_transformation;
  15. std::unique_ptr<CleanupGuard<>> m_transformation_guard;
  16. GradKeyWrapper();
  17. PyObject* get_name();
  18. void set_name(pybind11::handle name);
  19. void attach(PyObject* const* args, size_t nargs);
  20. static void backward(GradKeyWrapper* self, pybind11::list, pybind11::list);
  21. static pybind11::function get_backward_closure(
  22. GradKeyWrapper* self, pybind11::list);
  23. PyObject* is_attached_to(PyObject* const* args, size_t nargs);
  24. void enter();
  25. void exit();
  26. void suppress();
  27. void resume();
  28. static GradKeyWrapper* get(std::shared_ptr<GradKey> key);
  29. ~GradKeyWrapper();
  30. };
  31. } // namespace mgb::imperative::python
  32. namespace pybind11::detail {
  33. template <>
  34. struct type_caster<mgb::imperative::python::GradKeyWrapper>
  35. : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {};
  36. } // namespace pybind11::detail