You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_rt.h 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #pragma once
  2. #include "./helper.h"
  3. #include <future>
  4. #include <memory>
  5. #include <mutex>
  6. #include "megbrain/graph.h"
  7. #include "megbrain/plugin/opr_footprint.h"
  8. namespace py = pybind11;
  9. extern py::object Py_Varnode;
  10. template <typename T>
  11. class GraphNodePtr {
  12. std::shared_ptr<mgb::cg::ComputingGraph> m_graph;
  13. T* m_node;
  14. public:
  15. GraphNodePtr(T* node)
  16. : m_graph(node ? node->owner_graph()->shared_from_this() : nullptr),
  17. m_node(node) {}
  18. T* operator->() { return m_node; }
  19. T& operator*() { return *m_node; }
  20. operator bool() { return m_node; }
  21. T* get() { return m_node; }
  22. };
  23. PYBIND11_DECLARE_HOLDER_TYPE(T, GraphNodePtr<T>, true);
  24. class RendezvousBase {
  25. public:
  26. virtual ~RendezvousBase() = default;
  27. virtual void set_exception(std::exception_ptr p) = 0;
  28. };
  29. template <typename R>
  30. class Rendezvous : public RendezvousBase {
  31. std::mutex m_lock;
  32. int m_read_ahead = 0;
  33. bool m_drop_next = false;
  34. std::promise<R> m_promise;
  35. Rendezvous() = default;
  36. struct Factory {
  37. template <typename... Args>
  38. static auto make_rendezvous(Args&&... args) {
  39. auto ptr = new Rendezvous<R>{std::forward(args)...};
  40. return std::shared_ptr<Rendezvous<R>>(ptr);
  41. }
  42. };
  43. public:
  44. Rendezvous(const Rendezvous& rhs) = delete;
  45. Rendezvous(Rendezvous&& rhs) = delete;
  46. Rendezvous& operator=(const Rendezvous& rhs) = delete;
  47. template <typename... Args>
  48. static auto make(Args&&... args) {
  49. return Factory::make_rendezvous(std::forward<Args>(args)...);
  50. }
  51. R get() {
  52. std::future<R> f;
  53. {
  54. MGB_LOCK_GUARD(m_lock);
  55. mgb_assert(m_read_ahead <= 0);
  56. mgb_assert(m_read_ahead >= -1);
  57. f = m_promise.get_future();
  58. if (m_read_ahead == -1) {
  59. m_promise = {};
  60. }
  61. ++m_read_ahead;
  62. }
  63. return f.get();
  64. }
  65. void drop() {
  66. MGB_LOCK_GUARD(m_lock);
  67. mgb_assert(m_read_ahead <= 0);
  68. mgb_assert(m_read_ahead >= -1);
  69. if (m_read_ahead == -1) {
  70. m_promise = {};
  71. } else {
  72. m_drop_next = true;
  73. }
  74. ++m_read_ahead;
  75. }
  76. template <typename T>
  77. void set(T&& value) {
  78. MGB_LOCK_GUARD(m_lock);
  79. mgb_assert(m_read_ahead >= 0);
  80. mgb_assert(m_read_ahead <= 1);
  81. if (m_drop_next) {
  82. m_drop_next = false;
  83. } else {
  84. m_promise.set_value(std::forward<T>(value));
  85. }
  86. if (m_read_ahead == 1) {
  87. m_promise = {};
  88. }
  89. --m_read_ahead;
  90. }
  91. void reset() {
  92. MGB_LOCK_GUARD(m_lock);
  93. m_promise = {};
  94. m_read_ahead = 0;
  95. m_drop_next = false;
  96. }
  97. void set_exception(std::exception_ptr e) {
  98. if (e) {
  99. MGB_LOCK_GUARD(m_lock);
  100. if (m_read_ahead >= 0) {
  101. mgb_assert(m_read_ahead <= 1);
  102. if (m_drop_next) {
  103. m_drop_next = false;
  104. } else {
  105. m_promise.set_exception(e);
  106. }
  107. if (m_read_ahead == 1) {
  108. m_promise = {};
  109. }
  110. --m_read_ahead;
  111. } else {
  112. mgb_assert(m_read_ahead == -1);
  113. // TODO: maybe exception should be ignored
  114. // if value was already set ?
  115. m_promise.set_exception(e);
  116. }
  117. }
  118. }
  119. };
  120. void init_graph_rt(pybind11::module m);