| @@ -77,6 +77,7 @@ Status IntrpService::Deregister(const std::string &name) noexcept { | |||
| } | |||
| Status IntrpService::InterruptAll() noexcept { | |||
| std::lock_guard<std::mutex> lck(mutex_); | |||
| Status rc; | |||
| for (auto const &it : all_intrp_resources_) { | |||
| std::string kName = it.first; | |||
| @@ -25,8 +25,9 @@ thread_local Task *gMyTask = nullptr; | |||
| void Task::operator()() { | |||
| gMyTask = this; | |||
| id_ = this_thread::get_id(); | |||
| std::stringstream ss; | |||
| ss << this_thread::get_id(); | |||
| ss << id_; | |||
| MS_LOG(INFO) << my_name_ << " Thread ID " << ss.str() << " Started."; | |||
| try { | |||
| // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set | |||
| @@ -97,8 +98,7 @@ Status Task::Run() { | |||
| Status rc; | |||
| if (running_ == false) { | |||
| try { | |||
| thrd_ = std::thread(std::ref(*this)); | |||
| id_ = thrd_.get_id(); | |||
| thrd_ = std::async(std::launch::async, std::ref(*this)); | |||
| running_ = true; | |||
| caught_severe_exception_ = false; | |||
| } catch (const std::exception &e) { | |||
| @@ -110,16 +110,25 @@ Status Task::Run() { | |||
| Status Task::Join() { | |||
| if (running_) { | |||
| RETURN_UNEXPECTED_IF_NULL(MyTaskGroup()); | |||
| auto interrupt_svc = MyTaskGroup()->GetIntrpService(); | |||
| try { | |||
| thrd_.join(); | |||
| // There is a race condition in the global resource tracking such that a thread can miss the | |||
| // interrupt and becomes blocked on a conditional variable forever. As a result, calling | |||
| // join() will not come back. We need some timeout version of join such that if the thread | |||
| // doesn't come back in a reasonable of time, we will send the interrupt again. | |||
| while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { | |||
| // We can't tell which conditional_variable this thread is waiting on. So we may need | |||
| // to interrupt everything one more time. | |||
| MS_LOG(DEBUG) << "Some threads not responding. Interrupt again"; | |||
| interrupt_svc->InterruptAll(); | |||
| } | |||
| std::stringstream ss; | |||
| ss << get_id(); | |||
| MS_LOG(INFO) << MyName() << " Thread ID " << ss.str() << " Stopped."; | |||
| running_ = false; | |||
| RETURN_IF_NOT_OK(wp_.Deregister()); | |||
| if (MyTaskGroup()) { | |||
| RETURN_IF_NOT_OK(MyTaskGroup()->GetIntrpService()->Deregister(ss.str())); | |||
| } | |||
| RETURN_IF_NOT_OK(interrupt_svc->Deregister(ss.str())); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #include <chrono> | |||
| #include <exception> | |||
| #include <functional> | |||
| #include <future> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <mutex> | |||
| @@ -106,7 +107,7 @@ class Task : public IntrpResource { | |||
| std::function<Status()> fnc_obj_; | |||
| // Misc fields used by TaskManager. | |||
| TaskGroup *task_group_; | |||
| std::thread thrd_; | |||
| std::future<void> thrd_; | |||
| std::thread::id id_; | |||
| bool is_master_; | |||
| volatile bool running_; | |||
| @@ -116,7 +116,7 @@ TaskManager::TaskManager() try : global_interrupt_(0), | |||
| TaskManager::~TaskManager() { | |||
| if (watchdog_) { | |||
| WakeUpWatchDog(); | |||
| watchdog_->thrd_.join(); | |||
| watchdog_->Join(); | |||
| // watchdog_grp_ and watchdog_ pointers come from Services::GetInstance().GetServiceMemPool() which we will free it | |||
| // on shutdown. So no need to free these pointers one by one. | |||
| watchdog_grp_ = nullptr; | |||
| @@ -19,35 +19,19 @@ | |||
| #include "dataset/util/task_manager.h" | |||
| using namespace mindspore::dataset; | |||
| using namespace std::placeholders; | |||
| class MindDataTestTaskManager : public UT::Common { | |||
| public: | |||
| MindDataTestTaskManager() {} | |||
| MindDataTestTaskManager() {} | |||
| void SetUp() { Services::CreateInstance(); | |||
| } | |||
| void SetUp() { Services::CreateInstance(); } | |||
| }; | |||
| std::atomic<int> v(0); | |||
| Status f(TaskGroup &vg){ | |||
| for (int i = 0; i < 1; i++) { | |||
| RETURN_IF_NOT_OK(vg.CreateAsyncTask("Infinity", [&]() -> Status { | |||
| TaskManager::FindMe()->Post(); | |||
| int a = v.fetch_add(1); | |||
| MS_LOG(DEBUG) << a << std::endl; | |||
| return f(vg); | |||
| })); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| TEST_F(MindDataTestTaskManager, Test1) { | |||
| // Clear the rc of the master thread if any | |||
| (void) TaskManager::GetMasterThreadRc(); | |||
| (void)TaskManager::GetMasterThreadRc(); | |||
| TaskGroup vg; | |||
| Status vg_rc = vg.CreateAsyncTask("Test error", [this]() -> Status { | |||
| Status vg_rc = vg.CreateAsyncTask("Test error", []() -> Status { | |||
| TaskManager::FindMe()->Post(); | |||
| throw std::bad_alloc(); | |||
| }); | |||
| @@ -55,6 +39,46 @@ TEST_F(MindDataTestTaskManager, Test1) { | |||
| ASSERT_TRUE(vg.join_all().IsOk()); | |||
| ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOutofMemory()); | |||
| // Test the error is passed back to the master thread. | |||
| Status rc = TaskManager::GetMasterThreadRc(); | |||
| ASSERT_TRUE(rc.IsOutofMemory()); | |||
| // Some compiler may choose to run the next line in parallel with the above 3 lines | |||
| // and this will cause some mismatch once a while. | |||
| // To block this racing condition, we need to create a dependency that the next line | |||
| // depends on previous lines. | |||
| if (vg.GetTaskErrorIfAny().IsError()) { | |||
| Status rc = TaskManager::GetMasterThreadRc(); | |||
| ASSERT_TRUE(rc.IsOutofMemory()); | |||
| } | |||
| } | |||
| TEST_F(MindDataTestTaskManager, Test2) { | |||
| // This testcase will spawn about 10 threads and block on a conditional variable. | |||
| // The master thread will try to interrupt them almost at the same time. This can | |||
| // cause a racing condition that some threads may miss the interrupt and blocked. | |||
| // The new logic of Task::Join() will do a time-out join and wake up all those | |||
| // threads that miss the interrupt. | |||
| // Clear the rc of the master thread if any | |||
| (void)TaskManager::GetMasterThreadRc(); | |||
| TaskGroup vg; | |||
| CondVar cv; | |||
| Status rc; | |||
| rc = cv.Register(vg.GetIntrpService()); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| auto block_forever = [&cv]() -> Status { | |||
| std::mutex mux; | |||
| std::unique_lock<std::mutex> lck(mux); | |||
| TaskManager::FindMe()->Post(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(1)); | |||
| RETURN_IF_NOT_OK(cv.Wait(&lck, []() -> bool { return false; })); | |||
| return Status::OK(); | |||
| }; | |||
| auto f = [&vg, &block_forever]() -> Status { | |||
| for (auto i = 0; i < 10; ++i) { | |||
| RETURN_IF_NOT_OK(vg.CreateAsyncTask("Spawn block threads", block_forever)); | |||
| } | |||
| return Status::OK(); | |||
| }; | |||
| rc = f(); | |||
| vg.interrupt_all(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Now we test the async Join | |||
| ASSERT_TRUE(vg.join_all().IsOk()); | |||
| } | |||