You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module_trace.h 1.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. #pragma once
  2. #include <list>
  3. #include "megbrain/imperative/transformations/trace.h"
  4. #include "megbrain/imperative/utils/map.h"
  5. #include "./tensor.h"
  6. namespace mgb::imperative::python {
  7. namespace py = pybind11;
  8. class ModuleTraceTransformation final : public Transformation {
  9. private:
  10. py::function m_hook_fn;
  11. int m_enabled = 0;
  12. ValueRefList apply_module_trace_hook(const OpDef& op, Span<ValueRef> input_values) {
  13. py::list input_tws;
  14. for (auto&& input_value : input_values) {
  15. input_tws.append(TensorWrapper::make(py_tensor_type, input_value));
  16. }
  17. py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws);
  18. ValueRefList outputs(output_tws.size());
  19. auto it = outputs.begin();
  20. for (auto&& output_tw : output_tws) {
  21. *(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data();
  22. }
  23. return outputs;
  24. }
  25. public:
  26. ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {}
  27. ValueRefList apply_transformation(
  28. const Operator& op, Span<ValueRef> inputs) override {
  29. if (op.is<ApplyOp>() && m_enabled > 0) {
  30. auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs);
  31. return outputs;
  32. } else {
  33. return imperative::apply(op, inputs);
  34. }
  35. }
  36. void enable() { m_enabled = 1; }
  37. void disable() { m_enabled = 0; }
  38. bool enabled() const { return m_enabled; }
  39. ValueRef unwrap(ValueRef value) override { return value; }
  40. std::string name() const override { return "ModuleTraceTransformation"; }
  41. };
  42. } // namespace mgb::imperative::python