You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

treap.h 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_UTIL_TREAP_H_
  17. #define DATASET_UTIL_TREAP_H_
  18. #include <functional>
  19. #include <iterator>
  20. #include <stack>
  21. #include <utility>
  22. #include <vector>
  23. namespace mindspore {
  24. namespace dataset {
  25. // A treap is a combination of binary search tree and heap. Each key is given a priority. The priority
  26. // for any non-leaf node is greater than or equal to the priority of its children.
  27. // @tparam K
  28. // Data type of key
  29. // @tparam P
  30. // Data type of priority
  31. // @tparam KC
  32. // Class to compare key. Default to std::less
  33. // @tparam KP
  34. // Class to compare priority. Default to std:less
  35. template <typename K, typename P, typename KC = std::less<K>, typename KP = std::less<P>>
  36. class Treap {
  37. public:
  38. using key_type = K;
  39. using priority_type = P;
  40. using key_compare = KC;
  41. using priority_compare = KP;
  42. struct NodeValue {
  43. key_type key;
  44. priority_type priority;
  45. };
  46. class TreapNode {
  47. public:
  48. TreapNode() : left(nullptr), right(nullptr) {}
  49. ~TreapNode() {
  50. left = nullptr;
  51. right = nullptr;
  52. }
  53. NodeValue nv;
  54. TreapNode *left;
  55. TreapNode *right;
  56. };
  57. // search API
  58. // @param k
  59. // key to search for
  60. // @return
  61. // a pair is returned. The 2nd value of type bool indicate if the search is successful.
  62. // If true, the first value of the pair contains the key and the priority.
  63. std::pair<NodeValue, bool> Search(key_type k) const {
  64. auto *n = Search(root_, k);
  65. if (n != nullptr) {
  66. return std::make_pair(n->nv, true);
  67. } else {
  68. return std::make_pair(NodeValue{key_type(), priority_type()}, false);
  69. }
  70. }
  71. // @return
  72. // Return the root of the heap. It has the highest priority. But not necessarily the first key.
  73. std::pair<NodeValue, bool> Top() const {
  74. if (root_) {
  75. return std::make_pair(root_->nv, true);
  76. } else {
  77. return std::make_pair(NodeValue{key_type(), priority_type()}, false);
  78. }
  79. }
  80. // Remove the root of the heap.
  81. void Pop() {
  82. if (root_) {
  83. DeleteKey(root_->nv.key);
  84. }
  85. }
  86. // Insert API.
  87. // @param k
  88. // The key to insert.
  89. // @param p
  90. // The priority of the key.
  91. void Insert(key_type k, priority_type p) { root_ = Insert(root_, k, p); }
  92. // Delete a key.
  93. // @param k
  94. void DeleteKey(key_type k) { root_ = DeleteNode(root_, k); }
  95. Treap() : root_(nullptr), count_(0) { free_list_.reserve(kResvSz); }
  96. ~Treap() noexcept {
  97. DeleteTreap(root_);
  98. while (!free_list_.empty()) {
  99. TreapNode *n = free_list_.back();
  100. delete (n);
  101. free_list_.pop_back();
  102. }
  103. }
  104. class iterator : public std::iterator<std::forward_iterator_tag, TreapNode> {
  105. public:
  106. explicit iterator(Treap *tr) : tr_(tr), cur_(nullptr) {
  107. if (tr_) {
  108. cur_ = tr_->root_;
  109. while (cur_) {
  110. stack_.push(cur_);
  111. cur_ = cur_->left;
  112. }
  113. }
  114. if (!stack_.empty()) {
  115. cur_ = stack_.top();
  116. } else {
  117. cur_ = nullptr;
  118. }
  119. }
  120. ~iterator() {
  121. tr_ = nullptr;
  122. cur_ = nullptr;
  123. }
  124. NodeValue &operator*() { return cur_->nv; }
  125. NodeValue *operator->() { return &(cur_->nv); }
  126. const TreapNode &operator*() const { return *cur_; }
  127. const TreapNode *operator->() const { return cur_; }
  128. bool operator==(const iterator &rhs) const { return cur_ == rhs.cur_; }
  129. bool operator!=(const iterator &rhs) const { return cur_ != rhs.cur_; }
  130. // Prefix increment
  131. iterator &operator++() {
  132. if (cur_) {
  133. stack_.pop();
  134. if (cur_->right) {
  135. TreapNode *n = cur_->right;
  136. while (n) {
  137. stack_.push(n);
  138. n = n->left;
  139. }
  140. }
  141. }
  142. if (!stack_.empty()) {
  143. cur_ = stack_.top();
  144. } else {
  145. cur_ = nullptr;
  146. }
  147. return *this;
  148. }
  149. // Postfix increment
  150. iterator operator++(int junk) {
  151. iterator tmp(*this);
  152. if (cur_) {
  153. stack_.pop();
  154. if (cur_->right) {
  155. TreapNode *n = cur_->right;
  156. while (n) {
  157. stack_.push(n);
  158. n = n->left;
  159. }
  160. }
  161. }
  162. if (!stack_.empty()) {
  163. cur_ = stack_.top();
  164. } else {
  165. cur_ = nullptr;
  166. }
  167. return tmp;
  168. }
  169. private:
  170. Treap *tr_;
  171. TreapNode *cur_;
  172. std::stack<TreapNode *> stack_;
  173. };
  174. class const_iterator : public std::iterator<std::forward_iterator_tag, TreapNode> {
  175. public:
  176. explicit const_iterator(const Treap *tr) : tr_(tr), cur_(nullptr) {
  177. if (tr_) {
  178. cur_ = tr_->root_;
  179. while (cur_) {
  180. stack_.push(cur_);
  181. cur_ = cur_->left;
  182. }
  183. }
  184. if (!stack_.empty()) {
  185. cur_ = stack_.top();
  186. } else {
  187. cur_ = nullptr;
  188. }
  189. }
  190. ~const_iterator() {
  191. tr_ = nullptr;
  192. cur_ = nullptr;
  193. }
  194. const NodeValue &operator*() const { return cur_->nv; }
  195. const NodeValue *operator->() const { return &(cur_->nv); }
  196. bool operator==(const const_iterator &rhs) const { return cur_ == rhs.cur_; }
  197. bool operator!=(const const_iterator &rhs) const { return cur_ != rhs.cur_; }
  198. // Prefix increment
  199. const_iterator &operator++() {
  200. if (cur_) {
  201. stack_.pop();
  202. if (cur_->right) {
  203. TreapNode *n = cur_->right;
  204. while (n) {
  205. stack_.push(n);
  206. n = n->left;
  207. }
  208. }
  209. }
  210. if (!stack_.empty()) {
  211. cur_ = stack_.top();
  212. } else {
  213. cur_ = nullptr;
  214. }
  215. return *this;
  216. }
  217. // Postfix increment
  218. const_iterator operator++(int junk) {
  219. iterator tmp(*this);
  220. if (cur_) {
  221. stack_.pop();
  222. if (cur_->right) {
  223. TreapNode *n = cur_->right;
  224. while (n) {
  225. stack_.push(n);
  226. n = n->left;
  227. }
  228. }
  229. }
  230. if (!stack_.empty()) {
  231. cur_ = stack_.top();
  232. } else {
  233. cur_ = nullptr;
  234. }
  235. return tmp;
  236. }
  237. private:
  238. const Treap *tr_;
  239. TreapNode *cur_;
  240. std::stack<TreapNode *> stack_;
  241. };
  242. iterator begin() { return iterator(this); }
  243. iterator end() { return iterator(nullptr); }
  244. const_iterator begin() const { return const_iterator(this); }
  245. const_iterator end() const { return const_iterator(nullptr); }
  246. const_iterator cbegin() { return const_iterator(this); }
  247. const_iterator cend() { return const_iterator(nullptr); }
  248. bool empty() { return root_ == nullptr; }
  249. size_t size() { return count_; }
  250. private:
  251. TreapNode *NewNode() {
  252. TreapNode *n = nullptr;
  253. if (!free_list_.empty()) {
  254. n = free_list_.back();
  255. free_list_.pop_back();
  256. new (n) TreapNode();
  257. } else {
  258. n = new TreapNode();
  259. }
  260. return n;
  261. }
  262. void FreeNode(TreapNode *n) { free_list_.push_back(n); }
  263. void DeleteTreap(TreapNode *n) noexcept {
  264. if (n == nullptr) {
  265. return;
  266. }
  267. TreapNode *x = n->left;
  268. TreapNode *y = n->right;
  269. delete (n);
  270. DeleteTreap(x);
  271. DeleteTreap(y);
  272. }
  273. TreapNode *RightRotate(TreapNode *y) {
  274. TreapNode *x = y->left;
  275. TreapNode *T2 = x->right;
  276. x->right = y;
  277. y->left = T2;
  278. return x;
  279. }
  280. TreapNode *LeftRotate(TreapNode *x) {
  281. TreapNode *y = x->right;
  282. TreapNode *T2 = y->left;
  283. y->left = x;
  284. x->right = T2;
  285. return y;
  286. }
  287. TreapNode *Search(TreapNode *n, key_type k) const {
  288. key_compare keyCompare;
  289. if (n == nullptr) {
  290. return n;
  291. } else if (keyCompare(k, n->nv.key)) {
  292. return Search(n->left, k);
  293. } else if (keyCompare(n->nv.key, k)) {
  294. return Search(n->right, k);
  295. } else {
  296. return n;
  297. }
  298. }
  299. TreapNode *Insert(TreapNode *n, key_type k, priority_type p) {
  300. key_compare keyCompare;
  301. priority_compare priorityCompare;
  302. if (n == nullptr) {
  303. n = NewNode();
  304. n->nv.key = k;
  305. n->nv.priority = p;
  306. count_++;
  307. return n;
  308. }
  309. if (keyCompare(k, n->nv.key)) {
  310. n->left = Insert(n->left, k, p);
  311. if (priorityCompare(n->nv.priority, n->left->nv.priority)) {
  312. n = RightRotate(n);
  313. }
  314. } else if (keyCompare(n->nv.key, k)) {
  315. n->right = Insert(n->right, k, p);
  316. if (priorityCompare(n->nv.priority, n->right->nv.priority)) {
  317. n = LeftRotate(n);
  318. }
  319. } else {
  320. // If we insert the same key again, do nothing.
  321. return n;
  322. }
  323. return n;
  324. }
  325. TreapNode *DeleteNode(TreapNode *n, key_type k) {
  326. key_compare keyCompare;
  327. priority_compare priorityCompare;
  328. if (n == nullptr) {
  329. return n;
  330. }
  331. if (keyCompare(k, n->nv.key)) {
  332. n->left = DeleteNode(n->left, k);
  333. } else if (keyCompare(n->nv.key, k)) {
  334. n->right = DeleteNode(n->right, k);
  335. } else if (n->left == nullptr) {
  336. TreapNode *t = n;
  337. n = n->right;
  338. FreeNode(t);
  339. count_--;
  340. } else if (n->right == nullptr) {
  341. TreapNode *t = n;
  342. n = n->left;
  343. FreeNode(t);
  344. count_--;
  345. } else if (priorityCompare(n->left->nv.priority, n->right->nv.priority)) {
  346. n = LeftRotate(n);
  347. n->left = DeleteNode(n->left, k);
  348. } else {
  349. n = RightRotate(n);
  350. n->right = DeleteNode(n->right, k);
  351. }
  352. return n;
  353. }
  354. static constexpr int kResvSz = 512;
  355. TreapNode *root_;
  356. size_t count_;
  357. std::vector<TreapNode *> free_list_;
  358. };
  359. } // namespace dataset
  360. } // namespace mindspore
  361. #endif // DATASET_UTIL_TREAP_H_