| @@ -77,6 +77,7 @@ Status IntrpService::Deregister(const std::string &name) noexcept { | |||||
| } | } | ||||
| Status IntrpService::InterruptAll() noexcept { | Status IntrpService::InterruptAll() noexcept { | ||||
| std::lock_guard<std::mutex> lck(mutex_); | |||||
| Status rc; | Status rc; | ||||
| for (auto const &it : all_intrp_resources_) { | for (auto const &it : all_intrp_resources_) { | ||||
| std::string kName = it.first; | std::string kName = it.first; | ||||
| @@ -25,8 +25,9 @@ thread_local Task *gMyTask = nullptr; | |||||
| void Task::operator()() { | void Task::operator()() { | ||||
| gMyTask = this; | gMyTask = this; | ||||
| id_ = this_thread::get_id(); | |||||
| std::stringstream ss; | std::stringstream ss; | ||||
| ss << this_thread::get_id(); | |||||
| ss << id_; | |||||
| MS_LOG(INFO) << my_name_ << " Thread ID " << ss.str() << " Started."; | MS_LOG(INFO) << my_name_ << " Thread ID " << ss.str() << " Started."; | ||||
| try { | try { | ||||
| // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set | // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set | ||||
| @@ -97,8 +98,7 @@ Status Task::Run() { | |||||
| Status rc; | Status rc; | ||||
| if (running_ == false) { | if (running_ == false) { | ||||
| try { | try { | ||||
| thrd_ = std::thread(std::ref(*this)); | |||||
| id_ = thrd_.get_id(); | |||||
| thrd_ = std::async(std::launch::async, std::ref(*this)); | |||||
| running_ = true; | running_ = true; | ||||
| caught_severe_exception_ = false; | caught_severe_exception_ = false; | ||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| @@ -110,16 +110,25 @@ Status Task::Run() { | |||||
| Status Task::Join() { | Status Task::Join() { | ||||
| if (running_) { | if (running_) { | ||||
| RETURN_UNEXPECTED_IF_NULL(MyTaskGroup()); | |||||
| auto interrupt_svc = MyTaskGroup()->GetIntrpService(); | |||||
| try { | try { | ||||
| thrd_.join(); | |||||
| // There is a race condition in the global resource tracking such that a thread can miss the | |||||
| // interrupt and becomes blocked on a conditional variable forever. As a result, calling | |||||
| // join() will not come back. We need some timeout version of join such that if the thread | |||||
| // doesn't come back in a reasonable of time, we will send the interrupt again. | |||||
| while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { | |||||
| // We can't tell which conditional_variable this thread is waiting on. So we may need | |||||
| // to interrupt everything one more time. | |||||
| MS_LOG(DEBUG) << "Some threads not responding. Interrupt again"; | |||||
| interrupt_svc->InterruptAll(); | |||||
| } | |||||
| std::stringstream ss; | std::stringstream ss; | ||||
| ss << get_id(); | ss << get_id(); | ||||
| MS_LOG(INFO) << MyName() << " Thread ID " << ss.str() << " Stopped."; | MS_LOG(INFO) << MyName() << " Thread ID " << ss.str() << " Stopped."; | ||||
| running_ = false; | running_ = false; | ||||
| RETURN_IF_NOT_OK(wp_.Deregister()); | RETURN_IF_NOT_OK(wp_.Deregister()); | ||||
| if (MyTaskGroup()) { | |||||
| RETURN_IF_NOT_OK(MyTaskGroup()->GetIntrpService()->Deregister(ss.str())); | |||||
| } | |||||
| RETURN_IF_NOT_OK(interrupt_svc->Deregister(ss.str())); | |||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| RETURN_STATUS_UNEXPECTED(e.what()); | RETURN_STATUS_UNEXPECTED(e.what()); | ||||
| } | } | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <chrono> | #include <chrono> | ||||
| #include <exception> | #include <exception> | ||||
| #include <functional> | #include <functional> | ||||
| #include <future> | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <memory> | #include <memory> | ||||
| #include <mutex> | #include <mutex> | ||||
| @@ -106,7 +107,7 @@ class Task : public IntrpResource { | |||||
| std::function<Status()> fnc_obj_; | std::function<Status()> fnc_obj_; | ||||
| // Misc fields used by TaskManager. | // Misc fields used by TaskManager. | ||||
| TaskGroup *task_group_; | TaskGroup *task_group_; | ||||
| std::thread thrd_; | |||||
| std::future<void> thrd_; | |||||
| std::thread::id id_; | std::thread::id id_; | ||||
| bool is_master_; | bool is_master_; | ||||
| volatile bool running_; | volatile bool running_; | ||||
| @@ -116,7 +116,7 @@ TaskManager::TaskManager() try : global_interrupt_(0), | |||||
| TaskManager::~TaskManager() { | TaskManager::~TaskManager() { | ||||
| if (watchdog_) { | if (watchdog_) { | ||||
| WakeUpWatchDog(); | WakeUpWatchDog(); | ||||
| watchdog_->thrd_.join(); | |||||
| watchdog_->Join(); | |||||
| // watchdog_grp_ and watchdog_ pointers come from Services::GetInstance().GetServiceMemPool() which we will free it | // watchdog_grp_ and watchdog_ pointers come from Services::GetInstance().GetServiceMemPool() which we will free it | ||||
| // on shutdown. So no need to free these pointers one by one. | // on shutdown. So no need to free these pointers one by one. | ||||
| watchdog_grp_ = nullptr; | watchdog_grp_ = nullptr; | ||||
| @@ -19,35 +19,19 @@ | |||||
| #include "dataset/util/task_manager.h" | #include "dataset/util/task_manager.h" | ||||
| using namespace mindspore::dataset; | using namespace mindspore::dataset; | ||||
| using namespace std::placeholders; | |||||
| class MindDataTestTaskManager : public UT::Common { | class MindDataTestTaskManager : public UT::Common { | ||||
| public: | public: | ||||
| MindDataTestTaskManager() {} | |||||
| MindDataTestTaskManager() {} | |||||
| void SetUp() { Services::CreateInstance(); | |||||
| } | |||||
| void SetUp() { Services::CreateInstance(); } | |||||
| }; | }; | ||||
| std::atomic<int> v(0); | |||||
| Status f(TaskGroup &vg){ | |||||
| for (int i = 0; i < 1; i++) { | |||||
| RETURN_IF_NOT_OK(vg.CreateAsyncTask("Infinity", [&]() -> Status { | |||||
| TaskManager::FindMe()->Post(); | |||||
| int a = v.fetch_add(1); | |||||
| MS_LOG(DEBUG) << a << std::endl; | |||||
| return f(vg); | |||||
| })); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| TEST_F(MindDataTestTaskManager, Test1) { | TEST_F(MindDataTestTaskManager, Test1) { | ||||
| // Clear the rc of the master thread if any | // Clear the rc of the master thread if any | ||||
| (void) TaskManager::GetMasterThreadRc(); | |||||
| (void)TaskManager::GetMasterThreadRc(); | |||||
| TaskGroup vg; | TaskGroup vg; | ||||
| Status vg_rc = vg.CreateAsyncTask("Test error", [this]() -> Status { | |||||
| Status vg_rc = vg.CreateAsyncTask("Test error", []() -> Status { | |||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| throw std::bad_alloc(); | throw std::bad_alloc(); | ||||
| }); | }); | ||||
| @@ -55,6 +39,46 @@ TEST_F(MindDataTestTaskManager, Test1) { | |||||
| ASSERT_TRUE(vg.join_all().IsOk()); | ASSERT_TRUE(vg.join_all().IsOk()); | ||||
| ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOutofMemory()); | ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOutofMemory()); | ||||
| // Test the error is passed back to the master thread. | // Test the error is passed back to the master thread. | ||||
| Status rc = TaskManager::GetMasterThreadRc(); | |||||
| ASSERT_TRUE(rc.IsOutofMemory()); | |||||
| // Some compiler may choose to run the next line in parallel with the above 3 lines | |||||
| // and this will cause some mismatch once a while. | |||||
| // To block this racing condition, we need to create a dependency that the next line | |||||
| // depends on previous lines. | |||||
| if (vg.GetTaskErrorIfAny().IsError()) { | |||||
| Status rc = TaskManager::GetMasterThreadRc(); | |||||
| ASSERT_TRUE(rc.IsOutofMemory()); | |||||
| } | |||||
| } | |||||
| TEST_F(MindDataTestTaskManager, Test2) { | |||||
| // This testcase will spawn about 10 threads and block on a conditional variable. | |||||
| // The master thread will try to interrupt them almost at the same time. This can | |||||
| // cause a racing condition that some threads may miss the interrupt and blocked. | |||||
| // The new logic of Task::Join() will do a time-out join and wake up all those | |||||
| // threads that miss the interrupt. | |||||
| // Clear the rc of the master thread if any | |||||
| (void)TaskManager::GetMasterThreadRc(); | |||||
| TaskGroup vg; | |||||
| CondVar cv; | |||||
| Status rc; | |||||
| rc = cv.Register(vg.GetIntrpService()); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| auto block_forever = [&cv]() -> Status { | |||||
| std::mutex mux; | |||||
| std::unique_lock<std::mutex> lck(mux); | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::this_thread::sleep_for(std::chrono::milliseconds(1)); | |||||
| RETURN_IF_NOT_OK(cv.Wait(&lck, []() -> bool { return false; })); | |||||
| return Status::OK(); | |||||
| }; | |||||
| auto f = [&vg, &block_forever]() -> Status { | |||||
| for (auto i = 0; i < 10; ++i) { | |||||
| RETURN_IF_NOT_OK(vg.CreateAsyncTask("Spawn block threads", block_forever)); | |||||
| } | |||||
| return Status::OK(); | |||||
| }; | |||||
| rc = f(); | |||||
| vg.interrupt_all(); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| // Now we test the async Join | |||||
| ASSERT_TRUE(vg.join_all().IsOk()); | |||||
| } | } | ||||