You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ordered_set.h 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_UTILS_ORDERED_SET_H_
  19. #define MINDSPORE_CCSRC_UTILS_ORDERED_SET_H_
  20. #include <algorithm>
  21. #include <unordered_map>
  22. #include <vector>
  23. #include <list>
  24. #include <utility>
  25. #include <string>
  26. #include <functional>
  27. #include <memory>
  28. #include "utils/log_adapter.h"
  29. namespace mindspore {
  30. // Implementation of OrderedSet that keeps insertion order
  31. // using map as set, and use list as a sequential container to record elements to keep insertion order
  32. template <class T, class Hash = std::hash<T>, class KeyEqual = std::equal_to<T>>
  33. class OrderedSet {
  34. public:
  35. using element_type = T;
  36. using hasher = Hash;
  37. using equal = KeyEqual;
  38. using sequential_type = std::list<element_type>;
  39. using vector_type = std::vector<element_type>;
  40. using iterator = typename sequential_type::iterator;
  41. using const_iterator = typename sequential_type::const_iterator;
  42. using reverse_iterator = typename sequential_type::reverse_iterator;
  43. using const_reverse_iterator = typename sequential_type::const_reverse_iterator;
  44. using map_type = std::unordered_map<element_type, iterator, hasher, equal>;
  45. using ordered_set_type = OrderedSet<element_type, hasher, equal>;
  46. OrderedSet() = default;
  47. ~OrderedSet() = default;
  48. // OrderedSet use an iterator to list as mapped value to improve the performance of insertion and deletion,
  49. // So copy of OrderedSet should re-build value of the map key to make it pointer to the new list,, thus we use
  50. // traversal to build elements.
  51. OrderedSet(const OrderedSet &os) {
  52. for (auto &item : os.ordered_data_) {
  53. add(item);
  54. }
  55. }
  56. explicit OrderedSet(const sequential_type &other) {
  57. for (auto &item : other) {
  58. add(item);
  59. }
  60. }
  61. // Explicitly construct an OrderedSet use vector
  62. explicit OrderedSet(const vector_type &other) {
  63. for (auto &item : other) {
  64. add(item);
  65. }
  66. }
  67. OrderedSet &operator=(const OrderedSet &os) {
  68. if (this != &os) {
  69. for (auto &item : os.ordered_data_) {
  70. add(item);
  71. }
  72. }
  73. return *this;
  74. }
  75. // Add an element to the OrderedSet, without judging return value
  76. void add(const element_type &e) { (void)insert(e); }
  77. // insert an element to the OrderedSet
  78. std::pair<iterator, bool> insert(const element_type &e) {
  79. iterator empty_itr;
  80. std::pair<element_type, typename map_type::mapped_type> map_pair = std::make_pair(e, empty_itr);
  81. auto result = mapped_data_.insert(map_pair);
  82. auto &seq_idx = result.first->second;
  83. // if insert success;
  84. if (result.second) {
  85. auto it = ordered_data_.insert(ordered_data_.end(), e);
  86. seq_idx = it;
  87. }
  88. return std::pair<iterator, bool>(seq_idx, result.second);
  89. }
  90. // Remove an element, if removed return true, otherwise return false
  91. bool erase(const element_type &e) {
  92. auto pos = mapped_data_.find(e);
  93. if (pos == mapped_data_.end()) {
  94. return false;
  95. }
  96. // erase the sequential data first
  97. (void)ordered_data_.erase(pos->second);
  98. (void)mapped_data_.erase(pos);
  99. return true;
  100. }
  101. // Return the container size
  102. std::size_t size() const { return mapped_data_.size(); }
  103. bool empty() const { return mapped_data_.size() == 0; }
  104. // Return the string contents in orderset, using ordered_data
  105. std::string toString() {
  106. std::ostringstream res;
  107. res << "orderset content:\n";
  108. for (auto &item : ordered_data_) {
  109. res << std::to_string(reinterpret_cast<uintptr_t>(item.get())) << " ";
  110. }
  111. return res.str();
  112. }
  113. // Clear the elements
  114. void clear() {
  115. mapped_data_.clear();
  116. ordered_data_.clear();
  117. }
  118. // Compare two orderedset, if the order is not equal shall return false
  119. bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; }
  120. // Remove and return the first element in the OrderedSet
  121. T pop() {
  122. if (ordered_data_.size() != 0) {
  123. T res = ordered_data_.front();
  124. (void)mapped_data_.erase(res);
  125. (void)ordered_data_.erase(ordered_data_.begin());
  126. return res;
  127. }
  128. MS_LOG(EXCEPTION) << "pop() on empty OrderedSet";
  129. }
  130. T back() {
  131. if (ordered_data_.size() != 0) {
  132. return ordered_data_.back();
  133. }
  134. MS_LOG(EXCEPTION) << "back() on empty OrderedSet";
  135. }
  136. // Return true if there are no common elements
  137. bool is_disjoint(const OrderedSet &other) {
  138. for (auto &item : other.ordered_data_) {
  139. if (mapped_data_.find(item) != mapped_data_.end()) {
  140. return false;
  141. }
  142. }
  143. return true;
  144. }
  145. // Test whether this is subset of other
  146. bool is_subset(const OrderedSet &other) {
  147. for (auto &item : ordered_data_) {
  148. if (other.mapped_data_.find(item) == other.mapped_data_.end()) {
  149. return false;
  150. }
  151. }
  152. return true;
  153. }
  154. // Add elements in other to this orderedset
  155. void update(const OrderedSet &other) {
  156. for (auto &item : other.ordered_data_) {
  157. add(item);
  158. }
  159. }
  160. void update(const std::shared_ptr<OrderedSet> &other) { update(*other); }
  161. void update(const sequential_type &other) {
  162. for (auto &item : other) {
  163. add(item);
  164. }
  165. }
  166. void update(const vector_type &other) {
  167. for (auto &item : other) {
  168. add(item);
  169. }
  170. }
  171. ordered_set_type get_union(const OrderedSet &other) {
  172. ordered_set_type res(ordered_data_);
  173. res.update(other);
  174. return res;
  175. }
  176. // Get the union with other set, this operator may cost time because of copy
  177. ordered_set_type operator|(const OrderedSet &other) { return get_union(other); }
  178. // Return the intersection of two sets
  179. ordered_set_type intersection(const OrderedSet &other) {
  180. ordered_set_type res(ordered_data_);
  181. for (auto &item : ordered_data_) {
  182. if (other.mapped_data_.find(item) == other.mapped_data_.end()) {
  183. (void)res.erase(item);
  184. }
  185. }
  186. return res;
  187. }
  188. ordered_set_type operator&(const OrderedSet &other) { return intersection(other); }
  189. // Return the symmetric difference of two sets
  190. ordered_set_type symmetric_difference(const OrderedSet &other) {
  191. ordered_set_type res(ordered_data_);
  192. for (auto &item : other.ordered_data_) {
  193. if (mapped_data_.find(item) != mapped_data_.end()) {
  194. (void)res.erase(item);
  195. } else {
  196. res.add(item);
  197. }
  198. }
  199. return res;
  200. }
  201. ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); }
  202. // Remove elements which is also in others.
  203. void difference_update(const OrderedSet &other) {
  204. // use vector traversal, to keep ordrer
  205. for (auto &item : other.ordered_data_) {
  206. (void)erase(item);
  207. }
  208. }
  209. void difference_update(const sequential_type &other) {
  210. for (auto &item : other) {
  211. (void)erase(item);
  212. }
  213. }
  214. void difference_update(const vector_type &other) {
  215. for (auto &item : other) {
  216. (void)erase(item);
  217. }
  218. }
  219. // Return the set with elements that are not in the others
  220. ordered_set_type difference(const OrderedSet &other) {
  221. ordered_set_type res(ordered_data_);
  222. res.difference_update(other);
  223. return res;
  224. }
  225. ordered_set_type operator-(const OrderedSet &other) { return difference(other); }
  226. bool contains(const element_type &e) const { return (mapped_data_.find(e) != mapped_data_.end()); }
  227. // Return the count of an element in set
  228. std::size_t count(const element_type &e) const { return mapped_data_.count(e); }
  229. iterator begin() { return ordered_data_.begin(); }
  230. iterator end() { return ordered_data_.end(); }
  231. const_iterator begin() const { return ordered_data_.cbegin(); }
  232. const_iterator end() const { return ordered_data_.cend(); }
  233. const_iterator cbegin() const { return ordered_data_.cbegin(); }
  234. const_iterator cend() const { return ordered_data_.cend(); }
  235. private:
  236. map_type mapped_data_;
  237. sequential_type ordered_data_;
  238. };
  239. } // namespace mindspore
  240. #endif // MINDSPORE_CCSRC_UTILS_ORDERED_SET_H_