/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_ #define MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_ #include #include namespace mindspore { template class UnionFindSet { public: UnionFindSet() : union_find_set_() {} ~UnionFindSet() = default; void Add(const T &elem) { if (union_find_set_.find(elem) != union_find_set_.end()) { return; } union_find_set_[elem] = elem; } T Find(const T &key) { T key_parent = key; auto iter = union_find_set_.find(key_parent); if (iter == union_find_set_.end()) { MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent; } while (key_parent != iter->second) { key_parent = iter->second; iter = union_find_set_.find(key_parent); if (iter == union_find_set_.end()) { MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent; } } T tmp = key; T tmp_parent; while (tmp != key_parent) { iter = union_find_set_.find(tmp); if (iter == union_find_set_.end()) { MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << tmp; } tmp_parent = iter->second; union_find_set_[tmp] = key_parent; tmp = tmp_parent; } return key_parent; } void Union(const T &left, const T &right) { union_find_set_[Find(left)] = Find(right); } std::map> GetSets() { std::map> result; for (auto &iter : union_find_set_) { (void)Find(iter.first); } for (auto &iter : union_find_set_) { T parent = Find(iter.first); result[parent].insert(iter.first); } return result; } private: std::map union_find_set_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_