| @@ -160,7 +160,7 @@ private: | |||||
| template <typename TItem> | template <typename TItem> | ||||
| void register_converter() { | void register_converter() { | ||||
| m_table[typeid(TItem)] = [](const any_t& input) { | m_table[typeid(TItem)] = [](const any_t& input) { | ||||
| return variant_t(*input.as<TItem>()); | |||||
| return variant_t(input.cast<TItem>()); | |||||
| }; | }; | ||||
| } | } | ||||
| @@ -11,7 +11,6 @@ | |||||
| #pragma once | #pragma once | ||||
| #include <any> | |||||
| #include <bitset> | #include <bitset> | ||||
| #include <chrono> | #include <chrono> | ||||
| #include <deque> | #include <deque> | ||||
| @@ -28,6 +27,7 @@ | |||||
| #include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
| #include "megbrain/imperative/physical_tensor.h" | #include "megbrain/imperative/physical_tensor.h" | ||||
| #include "megbrain/imperative/utils/any.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -51,48 +51,6 @@ public: | |||||
| static std::shared_ptr<CompNode::Event> record_device(CompNode device); | static std::shared_ptr<CompNode::Event> record_device(CompNode device); | ||||
| }; | }; | ||||
| class AnyPtr { | |||||
| public: | |||||
| struct Deleter { | |||||
| void* object; | |||||
| void (*method)(void*, void*); | |||||
| void operator()(void* ptr) { method(object, ptr); } | |||||
| }; | |||||
| private: | |||||
| using holder_t = std::unique_ptr<void, Deleter>; | |||||
| const std::type_info* m_type = nullptr; | |||||
| holder_t m_holder = nullptr; | |||||
| public: | |||||
| AnyPtr() = default; | |||||
| template < | |||||
| typename T, | |||||
| typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, AnyPtr>>> | |||||
| explicit AnyPtr(T* value, Deleter deleter) { | |||||
| m_type = &typeid(T); | |||||
| m_holder = {value, deleter}; | |||||
| } | |||||
| template <typename T> | |||||
| T* as() { | |||||
| mgb_assert(is_exactly<T>(), "type mismatch"); | |||||
| return reinterpret_cast<T*>(m_holder.get()); | |||||
| } | |||||
| template <typename T> | |||||
| const T* as() const { | |||||
| mgb_assert(is_exactly<T>(), "type mismatch"); | |||||
| return reinterpret_cast<const T*>(m_holder.get()); | |||||
| } | |||||
| template <typename T> | |||||
| bool is_exactly() const { | |||||
| return std::type_index{typeid(T)} == std::type_index{*m_type}; | |||||
| } | |||||
| const std::type_info& type() const { return *m_type; } | |||||
| bool operator==(std::nullptr_t nptr) const { return m_holder == nullptr; } | |||||
| operator bool() const { return m_holder != nullptr; } | |||||
| }; | |||||
| class Profiler { | class Profiler { | ||||
| public: | public: | ||||
| struct Record { | struct Record { | ||||
| @@ -128,7 +86,6 @@ private: | |||||
| std::thread::id m_thread_id; | std::thread::id m_thread_id; | ||||
| std::vector<Record> m_records; | std::vector<Record> m_records; | ||||
| std::atomic<Status> m_status = Running; | std::atomic<Status> m_status = Running; | ||||
| std::unordered_map<std::type_index, AnyPtr> m_mem_pools; | |||||
| static std::vector<entry_t> sm_records; | static std::vector<entry_t> sm_records; | ||||
| static options_t sm_profile_options; | static options_t sm_profile_options; | ||||
| @@ -161,42 +118,21 @@ public: | |||||
| return *tm_profiler; | return *tm_profiler; | ||||
| } | } | ||||
| template <typename T> | |||||
| static MemPool<T>& get_mem_pool() { | |||||
| thread_local MemPool<T>* t_pool = nullptr; | |||||
| if (t_pool == nullptr) { | |||||
| auto& pool = get_instance().m_mem_pools[typeid(MemPool<T>)]; | |||||
| if (pool == nullptr) { | |||||
| pool = | |||||
| AnyPtr(new MemPool<T>(), | |||||
| {nullptr, [](void*, void* ptr) { | |||||
| delete reinterpret_cast<MemPool<T>*>(ptr); | |||||
| }}); | |||||
| } | |||||
| t_pool = pool.as<MemPool<T>>(); | |||||
| } | |||||
| return *t_pool; | |||||
| } | |||||
| static uint64_t next_id() { return sm_last_id++; } | static uint64_t next_id() { return sm_last_id++; } | ||||
| template <typename T, typename... TArgs> | template <typename T, typename... TArgs> | ||||
| static uint64_t record(TArgs&&... args) { | static uint64_t record(TArgs&&... args) { | ||||
| auto& profiler = get_instance(); | auto& profiler = get_instance(); | ||||
| auto& mem_pool = get_mem_pool<T>(); | |||||
| // auto& mem_pool = get_mem_pool<T>(); | |||||
| if constexpr (sm_debug) { | if constexpr (sm_debug) { | ||||
| Status expected = Running; | Status expected = Running; | ||||
| mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording)); | mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording)); | ||||
| } | } | ||||
| uint64_t id = next_id(); | uint64_t id = next_id(); | ||||
| profiler::Time time = sm_timer.record_host(); | profiler::Time time = sm_timer.record_host(); | ||||
| auto deleter = [](void* obj, void* ptr) { | |||||
| reinterpret_cast<MemPool<T>*>(obj)->free(reinterpret_cast<T*>(ptr)); | |||||
| }; | |||||
| profiler.m_records.emplace_back( | profiler.m_records.emplace_back( | ||||
| id, profiler.m_thread_id, time, | id, profiler.m_thread_id, time, | ||||
| AnyPtr{mem_pool.alloc(T{std::forward<TArgs>(args)...}), | |||||
| {&mem_pool, deleter}}); | |||||
| AnyPtr::make<T>(T{std::forward<TArgs&&>(args)...})); | |||||
| if constexpr (sm_debug) { | if constexpr (sm_debug) { | ||||
| Status expected = Recording; | Status expected = Recording; | ||||
| mgb_assert(profiler.m_status.compare_exchange_strong(expected, Running)); | mgb_assert(profiler.m_status.compare_exchange_strong(expected, Running)); | ||||
| @@ -241,7 +177,7 @@ public: | |||||
| bundle.options = get_options(); | bundle.options = get_options(); | ||||
| bundle.start_at = sm_start_at; | bundle.start_at = sm_start_at; | ||||
| bundle.thread_dict = get_thread_dict(); | bundle.thread_dict = get_thread_dict(); | ||||
| return std::move(bundle); | |||||
| return bundle; | |||||
| } | } | ||||
| static option_t get_option(std::string key, option_t default_val) { | static option_t get_option(std::string key, option_t default_val) { | ||||
| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/allocator.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <typeindex> | |||||
| #include "megbrain/utils/mempool.h" | |||||
| #include "megbrain/utils/metahelper.h" | |||||
| namespace mgb::imperative { | |||||
| template <typename T> | |||||
| class Allocator { | |||||
| public: | |||||
| using pointer = T*; | |||||
| using const_pointer = const T*; | |||||
| using void_pointer = void*; | |||||
| using const_void_pointer = const void*; | |||||
| using value_type = T; | |||||
| using size_type = std::size_t; | |||||
| using diffenence_type = std::ptrdiff_t; | |||||
| using pool_type = MemPoolStorage; | |||||
| private: | |||||
| pool_type* m_pool = nullptr; | |||||
| public: | |||||
| Allocator(pool_type* pool) : m_pool(pool) {} | |||||
| T* allocate(size_type n) { | |||||
| mgb_assert(n == 1); | |||||
| return m_pool->alloc(sizeof(T)); | |||||
| } | |||||
| void deallocate(pointer* p, size_type n) { | |||||
| mgb_assert(n == 1); | |||||
| m_pool->free(p); | |||||
| } | |||||
| bool operator==(const Allocator& rhs) const { return m_pool == rhs.m_pool; } | |||||
| bool operator!=(const Allocator& rhs) const { return m_pool != rhs.m_pool; } | |||||
| }; | |||||
| template <typename T> | |||||
| class ThreadLocalAllocatorAdapter { | |||||
| public: | |||||
| using value_type = T; | |||||
| using size_type = std::size_t; | |||||
| using pointer = T*; | |||||
| public: | |||||
| T* allocate(size_type n) { mgb_assert(false); } | |||||
| void deallocate(pointer* p, size_type n) { mgb_assert(false); } | |||||
| bool operator==(const ThreadLocalAllocatorAdapter& rhs) const { return true; } | |||||
| bool operator!=(const ThreadLocalAllocatorAdapter& rhs) const { return false; } | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/any.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <typeindex> | |||||
| #include "megbrain/imperative/utils/local_ptr.h" | |||||
| namespace mgb::imperative { | |||||
| class AnyMixinBase { | |||||
| private: | |||||
| const std::type_info* m_type = nullptr; | |||||
| public: | |||||
| AnyMixinBase() = default; | |||||
| const std::type_info& type() const { return *m_type; } | |||||
| friend class AnyPtr; | |||||
| }; | |||||
| template <typename T> | |||||
| class AnyMixin : public AnyMixinBase, public T { | |||||
| public: | |||||
| AnyMixin(T&& val) : T(std::move(val)) {} | |||||
| }; | |||||
| class AnyPtr { | |||||
| public: | |||||
| using storage_t = LocalPtr<AnyMixinBase>; | |||||
| private: | |||||
| storage_t m_storage; | |||||
| public: | |||||
| const std::type_info& type() const { return m_storage->type(); } | |||||
| template <typename T> | |||||
| const T& cast() const { | |||||
| mgb_assert(is_exactly<T>(), "type mismatch"); | |||||
| return *static_cast<const AnyMixin<T>*>(m_storage.get()); | |||||
| } | |||||
| template <typename T> | |||||
| bool is_exactly() const { | |||||
| return std::type_index{typeid(T)} == std::type_index{type()}; | |||||
| } | |||||
| bool operator==(std::nullptr_t nptr) const { return m_storage == nullptr; } | |||||
| bool operator!=(std::nullptr_t nptr) const { return m_storage != nullptr; } | |||||
| operator bool() const { return m_storage != nullptr; } | |||||
| template <typename T, typename... TArgs> | |||||
| static AnyPtr make(TArgs&&... args) { | |||||
| AnyPtr ret; | |||||
| ret.m_storage = LocalPtr<AnyMixinBase>::make<AnyMixin<T>>( | |||||
| std::forward<TArgs&&>(args)...); | |||||
| ret.m_storage->m_type = &typeid(T); | |||||
| return ret; | |||||
| } | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -0,0 +1,96 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/visit.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <chrono> | |||||
| #include <future> | |||||
| #include <vector> | |||||
| #include "megbrain/utils/metahelper.h" | |||||
| #include "megbrain/utils/small_vector.h" | |||||
| namespace mgb::imperative { | |||||
| class BoxBase : public NonCopyableObj { | |||||
| public: | |||||
| virtual void reset() = 0; | |||||
| virtual void set_exception(std::exception_ptr exc) = 0; | |||||
| virtual bool try_set_exception(std::exception_ptr exc) = 0; | |||||
| }; | |||||
| /** | |||||
| * \brief An reusable promise | |||||
| * | |||||
| * \tparam T type of value | |||||
| */ | |||||
| template <typename T> | |||||
| class Box final : public BoxBase { | |||||
| private: | |||||
| std::promise<T> m_promise; | |||||
| std::shared_future<T> m_future; | |||||
| std::mutex m_mutex; | |||||
| bool m_value_set; | |||||
| bool m_exception_set; | |||||
| public: | |||||
| Box() { reset(); } | |||||
| const T& get_value() { return m_future.get(); } | |||||
| T take_value() { | |||||
| T value = m_future.get(); | |||||
| reset(); | |||||
| return value; | |||||
| } | |||||
| void set_value(T value) { | |||||
| MGB_LOCK_GUARD(m_mutex); | |||||
| m_promise.set_value(std::move(value)); | |||||
| m_value_set = true; | |||||
| } | |||||
| bool try_set_value(T value) { | |||||
| MGB_LOCK_GUARD(m_mutex); | |||||
| if (m_exception_set) { | |||||
| return false; | |||||
| } | |||||
| m_promise.set_value(std::move(value)); | |||||
| m_value_set = true; | |||||
| return true; | |||||
| } | |||||
| void set_exception(std::exception_ptr exc) override { | |||||
| MGB_LOCK_GUARD(m_mutex); | |||||
| m_promise.set_exception(exc); | |||||
| m_exception_set = true; | |||||
| } | |||||
| bool try_set_exception(std::exception_ptr exc) override { | |||||
| MGB_LOCK_GUARD(m_mutex); | |||||
| if (m_value_set) { | |||||
| return false; | |||||
| } | |||||
| m_promise.set_exception(exc); | |||||
| m_exception_set = true; | |||||
| return true; | |||||
| } | |||||
| void reset() override { | |||||
| MGB_LOCK_GUARD(m_mutex); | |||||
| m_promise = {}; | |||||
| m_future = m_promise.get_future(); | |||||
| m_value_set = false; | |||||
| m_exception_set = false; | |||||
| } | |||||
| /** | |||||
| * \brief make an empty box | |||||
| * | |||||
| * \return std::shared_ptr<Box> | |||||
| */ | |||||
| static std::shared_ptr<Box> make() { return std::make_shared<Box>(); } | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/span.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <iomanip> | |||||
| #include <memory> | |||||
| #include <sstream> | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| template <typename T> | |||||
| class CleanupGuard { | |||||
| private: | |||||
| T m_callback; | |||||
| public: | |||||
| explicit CleanupGuard(T cb) : m_callback{std::move(cb)} {} | |||||
| ~CleanupGuard() { m_callback(); } | |||||
| }; | |||||
| inline std::string quoted(std::string str) { | |||||
| std::stringstream ss; | |||||
| ss << std::quoted(str); | |||||
| return ss.str(); | |||||
| } | |||||
| } // namespace imperative | |||||
| } // namespace mgb | |||||
| @@ -0,0 +1,245 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/intrusive_list.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/utils/metahelper.h" | |||||
| namespace mgb::imperative::utils::intrusive_list { | |||||
| // copy policy | |||||
| struct after_t {}; | |||||
| struct before_t {}; | |||||
| struct disable_t {}; | |||||
| template <typename T> | |||||
| struct Tail; | |||||
| // invariant: next->prev == this | |||||
| template <typename T> | |||||
| struct Head { | |||||
| Tail<T>* next; | |||||
| Head(Tail<T>* node = nullptr) : next(node) {} | |||||
| Head(const Head<T>&) = delete; | |||||
| Head<T>& operator=(const Head<T>&) = delete; | |||||
| Head(Head<T>&& rhs) : next(rhs.next) { | |||||
| rhs.next = nullptr; | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| } | |||||
| Head<T>& operator=(Head<T>&& rhs) { | |||||
| mgb_assert(!next); | |||||
| next = rhs.next; | |||||
| rhs.next = nullptr; | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| ~Head() { | |||||
| if (next) { | |||||
| next->prev = nullptr; | |||||
| } | |||||
| } | |||||
| }; | |||||
| // invariant: prev->next == this | |||||
| template <typename T> | |||||
| struct Tail { | |||||
| Head<T>* prev; | |||||
| Tail(Head<T>* node = nullptr) : prev(node) {} | |||||
| Tail(const Tail<T>&) = delete; | |||||
| Tail<T>& operator=(const Tail<T>&) = delete; | |||||
| Tail(Tail<T>&& rhs) : prev(rhs.prev) { | |||||
| rhs.prev = nullptr; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| } | |||||
| Tail<T>& operator=(Tail<T>&& rhs) { | |||||
| mgb_assert(!prev); | |||||
| prev = rhs.prev; | |||||
| rhs.prev = nullptr; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| ~Tail() { | |||||
| if (prev) { | |||||
| prev->next = nullptr; | |||||
| } | |||||
| } | |||||
| }; | |||||
| template <typename T, typename policy> | |||||
| struct Node; | |||||
| template <typename T> | |||||
| class Iterator { | |||||
| T* ptr; | |||||
| void inc() { ptr = static_cast<T*>(ptr->Head<T>::next); } | |||||
| void dec() { ptr = static_cast<T*>(ptr->Head<T>::prev); } | |||||
| public: | |||||
| Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {} | |||||
| Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {} | |||||
| template <typename policy> | |||||
| Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {} | |||||
| T& operator*() { return *static_cast<T*>(ptr); } | |||||
| T* operator->() { return static_cast<T*>(ptr); } | |||||
| operator bool() { return ptr; } | |||||
| bool operator==(const Iterator<T>& rhs) { return ptr == rhs.ptr; } | |||||
| Iterator& operator++() { | |||||
| inc(); | |||||
| return *this; | |||||
| } | |||||
| Iterator& operator--() { | |||||
| dec(); | |||||
| return *this; | |||||
| } | |||||
| Iterator operator++(int) { | |||||
| auto ret = *this; | |||||
| inc(); | |||||
| return ret; | |||||
| } | |||||
| Iterator operator--(int) { | |||||
| auto ret = *this; | |||||
| dec(); | |||||
| return ret; | |||||
| } | |||||
| }; | |||||
| // Node in a doubly linked list. Unlike std::list, nodes are not owned by a container. | |||||
| // Instead, nodes may join or leave a list freely. | |||||
| // NOTE: Derived classes have to explicitly declare copy / assignment as default, | |||||
| // otherwise the compiler generated version would use the const T& signature, | |||||
| // which is deleted. | |||||
| template <typename T = void, typename policy = disable_t> | |||||
| struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>, | |||||
| Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> { | |||||
| private: | |||||
| using this_t = Node<T, policy>; | |||||
| using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>; | |||||
| public: | |||||
| using head_t = Head<U>; | |||||
| using tail_t = Tail<U>; | |||||
| using head_t::next; | |||||
| using tail_t::prev; | |||||
| Node() = default; | |||||
| Node(const this_t&) = delete; | |||||
| this_t& operator=(const this_t&) = delete; | |||||
| //! constructed node is inserted after the input node | |||||
| Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) { | |||||
| node.next = this; | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| } | |||||
| //! constructed node is inserted before the input node | |||||
| Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) { | |||||
| node.prev = this; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| } | |||||
| Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) { | |||||
| rhs.prev = nullptr; | |||||
| rhs.next = nullptr; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| } | |||||
| Node& operator=(this_t&& rhs) { | |||||
| unlink(); | |||||
| prev = rhs.prev; | |||||
| next = rhs.next; | |||||
| rhs.prev = nullptr; | |||||
| rhs.next = nullptr; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| template < | |||||
| typename p = policy, | |||||
| typename = std::enable_if_t< | |||||
| std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||||
| Node(this_t& rhs) : Node(policy{}, rhs) {} | |||||
| template < | |||||
| typename p = policy, | |||||
| typename = std::enable_if_t< | |||||
| std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||||
| this_t& operator=(this_t& rhs) { | |||||
| insert(policy{}, rhs); | |||||
| return *this; | |||||
| } | |||||
| void unlink() { | |||||
| if (prev) { | |||||
| prev->next = next; | |||||
| } | |||||
| if (next) { | |||||
| next->prev = prev; | |||||
| } | |||||
| prev = nullptr; | |||||
| next = nullptr; | |||||
| } | |||||
| //! this node is unlinked from its list and inserted after the input node | |||||
| void insert(after_t, head_t& node) { | |||||
| unlink(); | |||||
| prev = &node; | |||||
| next = node.next; | |||||
| node.next = this; | |||||
| if (next) { | |||||
| next->prev = this; | |||||
| } | |||||
| } | |||||
| //! this node is unlinked from its list and inserted before the input node | |||||
| void insert(before_t, tail_t& node) { | |||||
| unlink(); | |||||
| next = &node; | |||||
| prev = node.prev; | |||||
| node.prev = this; | |||||
| if (prev) { | |||||
| prev->next = this; | |||||
| } | |||||
| } | |||||
| void insert_before(tail_t& node) { insert(before_t{}, node); } | |||||
| void insert_after(head_t& node) { insert(after_t{}, node); } | |||||
| ~Node() { unlink(); } | |||||
| }; | |||||
| } // namespace mgb::imperative::utils::intrusive_list | |||||
| @@ -0,0 +1,285 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/local_ptr.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <optional> | |||||
| #include "megbrain/imperative/utils/mempool.h" | |||||
| #include "megbrain/utils/metahelper.h" | |||||
| namespace mgb::imperative { | |||||
| template <typename T> | |||||
| class LocalPtrStorage : public NonCopyableObj { | |||||
| private: | |||||
| size_t m_ref_count = 0; | |||||
| size_t m_weak_count = 0; | |||||
| T* m_pointer = nullptr; | |||||
| void (*reset)(LocalPtrStorage*) = nullptr; | |||||
| void (*free)(LocalPtrStorage*) = nullptr; | |||||
| void inc_ref() { m_ref_count++; } | |||||
| void dec_ref() { | |||||
| m_ref_count--; | |||||
| if (m_ref_count == 0) { | |||||
| reset(this); | |||||
| m_pointer = nullptr; | |||||
| reset = nullptr; | |||||
| if (m_weak_count == 0) { | |||||
| free(this); | |||||
| // dead | |||||
| } | |||||
| } | |||||
| } | |||||
| void inc_weak_ref() { m_weak_count++; } | |||||
| void dec_weak_ref() { | |||||
| m_weak_count--; | |||||
| if ((m_weak_count + m_ref_count) == 0) { | |||||
| free(this); | |||||
| // dead | |||||
| } | |||||
| } | |||||
| template <typename U> | |||||
| friend class LocalPtr; | |||||
| template <typename U> | |||||
| friend class LocalWeakPtr; | |||||
| public: | |||||
| }; | |||||
| template <typename T, typename TDerived> | |||||
| class LocalPtrStorgeImpl : public LocalPtrStorage<T> { | |||||
| private: | |||||
| std::optional<TDerived> m_value; | |||||
| void* m_pool = nullptr; | |||||
| template <typename U> | |||||
| friend class LocalPtr; | |||||
| template <typename U> | |||||
| friend class LocalWeakPtr; | |||||
| }; | |||||
| template <typename T> | |||||
| class LocalWeakPtr; | |||||
| /** | |||||
| * \brief thread-unsafe smart pointer | |||||
| * | |||||
| * \tparam T type of value | |||||
| */ | |||||
| template <typename T> | |||||
| class LocalPtr { | |||||
| public: | |||||
| using storage_t = LocalPtrStorage<T>; | |||||
| using pool_t = MemPool<storage_t>; | |||||
| using weak_type = LocalWeakPtr<T>; | |||||
| private: | |||||
| storage_t* m_storage = nullptr; | |||||
| void emplace(storage_t* ptr) { | |||||
| if (ptr) { | |||||
| ptr->inc_ref(); | |||||
| m_storage = ptr; | |||||
| } | |||||
| } | |||||
| LocalPtr(storage_t* ptr) { emplace(ptr); } | |||||
| public: | |||||
| LocalPtr() = default; | |||||
| LocalPtr(const LocalPtr& rhs) { (*this) = rhs; } | |||||
| LocalPtr(LocalPtr&& rhs) { (*this) = std::move(rhs); } | |||||
| LocalPtr& operator=(const LocalPtr& rhs) { | |||||
| if (this == &rhs) { | |||||
| return *this; | |||||
| } | |||||
| auto storage = rhs.m_storage; | |||||
| if (storage) { | |||||
| storage->inc_ref(); | |||||
| } | |||||
| if (m_storage) { | |||||
| m_storage->dec_ref(); | |||||
| // rhs.m_storage may be invalid here | |||||
| } | |||||
| m_storage = storage; | |||||
| return *this; | |||||
| } | |||||
| LocalPtr& operator=(LocalPtr&& rhs) { | |||||
| if (this == &rhs) { | |||||
| return *this; | |||||
| } | |||||
| std::swap(m_storage, rhs.m_storage); | |||||
| rhs.reset(); | |||||
| return *this; | |||||
| } | |||||
| bool operator==(const LocalPtr& rhs) const { return m_storage == rhs.m_storage; } | |||||
| bool operator!=(const LocalPtr& rhs) const { return m_storage != rhs.m_storage; } | |||||
| size_t hash() const { return reinterpret_cast<uintptr_t>(m_storage); } | |||||
| ~LocalPtr() { reset(); } | |||||
| /** | |||||
| * \brief Construct an instance of TDerived and return an LocalPtr | |||||
| * | |||||
| * There is an memory pool for each (T, TDerived) pair | |||||
| * | |||||
| * \tparam TDerived type of concrete instance, should be subclass of T | |||||
| * \tparam TArgs | |||||
| * \param args constructor arguments | |||||
| * \return LocalPtr points to the instance | |||||
| */ | |||||
| template <typename TDerived = T, typename... TArgs> | |||||
| static LocalPtr make(TArgs&&... args) { | |||||
| static_assert(std::is_base_of_v<T, TDerived>); | |||||
| using storage_impl_t = LocalPtrStorgeImpl<T, TDerived>; | |||||
| constexpr auto normalize_size = [](size_t size) { | |||||
| size_t normalized_size = 64; | |||||
| while (normalized_size < size) { | |||||
| normalized_size *= 2; | |||||
| } | |||||
| return normalized_size; | |||||
| }; | |||||
| using raw_storage_t = | |||||
| std::aligned_storage_t<normalize_size(sizeof(storage_impl_t))>; | |||||
| static_assert(alignof(raw_storage_t) % alignof(storage_impl_t) == 0); | |||||
| static_assert(sizeof(raw_storage_t) >= sizeof(storage_impl_t)); | |||||
| using pool_t = MemPool<raw_storage_t>; | |||||
| pool_t& pool = MemPoolUtils<raw_storage_t>::get_thread_local(); | |||||
| auto* raw_storage = pool.alloc_raw(); | |||||
| auto* storage = reinterpret_cast<storage_impl_t*>(raw_storage); | |||||
| new (storage) storage_impl_t(); | |||||
| storage->m_value.emplace(std::forward<TArgs&&>(args)...); | |||||
| storage->m_pointer = &*storage->m_value; | |||||
| storage->reset = [](storage_t* storage) { | |||||
| auto* storage_impl = static_cast<storage_impl_t*>(storage); | |||||
| storage_impl->m_value.reset(); | |||||
| storage_impl->m_pointer = nullptr; | |||||
| }; | |||||
| storage->free = [](storage_t* storage_base) { | |||||
| auto* storage = static_cast<storage_impl_t*>(storage_base); | |||||
| auto* pool = reinterpret_cast<pool_t*>(storage->m_pool); | |||||
| storage->m_pool = nullptr; | |||||
| storage->~storage_impl_t(); | |||||
| auto* raw_storage = reinterpret_cast<raw_storage_t*>(storage); | |||||
| pool->free_raw(raw_storage); | |||||
| }; | |||||
| storage->m_pool = &pool; | |||||
| return {(storage_t*)storage}; | |||||
| } | |||||
| T& operator*() const { return *get(); } | |||||
| T* get() const { | |||||
| if ((!m_storage) || !m_storage->m_pointer) { | |||||
| return nullptr; | |||||
| } | |||||
| return m_storage->m_pointer; | |||||
| } | |||||
| T* operator->() const { return get(); } | |||||
| size_t ref_count() const { return m_storage->m_ref_count; } | |||||
| bool unique() const { return ref_count() == 1; } | |||||
| void reset() { | |||||
| if (m_storage) { | |||||
| m_storage->dec_ref(); | |||||
| m_storage = nullptr; | |||||
| } | |||||
| } | |||||
| operator bool() const { return bool(m_storage); } | |||||
| bool operator==(std::nullptr_t nptr) const { return m_storage == nullptr; } | |||||
| bool operator!=(std::nullptr_t nptr) const { return m_storage != nullptr; } | |||||
| template <typename U> | |||||
| friend class LocalWeakPtr; | |||||
| }; | |||||
| template <typename T> | |||||
| class LocalWeakPtr { | |||||
| public: | |||||
| using storage_t = LocalPtrStorage<T>; | |||||
| private: | |||||
| storage_t* m_storage = nullptr; | |||||
| void emplace(storage_t* ptr) { | |||||
| if (ptr) { | |||||
| ptr->inc_weak_ref(); | |||||
| m_storage = ptr; | |||||
| } | |||||
| } | |||||
| public: | |||||
| LocalWeakPtr() = default; | |||||
| LocalWeakPtr(const LocalPtr<T>& rhs) { emplace(rhs.m_storage); } | |||||
| LocalWeakPtr(const LocalWeakPtr& rhs) { (*this) = rhs; } | |||||
| LocalWeakPtr(LocalWeakPtr&& rhs) { (*this) = std::move(rhs); } | |||||
| LocalWeakPtr& operator=(const LocalWeakPtr& rhs) { | |||||
| if (this == &rhs) { | |||||
| return *this; | |||||
| } | |||||
| reset(); | |||||
| emplace(rhs.m_storage); | |||||
| return *this; | |||||
| } | |||||
| LocalWeakPtr& operator=(LocalWeakPtr&& rhs) { | |||||
| if (this == &rhs) { | |||||
| return *this; | |||||
| } | |||||
| std::swap(m_storage, rhs.m_storage); | |||||
| rhs.reset(); | |||||
| return *this; | |||||
| } | |||||
| ~LocalWeakPtr() { reset(); } | |||||
| void reset() { | |||||
| if (m_storage) { | |||||
| m_storage->dec_weak_ref(); | |||||
| m_storage = nullptr; | |||||
| } | |||||
| } | |||||
| LocalPtr<T> lock() const { | |||||
| if (m_storage && m_storage->m_ref_count) { | |||||
| return {m_storage}; | |||||
| } | |||||
| return {}; | |||||
| } | |||||
| bool operator==(const LocalWeakPtr& rhs) const { | |||||
| return m_storage == rhs.m_storage; | |||||
| } | |||||
| bool operator!=(const LocalWeakPtr& rhs) const { | |||||
| return m_storage != rhs.m_storage; | |||||
| } | |||||
| size_t hash() const { return reinterpret_cast<uintptr_t>(m_storage); } | |||||
| }; | |||||
| template <typename T, typename TDerived, typename... TArgs> | |||||
| LocalPtr<T> make_local(TArgs&&... args) { | |||||
| return LocalPtr<T>::template make<TDerived>(std::forward<TArgs&&>(args)...); | |||||
| } | |||||
| } // namespace mgb::imperative | |||||
| @@ -0,0 +1,157 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/map.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <optional> | |||||
| #include "megbrain/utils/metahelper.h" | |||||
| namespace mgb::imperative { | |||||
| /** | |||||
| * \brief an hash map optimized for weak pointer as key | |||||
| * | |||||
| * Keys were scanned automatically, so values referenced by invalid keys whould be | |||||
| * released soon | |||||
| * | |||||
| * \tparam TKey key type, requires(bool(key.lock())) | |||||
| * \tparam TValue value type | |||||
| */ | |||||
| template <typename TKey, typename TValue> | |||||
| class WeakKeyMap : public NonCopyableObj { | |||||
| public: | |||||
| using storage_t = std::unordered_map<TKey, TValue>; | |||||
| private: | |||||
| storage_t m_storage; | |||||
| typename storage_t::iterator m_cursor = m_storage.begin(); | |||||
| /** | |||||
| * \brief select a key and verify that whether it is invalid. If yes, erase it | |||||
| * | |||||
| */ | |||||
| void _step() { | |||||
| if (m_cursor == m_storage.end()) { | |||||
| m_cursor = m_storage.begin(); | |||||
| return; | |||||
| } | |||||
| auto key = m_cursor->first; | |||||
| if (!key.lock()) { | |||||
| m_cursor = m_storage.erase(m_cursor); | |||||
| } else { | |||||
| ++m_cursor; | |||||
| } | |||||
| } | |||||
| public: | |||||
| size_t count(TKey key) { | |||||
| _step(); | |||||
| _step(); | |||||
| return m_storage.count(key); | |||||
| } | |||||
| TValue& at(TKey key) const { return m_storage.at(key); } | |||||
| TValue& at(TKey key) { | |||||
| _step(); | |||||
| _step(); | |||||
| return m_storage.at(key); | |||||
| } | |||||
| TValue& operator[](TKey key) { | |||||
| _step(); | |||||
| _step(); | |||||
| if (m_storage.count(key)) { | |||||
| return m_storage.at(key); | |||||
| } else { | |||||
| size_t bucket_count = m_storage.bucket_count(); | |||||
| TValue& result = m_storage[key]; | |||||
| if (bucket_count != m_storage.bucket_count()) { | |||||
| m_cursor = m_storage.begin(); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| } | |||||
| std::optional<TValue> try_get(TKey key) const { | |||||
| auto iter = m_storage.find(key); | |||||
| if (iter == m_storage.end()) { | |||||
| return {}; | |||||
| } | |||||
| return {iter->second}; | |||||
| } | |||||
| std::optional<TValue> try_get(TKey key) { | |||||
| _step(); | |||||
| _step(); | |||||
| return ((const WeakKeyMap*)this)->try_get(std::move(key)); | |||||
| } | |||||
| }; | |||||
| template <typename TKey, typename TValue> | |||||
| class WeakValueMap : public NonCopyableObj { | |||||
| public: | |||||
| using storage_t = std::unordered_map<TKey, TValue>; | |||||
| private: | |||||
| storage_t m_storage; | |||||
| typename storage_t::iterator m_cursor = m_storage.begin(); | |||||
| /** | |||||
| * \brief select a key and verify that whether it is invalid. If yes, erase it | |||||
| * | |||||
| */ | |||||
| void _step() { | |||||
| if (m_cursor == m_storage.end()) { | |||||
| m_cursor = m_storage.begin(); | |||||
| return; | |||||
| } | |||||
| auto value = m_cursor->second; | |||||
| if (!value.lock()) { | |||||
| m_cursor = m_storage.erase(m_cursor); | |||||
| } else { | |||||
| ++m_cursor; | |||||
| } | |||||
| } | |||||
| public: | |||||
| size_t count(TKey key) { | |||||
| _step(); | |||||
| _step(); | |||||
| return m_storage.count(key); | |||||
| } | |||||
| TValue& at(TKey key) const { return m_storage.at(key); } | |||||
| TValue& at(TKey key) { | |||||
| _step(); | |||||
| _step(); | |||||
| return m_storage.at(key); | |||||
| } | |||||
| TValue& operator[](TKey key) { | |||||
| _step(); | |||||
| _step(); | |||||
| if (m_storage.count(key)) { | |||||
| return m_storage.at(key); | |||||
| } else { | |||||
| size_t bucket_count = m_storage.bucket_count(); | |||||
| TValue& result = m_storage[key]; | |||||
| if (bucket_count != m_storage.bucket_count()) { | |||||
| m_cursor = m_storage.begin(); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| } | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/mempool.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <mutex> | |||||
| #include <thread> | |||||
| #include <unordered_map> | |||||
| #include "megbrain/utils/mempool.h" | |||||
| #include "megbrain/utils/metahelper.h" | |||||
| namespace mgb::imperative { | |||||
| template <typename T> | |||||
| class MemPoolUtils { | |||||
| private: | |||||
| static std::mutex sm_mutex; | |||||
| static std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>> | |||||
| sm_instances; | |||||
| static thread_local MemPool<T>* tm_instance; | |||||
| static MemPool<T>* sm_instance; | |||||
| public: | |||||
| static MemPool<T>& get_thread_local() { | |||||
| if (!tm_instance) { | |||||
| MGB_LOCK_GUARD(sm_mutex); | |||||
| auto& instance = sm_instances[std::this_thread::get_id()]; | |||||
| if (!instance) { // thread id may be duplicated | |||||
| instance = std::make_unique<MemPool<T>>(); | |||||
| } | |||||
| tm_instance = instance.get(); | |||||
| } | |||||
| return *tm_instance; | |||||
| } | |||||
| static MemPool<T>& get_static() { | |||||
| if (!sm_instance) { | |||||
| MGB_LOCK_GUARD(sm_mutex); | |||||
| auto& instance = sm_instances[{}]; | |||||
| if (!instance) { // double check | |||||
| instance = std::make_unique<MemPool<T>>(); | |||||
| sm_instance = instance.get(); | |||||
| } | |||||
| mgb_assert(sm_instance); | |||||
| } | |||||
| } | |||||
| }; | |||||
| template <typename T> | |||||
| std::mutex MemPoolUtils<T>::sm_mutex; | |||||
| template <typename T> | |||||
| std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>> | |||||
| MemPoolUtils<T>::sm_instances; | |||||
| template <typename T> | |||||
| thread_local MemPool<T>* MemPoolUtils<T>::tm_instance; | |||||
| template <typename T> | |||||
| MemPool<T>* MemPoolUtils<T>::sm_instance; | |||||
| } // namespace mgb::imperative | |||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/span.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <array> | |||||
| #include <vector> | |||||
| #include "megbrain/utils/small_vector.h" | |||||
| namespace mgb::imperative { | |||||
| /** | |||||
| * \brief wrapper for c-style array | |||||
| * | |||||
| * \tparam T value type | |||||
| */ | |||||
| template <typename T> | |||||
| class Span { | |||||
| private: | |||||
| const T* m_begin = nullptr; | |||||
| const T* m_end = nullptr; | |||||
| public: | |||||
| Span() {} | |||||
| Span(const T* begin, const T* end) : m_begin{begin}, m_end{end} {} | |||||
| Span(const T* begin, size_t size) : Span(begin, begin + size) {} | |||||
| template <typename TContainer> | |||||
| Span(TContainer& container) : Span(container.data(), container.size()) {} | |||||
| const T* begin() const { return m_begin; } | |||||
| const T* end() const { return m_end; } | |||||
| const T* data() const { return m_begin; } | |||||
| size_t size() const { return m_end - m_begin; } | |||||
| template <typename TContainer> | |||||
| TContainer copy_into() { | |||||
| return TContainer(m_begin, m_end); | |||||
| } | |||||
| const T& operator[](size_t idx) const { return m_begin[idx]; } | |||||
| const T& at(size_t idx) const { return m_begin[idx]; } | |||||
| const T& item() const { | |||||
| mgb_assert( | |||||
| m_end - m_begin == 1, "size mismatch: %zu vs %zu", (m_end - m_begin), | |||||
| (size_t)1); | |||||
| return m_begin[0]; | |||||
| } | |||||
| template <size_t N> | |||||
| const std::array<T, N>& as_array() { | |||||
| mgb_assert( | |||||
| m_end - m_begin == N, "size mismatch: %zu vs %zu", (m_end - m_begin), | |||||
| N); | |||||
| return *reinterpret_cast<const std::array<T, N>*>(m_begin); | |||||
| } | |||||
| Span sub(size_t begin, size_t length) { | |||||
| mgb_assert(begin + length <= m_end - m_begin); | |||||
| return {m_begin + begin, length}; | |||||
| } | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -16,6 +16,7 @@ | |||||
| #include <tuple> | #include <tuple> | ||||
| #include <type_traits> | #include <type_traits> | ||||
| #include "megbrain/imperative/utils/span.h" | |||||
| #include "megbrain/tensor.h" | #include "megbrain/tensor.h" | ||||
| #include "megbrain/utils/small_vector.h" | #include "megbrain/utils/small_vector.h" | ||||
| @@ -59,6 +60,22 @@ struct ToStringTrait<SmallVector<T, N>> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename T> | |||||
| struct ToStringTrait<std::vector<T>> { | |||||
| std::string operator()(const std::vector<T>& v) const { | |||||
| if (v.empty()) { | |||||
| return "[]"; | |||||
| } | |||||
| std::string result = "["; | |||||
| result += to_string(v[0]); | |||||
| for (size_t i = 1; i < v.size(); ++i) { | |||||
| result += ", "; | |||||
| result += to_string(v[i]); | |||||
| } | |||||
| return result + "]"; | |||||
| } | |||||
| }; | |||||
| template <typename T> | template <typename T> | ||||
| struct ToStringTrait<std::shared_ptr<T>> { | struct ToStringTrait<std::shared_ptr<T>> { | ||||
| std::string operator()(const std::shared_ptr<T>& sp) const { | std::string operator()(const std::shared_ptr<T>& sp) const { | ||||
| @@ -115,4 +132,36 @@ struct ToStringTrait<CompNode> { | |||||
| std::string operator()(CompNode device) const { return device.to_string(); } | std::string operator()(CompNode device) const { return device.to_string(); } | ||||
| }; | }; | ||||
| inline std::string string_join(Span<std::string> span, char delimiter = ',') { | |||||
| std::string buffer = "["; | |||||
| for (size_t i = 1; i < span.size(); ++i) { | |||||
| if (i) { | |||||
| buffer.push_back(delimiter); | |||||
| } | |||||
| buffer.append(span[0]); | |||||
| } | |||||
| return buffer + "]"; | |||||
| } | |||||
| template <typename T> | |||||
| struct ToStringTrait<Span<T>> { | |||||
| std::string operator()(Span<T> span) const { | |||||
| if (span.size() == 0) { | |||||
| return "[]"; | |||||
| } | |||||
| std::string result = "["; | |||||
| result += to_string(span[0]); | |||||
| for (size_t i = 1; i < span.size(); ++i) { | |||||
| result += ", "; | |||||
| result += to_string(span[i]); | |||||
| } | |||||
| return result + "]"; | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| struct ToStringTrait<std::type_info> { | |||||
| std::string operator()(const std::type_info& info) const { return info.name(); } | |||||
| }; | |||||
| } // namespace mgb::imperative | } // namespace mgb::imperative | ||||
| @@ -0,0 +1,104 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/visit.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <vector> | |||||
| #include "megbrain/imperative/utils/span.h" | |||||
| #include "megbrain/tensor.h" | |||||
| namespace mgb::imperative { | |||||
| /** | |||||
| * \brief like TensorShape, but allow real scalar shape. | |||||
| * | |||||
| */ | |||||
| struct ValueShape { | |||||
| size_t shape[TensorShape::MAX_NDIM]; | |||||
| int ndim = 0; | |||||
| ValueShape() = default; | |||||
| ValueShape(std::initializer_list<size_t> dims) { | |||||
| for (auto&& dim : dims) { | |||||
| shape[ndim++] = dim; | |||||
| } | |||||
| } | |||||
| ValueShape(Span<size_t> dims) { | |||||
| for (auto&& dim : dims) { | |||||
| shape[ndim++] = dim; | |||||
| } | |||||
| } | |||||
| size_t& operator[](int axis) { return shape[axis]; } | |||||
| size_t operator[](int axis) const { return shape[axis]; } | |||||
| size_t at(int axis) const { | |||||
| mgb_assert(axis < ndim); | |||||
| return shape[axis]; | |||||
| } | |||||
| size_t total_nr_elems() const { | |||||
| size_t prod = 1; | |||||
| for (int i = 0; i < ndim; ++i) { | |||||
| prod *= shape[i]; | |||||
| } | |||||
| return prod; | |||||
| } | |||||
| bool is_scalar() const { return ndim == 0; } | |||||
| std::string to_string() const { | |||||
| std::string buffer = "{"; | |||||
| for (size_t i = 0; i < ndim; ++i) { | |||||
| if (i) { | |||||
| buffer.append(","); | |||||
| } | |||||
| buffer.append(std::to_string(shape[i])); | |||||
| } | |||||
| buffer.append("}"); | |||||
| return buffer; | |||||
| } | |||||
| static ValueShape from(TensorShape tensor_shape) { | |||||
| mgb_assert(tensor_shape.ndim); | |||||
| return Span<size_t>{tensor_shape.shape, tensor_shape.ndim}; | |||||
| } | |||||
| TensorShape as_tensor_shape() const { | |||||
| mgb_assert(ndim != 0); | |||||
| TensorShape ret; | |||||
| for (size_t i = 0; i < ndim; ++i) { | |||||
| ret.shape[i] = shape[i]; | |||||
| } | |||||
| ret.ndim = ndim; | |||||
| return ret; | |||||
| } | |||||
| bool operator==(const ValueShape& rhs) const { | |||||
| if (ndim != rhs.ndim) { | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 0; i < ndim; ++i) { | |||||
| if (shape[i] != rhs.shape[i]) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| }; | |||||
| static_assert(sizeof(size_t) >= sizeof(int)); | |||||
| static_assert(TensorShape::MAX_NDIM == 7); | |||||
| static_assert(sizeof(ValueShape) <= sizeof(size_t) * 8); | |||||
| } // namespace mgb::imperative | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/visit.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <vector> | |||||
| #include "megbrain/utils/small_vector.h" | |||||
| namespace mgb::imperative { | |||||
| template <typename... TVisitors> | |||||
| class Visitor : public TVisitors... { | |||||
| public: | |||||
| using TVisitors::operator()...; | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -28,10 +28,10 @@ TEST(TestProfiler, ImperativeLogProfile) { | |||||
| auto results = imperative::Profiler::collect(); | auto results = imperative::Profiler::collect(); | ||||
| imperative::Profiler::stop_profile(); | imperative::Profiler::stop_profile(); | ||||
| mgb_assert(results.entries.size() == 2); | mgb_assert(results.entries.size() == 2); | ||||
| auto* event_start = results.entries[0].data.as<profiler::CustomEvent>(); | |||||
| auto* event_finish = results.entries[1].data.as<profiler::CustomFinishEvent>(); | |||||
| mgb_assert(event_start && event_start->title == "XXX"); | |||||
| mgb_assert(event_finish && event_finish->title == "XXX"); | |||||
| auto& event_start = results.entries[0].data.cast<profiler::CustomEvent>(); | |||||
| auto& event_finish = results.entries[1].data.cast<profiler::CustomFinishEvent>(); | |||||
| mgb_assert(event_start.title == "XXX"); | |||||
| mgb_assert(event_finish.title == "XXX"); | |||||
| mgb_assert(results.entries[0].time < results.entries[1].time); | mgb_assert(results.entries[0].time < results.entries[1].time); | ||||
| mgb_assert(results.entries[0].id < results.entries[1].id); | mgb_assert(results.entries[0].id < results.entries[1].id); | ||||
| } | } | ||||