| @@ -223,8 +223,10 @@ void test_multi_thread(bool multi_thread_compnode) { | |||||
| std::string model_path = "./shufflenet.mge"; | std::string model_path = "./shufflenet.mge"; | ||||
| size_t nr_threads = 2; | size_t nr_threads = 2; | ||||
| std::vector<std::thread::id> thread_ids(nr_threads); | |||||
| std::vector<size_t> thread_ids_user(nr_threads); | |||||
| std::vector<size_t> thread_ids_worker(nr_threads); | |||||
| auto runner = [&](size_t i) { | auto runner = [&](size_t i) { | ||||
| thread_ids_user[i] = std::hash<std::thread::id>{}(std::this_thread::get_id()); | |||||
| std::shared_ptr<Network> network = std::make_shared<Network>(config); | std::shared_ptr<Network> network = std::make_shared<Network>(config); | ||||
| Runtime::set_cpu_inplace_mode(network); | Runtime::set_cpu_inplace_mode(network); | ||||
| if (multi_thread_compnode) { | if (multi_thread_compnode) { | ||||
| @@ -232,11 +234,18 @@ void test_multi_thread(bool multi_thread_compnode) { | |||||
| } | } | ||||
| network->load_model(model_path); | network->load_model(model_path); | ||||
| Runtime::set_runtime_thread_affinity(network, [&thread_ids, i](int id) { | |||||
| if (id == 0) { | |||||
| thread_ids[i] = std::this_thread::get_id(); | |||||
| } | |||||
| }); | |||||
| Runtime::set_runtime_thread_affinity( | |||||
| network, [&multi_thread_compnode, &thread_ids_worker, i](int id) { | |||||
| if (multi_thread_compnode) { | |||||
| if (id == 1) { | |||||
| thread_ids_worker[i] = std::hash<std::thread::id>{}( | |||||
| std::this_thread::get_id()); | |||||
| } | |||||
| } else { | |||||
| thread_ids_worker[i] = std::hash<std::thread::id>{}( | |||||
| std::this_thread::get_id()); | |||||
| } | |||||
| }); | |||||
| std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0); | std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0); | ||||
| auto src_ptr = lite_tensor->get_memory_ptr(); | auto src_ptr = lite_tensor->get_memory_ptr(); | ||||
| @@ -250,11 +259,11 @@ void test_multi_thread(bool multi_thread_compnode) { | |||||
| std::vector<std::thread> threads; | std::vector<std::thread> threads; | ||||
| for (size_t i = 0; i < nr_threads; i++) { | for (size_t i = 0; i < nr_threads; i++) { | ||||
| threads.emplace_back(runner, i); | threads.emplace_back(runner, i); | ||||
| threads[i].join(); | |||||
| } | } | ||||
| for (size_t i = 0; i < nr_threads; i++) { | for (size_t i = 0; i < nr_threads; i++) { | ||||
| threads[i].join(); | |||||
| ASSERT_EQ(thread_ids_user[i], thread_ids_worker[i]); | |||||
| } | } | ||||
| ASSERT_NE(thread_ids[0], thread_ids[1]); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||