You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

union_find_set.h 2.5 kB

5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_
  19. #define MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_
  20. #include <map>
  21. #include <set>
  22. namespace mindspore {
  23. template <class T>
  24. class UnionFindSet {
  25. public:
  26. UnionFindSet() : union_find_set_() {}
  27. ~UnionFindSet() = default;
  28. void Add(const T &elem) {
  29. if (union_find_set_.find(elem) != union_find_set_.end()) {
  30. return;
  31. }
  32. union_find_set_[elem] = elem;
  33. }
  34. T Find(const T &key) {
  35. T key_parent = key;
  36. auto iter = union_find_set_.find(key_parent);
  37. if (iter == union_find_set_.end()) {
  38. MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent;
  39. }
  40. while (key_parent != iter->second) {
  41. key_parent = iter->second;
  42. iter = union_find_set_.find(key_parent);
  43. if (iter == union_find_set_.end()) {
  44. MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent;
  45. }
  46. }
  47. T tmp = key;
  48. T tmp_parent;
  49. while (tmp != key_parent) {
  50. iter = union_find_set_.find(tmp);
  51. if (iter == union_find_set_.end()) {
  52. MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << tmp;
  53. }
  54. tmp_parent = iter->second;
  55. union_find_set_[tmp] = key_parent;
  56. tmp = tmp_parent;
  57. }
  58. return key_parent;
  59. }
  60. void Union(const T &left, const T &right) { union_find_set_[Find(left)] = Find(right); }
  61. std::map<T, std::set<T>> GetSets() {
  62. std::map<T, std::set<T>> result;
  63. for (auto &iter : union_find_set_) {
  64. (void)Find(iter.first);
  65. }
  66. for (auto &iter : union_find_set_) {
  67. T parent = Find(iter.first);
  68. result[parent].insert(iter.first);
  69. }
  70. return result;
  71. }
  72. private:
  73. std::map<T, T> union_find_set_;
  74. };
  75. } // namespace mindspore
  76. #endif // MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_