You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transformation.h 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. /**
  2. * \file imperative/python/src/transformation.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <optional>
  13. #include <string>
  14. #include "pybind11/pybind11.h"
  15. #include "megbrain/imperative/dispatch.h"
  16. #include "megbrain/imperative/transformation.h"
  17. #include "megbrain/imperative/value.h"
  18. #include "megbrain/utils/small_vector.h"
  19. namespace mgb::imperative::python {
  20. struct TransformationManager {
  21. enum Segment {
  22. ModuleTrace,
  23. DTypePromote,
  24. DimExpansion,
  25. Grad,
  26. Scalar,
  27. Trace,
  28. Eval,
  29. };
  30. std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments;
  31. template <Segment segment>
  32. void register_at(std::shared_ptr<Transformation> transformation) {
  33. mgb_assert(segment < segments.size());
  34. std::shared_ptr<Transformation> next;
  35. for (size_t i = segment; i < segments.size(); ++i) {
  36. if (!segments[i].empty()) {
  37. next = segments[i].back();
  38. break;
  39. }
  40. }
  41. if (!next) {
  42. transformation->register_at(Transformation::bottom());
  43. } else {
  44. transformation->register_at(next->pos());
  45. }
  46. segments[segment].push_back(transformation);
  47. }
  48. template <Segment segment>
  49. void unregister(std::shared_ptr<Transformation> transformation) noexcept {
  50. mgb_assert(segment < segments.size());
  51. auto iter = std::find(
  52. segments[segment].begin(), segments[segment].end(), transformation);
  53. mgb_assert(iter != segments[segment].end());
  54. transformation->unregister();
  55. segments[segment].erase(iter);
  56. }
  57. static TransformationManager& get_instance() {
  58. static TransformationManager sl_instance;
  59. return sl_instance;
  60. }
  61. };
  62. class PyValue final : public PrimitiveValue<PyValue, pybind11::object> {
  63. public:
  64. using PrimitiveValue::PrimitiveValue;
  65. std::string to_string() const {
  66. return pybind11::str((const pybind11::object&)*this).cast<std::string>();
  67. }
  68. };
  69. } // namespace mgb::imperative::python