GitOrigin-RevId: 0f910c34b6
tags/v1.9.0
| @@ -14,6 +14,7 @@ | |||||
| #include "megbrain/imperative/backward_graph_opt.h" | #include "megbrain/imperative/backward_graph_opt.h" | ||||
| #include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
| #include "megbrain/imperative/proxy_graph_detail.h" | #include "megbrain/imperative/proxy_graph_detail.h" | ||||
| #include "megbrain/imperative/resource_manager.h" | |||||
| #include "megbrain/utils/mempool.h" | #include "megbrain/utils/mempool.h" | ||||
| #include "range/v3/all.hpp" | #include "range/v3/all.hpp" | ||||
| @@ -1158,11 +1158,16 @@ void init_tensor(py::module m) { | |||||
| using Segment = TransformationManager::Segment; | using Segment = TransformationManager::Segment; | ||||
| auto* channel = interpreter::Interpreter::inst().create_channel().release(); | |||||
| using Channel = interpreter::Interpreter::Channel; | |||||
| auto* channel = | |||||
| imperative::ResourceManager::create_global<std::unique_ptr<Channel>>( | |||||
| interpreter::Interpreter::inst().create_channel()) | |||||
| ->get(); | |||||
| interpreter_for_py = channel; | interpreter_for_py = channel; | ||||
| transformations.register_at<Segment::Eval>( | transformations.register_at<Segment::Eval>( | ||||
| std::make_shared<InterpreterTransformation>( | std::make_shared<InterpreterTransformation>( | ||||
| std::unique_ptr<interpreter::Interpreter::Channel>(channel))); | |||||
| std::shared_ptr<Channel>(channel, [](Channel*) {}))); | |||||
| transformations.register_at<Segment::Scalar>( | transformations.register_at<Segment::Scalar>( | ||||
| std::make_shared<ScalarTransformation>()); | std::make_shared<ScalarTransformation>()); | ||||
| @@ -13,6 +13,7 @@ | |||||
| #include "megbrain/comp_node.h" | #include "megbrain/comp_node.h" | ||||
| #include "megbrain/imperative/blob_manager.h" | #include "megbrain/imperative/blob_manager.h" | ||||
| #include "megbrain/imperative/resource_manager.h" | |||||
| #include "megbrain/system.h" | #include "megbrain/system.h" | ||||
| #include "./event_pool.h" | #include "./event_pool.h" | ||||
| @@ -61,8 +62,8 @@ protected: | |||||
| public: | public: | ||||
| static AsyncReleaser* inst() { | static AsyncReleaser* inst() { | ||||
| static AsyncReleaser releaser; | |||||
| return &releaser; | |||||
| static auto* releaser = ResourceManager::create_global<AsyncReleaser>(); | |||||
| return releaser; | |||||
| } | } | ||||
| ~AsyncReleaser() { m_waiter.wait_task_queue_empty(); } | ~AsyncReleaser() { m_waiter.wait_task_queue_empty(); } | ||||
| @@ -10,6 +10,9 @@ | |||||
| */ | */ | ||||
| #include "./event_pool.h" | #include "./event_pool.h" | ||||
| #include <memory> | |||||
| #include "megbrain/imperative/resource_manager.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -17,22 +20,18 @@ namespace imperative { | |||||
| EventPool::EventPool(size_t flags) : m_flags{flags} {} | EventPool::EventPool(size_t flags) : m_flags{flags} {} | ||||
| EventPool& EventPool::with_timer() { | EventPool& EventPool::with_timer() { | ||||
| static Spinlock lock; | |||||
| static std::unique_ptr<EventPool> ptr; | |||||
| MGB_LOCK_GUARD(lock); | |||||
| if (!ptr || ptr->is_finalized()) { | |||||
| ptr.reset(new EventPool(CompNode::Event::NEED_TIMER)); | |||||
| } | |||||
| return *ptr; | |||||
| static auto* sm_pool = | |||||
| ResourceManager::create_global<CompNodeDependentResource<EventPool>>([] { | |||||
| return std::unique_ptr<EventPool>( | |||||
| new EventPool(CompNode::Event::NEED_TIMER)); | |||||
| }); | |||||
| return **sm_pool; | |||||
| } | } | ||||
| EventPool& EventPool::without_timer() { | EventPool& EventPool::without_timer() { | ||||
| static Spinlock lock; | |||||
| static std::unique_ptr<EventPool> ptr; | |||||
| MGB_LOCK_GUARD(lock); | |||||
| if (!ptr || ptr->is_finalized()) { | |||||
| ptr.reset(new EventPool()); | |||||
| } | |||||
| return *ptr; | |||||
| static auto* sm_pool = | |||||
| ResourceManager::create_global<CompNodeDependentResource<EventPool>>( | |||||
| [] { return std::unique_ptr<EventPool>(new EventPool()); }); | |||||
| return **sm_pool; | |||||
| } | } | ||||
| CompNode::Event* EventPool::alloc(CompNode cn) { | CompNode::Event* EventPool::alloc(CompNode cn) { | ||||
| CompNode::EventPool* pool; | CompNode::EventPool* pool; | ||||
| @@ -31,6 +31,8 @@ public: | |||||
| void free(CompNode::Event* event); | void free(CompNode::Event* event); | ||||
| std::shared_ptr<void> on_comp_node_finalize(); | std::shared_ptr<void> on_comp_node_finalize(); | ||||
| ~EventPool(); | ~EventPool(); | ||||
| using CompNodeDepedentObject::is_finalized; | |||||
| }; | }; | ||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -14,6 +14,7 @@ | |||||
| #include <sstream> | #include <sstream> | ||||
| #include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
| #include "megbrain/imperative/resource_manager.h" | |||||
| #include "./op_trait.h" | #include "./op_trait.h" | ||||
| @@ -63,16 +64,16 @@ EncodedSubgraph OpDef::make_backward_graph( | |||||
| const SmallVector<bool>& output_has_grad) { | const SmallVector<bool>& output_has_grad) { | ||||
| using BackwardGraphCache = | using BackwardGraphCache = | ||||
| OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>; | OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>; | ||||
| thread_local auto cache = std::make_unique<BackwardGraphCache>(); | |||||
| thread_local auto& cache = *ResourceManager::create_local<BackwardGraphCache>(); | |||||
| BackwardGraphCache::key_t cache_key{ | BackwardGraphCache::key_t cache_key{ | ||||
| const_cast<OpDef&>(def).shared_from_this(), | const_cast<OpDef&>(def).shared_from_this(), | ||||
| inputs, | inputs, | ||||
| {input_requires_grad, output_has_grad}}; | {input_requires_grad, output_has_grad}}; | ||||
| auto iter = cache->find(cache_key); | |||||
| if (iter == cache->end()) { | |||||
| iter = cache->insert({cache_key, def.trait()->make_backward_graph( | |||||
| def, inputs, input_requires_grad, | |||||
| output_has_grad)}) | |||||
| auto iter = cache.find(cache_key); | |||||
| if (iter == cache.end()) { | |||||
| iter = cache.insert({cache_key, def.trait()->make_backward_graph( | |||||
| def, inputs, input_requires_grad, | |||||
| output_has_grad)}) | |||||
| .first; | .first; | ||||
| } | } | ||||
| return iter->second; | return iter->second; | ||||
| @@ -86,12 +87,12 @@ EncodedSubgraph OpDef::make_forward_graph( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | ||||
| using ForwardGraphCache = | using ForwardGraphCache = | ||||
| OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>; | OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>; | ||||
| thread_local auto cache = std::make_unique<ForwardGraphCache>(); | |||||
| thread_local auto& cache = *ResourceManager::create_local<ForwardGraphCache>(); | |||||
| ForwardGraphCache::key_t cache_key{ | ForwardGraphCache::key_t cache_key{ | ||||
| const_cast<OpDef&>(def).shared_from_this(), inputs}; | const_cast<OpDef&>(def).shared_from_this(), inputs}; | ||||
| auto iter = cache->find(cache_key); | |||||
| if (iter == cache->end()) { | |||||
| iter = cache->insert({cache_key, def.trait()->make_forward_graph(def, inputs)}) | |||||
| auto iter = cache.find(cache_key); | |||||
| if (iter == cache.end()) { | |||||
| iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)}) | |||||
| .first; | .first; | ||||
| } | } | ||||
| return iter->second; | return iter->second; | ||||
| @@ -9,6 +9,7 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include <atomic> | |||||
| #include <deque> | #include <deque> | ||||
| #include "megbrain/imperative/graph_cache.h" | #include "megbrain/imperative/graph_cache.h" | ||||
| @@ -16,6 +17,7 @@ | |||||
| #include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
| #include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
| #include "megbrain/imperative/ops/utility.h" | #include "megbrain/imperative/ops/utility.h" | ||||
| #include "megbrain/imperative/resource_manager.h" | |||||
| #include "megbrain/imperative/subgraph_detail.h" | #include "megbrain/imperative/subgraph_detail.h" | ||||
| #include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
| #include "megbrain/opr/tensor_gen.h" | #include "megbrain/opr/tensor_gen.h" | ||||
| @@ -510,16 +512,32 @@ struct ComputingGraphHolder { | |||||
| } | } | ||||
| }; | }; | ||||
| static std::atomic<size_t> nr_cg_cache = 0; | |||||
| template <HolderKind Kind> | template <HolderKind Kind> | ||||
| ComputingGraphHolder<Kind>& get_computing_graph( | ComputingGraphHolder<Kind>& get_computing_graph( | ||||
| std::shared_ptr<OpDef> compiled_op, | std::shared_ptr<OpDef> compiled_op, | ||||
| const SmallVector<LogicalTensorDesc>& descs) { | const SmallVector<LogicalTensorDesc>& descs) { | ||||
| using ComputingGraphHolderCache = | using ComputingGraphHolderCache = | ||||
| OpMethResultCache<std::deque<std::unique_ptr<ComputingGraphHolder<Kind>>>>; | OpMethResultCache<std::deque<std::unique_ptr<ComputingGraphHolder<Kind>>>>; | ||||
| thread_local auto cache = std::make_unique<ComputingGraphHolderCache>(); | |||||
| thread_local auto& cache = ([]() -> auto& { | |||||
| mgb_assert( | |||||
| nr_cg_cache++ < 5, | |||||
| "using subgraph in too many threads, this causes resource leakage"); | |||||
| #if MGB_CUDA && defined(WIN32) | |||||
| // FIXME: Create as global to skip resource finalize and windows with cuda | |||||
| // doesn't cleanup global resources | |||||
| return *ResourceManager::create_global<ComputingGraphHolderCache>(); | |||||
| #else | |||||
| // Otherwise this should be local because compnode may be unusable when global | |||||
| // resource finalizing. | |||||
| // For example, CpuCompNode.sync hang on because underlying thread died | |||||
| return *ResourceManager::create_local<ComputingGraphHolderCache>(); | |||||
| #endif | |||||
| })(); | |||||
| thread_local size_t nr_cg_holders = 0; | thread_local size_t nr_cg_holders = 0; | ||||
| typename ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs}; | typename ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs}; | ||||
| auto& cg_holder_queue = (*cache)[cache_key]; | |||||
| auto& cg_holder_queue = cache[cache_key]; | |||||
| std::unique_ptr<ComputingGraphHolder<Kind>> holder; | std::unique_ptr<ComputingGraphHolder<Kind>> holder; | ||||
| if (!cg_holder_queue.empty()) { | if (!cg_holder_queue.empty()) { | ||||
| // pick one | // pick one | ||||
| @@ -12,6 +12,7 @@ | |||||
| #include "megbrain/imperative.h" | #include "megbrain/imperative.h" | ||||
| #include "megbrain/imperative/blob_manager.h" | #include "megbrain/imperative/blob_manager.h" | ||||
| #include "megbrain/imperative/profiler.h" | #include "megbrain/imperative/profiler.h" | ||||
| #include "megbrain/imperative/resource_manager.h" | |||||
| #include "./async_releaser.h" | #include "./async_releaser.h" | ||||
| #include "./event_pool.h" | #include "./event_pool.h" | ||||
| @@ -30,13 +31,6 @@ class CompNodeSyncManager : public CompNodeDepedentObject { | |||||
| std::mutex m_mtx; | std::mutex m_mtx; | ||||
| public: | public: | ||||
| #if MGB_CUDA && defined(WIN32) | |||||
| //! FIXME: windows cuda driver shutdown before call atexit function even | |||||
| //! register atexit function after init cuda driver! as a workround | |||||
| //! recovery resource by OS temporarily, may need remove this after | |||||
| //! upgrade cuda runtime | |||||
| static bool is_into_atexit; | |||||
| #endif | |||||
| std::shared_ptr<void> on_comp_node_finalize() override { | std::shared_ptr<void> on_comp_node_finalize() override { | ||||
| MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
| m_blob2event.clear(); | m_blob2event.clear(); | ||||
| @@ -44,17 +38,7 @@ public: | |||||
| } | } | ||||
| static CompNodeSyncManager& inst() { | static CompNodeSyncManager& inst() { | ||||
| static CompNodeSyncManager* sl_inst = new CompNodeSyncManager(); | |||||
| #if MGB_CUDA && defined(WIN32) | |||||
| //! FIXME: windows cuda driver shutdown before call atexit function even | |||||
| //! register atexit function after init cuda driver! as a workround | |||||
| //! recovery resource by OS temporarily, may need remove this after | |||||
| //! upgrade cuda runtime | |||||
| if (!is_into_atexit) { | |||||
| auto err = atexit([] { is_into_atexit = true; }); | |||||
| mgb_assert(!err, "failed to register atexit function"); | |||||
| } | |||||
| #endif | |||||
| static auto* sl_inst = ResourceManager::create_global<CompNodeSyncManager>(); | |||||
| return *sl_inst; | return *sl_inst; | ||||
| } | } | ||||
| @@ -73,13 +57,6 @@ public: | |||||
| m_blob2event.erase(blob); | m_blob2event.erase(blob); | ||||
| } | } | ||||
| }; | }; | ||||
| #if MGB_CUDA && defined(WIN32) | |||||
| //! FIXME: windows cuda driver shutdown before call atexit function even | |||||
| //! register atexit function after init cuda driver! as a workround | |||||
| //! recovery resource by OS temporarily, may need remove this after | |||||
| //! upgrade cuda runtime | |||||
| bool CompNodeSyncManager::is_into_atexit = false; | |||||
| #endif | |||||
| } // namespace | } // namespace | ||||
| @@ -106,15 +83,6 @@ Blob::Blob(CompNode cn, size_t sz) : m_comp_node{cn}, m_storage{}, m_size{sz} { | |||||
| Blob::~Blob() { | Blob::~Blob() { | ||||
| BlobManager::inst()->unregister_blob(this); | BlobManager::inst()->unregister_blob(this); | ||||
| #if MGB_CUDA && defined(WIN32) | |||||
| //! FIXME: windows cuda driver shutdown before call atexit function even | |||||
| //! register atexit function after init cuda driver! as a workround | |||||
| //! recovery resource by OS temporarily, may need remove this after | |||||
| //! upgrade cuda runtime | |||||
| if (CompNodeSyncManager::is_into_atexit) | |||||
| return; | |||||
| #endif | |||||
| CompNodeSyncManager::inst().remove(this); | CompNodeSyncManager::inst().remove(this); | ||||
| } | } | ||||
| @@ -242,8 +210,6 @@ void Tensor::static_initialize() { | |||||
| AsyncReleaser::inst(); | AsyncReleaser::inst(); | ||||
| CompNodeSyncManager::inst(); | CompNodeSyncManager::inst(); | ||||
| MultiCNConstTensorCache::inst(); | MultiCNConstTensorCache::inst(); | ||||
| // clean all CompNodeDepedentObjects | |||||
| mgb_assert(!atexit(CompNode::finalize), "atexit register failed"); | |||||
| } | } | ||||
| } // namespace imperative | } // namespace imperative | ||||
| @@ -0,0 +1,95 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/resource_manager.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/imperative/resource_manager.h" | |||||
| #include <thread> | |||||
| #include <unordered_map> | |||||
| using namespace mgb; | |||||
| using namespace imperative; | |||||
| namespace { | |||||
| class LocalResourceManager; | |||||
| std::unordered_map<std::thread::id, std::shared_ptr<LocalResourceManager>> | |||||
| local_managers; | |||||
| std::mutex global_lock; | |||||
| bool throw_all_resources = false; | |||||
| class LocalResourceManager final : public ResourceManager { | |||||
| private: | |||||
| std::thread::id m_id; | |||||
| public: | |||||
| LocalResourceManager() : m_id(std::this_thread::get_id()) {} | |||||
| std::thread::id id() const { return m_id; } | |||||
| }; | |||||
| class GlobalResourceManager final : public ResourceManager { | |||||
| public: | |||||
| ~GlobalResourceManager() { | |||||
| #if MGB_CUDA && defined(WIN32) | |||||
| //! FIXME: windows cuda driver shutdown before call atexit function even | |||||
| //! register atexit function after init cuda driver! as a workround | |||||
| //! recovery resource by OS temporarily, may need remove this after | |||||
| //! upgrade cuda runtime | |||||
| throw_all_resources = true; | |||||
| #endif | |||||
| MGB_LOCK_GUARD(global_lock); | |||||
| local_managers.clear(); | |||||
| } | |||||
| }; | |||||
| class LocalResourceManagerRef : public NonCopyableObj { | |||||
| private: | |||||
| std::weak_ptr<LocalResourceManager> m_manager; | |||||
| public: | |||||
| LocalResourceManagerRef() { | |||||
| auto manager = std::make_shared<LocalResourceManager>(); | |||||
| mgb_assert( | |||||
| local_managers.insert({manager->id(), manager}).second, | |||||
| "duplicated local manager"); | |||||
| m_manager = manager; | |||||
| } | |||||
| ~LocalResourceManagerRef() { | |||||
| if (auto manager = m_manager.lock()) { | |||||
| local_managers.erase(manager->id()); | |||||
| } | |||||
| } | |||||
| ResourceManager& operator*() { return *m_manager.lock(); } | |||||
| }; | |||||
| } // namespace | |||||
| void ResourceManager::clear() { | |||||
| if (throw_all_resources) { | |||||
| new std::vector<std::any>(std::move(m_handles)); | |||||
| } | |||||
| for (auto iter = m_handles.rbegin(); iter != m_handles.rend(); ++iter) { | |||||
| (*iter) = {}; | |||||
| } | |||||
| } | |||||
| ResourceManager& ResourceManager::get_global() { | |||||
| static GlobalResourceManager sl_manager; | |||||
| return sl_manager; | |||||
| } | |||||
| ResourceManager& ResourceManager::get_local() { | |||||
| thread_local LocalResourceManagerRef tl_manager; | |||||
| return *tl_manager; | |||||
| } | |||||
| @@ -12,6 +12,7 @@ | |||||
| #include "megbrain/imperative/transformations/grad.h" | #include "megbrain/imperative/transformations/grad.h" | ||||
| #include "megbrain/imperative/graph_cache.h" | #include "megbrain/imperative/graph_cache.h" | ||||
| #include "megbrain/imperative/resource_manager.h" | |||||
| #include <range/v3/all.hpp> | #include <range/v3/all.hpp> | ||||
| @@ -24,7 +25,8 @@ static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_gra | |||||
| // hash | // hash | ||||
| using OptimizedBackwardGraphCache = OpMethResultCache< | using OptimizedBackwardGraphCache = OpMethResultCache< | ||||
| std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>; | std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>; | ||||
| thread_local auto cache = std::make_unique<OptimizedBackwardGraphCache>(); | |||||
| thread_local auto& cache = | |||||
| *ResourceManager::create_local<OptimizedBackwardGraphCache>(); | |||||
| OptimizedBackwardGraphCache::key_t cache_key{op}; | OptimizedBackwardGraphCache::key_t cache_key{op}; | ||||
| SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs; | SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs; | ||||
| std::get<0>(cache_key.extras) = inputs_require_grad.copy_into<SmallVector<bool>>(); | std::get<0>(cache_key.extras) = inputs_require_grad.copy_into<SmallVector<bool>>(); | ||||
| @@ -34,8 +36,8 @@ static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_gra | |||||
| input_descs[i].comp_node = inputs[i].device().cast<CompNodeValue>(); | input_descs[i].comp_node = inputs[i].device().cast<CompNodeValue>(); | ||||
| } | } | ||||
| auto iter = cache->find(cache_key); | |||||
| if (iter != cache->end()) { | |||||
| auto iter = cache.find(cache_key); | |||||
| if (iter != cache.end()) { | |||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| @@ -47,7 +49,7 @@ static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_gra | |||||
| if (!bg.graph.empty()) { | if (!bg.graph.empty()) { | ||||
| ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ||||
| } | } | ||||
| cache->emplace(cache_key, ret); | |||||
| cache.emplace(cache_key, ret); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -14,6 +14,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include "megbrain/imperative/resource_manager.h" | |||||
| #include "megbrain/tensor.h" | #include "megbrain/tensor.h" | ||||
| namespace mgb { | namespace mgb { | ||||
| @@ -278,8 +279,9 @@ struct MultiCNConstTensorCache : CompNodeDepedentObject { | |||||
| } | } | ||||
| static MultiCNConstTensorCache& inst() { | static MultiCNConstTensorCache& inst() { | ||||
| static MultiCNConstTensorCache sl_inst; | |||||
| return sl_inst; | |||||
| static auto* sl_inst = | |||||
| ResourceManager::create_global<MultiCNConstTensorCache>(); | |||||
| return *sl_inst; | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -0,0 +1,87 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/resource_manager.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <any> | |||||
| #include <functional> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <vector> | |||||
| #include "megbrain/common.h" | |||||
| #include "megbrain/utils/metahelper.h" | |||||
| #include "megbrain/utils/thread.h" | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| class ResourceManager : public NonCopyableObj { | |||||
| protected: | |||||
| std::vector<std::any> m_handles; | |||||
| std::mutex m_mutex; | |||||
| private: | |||||
| static ResourceManager& get_global(); | |||||
| static ResourceManager& get_local(); | |||||
| public: | |||||
| template <typename T, typename... TArgs> | |||||
| static T* create_global(TArgs&&... args) { | |||||
| mgb_log_debug("create global resource: %s", typeid(T).name()); | |||||
| auto instance = std::make_shared<T>(std::forward<TArgs&&>(args)...); | |||||
| auto& manager = get_global(); | |||||
| MGB_LOCK_GUARD(manager.m_mutex); | |||||
| manager.m_handles.push_back((std::any)instance); | |||||
| return instance.get(); | |||||
| } | |||||
| template <typename T, typename... TArgs> | |||||
| static T* create_local(TArgs&&... args) { | |||||
| mgb_log_debug("create local resource: %s", typeid(T).name()); | |||||
| auto instance = std::make_shared<T>(std::forward<TArgs&&>(args)...); | |||||
| get_local().m_handles.push_back((std::any)instance); | |||||
| return instance.get(); | |||||
| } | |||||
| void clear(); | |||||
| ~ResourceManager() { clear(); } | |||||
| }; | |||||
| template <typename T> | |||||
| class CompNodeDependentResource : public NonCopyableObj { | |||||
| private: | |||||
| std::function<std::unique_ptr<T>()> m_ctor; | |||||
| std::unique_ptr<T> m_ptr; | |||||
| Spinlock m_spin; | |||||
| public: | |||||
| explicit CompNodeDependentResource(std::function<std::unique_ptr<T>()> ctor) | |||||
| : m_ctor(ctor) {} | |||||
| T& operator*() { | |||||
| if ((!m_ptr) || m_ptr->is_finalized()) { | |||||
| m_ptr = m_ctor(); | |||||
| } | |||||
| return *m_ptr; | |||||
| } | |||||
| T* operator->() { | |||||
| if ((!m_ptr) || m_ptr->is_finalized()) { | |||||
| m_ptr = m_ctor(); | |||||
| } | |||||
| return m_ptr.get(); | |||||
| } | |||||
| }; | |||||
| } // namespace imperative | |||||
| } // namespace mgb | |||||
| @@ -63,10 +63,10 @@ public: | |||||
| using Channel = Interpreter::Channel; | using Channel = Interpreter::Channel; | ||||
| private: | private: | ||||
| std::unique_ptr<Channel> m_channel; | |||||
| std::shared_ptr<Channel> m_channel; | |||||
| public: | public: | ||||
| explicit InterpreterTransformation(std::unique_ptr<Channel> channel) | |||||
| explicit InterpreterTransformation(std::shared_ptr<Channel> channel) | |||||
| : m_channel{std::move(channel)} {} | : m_channel{std::move(channel)} {} | ||||
| Channel* channel() { return m_channel.get(); } | Channel* channel() { return m_channel.get(); } | ||||