/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_H_ #define MINDSPORE_CCSRC_UTILS_BASE_REF_H_ #include #include #include #include #include #include #include #include #include #include "ir/value.h" namespace mindspore { class BaseRef; class VectorRef; class SetRef; class RunFunctionRef; using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; using const_reverse_iterator = std::vector::const_reverse_iterator; using RunFunc = std::function; using RunFuncPtr = std::shared_ptr; template using remove_reference_t = typename std::remove_reference::type; template using remove_const_t = typename std::remove_const::type; template using is_base = std::is_base_of>; template using is_value = std::is_base_of>; template using is_base_ref = std::is_base_of>; iterator ConstIteratorCast(std::vector *v, const_iterator iter); inline std::shared_ptr MakeNode(const std::vector &elements) { return std::make_shared(elements); } inline std::shared_ptr MakeNode(std::initializer_list elements) { return std::make_shared(elements); } // Anfnode, Funcgraph and some not value node class template >::value && is_base::value, int>::type = 0> inline BasePtr MakeNode(const T &v) { return v; } template >::value && !is_base_ref::value, int>::type = 0> inline BasePtr MakeNode(const T &v) { return MakeValue(v); } inline std::shared_ptr MakeNode(const VectorRef &a) { return std::make_shared(std::move(a)); } inline std::shared_ptr MakeNode(const AnfNodePtrList &a) { std::vector ret; (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; }); return std::make_shared(ret); } inline std::shared_ptr MakeNode(const SetRef &a) { return std::make_shared(std::move(a)); } inline std::shared_ptr MakeNode(const RunFuncPtr &a) { return std::make_shared(a); } class BaseRef : public Base { public: BaseRef() : m_ptr(nullptr) {} BaseRef(const BaseRef &other); virtual std::shared_ptr copy() const { return m_ptr; } BaseRef(BaseRef &&other) : Base(other) { m_ptr = other.m_ptr; other.m_ptr = nullptr; } // right reference constructor template ::type, BaseRef>::value, T>::type> BaseRef(T &&t) { // NOLINT m_ptr = MakeNode(t); } ~BaseRef() override { m_ptr = nullptr; } MS_DECLARE_PARENT(BaseRef, Base) bool operator!=(const BaseRef &other) const { return !(operator==(other)); } virtual bool operator==(const BaseRef &other) const; // left reference virtual BaseRef &operator=(const BaseRef &other); // right reference virtual BaseRef &operator=(BaseRef &&other); std::size_t hash() const override { if (m_ptr == nullptr) { MS_LOG(ERROR) << "Invalid m_ptr"; return 0; } return m_ptr->hash(); } std::string ToString() const override; bool is_null() const { return m_ptr == nullptr; } virtual uint32_t type() const; BasePtr m_ptr; // point to real data }; using BaseRefPtr = std::shared_ptr; struct BaseRefHash { std::size_t operator()(const BaseRef &c) const { return c.hash(); } }; struct BaseRefLess { bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); } }; namespace utils { // judge isa relation // examples: isa(handle), isa(handle) template ::value && !is_base_ref::value, int>::type = 0> bool isa(const BaseRef &handle) { if (!handle.m_ptr) { return false; } return handle.m_ptr->isa(); } // noderef isa ptr isa(x) or isa() template ::value, typename T::element_type>::type, typename std::enable_if::value || is_base_ref::value, int>::type = 0> bool isa(const BaseRef &handle) { if (handle.m_ptr == nullptr) { return typeid(handle.m_ptr) == typeid(T); } if (handle.m_ptr->isa()) { return true; } // constptr isa can be true return std::dynamic_pointer_cast(handle.m_ptr) != nullptr; } // isa(handle) template ::type::element_type> bool isa(const BaseRef &handle) { if (handle.m_ptr == nullptr) { return false; } return handle.m_ptr->isa(); } // isa(handle), judge reference or ptr template ::value, int>::type = 0> bool isa(const BaseRef &handle) { static const uint32_t tid = Base::GetTypeId(typeid(T).name()); return handle.IsFromTypeId(tid) || (handle.m_ptr && handle.m_ptr->isa()); } // valueref -> C++ type // cast(handle) template ::value && !is_shared_ptr::value, int>::type = 0> T cast(const BaseRef &handle) { T ret = GetValue(std::static_pointer_cast(handle.m_ptr)); return std::move(ret); } // valueref -> valueref type // cast(handle) template ::value, int>::type = 0> const T &cast(const BaseRef &handle) { if (handle.m_ptr) { return static_cast(*handle.m_ptr); } return std::move(static_cast(handle)); } // valueref -> nodeptr type // cast(handle) template ::value, typename T::element_type>::type, typename std::enable_if::value && std::is_base_of::value, int>::type = 0> T cast(const BaseRef &handle) { if (!handle.m_ptr) { MS_LOG(EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null"; } auto m = handle.m_ptr->cast(); if (nullptr != m) { return m; } return std::static_pointer_cast(handle.m_ptr); } } // namespace utils class VectorRef : public BaseRef { public: using value_type = BaseRef; VectorRef() {} explicit VectorRef(const std::vector &elements) : elements_(elements) {} VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {} // left reference virtual VectorRef &operator=(const VectorRef &other); ~VectorRef() override = default; std::shared_ptr copy() const override { return std::make_shared(elements_); } bool empty() const { return (elements_.size() == 0); } std::size_t size() const { return elements_.size(); } MS_DECLARE_PARENT(VectorRef, BaseRef) const BaseRef &operator[](const std::size_t &dim) const { if (dim >= size()) { MS_LOG(EXCEPTION) << "Out of the size of the tuple."; } return elements_[dim]; } BaseRef &operator[](const std::size_t &dim) { if (dim >= size()) { MS_LOG(EXCEPTION) << "Out of the size of the tuple."; } return elements_[dim]; } uint32_t type() const override { return tid(); } std::string ToString() const override; std::vector &elements() { return elements_; } void clear() { elements_.clear(); } bool operator==(const BaseRef &other) const override; bool operator==(const VectorRef &other) const; void push_back(const BaseRef &value) { elements_.push_back(value); } void push_back(BaseRef &&value) { elements_.push_back(value); } void emplace_back(const BaseRef &value) { elements_.emplace_back(value); } void emplace_back(BaseRef &&value) { elements_.emplace_back(value); } template void insert(const iterator pos, const InputIt first, const InputIt last) { (void)elements_.insert(pos, first, last); } template void insert(const const_iterator cpos, const InputIt first, const InputIt last) { auto pos = ConstIteratorCast(&elements_, cpos); (void)elements_.insert(pos, first, last); } const_iterator begin() const { return elements_.begin(); } const_iterator end() const { return elements_.end(); } const_reverse_iterator rbegin() const { return elements_.rbegin(); } const_reverse_iterator rend() const { return elements_.rend(); } iterator erase(const const_iterator cpos) { auto pos = ConstIteratorCast(&elements_, cpos); return elements_.erase(pos); } iterator erase(const const_iterator cfirst, const const_iterator clast) { auto first = ConstIteratorCast(&elements_, cfirst); auto last = ConstIteratorCast(&elements_, clast); return elements_.erase(first, last); } std::size_t hash() const override { std::stringstream buffer; buffer << ToString(); return std::hash()(buffer.str()); } std::vector elements_; }; using VectorRefPtr = std::shared_ptr; using set_iterator = std::set::iterator; using const_set_iterator = std::set::const_iterator; struct VectorRefHash { std::size_t operator()(const VectorRef &c) const { return c.hash(); } }; class SetRef : public BaseRef { public: SetRef() {} explicit SetRef(const std::set &elements) : elements_(elements) {} SetRef(const std::initializer_list elements) : elements_(elements.begin(), elements.end()) {} SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {} // left reference virtual SetRef &operator=(const SetRef &other); bool operator==(const BaseRef &other) const override; bool operator==(const SetRef &other) const; ~SetRef() override = default; std::shared_ptr copy() const override { return std::make_shared(elements_); } bool empty() const { return (elements_.size() == 0); } std::size_t size() const { return elements_.size(); } MS_DECLARE_PARENT(SetRef, BaseRef) uint32_t type() const override { return tid(); } std::string ToString() const override; std::set &elements() { return elements_; } void clear() { elements_.clear(); } void insert(const BaseRef &elem) { (void)elements_.insert(elem); } const_set_iterator begin() const { return elements_.begin(); } const_set_iterator end() const { return elements_.end(); } template void insert(const InputIt first, const InputIt last) { (void)elements_.insert(first, last); } std::size_t count(const BaseRef &elem) const { return elements_.count(elem); } const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); } std::set elements_; }; using SetRefPtr = std::shared_ptr; class RunFunctionRef : public BaseRef { public: RunFunctionRef() {} explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {} ~RunFunctionRef() override = default; MS_DECLARE_PARENT(RunFunctionRef, BaseRef) uint32_t type() const override { return tid(); } std::string ToString() const override { return std::string("RunFunctionRef"); } bool operator==(const BaseRef &other) const override; bool operator==(const RunFunctionRef &other) const; RunFuncPtr func_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_NODE_REF_H_