You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module_trace.h 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #pragma once
  2. #include <list>
  3. #include "megbrain/imperative/transformations/trace.h"
  4. #include "megbrain/imperative/utils/map.h"
  5. #include "./tensor.h"
  6. namespace mgb::imperative::python {
  7. namespace py = pybind11;
  8. class ModuleTraceTransformation final : public Transformation {
  9. private:
  10. py::function m_hook_fn;
  11. int m_enabled = 0;
  12. ValueRefList apply_module_trace_hook(const OpDef& op, Span<ValueRef> input_values) {
  13. py::list input_tws;
  14. for (auto&& input_value : input_values) {
  15. input_tws.append(TensorWrapper::make(py_tensor_type, input_value));
  16. }
  17. py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws);
  18. ValueRefList outputs(output_tws.size());
  19. auto it = outputs.begin();
  20. for (auto&& output_tw : output_tws) {
  21. *(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data();
  22. }
  23. return outputs;
  24. }
  25. public:
  26. inline static WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
  27. ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {}
  28. ValueRefList apply_transformation(
  29. const Operator& op, Span<ValueRef> inputs) override {
  30. if (op.is<ApplyOp>() && m_enabled > 0) {
  31. auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs);
  32. return outputs;
  33. } else if (op.is<RenameValue>()) {
  34. auto outputs = imperative::apply(op, inputs);
  35. if (auto module_trace_info = module_trace_info_map.try_get(inputs[0])) {
  36. if (module_trace_info->ptr()) {
  37. auto node = module_trace_info.value();
  38. module_trace_info_map[outputs[0]] = module_trace_info.value();
  39. }
  40. }
  41. return outputs;
  42. } else {
  43. return imperative::apply(op, inputs);
  44. }
  45. }
  46. void enable() { m_enabled = 1; }
  47. void disable() { m_enabled = 0; }
  48. bool enabled() const { return m_enabled; }
  49. ValueRef unwrap(ValueRef value) override { return value; }
  50. std::string name() const override { return "ModuleTraceTransformation"; }
  51. };
  52. } // namespace mgb::imperative::python