You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transformation.h 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. /**
  2. * \file imperative/python/src/transformation.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <optional>
  13. #include <string>
  14. #include "pybind11/pybind11.h"
  15. #include "megbrain/imperative/dispatch.h"
  16. #include "megbrain/imperative/transformation.h"
  17. #include "megbrain/imperative/value.h"
  18. #include "megbrain/utils/small_vector.h"
  19. namespace mgb::imperative::python {
  20. struct TransformationManager {
  21. enum Segment {
  22. ModuleTrace,
  23. Grad,
  24. Scalar,
  25. Trace,
  26. Eval,
  27. };
  28. std::array<std::vector<std::shared_ptr<Transformation>>, 5> segments;
  29. template <Segment segment>
  30. void register_at(std::shared_ptr<Transformation> transformation) {
  31. mgb_assert(segment < segments.size());
  32. std::shared_ptr<Transformation> next;
  33. for (size_t i = segment; i < segments.size(); ++i) {
  34. if (!segments[i].empty()) {
  35. next = segments[i].back();
  36. break;
  37. }
  38. }
  39. if (!next) {
  40. transformation->register_at(Transformation::bottom());
  41. } else {
  42. transformation->register_at(next->pos());
  43. }
  44. segments[segment].push_back(transformation);
  45. }
  46. template <Segment segment>
  47. void unregister(std::shared_ptr<Transformation> transformation) noexcept {
  48. mgb_assert(segment < segments.size());
  49. auto iter = std::find(
  50. segments[segment].begin(), segments[segment].end(), transformation);
  51. mgb_assert(iter != segments[segment].end());
  52. transformation->unregister();
  53. segments[segment].erase(iter);
  54. }
  55. static TransformationManager& get_instance() {
  56. static TransformationManager sl_instance;
  57. return sl_instance;
  58. }
  59. };
  60. class PyValue final : public PrimitiveValue<PyValue, pybind11::object> {
  61. public:
  62. using PrimitiveValue::PrimitiveValue;
  63. std::string to_string() const {
  64. return pybind11::str((const pybind11::object&)*this).cast<std::string>();
  65. }
  66. };
  67. } // namespace mgb::imperative::python