You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transformation.h 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #pragma once
  2. #include <optional>
  3. #include <string>
  4. #include "pybind11/pybind11.h"
  5. #include "megbrain/imperative/dispatch.h"
  6. #include "megbrain/imperative/transformation.h"
  7. #include "megbrain/imperative/utils/helper.h"
  8. #include "megbrain/imperative/value.h"
  9. #include "megbrain/utils/small_vector.h"
  10. namespace mgb::imperative::python {
  11. struct TransformationManager {
  12. public:
  13. enum Segment {
  14. ModuleTrace,
  15. GroupComm,
  16. DTypePromote,
  17. DimExpansion,
  18. Format,
  19. Grad,
  20. Scalar,
  21. Symbol,
  22. Trace,
  23. Eval,
  24. };
  25. std::array<std::vector<std::shared_ptr<Transformation>>, 10> segments;
  26. private:
  27. template <Segment segment>
  28. void unregister(std::shared_ptr<Transformation> transformation) noexcept {
  29. mgb_assert(segment < segments.size());
  30. auto iter = std::find(
  31. segments[segment].begin(), segments[segment].end(), transformation);
  32. mgb_assert(iter != segments[segment].end());
  33. transformation->unregister();
  34. segments[segment].erase(iter);
  35. }
  36. public:
  37. template <Segment segment>
  38. [[nodiscard]] std::unique_ptr<CleanupGuard<>> register_at(
  39. std::shared_ptr<Transformation> transformation) {
  40. mgb_assert(segment < segments.size());
  41. std::shared_ptr<Transformation> next;
  42. for (size_t i = segment; i < segments.size(); ++i) {
  43. if (!segments[i].empty()) {
  44. next = segments[i].back();
  45. break;
  46. }
  47. }
  48. if (!next) {
  49. transformation->register_at(Transformation::bottom());
  50. } else {
  51. transformation->register_at(next->pos());
  52. }
  53. segments[segment].push_back(transformation);
  54. return std::make_unique<CleanupGuard<>>(
  55. [this, transformation]() { unregister<segment>(transformation); });
  56. }
  57. static TransformationManager& get_instance() {
  58. static TransformationManager sl_instance;
  59. return sl_instance;
  60. }
  61. };
  62. class PyValue final : public PrimitiveValue<PyValue, pybind11::object> {
  63. public:
  64. using PrimitiveValue::PrimitiveValue;
  65. std::string to_string() const {
  66. return pybind11::str((const pybind11::object&)*this).cast<std::string>();
  67. }
  68. };
  69. } // namespace mgb::imperative::python