GitOrigin-RevId: 5dddc68a84
tags/v1.3.0
| @@ -420,7 +420,7 @@ CompNode::Impl* AtlasCompNode::load_atlas(const Locator& locator, | |||||
| for (int i = 0; i < sd.nr_node; ++i) { | for (int i = 0; i < sd.nr_node; ++i) { | ||||
| auto&& cur = sd.node[i]; | auto&& cur = sd.node[i]; | ||||
| if (cur.m_initialized) { | if (cur.m_initialized) { | ||||
| if (cur.m_locator_logical == locator_logical) { | |||||
| if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||||
| return &cur; | return &cur; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -604,7 +604,7 @@ CompNode::Impl* CambriconCompNode::load_cambricon( | |||||
| for (int i = 0; i < sd.nr_node; ++i) { | for (int i = 0; i < sd.nr_node; ++i) { | ||||
| auto&& cur = sd.node[i]; | auto&& cur = sd.node[i]; | ||||
| if (cur.m_initialized) { | if (cur.m_initialized) { | ||||
| if (cur.m_locator_logical == locator_logical) { | |||||
| if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||||
| return &cur; | return &cur; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -250,6 +250,10 @@ void CompNode::Locator::set_device_map(DeviceType type, int from, int to) { | |||||
| void CompNode::Locator::set_unspec_device_type(DeviceType type) { | void CompNode::Locator::set_unspec_device_type(DeviceType type) { | ||||
| mgb_assert(type != DeviceType::UNSPEC); | mgb_assert(type != DeviceType::UNSPEC); | ||||
| if (type != DeviceType::CPU && type != DeviceType::CUDA) { | |||||
| mgb_log_warn("to resolve unspec device type as one except " | |||||
| "CUDA and CPU may lead to unknown problems."); | |||||
| } | |||||
| g_unspec_locator_type = type; | g_unspec_locator_type = type; | ||||
| } | } | ||||
| @@ -723,12 +723,13 @@ struct CpuCompNode::Pool { | |||||
| impl_storage[MAX_NR_COMP_NODE]; | impl_storage[MAX_NR_COMP_NODE]; | ||||
| size_t nr_used_impl_storage = 0; | size_t nr_used_impl_storage = 0; | ||||
| ThinHashMap<std::pair<int, int>, | |||||
| std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>> logical2impl; | |||||
| std::unordered_map<CompNode::LocatorPairHashKey, | |||||
| std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>, | |||||
| CompNode::LocatorPairHashKey::Hash> locator2impl; | |||||
| ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> physical2queue; | ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> physical2queue; | ||||
| ThinHashMap<std::pair<int, int>, | |||||
| std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>> | |||||
| logical2impl_multi_thread; | |||||
| std::unordered_map<CompNode::LocatorPairHashKey, | |||||
| std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>, | |||||
| CompNode::LocatorPairHashKey::Hash> locator2impl_multi_thread; | |||||
| ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> | ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> | ||||
| physical2queue_multithead; | physical2queue_multithead; | ||||
| }; | }; | ||||
| @@ -792,14 +793,9 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, | |||||
| MGB_LOCK_GUARD(sm_pool->mtx); | MGB_LOCK_GUARD(sm_pool->mtx); | ||||
| // encode both device ID and type into a int | // encode both device ID and type into a int | ||||
| int compact_logical_device = locator_logical.device; | |||||
| mgb_assert(compact_logical_device >= -1 || | |||||
| compact_logical_device <= Locator::DEVICE_CPU_DEFAULT); | |||||
| if (locator_logical.type == CompNode::DeviceType::UNSPEC) { | |||||
| compact_logical_device += std::numeric_limits<int>::min() + 1; | |||||
| mgb_assert(compact_logical_device < | |||||
| Locator::DEVICE_MULTITHREAD_DEFAULT); | |||||
| } else { | |||||
| mgb_assert(locator_logical.device >= -1 || | |||||
| locator_logical.device <= Locator::DEVICE_CPU_DEFAULT); | |||||
| if (locator_logical.type != CompNode::DeviceType::UNSPEC) { | |||||
| mgb_assert(locator_logical.type == CompNode::DeviceType::CPU || | mgb_assert(locator_logical.type == CompNode::DeviceType::CPU || | ||||
| locator_logical.type == CompNode::DeviceType::MULTITHREAD); | locator_logical.type == CompNode::DeviceType::MULTITHREAD); | ||||
| } | } | ||||
| @@ -811,8 +807,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, | |||||
| pqueue = std::make_shared<WorkerQueue>(locator); | pqueue = std::make_shared<WorkerQueue>(locator); | ||||
| pqueue_weak = pqueue; | pqueue_weak = pqueue; | ||||
| } | } | ||||
| auto&& pimpl = sm_pool->logical2impl[{compact_logical_device, | |||||
| locator_logical.stream}]; | |||||
| auto&& pimpl = sm_pool->locator2impl[{locator, | |||||
| locator_logical}]; | |||||
| if (!pimpl) { | if (!pimpl) { | ||||
| mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | ||||
| "too many cpu comp nodes; max %d allowed", | "too many cpu comp nodes; max %d allowed", | ||||
| @@ -833,8 +829,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, | |||||
| pqueue = std::make_shared<WorkerQueue>(locator); | pqueue = std::make_shared<WorkerQueue>(locator); | ||||
| pqueue_weak = pqueue; | pqueue_weak = pqueue; | ||||
| } | } | ||||
| auto&& pimpl = sm_pool->logical2impl_multi_thread[{ | |||||
| compact_logical_device, locator_logical.nr_threads}]; | |||||
| auto&& pimpl = sm_pool->locator2impl_multi_thread[{ | |||||
| locator, locator_logical}]; | |||||
| if (!pimpl) { | if (!pimpl) { | ||||
| mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | ||||
| "too many cpu multithread comp nodes; max %d allowed", | "too many cpu multithread comp nodes; max %d allowed", | ||||
| @@ -854,9 +850,9 @@ void CpuCompNode::sync_all() { | |||||
| return; | return; | ||||
| MGB_LOCK_GUARD(sm_pool->mtx); | MGB_LOCK_GUARD(sm_pool->mtx); | ||||
| for (auto &&i: sm_pool->logical2impl) | |||||
| for (auto &&i: sm_pool->locator2impl) | |||||
| i.second->sync(); | i.second->sync(); | ||||
| for (auto&& i : sm_pool->logical2impl_multi_thread) | |||||
| for (auto&& i : sm_pool->locator2impl_multi_thread) | |||||
| i.second->sync(); | i.second->sync(); | ||||
| } | } | ||||
| @@ -718,7 +718,7 @@ CompNode::Impl* CudaCompNode::load_cuda( | |||||
| for (int i = 0; i < sd.nr_node; ++ i) { | for (int i = 0; i < sd.nr_node; ++ i) { | ||||
| auto &&cur = sd.node[i]; | auto &&cur = sd.node[i]; | ||||
| if (cur.m_initialized) { | if (cur.m_initialized) { | ||||
| if (cur.m_locator_logical == locator_logical) { | |||||
| if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||||
| return &cur; | return &cur; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -606,7 +606,7 @@ CompNode::Impl* ROCmCompNode::load_rocm(const Locator& locator, | |||||
| for (int i = 0; i < sd.nr_node; ++i) { | for (int i = 0; i < sd.nr_node; ++i) { | ||||
| auto&& cur = sd.node[i]; | auto&& cur = sd.node[i]; | ||||
| if (cur.m_initialized) { | if (cur.m_initialized) { | ||||
| if (cur.m_locator_logical == locator_logical) { | |||||
| if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||||
| return &cur; | return &cur; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -168,6 +168,22 @@ class CompNode { | |||||
| return type == rhs.type && device == rhs.device && | return type == rhs.type && device == rhs.device && | ||||
| stream == rhs.stream; | stream == rhs.stream; | ||||
| } | } | ||||
| }; | |||||
| struct LocatorPairHashKey { | |||||
| Locator locator, locator_logical; | |||||
| bool operator==(const LocatorPairHashKey& rhs) const { | |||||
| return locator == rhs.locator && locator_logical == rhs.locator_logical; | |||||
| } | |||||
| struct Hash { | |||||
| size_t operator()(const LocatorPairHashKey& k) const { | |||||
| return hash_pair_combine(mgb::hash(k.locator), | |||||
| mgb::hash(k.locator_logical)); | |||||
| } | |||||
| }; | |||||
| }; | }; | ||||
| //! predefined special streams | //! predefined special streams | ||||
| @@ -537,6 +553,7 @@ class CompNode { | |||||
| friend class CompNodeEnv; | friend class CompNodeEnv; | ||||
| friend struct HashTrait<CompNode>; | friend struct HashTrait<CompNode>; | ||||
| friend struct HashTrait<CompNode::Locator>; | |||||
| friend class CompNodeImplHelper; | friend class CompNodeImplHelper; | ||||
| public: | public: | ||||
| CompNode(ImplBase* impl) : m_impl{impl} {} | CompNode(ImplBase* impl) : m_impl{impl} {} | ||||
| @@ -686,6 +703,15 @@ struct HashTrait<CompNode> { | |||||
| } | } | ||||
| }; | }; | ||||
| template<> | |||||
| struct HashTrait<CompNode::Locator> { | |||||
| static size_t eval(const CompNode::Locator &val) { | |||||
| return static_cast<size_t>(val.device) | |||||
| + (static_cast<size_t>(val.type) << 4) | |||||
| + (static_cast<size_t>(val.stream) << 8); | |||||
| } | |||||
| }; | |||||
| namespace comp_node_detail { | namespace comp_node_detail { | ||||
| /*! | /*! | ||||
| @@ -86,19 +86,34 @@ TEST(TestCompNode, SetDefaultDev) { | |||||
| CompNode::finalize(); | CompNode::finalize(); | ||||
| using L = CompNode::Locator; | using L = CompNode::Locator; | ||||
| auto orig_dt = L::parse("xpu").to_physical(), | auto orig_dt = L::parse("xpu").to_physical(), | ||||
| orig_gpu = L::parse("gpux").to_physical(); | |||||
| orig_gpu = L::parse("gpux").to_physical(), | |||||
| orig_cpu = L::parse("cpux").to_physical(); | |||||
| constexpr auto CUDA = CompNode::DeviceType::CUDA; | constexpr auto CUDA = CompNode::DeviceType::CUDA; | ||||
| constexpr auto CPU = CompNode::DeviceType::CPU; | |||||
| L::set_unspec_device_type(CUDA); | L::set_unspec_device_type(CUDA); | ||||
| L::set_device_map(CUDA, -1, 2); | |||||
| auto run = []() { | |||||
| ASSERT_EQ(CompNode::load("xpu").locator(), L::parse("gpu2")); | |||||
| auto run = [](int device) { | |||||
| ASSERT_EQ(CompNode::load("xpu").locator(), | |||||
| L::parse("gpu" + std::to_string(device))); | |||||
| }; | |||||
| auto run_cpu = [](int device) { | |||||
| ASSERT_EQ(CompNode::load("cpux").locator(), | |||||
| L::parse("cpu" + std::to_string(device))); | |||||
| }; | }; | ||||
| MGB_TRY { | MGB_TRY { | ||||
| run(); | |||||
| L::set_device_map(CUDA, -1, 2); | |||||
| run(2); | |||||
| L::set_device_map(CUDA, -1, 1); | |||||
| run(1); | |||||
| L::set_device_map(CPU, -1, 2); | |||||
| run_cpu(2); | |||||
| L::set_device_map(CPU, -1, 1); | |||||
| run_cpu(1); | |||||
| } MGB_FINALLY({ | } MGB_FINALLY({ | ||||
| L::set_unspec_device_type(orig_dt.type); | L::set_unspec_device_type(orig_dt.type); | ||||
| L::set_device_map(CUDA, -1, orig_gpu.device); | L::set_device_map(CUDA, -1, orig_gpu.device); | ||||
| L::set_device_map(CPU, -1, orig_cpu.device); | |||||
| }); | }); | ||||
| CompNode::finalize(); | CompNode::finalize(); | ||||
| } | } | ||||