You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.h 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. /**
  2. * \file imperative/python/src/grad.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "./tensor.h"
  13. #include "megbrain/imperative/ops/utility.h"
  14. #include "megbrain/imperative/transformations/grad.h"
  15. #include "megbrain/utils/small_vector.h"
  16. #include <memory>
  17. #include <optional>
  18. namespace mgb::imperative::python {
  19. struct GradKeyWrapper : NonCopyableObj {
  20. using wrap_t = pyext17::wrap<GradKeyWrapper>;
  21. static constexpr auto tp_name = pybind11::detail::_("GradKey");
  22. std::string m_name;
  23. std::shared_ptr<GradKey> m_key;
  24. std::shared_ptr<GradTransformation> m_transformation;
  25. std::unique_ptr<CleanupGuard<>> m_transformation_guard;
  26. GradKeyWrapper();
  27. PyObject* get_name();
  28. void set_name(pybind11::handle name);
  29. void attach(PyObject* const* args, size_t nargs);
  30. static void backward(GradKeyWrapper* self, pybind11::list, pybind11::list);
  31. static pybind11::function get_backward_closure(
  32. GradKeyWrapper* self, pybind11::list);
  33. PyObject* is_attached_to(PyObject* const* args, size_t nargs);
  34. void enter();
  35. void exit();
  36. void suppress();
  37. void resume();
  38. static GradKeyWrapper* get(std::shared_ptr<GradKey> key);
  39. ~GradKeyWrapper();
  40. };
  41. } // namespace mgb::imperative::python
  42. namespace pybind11::detail {
  43. template <>
  44. struct type_caster<mgb::imperative::python::GradKeyWrapper>
  45. : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {};
  46. } // namespace pybind11::detail