You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transformation.h 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. /**
  2. * \file imperative/python/src/transformation.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/imperative/transformation.h"
  13. namespace mgb::imperative::python {
  14. struct TransformationManager {
  15. enum Segment {
  16. ModuleTrace,
  17. Grad,
  18. Scalar,
  19. Trace,
  20. Eval,
  21. };
  22. std::array<std::vector<std::shared_ptr<Transformation>>, 5> segments;
  23. template <Segment segment>
  24. void register_at(std::shared_ptr<Transformation> transformation) {
  25. mgb_assert(segment < segments.size());
  26. std::shared_ptr<Transformation> next;
  27. for (size_t i = segment; i < segments.size(); ++i) {
  28. if (!segments[i].empty()) {
  29. next = segments[i].back();
  30. break;
  31. }
  32. }
  33. if (!next) {
  34. transformation->register_at(Transformation::bottom());
  35. } else {
  36. transformation->register_at(next->pos());
  37. }
  38. segments[segment].push_back(transformation);
  39. }
  40. template <Segment segment>
  41. void unregister(std::shared_ptr<Transformation> transformation) noexcept {
  42. mgb_assert(segment < segments.size());
  43. auto iter = std::find(
  44. segments[segment].begin(), segments[segment].end(), transformation);
  45. mgb_assert(iter != segments[segment].end());
  46. transformation->unregister();
  47. segments[segment].erase(iter);
  48. }
  49. static TransformationManager& get_instance() {
  50. static TransformationManager sl_instance;
  51. return sl_instance;
  52. }
  53. };
  54. } // namespace mgb::imperative::python