You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module_trace.h 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. /**
  2. * \file imperative/python/src/module_trace.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/imperative/transformations/trace.h"
  13. #include "megbrain/imperative/utils/map.h"
  14. #include "./tensor.h"
  15. namespace mgb::imperative::python {
  16. namespace py = pybind11;
  17. class ModuleTraceTransformation final : public Transformation {
  18. private:
  19. py::function m_hook_fn;
  20. int m_enabled = 0;
  21. std::vector<ValueRef> apply_module_trace_hook(
  22. const OpDef& op, Span<ValueRef> input_values) {
  23. py::list input_tws;
  24. for (auto&& input_value : input_values) {
  25. input_tws.append(TensorWrapper::make(py_tensor_type, input_value));
  26. }
  27. py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws);
  28. std::vector<ValueRef> outputs;
  29. for (auto&& output_tw : output_tws) {
  30. outputs.push_back(
  31. TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data());
  32. }
  33. return outputs;
  34. }
  35. public:
  36. ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {}
  37. std::vector<ValueRef> apply_transformation(
  38. const Operator& op, Span<ValueRef> inputs) override {
  39. if (op.is<ApplyOp>() && m_enabled > 0) {
  40. auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs);
  41. return outputs;
  42. } else {
  43. return imperative::apply(op, inputs);
  44. }
  45. }
  46. ValueRef unwrap(ValueRef value) override { return value; }
  47. std::string name() const override { return "ModuleTraceTransformation"; }
  48. };
  49. } // namespace mgb::imperative::python