You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module_trace.h 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. /**
  2. * \file imperative/python/src/module_trace.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/imperative/transformations/trace.h"
  13. #include "megbrain/imperative/utils/map.h"
  14. #include "megbrain/imperative/utils/stats.h"
  15. #include "./tensor.h"
  16. namespace mgb::imperative::python {
  17. namespace py = pybind11;
  18. class ModuleTraceTransformation final : public Transformation {
  19. private:
  20. py::function m_hook_fn;
  21. int m_enabled = 0;
  22. ValueRefList apply_module_trace_hook(const OpDef& op, Span<ValueRef> input_values) {
  23. py::list input_tws;
  24. for (auto&& input_value : input_values) {
  25. input_tws.append(TensorWrapper::make(py_tensor_type, input_value));
  26. }
  27. py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws);
  28. ValueRefList outputs(output_tws.size());
  29. auto it = outputs.begin();
  30. for (auto&& output_tw : output_tws) {
  31. *(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data();
  32. }
  33. return outputs;
  34. }
  35. public:
  36. ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {}
  37. ValueRefList apply_transformation(
  38. const Operator& op, Span<ValueRef> inputs) override {
  39. if (op.is<ApplyOp>() && m_enabled > 0) {
  40. auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs);
  41. return outputs;
  42. } else {
  43. return imperative::apply(op, inputs);
  44. }
  45. }
  46. void enable() { m_enabled = 1; }
  47. void disable() { m_enabled = 0; }
  48. bool enabled() const { return m_enabled; }
  49. ValueRef unwrap(ValueRef value) override { return value; }
  50. std::string name() const override { return "ModuleTraceTransformation"; }
  51. };
  52. } // namespace mgb::imperative::python