You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

intrusive_list.h 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. /**
  2. * \file imperative/python/src/intrusive_list.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/utils/metahelper.h"
  12. namespace mgb::imperative::python::intrusive_list {
  13. // copy policy
  14. struct after_t {};
  15. struct before_t {};
  16. struct disable_t {};
  17. template <typename T> struct Tail;
  18. // invariant: next->prev == this
  19. template <typename T>
  20. struct Head {
  21. Tail<T>* next;
  22. Head(Tail<T>* node = nullptr) : next(node) {}
  23. Head(const Head<T>&) = delete;
  24. Head<T>& operator=(const Head<T>&) = delete;
  25. Head(Head<T>&& rhs) : next(rhs.next) {
  26. rhs.next = nullptr;
  27. if (next) {
  28. next->prev = this;
  29. }
  30. }
  31. Head<T>& operator=(Head<T>&& rhs) {
  32. mgb_assert(!next);
  33. next = rhs.next;
  34. rhs.next = nullptr;
  35. if (next) {
  36. next->prev = this;
  37. }
  38. return *this;
  39. }
  40. ~Head() {
  41. if (next) {
  42. next->prev = nullptr;
  43. }
  44. }
  45. };
  46. // invariant: prev->next == this
  47. template <typename T>
  48. struct Tail {
  49. Head<T>* prev;
  50. Tail(Head<T>* node = nullptr) : prev(node) {}
  51. Tail(const Tail<T>&) = delete;
  52. Tail<T>& operator=(const Tail<T>&) = delete;
  53. Tail(Tail<T>&& rhs) : prev(rhs.prev) {
  54. rhs.prev = nullptr;
  55. if (prev) {
  56. prev->next = this;
  57. }
  58. }
  59. Tail<T>& operator=(Tail<T>&& rhs) {
  60. mgb_assert(!prev);
  61. prev = rhs.prev;
  62. rhs.prev = nullptr;
  63. if (prev) {
  64. prev->next = this;
  65. }
  66. return *this;
  67. }
  68. ~Tail() {
  69. if (prev) {
  70. prev->next = nullptr;
  71. }
  72. }
  73. };
  74. template <typename T, typename policy> struct Node;
  75. template <typename T>
  76. class Iterator {
  77. T* ptr;
  78. void inc() {ptr = static_cast<T*>(ptr->Head<T>::next);}
  79. void dec() {ptr = static_cast<T*>(ptr->Head<T>::prev);}
  80. public:
  81. Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {}
  82. Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {}
  83. template<typename policy>
  84. Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {}
  85. T& operator*() {return *static_cast<T*>(ptr);}
  86. T* operator->() {return static_cast<T*>(ptr);}
  87. operator bool() {return ptr;}
  88. bool operator==(const Iterator<T>& rhs) {return ptr == rhs.ptr;}
  89. Iterator& operator++() {inc(); return *this;}
  90. Iterator& operator--() {dec(); return *this;}
  91. Iterator operator++(int) {auto ret = *this; inc(); return ret;}
  92. Iterator operator--(int) {auto ret = *this; dec(); return ret;}
  93. };
  94. // Node in a doubly linked list. Unlike std::list, nodes are not owned by a container.
  95. // Instead, nodes may join or leave a list freely.
  96. // NOTE: Derived classes have to explicitly declare copy / assignment as default,
  97. // otherwise the compiler generated version would use the const T& signature,
  98. // which is deleted.
  99. template <typename T = void, typename policy = disable_t>
  100. struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>,
  101. Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> {
  102. private:
  103. using this_t = Node<T, policy>;
  104. using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>;
  105. public:
  106. using head_t = Head<U>;
  107. using tail_t = Tail<U>;
  108. using head_t::next;
  109. using tail_t::prev;
  110. Node() = default;
  111. Node(const this_t&) = delete;
  112. this_t& operator=(const this_t&) = delete;
  113. //! constructed node is inserted after the input node
  114. Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) {
  115. node.next = this;
  116. if (next) {
  117. next->prev = this;
  118. }
  119. }
  120. //! constructed node is inserted before the input node
  121. Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) {
  122. node.prev = this;
  123. if (prev) {
  124. prev->next = this;
  125. }
  126. }
  127. Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) {
  128. rhs.prev = nullptr;
  129. rhs.next = nullptr;
  130. if (prev) {
  131. prev->next = this;
  132. }
  133. if (next) {
  134. next->prev = this;
  135. }
  136. }
  137. Node& operator=(this_t&& rhs) {
  138. unlink();
  139. prev = rhs.prev;
  140. next = rhs.next;
  141. rhs.prev = nullptr;
  142. rhs.next = nullptr;
  143. if (prev) {
  144. prev->next = this;
  145. }
  146. if (next) {
  147. next->prev = this;
  148. }
  149. return *this;
  150. }
  151. template<typename p = policy,
  152. typename = std::enable_if_t<std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
  153. Node(this_t& rhs) : Node(policy{}, rhs) {}
  154. template<typename p = policy,
  155. typename = std::enable_if_t<std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
  156. this_t& operator=(this_t& rhs) {
  157. insert(policy{}, rhs);
  158. return *this;
  159. }
  160. void unlink() {
  161. if (prev) {
  162. prev->next = next;
  163. }
  164. if (next) {
  165. next->prev = prev;
  166. }
  167. prev = nullptr;
  168. next = nullptr;
  169. }
  170. //! this node is unlinked from its list and inserted after the input node
  171. void insert(after_t, head_t& node) {
  172. unlink();
  173. prev = &node;
  174. next = node.next;
  175. node.next = this;
  176. if (next) {
  177. next->prev = this;
  178. }
  179. }
  180. //! this node is unlinked from its list and inserted before the input node
  181. void insert(before_t, tail_t& node) {
  182. unlink();
  183. next = &node;
  184. prev = node.prev;
  185. node.prev = this;
  186. if (prev) {
  187. prev->next = this;
  188. }
  189. }
  190. void insert_before(tail_t& node) {insert(before_t{}, node);}
  191. void insert_after(head_t& node) {insert(after_t{}, node);}
  192. ~Node() {
  193. unlink();
  194. }
  195. };
  196. } // namespace mgb::imperative::python::intrusive_list

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台