GitOrigin-RevId: 47373d291d
tags/v0.4.0
| @@ -38,9 +38,12 @@ def set_default_device(device: str = "xpux"): | |||||
| :param device: default device type. The type can be 'cpu0', 'cpu1', etc., | :param device: default device type. The type can be 'cpu0', 'cpu1', etc., | ||||
| or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use. | or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use. | ||||
| To specify multiple devices, use cpu0:1 or gpu0:2. | |||||
| 'cpux' and 'gupx' can also be used to specify any number of cpu or gpu devices. | 'cpux' and 'gupx' can also be used to specify any number of cpu or gpu devices. | ||||
| 'multithread' device type is avaliable when inference, which implements | |||||
| multi-threading parallelism at the operator level. For example, | |||||
| 'multithread4' will compute with 4 threads. which implements | |||||
| The default value is 'xpux' to specify any device available. | The default value is 'xpux' to specify any device available. | ||||
| It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | ||||
| @@ -603,11 +603,11 @@ Args Args::from_argv(int argc, char **argv) { | |||||
| ++ i; | ++ i; | ||||
| ret.multithread_number = std::stoi(argv[i]); | ret.multithread_number = std::stoi(argv[i]); | ||||
| ret.load_config.comp_node_mapper = | ret.load_config.comp_node_mapper = | ||||
| [nr_thread = | |||||
| [nr_threads = | |||||
| ret.multithread_number](CompNode::Locator& loc) { | ret.multithread_number](CompNode::Locator& loc) { | ||||
| loc.type = CompNode::DeviceType::MULTITHREAD; | loc.type = CompNode::DeviceType::MULTITHREAD; | ||||
| loc.device = 0; | loc.device = 0; | ||||
| loc.stream = nr_thread; | |||||
| loc.nr_threads = nr_threads; | |||||
| }; | }; | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -615,11 +615,12 @@ Args Args::from_argv(int argc, char **argv) { | |||||
| mgb_log_warn("use multithread:default mode"); | mgb_log_warn("use multithread:default mode"); | ||||
| ++i; | ++i; | ||||
| ret.multithread_number = std::stoi(argv[i]); | ret.multithread_number = std::stoi(argv[i]); | ||||
| ret.load_config.comp_node_mapper = [nr_thread = | |||||
| ret.multithread_number](CompNode::Locator& loc) { | |||||
| ret.load_config.comp_node_mapper = [nr_threads = | |||||
| ret.multithread_number]( | |||||
| CompNode::Locator& loc) { | |||||
| loc.type = CompNode::DeviceType::MULTITHREAD; | loc.type = CompNode::DeviceType::MULTITHREAD; | ||||
| loc.device = CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | loc.device = CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | ||||
| loc.stream = nr_thread; | |||||
| loc.nr_threads = nr_threads; | |||||
| }; | }; | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -127,13 +127,19 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { | |||||
| // current parsing location | // current parsing location | ||||
| const char *ptr = id.data(); | const char *ptr = id.data(); | ||||
| if (id == "cpu:default") { | if (id == "cpu:default") { | ||||
| return {DeviceType::CPU, DEVICE_CPU_DEFAULT, 0}; | |||||
| return {DeviceType::CPU, DEVICE_CPU_DEFAULT, {0}}; | |||||
| } | } | ||||
| if (!strncmp(ptr, "multithread:default", 19)) { | if (!strncmp(ptr, "multithread:default", 19)) { | ||||
| //! the multithread default compnode string like "multithread:default:x" | //! the multithread default compnode string like "multithread:default:x" | ||||
| ptr += 20; | |||||
| int nr_thread =std::stoi(ptr); | |||||
| return {DeviceType::MULTITHREAD, DEVICE_MULTITHREAD_DEFAULT, nr_thread}; | |||||
| if (id.size() > 20) { | |||||
| ptr += 20; | |||||
| int nr_thread = std::stoi(ptr); | |||||
| return {DeviceType::MULTITHREAD, | |||||
| DEVICE_MULTITHREAD_DEFAULT, | |||||
| {nr_thread}}; | |||||
| } else { | |||||
| err(); | |||||
| } | |||||
| } | } | ||||
| DeviceType dev_type; | DeviceType dev_type; | ||||
| @@ -192,8 +198,16 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { | |||||
| int num_stream = parse_int(); | int num_stream = parse_int(); | ||||
| if (*ptr) | if (*ptr) | ||||
| err(); | err(); | ||||
| //! multi thread with thread number(num_stream) being zero is illegal | |||||
| if (dev_type == DeviceType::MULTITHREAD) { | |||||
| if (num_dev == 0) { | |||||
| err(); | |||||
| } | |||||
| //! num_steam store the nr_thread | |||||
| std::swap(num_dev, num_stream); | |||||
| } | |||||
| return {dev_type, num_dev, num_stream}; | |||||
| return {dev_type, num_dev, {num_stream}}; | |||||
| } | } | ||||
| void CompNode::Locator::set_device_map(DeviceType type, int from, int to) { | void CompNode::Locator::set_device_map(DeviceType type, int from, int to) { | ||||
| @@ -242,16 +256,22 @@ CompNode::Locator CompNode::Locator::to_physical() const { | |||||
| stream_physical = 1023; | stream_physical = 1023; | ||||
| } | } | ||||
| } | } | ||||
| return {type_physical, device_physical, stream_physical}; | |||||
| return {type_physical, device_physical, {stream_physical}}; | |||||
| } | } | ||||
| std::string CompNode::Locator::to_string() const { | std::string CompNode::Locator::to_string() const { | ||||
| if (device == DEVICE_CPU_DEFAULT) { | if (device == DEVICE_CPU_DEFAULT) { | ||||
| return "cpu:default"; | return "cpu:default"; | ||||
| } else if (device == DEVICE_MULTITHREAD_DEFAULT) { | } else if (device == DEVICE_MULTITHREAD_DEFAULT) { | ||||
| std::string ret="multithread:default:"; | |||||
| std::string ret = "multithread:default:"; | |||||
| ret.append(get_stream_str(stream)); | ret.append(get_stream_str(stream)); | ||||
| return ret; | return ret; | ||||
| } else if (type == DeviceType::MULTITHREAD) { | |||||
| std::string ret("multithread"); | |||||
| ret.append(get_stream_str(stream)) | |||||
| .append(":") | |||||
| .append(get_stream_str(device)); | |||||
| return ret; | |||||
| } | } | ||||
| char numstr[32]; | char numstr[32]; | ||||
| if (device == -1) { | if (device == -1) { | ||||
| @@ -380,9 +380,9 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { | |||||
| m_locator_logical(locator_logical) { | m_locator_logical(locator_logical) { | ||||
| auto cn = make_comp_node_from_impl(this); | auto cn = make_comp_node_from_impl(this); | ||||
| if (locator.type == DeviceType::MULTITHREAD) { | if (locator.type == DeviceType::MULTITHREAD) { | ||||
| //! When multi-thread the stream stand for thread number | |||||
| m_thread_pool = std::unique_ptr<ThreadPool>( | |||||
| new ThreadPool(static_cast<size_t>(locator.stream))); | |||||
| m_thread_pool = std::unique_ptr<ThreadPool>(new ThreadPool( | |||||
| static_cast<size_t>(locator.nr_threads))); | |||||
| mgb_assert(m_thread_pool, "ThradPool create failed"); | |||||
| } | } | ||||
| if (locator.type == DeviceType::CPU) { | if (locator.type == DeviceType::CPU) { | ||||
| @@ -398,7 +398,6 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { | |||||
| cn); | cn); | ||||
| } | } | ||||
| } else if (locator.type == DeviceType::MULTITHREAD) { | } else if (locator.type == DeviceType::MULTITHREAD) { | ||||
| mgb_assert(m_thread_pool, "ThradPool create failed"); | |||||
| if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { | if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { | ||||
| m_env.init_cpu( | m_env.init_cpu( | ||||
| {std::make_shared<InplaceCPUDispatcher>( | {std::make_shared<InplaceCPUDispatcher>( | ||||
| @@ -745,15 +744,14 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, | |||||
| } else { | } else { | ||||
| mgb_assert(locator.type == DeviceType::MULTITHREAD); | mgb_assert(locator.type == DeviceType::MULTITHREAD); | ||||
| auto&& pqueue_weak = sm_pool->physical2queue_multithead[{ | auto&& pqueue_weak = sm_pool->physical2queue_multithead[{ | ||||
| locator.device, locator.stream}]; | |||||
| locator.device, locator.nr_threads}]; | |||||
| auto pqueue = pqueue_weak.lock(); | auto pqueue = pqueue_weak.lock(); | ||||
| if (!pqueue) { | if (!pqueue) { | ||||
| pqueue = std::make_shared<WorkerQueue>(locator); | pqueue = std::make_shared<WorkerQueue>(locator); | ||||
| pqueue_weak = pqueue; | pqueue_weak = pqueue; | ||||
| } | } | ||||
| auto&& pimpl = sm_pool->logical2impl_multi_thread[{ | auto&& pimpl = sm_pool->logical2impl_multi_thread[{ | ||||
| static_cast<int>(compact_logical_device), | |||||
| locator_logical.stream}]; | |||||
| compact_logical_device, locator_logical.nr_threads}]; | |||||
| if (!pimpl) { | if (!pimpl) { | ||||
| mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | ||||
| "too many cpu multithread comp nodes; max %d allowed", | "too many cpu multithread comp nodes; max %d allowed", | ||||
| @@ -153,8 +153,12 @@ class CompNode { | |||||
| int device = -1; | int device = -1; | ||||
| //! multiple streams can execute on one computing device and share | //! multiple streams can execute on one computing device and share | ||||
| //! memory | |||||
| int stream = 0; | |||||
| //! memory, when compnode type is multithread the field also stand | |||||
| //! for nr_threads | |||||
| union { | |||||
| int stream = 0; | |||||
| int nr_threads; | |||||
| }; | |||||
| /*! | /*! | ||||
| * \brief parse a string identifier | * \brief parse a string identifier | ||||
| @@ -162,7 +166,7 @@ class CompNode { | |||||
| * currently supported ID format: (gpu|cpu)<n>[:m] where n is the | * currently supported ID format: (gpu|cpu)<n>[:m] where n is the | ||||
| * device number, possibly with m as the stream id. | * device number, possibly with m as the stream id. | ||||
| */ | */ | ||||
| static Locator parse(const std::string &id); | |||||
| static Locator parse(const std::string& id); | |||||
| /*! | /*! | ||||
| * \brief set mapping between device numbers of a device type | * \brief set mapping between device numbers of a device type | ||||
| @@ -28,9 +28,7 @@ using namespace mgb; | |||||
| TEST(TestCompNode, Parse) { | TEST(TestCompNode, Parse) { | ||||
| using L = CompNode::Locator; | using L = CompNode::Locator; | ||||
| using D = CompNode::DeviceType; | using D = CompNode::DeviceType; | ||||
| auto make_lc = [](D t, int dev, int s) -> L { | |||||
| return {t, dev, s}; | |||||
| }; | |||||
| auto make_lc = [](D t, int dev, int s) -> L { return {t, dev, {s}}; }; | |||||
| ASSERT_EQ(L::parse("xpux"), make_lc(D::UNSPEC, -1, 0)); | ASSERT_EQ(L::parse("xpux"), make_lc(D::UNSPEC, -1, 0)); | ||||
| ASSERT_EQ(L::parse("xpux:23"), make_lc(D::UNSPEC, -1, 23)); | ASSERT_EQ(L::parse("xpux:23"), make_lc(D::UNSPEC, -1, 23)); | ||||
| @@ -47,10 +45,9 @@ TEST(TestCompNode, Parse) { | |||||
| ASSERT_EQ(L::parse("xpu23"), make_lc(D::UNSPEC, 23, 0)); | ASSERT_EQ(L::parse("xpu23"), make_lc(D::UNSPEC, 23, 0)); | ||||
| ASSERT_EQ(L::parse("xpu23:1"), make_lc(D::UNSPEC, 23, 1)); | ASSERT_EQ(L::parse("xpu23:1"), make_lc(D::UNSPEC, 23, 1)); | ||||
| ASSERT_EQ(L::parse("cpu:default"), | |||||
| make_lc(D::CPU, L::DEVICE_CPU_DEFAULT, 0)); | |||||
| ASSERT_EQ(L::parse("multithread0:2"), make_lc(D::MULTITHREAD, 0, 2)); | |||||
| ASSERT_EQ(L::parse("multithread1:3"), make_lc(D::MULTITHREAD, 1, 3)); | |||||
| ASSERT_EQ(L::parse("cpu:default"), make_lc(D::CPU, L::DEVICE_CPU_DEFAULT, 0)); | |||||
| ASSERT_EQ(L::parse("multithread2:0"), make_lc(D::MULTITHREAD, 0, 2)); | |||||
| ASSERT_EQ(L::parse("multithread1:3"), make_lc(D::MULTITHREAD, 3, 1)); | |||||
| ASSERT_EQ(L::parse("multithread:default:2"), | ASSERT_EQ(L::parse("multithread:default:2"), | ||||
| make_lc(D::MULTITHREAD, L::DEVICE_MULTITHREAD_DEFAULT, 2)); | make_lc(D::MULTITHREAD, L::DEVICE_MULTITHREAD_DEFAULT, 2)); | ||||
| @@ -65,6 +62,10 @@ TEST(TestCompNode, Parse) { | |||||
| ASSERT_THROW(L::parse("heaxgon0"), MegBrainError); | ASSERT_THROW(L::parse("heaxgon0"), MegBrainError); | ||||
| ASSERT_THROW(L::parse("rcom0"), MegBrainError); | ASSERT_THROW(L::parse("rcom0"), MegBrainError); | ||||
| ASSERT_THROW(L::parse("cmabricon0"), MegBrainError); | ASSERT_THROW(L::parse("cmabricon0"), MegBrainError); | ||||
| ASSERT_THROW(L::parse("multithread"), MegBrainError); | |||||
| ASSERT_THROW(L::parse("multithread1:"), MegBrainError); | |||||
| ASSERT_THROW(L::parse("multithread1:default"), MegBrainError); | |||||
| ASSERT_THROW(L::parse("multithread1:default:0"), MegBrainError); | |||||
| } | } | ||||
| TEST(TestCompNode, SetDefaultDev) { | TEST(TestCompNode, SetDefaultDev) { | ||||
| @@ -107,12 +108,12 @@ TEST(TestCompNode, Load) { | |||||
| #endif | #endif | ||||
| #if MGB_HAVE_THREAD | #if MGB_HAVE_THREAD | ||||
| auto cn_multi_thread0 = CompNode::load("multithread0:2"); | |||||
| auto cn_multi_thread1 = CompNode::load("multithread1:2"); | |||||
| ASSERT_EQ(CompNode::load("multithread0:2"), cn_multi_thread0); | |||||
| ASSERT_EQ(CompNode::load("multithread1:2"), cn_multi_thread1); | |||||
| ASSERT_NE(CompNode::load("multithread0:4"), cn_multi_thread0); | |||||
| ASSERT_NE(CompNode::load("multithread1:4"), cn_multi_thread1); | |||||
| auto cn_multi_thread0 = CompNode::load("multithread2:0"); | |||||
| auto cn_multi_thread1 = CompNode::load("multithread2:1"); | |||||
| ASSERT_EQ(CompNode::load("multithread2:0"), cn_multi_thread0); | |||||
| ASSERT_EQ(CompNode::load("multithread2:1"), cn_multi_thread1); | |||||
| ASSERT_NE(CompNode::load("multithread4:0"), cn_multi_thread0); | |||||
| ASSERT_NE(CompNode::load("multithread4:1"), cn_multi_thread1); | |||||
| auto cn_multi_default0 = CompNode::load("multithread:default:2"); | auto cn_multi_default0 = CompNode::load("multithread:default:2"); | ||||
| auto cn_multi_default1 = CompNode::load("multithread:default:4"); | auto cn_multi_default1 = CompNode::load("multithread:default:4"); | ||||
| @@ -139,7 +140,7 @@ TEST(TestCompNode, FreeAfterFinalize) { | |||||
| auto type = static_cast<CompNode::DeviceType>(i); | auto type = static_cast<CompNode::DeviceType>(i); | ||||
| if (!CompNode::get_device_count(type)) | if (!CompNode::get_device_count(type)) | ||||
| continue; | continue; | ||||
| auto cn = CompNode::load(CompNode::Locator{type}); | |||||
| auto cn = CompNode::load(CompNode::Locator{type, -1, {0}}); | |||||
| auto ptr = cn.alloc_device(123); | auto ptr = cn.alloc_device(123); | ||||
| CompNode::finalize(); | CompNode::finalize(); | ||||
| cn.free_device(ptr); | cn.free_device(ptr); | ||||
| @@ -190,13 +191,13 @@ TEST(TestCompNodeCPU, CoreAffinity) { | |||||
| size_t data0, data1 = 0; | size_t data0, data1 = 0; | ||||
| auto empty_task = []() {}; | auto empty_task = []() {}; | ||||
| auto cn0 = CompNode::load("cpu:default"), cn1 = CompNode::load("cpu0"), | auto cn0 = CompNode::load("cpu:default"), cn1 = CompNode::load("cpu0"), | ||||
| cn2 = CompNode::load("multithread0:2"); | |||||
| cn2 = CompNode::load("multithread2:0"); | |||||
| auto binding0 = [&](size_t) { data0 = 10; }; | auto binding0 = [&](size_t) { data0 = 10; }; | ||||
| CompNodeEnv::from_comp_node(cn0).cpu_env().set_affinity(binding0); | CompNodeEnv::from_comp_node(cn0).cpu_env().set_affinity(binding0); | ||||
| CompNodeEnv::from_comp_node(cn0).cpu_env().dispatch(empty_task); | CompNodeEnv::from_comp_node(cn0).cpu_env().dispatch(empty_task); | ||||
| cn0.sync(); | cn0.sync(); | ||||
| auto binding1 = [&](size_t) { data1 = 20; }; | |||||
| auto binding1 = [&](size_t ) { data1 = 20; }; | |||||
| CompNodeEnv::from_comp_node(cn1).cpu_env().set_affinity(binding1); | CompNodeEnv::from_comp_node(cn1).cpu_env().set_affinity(binding1); | ||||
| CompNodeEnv::from_comp_node(cn1).cpu_env().dispatch(empty_task); | CompNodeEnv::from_comp_node(cn1).cpu_env().dispatch(empty_task); | ||||
| cn1.sync(); | cn1.sync(); | ||||
| @@ -238,7 +239,7 @@ TEST(TestCompNode, CPU_MULTI_THREAD) { | |||||
| }; | }; | ||||
| for (auto&& str : std::vector<std::string>{ | for (auto&& str : std::vector<std::string>{ | ||||
| "multithread0:2", "multithread0:4", "multithread:default:4"}) { | |||||
| "multithread2:0", "multithread4:0", "multithread:default:4"}) { | |||||
| auto cn0 = CompNode::load("cpu0"), cn1 = CompNode::load(str); | auto cn0 = CompNode::load("cpu0"), cn1 = CompNode::load(str); | ||||
| std::thread wk_thread0{std::ref(worker), std::ref(dst0), std::ref(cn0)}; | std::thread wk_thread0{std::ref(worker), std::ref(dst0), std::ref(cn0)}; | ||||
| std::thread wk_thread1{std::ref(worker), std::ref(dst1), std::ref(cn1)}; | std::thread wk_thread1{std::ref(worker), std::ref(dst1), std::ref(cn1)}; | ||||
| @@ -271,9 +272,9 @@ TEST(TestCompNodeCPU, PhysicalDispatch) { | |||||
| L::set_device_map(DT, ID, 0); | L::set_device_map(DT, ID, 0); | ||||
| L::set_device_map(DT, ID + 1, 0); | L::set_device_map(DT, ID + 1, 0); | ||||
| L::set_device_map(DT, ID + 2, 1); | L::set_device_map(DT, ID + 2, 1); | ||||
| auto cn0 = CompNode::load({DT, ID, 0}), | |||||
| cn1 = CompNode::load({DT, ID + 1, 0}), | |||||
| cn2 = CompNode::load({DT, ID + 2, 0}); | |||||
| auto cn0 = CompNode::load({DT, ID, {0}}), | |||||
| cn1 = CompNode::load({DT, ID + 1, {0}}), | |||||
| cn2 = CompNode::load({DT, ID + 2, {0}}); | |||||
| #if MGB_HAVE_THREAD | #if MGB_HAVE_THREAD | ||||
| ASSERT_NE(cn0, cn1); | ASSERT_NE(cn0, cn1); | ||||
| #else | #else | ||||
| @@ -532,10 +533,10 @@ TEST(TestCompNode, MultipleLoad) { | |||||
| for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) { | for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) { | ||||
| auto dt = static_cast<CompNode::DeviceType>(i); | auto dt = static_cast<CompNode::DeviceType>(i); | ||||
| if (CompNode::get_device_count(dt)) { | if (CompNode::get_device_count(dt)) { | ||||
| auto cn = CompNode::load({dt}); | |||||
| auto cn = CompNode::load({dt, 0, {0}}); | |||||
| mgb_log("comp node %s is available", cn.to_string().c_str()); | mgb_log("comp node %s is available", cn.to_string().c_str()); | ||||
| run(cn); | run(cn); | ||||
| cn = CompNode::load({dt}); | |||||
| cn = CompNode::load({dt, 0, {0}}); | |||||
| run(cn); | run(cn); | ||||
| } | } | ||||
| } | } | ||||
| @@ -591,7 +592,7 @@ TYPED_TEST(TestCPUCompSeqRec, run_default_cpu) { | |||||
| comp_node_test::seq_rec::run<TypeParam>(CompNode::load("cpu:default")); | comp_node_test::seq_rec::run<TypeParam>(CompNode::load("cpu:default")); | ||||
| } | } | ||||
| TYPED_TEST(TestCPUCompSeqRec, run_multi_thread) { | TYPED_TEST(TestCPUCompSeqRec, run_multi_thread) { | ||||
| auto cn = CompNode::load("multithread0:4"); | |||||
| auto cn = CompNode::load("multithread4:0"); | |||||
| comp_node_test::seq_rec::run<TypeParam>(cn); | comp_node_test::seq_rec::run<TypeParam>(cn); | ||||
| } | } | ||||