Merge pull request !1406 from JesseKLee/CondVartags/v0.5.0-beta
| @@ -14,35 +14,34 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/util/cond_var.h" | |||
| #include <exception> | |||
| #include <utility> | |||
| #include "dataset/util/services.h" | |||
| #include "dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CondVar::CondVar() : svc_(nullptr), my_name_(std::move(Services::GetUniqueID())) {} | |||
| CondVar::CondVar() : svc_(nullptr), my_name_(Services::GetUniqueID()) {} | |||
| Status CondVar::Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred) { | |||
| // Append an additional condition on top of the given predicate. | |||
| // We will also bail out if this cv got interrupted. | |||
| auto f = [this, &pred]() -> bool { return (pred() || (CurState() == State::kInterrupted)); }; | |||
| // If we have interrupt service, just wait on the cv unconditionally. | |||
| // Otherwise fall back to the old way of checking interrupt. | |||
| if (svc_) { | |||
| cv_.wait(*lck, f); | |||
| if (CurState() == State::kInterrupted) { | |||
| Task *my_task = TaskManager::FindMe(); | |||
| if (my_task->IsMasterThread() && my_task->CaughtSevereException()) { | |||
| return TaskManager::GetMasterThreadRc(); | |||
| } else { | |||
| return Status(StatusCode::kInterrupted); | |||
| try { | |||
| if (svc_ != nullptr) { | |||
| // If this cv registers with a global resource tracking, then wait unconditionally. | |||
| auto f = [this, &pred]() -> bool { return (pred() || this->Interrupted()); }; | |||
| cv_.wait(*lck, f); | |||
| // If we are interrupted, override the return value if this is the master thread. | |||
| // Master thread is being interrupted mostly because of some thread is reporting error. | |||
| RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus())); | |||
| } else { | |||
| // Otherwise we wake up once a while to check for interrupt (for this thread). | |||
| auto f = [&pred]() -> bool { return (pred() || this_thread::is_interrupted()); }; | |||
| while (!f()) { | |||
| (void)cv_.wait_for(*lck, std::chrono::milliseconds(1)); | |||
| } | |||
| RETURN_IF_INTERRUPTED(); | |||
| } | |||
| } else { | |||
| RETURN_IF_NOT_OK(interruptible_wait(&cv_, lck, pred)); | |||
| if (CurState() == State::kInterrupted) { | |||
| return Status(StatusCode::kInterrupted); | |||
| } | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -66,10 +65,9 @@ Status CondVar::Register(std::shared_ptr<IntrpService> svc) { | |||
| return rc; | |||
| } | |||
| Status CondVar::Interrupt() { | |||
| RETURN_IF_NOT_OK(IntrpResource::Interrupt()); | |||
| void CondVar::Interrupt() { | |||
| IntrpResource::Interrupt(); | |||
| cv_.notify_all(); | |||
| return Status::OK(); | |||
| } | |||
| std::string CondVar::my_name() const { return my_name_; } | |||
| @@ -35,7 +35,7 @@ class CondVar : public IntrpResource { | |||
| Status Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred); | |||
| Status Interrupt() override; | |||
| void Interrupt() override; | |||
| void NotifyOne() noexcept; | |||
| @@ -29,10 +29,7 @@ class IntrpResource { | |||
| virtual ~IntrpResource() = default; | |||
| virtual Status Interrupt() { | |||
| st_ = State::kInterrupted; | |||
| return Status::OK(); | |||
| } | |||
| virtual void Interrupt() { st_ = State::kInterrupted; } | |||
| virtual void ResetIntrpState() { st_ = State::kRunning; } | |||
| @@ -40,6 +37,13 @@ class IntrpResource { | |||
| bool Interrupted() const { return CurState() == State::kInterrupted; } | |||
| virtual Status GetInterruptStatus() const { | |||
| if (Interrupted()) { | |||
| return Status(StatusCode::kInterrupted); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| protected: | |||
| std::atomic<State> st_; | |||
| }; | |||
| @@ -27,7 +27,7 @@ IntrpService::~IntrpService() noexcept { | |||
| MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << "."; | |||
| if (!all_intrp_resources_.empty()) { | |||
| try { | |||
| (void)InterruptAll(); | |||
| InterruptAll(); | |||
| } catch (const std::exception &e) { | |||
| // Ignore all error as we can't throw in the destructor. | |||
| } | |||
| @@ -64,11 +64,9 @@ Status IntrpService::Deregister(const std::string &name) noexcept { | |||
| std::ostringstream ss; | |||
| ss << this_thread::get_id(); | |||
| MS_LOG(DEBUG) << "De-register resource with name " << name << ". Thread ID is " << ss.str() << "."; | |||
| auto it = all_intrp_resources_.find(name); | |||
| if (it != all_intrp_resources_.end()) { | |||
| (void)all_intrp_resources_.erase(it); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Key " << name << " not found."; | |||
| auto n = all_intrp_resources_.erase(name); | |||
| if (n == 0) { | |||
| MS_LOG(INFO) << "Key " << name << " not found."; | |||
| } | |||
| } catch (std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| @@ -76,21 +74,16 @@ Status IntrpService::Deregister(const std::string &name) noexcept { | |||
| return Status::OK(); | |||
| } | |||
| Status IntrpService::InterruptAll() noexcept { | |||
| void IntrpService::InterruptAll() noexcept { | |||
| std::lock_guard<std::mutex> lck(mutex_); | |||
| Status rc; | |||
| for (auto const &it : all_intrp_resources_) { | |||
| std::string kName = it.first; | |||
| try { | |||
| Status rc2 = it.second->Interrupt(); | |||
| if (rc2.IsError()) { | |||
| rc = rc2; | |||
| } | |||
| it.second->Interrupt(); | |||
| } catch (const std::exception &e) { | |||
| // continue the clean up. | |||
| } | |||
| } | |||
| return rc; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -47,7 +47,7 @@ class IntrpService : public Service { | |||
| Status Deregister(const std::string &name) noexcept; | |||
| Status InterruptAll() noexcept; | |||
| void InterruptAll() noexcept; | |||
| Status DoServiceStart() override { return Status::OK(); } | |||
| @@ -110,7 +110,7 @@ class Queue { | |||
| empty_cv_.NotifyAll(); | |||
| _lock.unlock(); | |||
| } else { | |||
| (void)empty_cv_.Interrupt(); | |||
| empty_cv_.Interrupt(); | |||
| } | |||
| return rc; | |||
| } | |||
| @@ -125,7 +125,7 @@ class Queue { | |||
| empty_cv_.NotifyAll(); | |||
| _lock.unlock(); | |||
| } else { | |||
| (void)empty_cv_.Interrupt(); | |||
| empty_cv_.Interrupt(); | |||
| } | |||
| return rc; | |||
| } | |||
| @@ -141,7 +141,7 @@ class Queue { | |||
| empty_cv_.NotifyAll(); | |||
| _lock.unlock(); | |||
| } else { | |||
| (void)empty_cv_.Interrupt(); | |||
| empty_cv_.Interrupt(); | |||
| } | |||
| return rc; | |||
| } | |||
| @@ -160,7 +160,7 @@ class Queue { | |||
| full_cv_.NotifyAll(); | |||
| _lock.unlock(); | |||
| } else { | |||
| (void)full_cv_.Interrupt(); | |||
| full_cv_.Interrupt(); | |||
| } | |||
| return rc; | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include <mutex> | |||
| #include <string> | |||
| #include "dataset/util/memory_pool.h" | |||
| #include "dataset/util/allocator.h" | |||
| #include "dataset/util/service.h" | |||
| #define UNIQUEID_LEN 36 | |||
| @@ -72,6 +73,11 @@ class Services { | |||
| static std::string GetUniqueID(); | |||
| template <typename T> | |||
| static Allocator<T> GetAllocator() { | |||
| return Allocator<T>(Services::GetInstance().GetServiceMemPool()); | |||
| } | |||
| private: | |||
| static std::once_flag init_instance_flag_; | |||
| static std::unique_ptr<Services> instance_; | |||
| @@ -72,7 +72,7 @@ void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine. | |||
| } | |||
| } | |||
| Status Task::GetTaskErrorIfAny() { | |||
| Status Task::GetTaskErrorIfAny() const { | |||
| std::lock_guard<std::mutex> lk(mux_); | |||
| if (caught_severe_exception_) { | |||
| return rc_; | |||
| @@ -141,5 +141,13 @@ TaskGroup *Task::MyTaskGroup() { return task_group_; } | |||
| void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; } | |||
| Task::~Task() { task_group_ = nullptr; } | |||
| Status Task::OverrideInterruptRc(const Status &rc) { | |||
| if (rc.IsInterrupted() && this_thread::is_master_thread()) { | |||
| // If we are interrupted, override the return value if this is the master thread. | |||
| // Master thread is being interrupted mostly because of some thread is reporting error. | |||
| return TaskManager::GetMasterThreadRc(); | |||
| } | |||
| return rc; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -60,7 +60,7 @@ class Task : public IntrpResource { | |||
| Task &operator=(Task &&) = delete; | |||
| Status GetTaskErrorIfAny(); | |||
| Status GetTaskErrorIfAny() const; | |||
| void ChangeName(const std::string &newName) { my_name_ = newName; } | |||
| @@ -95,10 +95,10 @@ class Task : public IntrpResource { | |||
| Status Wait() { return (wp_.Wait()); } | |||
| void set_task_group(TaskGroup *vg); | |||
| static Status OverrideInterruptRc(const Status &rc); | |||
| private: | |||
| std::mutex mux_; | |||
| mutable std::mutex mux_; | |||
| std::string my_name_; | |||
| Status rc_; | |||
| WaitPost wp_; | |||
| @@ -115,6 +115,7 @@ class Task : public IntrpResource { | |||
| void ShutdownGroup(); | |||
| TaskGroup *MyTaskGroup(); | |||
| void set_task_group(TaskGroup *vg); | |||
| }; | |||
| extern thread_local Task *gMyTask; | |||
| @@ -84,7 +84,7 @@ void TaskManager::interrupt_all() noexcept { | |||
| svc->InterruptAll(); | |||
| } | |||
| } | |||
| (void)master_->Interrupt(); | |||
| master_->Interrupt(); | |||
| } | |||
| Task *TaskManager::FindMe() { return gMyTask; } | |||
| @@ -94,8 +94,7 @@ TaskManager::TaskManager() try : global_interrupt_(0), | |||
| free_lst_(&Task::free), | |||
| watchdog_grp_(nullptr), | |||
| watchdog_(nullptr) { | |||
| std::shared_ptr<MemoryPool> mp = Services::GetInstance().GetServiceMemPool(); | |||
| Allocator<Task> alloc(mp); | |||
| auto alloc = Services::GetAllocator<Task>(); | |||
| // Create a dummy Task for the master thread (this thread) | |||
| master_ = std::allocate_shared<Task>(alloc, "master", []() -> Status { return Status::OK(); }); | |||
| master_->id_ = this_thread::get_id(); | |||
| @@ -185,7 +184,7 @@ void TaskManager::InterruptMaster(const Status &rc) { | |||
| TaskManager &tm = TaskManager::GetInstance(); | |||
| std::shared_ptr<Task> master = tm.master_; | |||
| std::lock_guard<std::mutex> lck(master->mux_); | |||
| (void)master->Interrupt(); | |||
| master->Interrupt(); | |||
| if (rc.IsError() && master->rc_.IsOk()) { | |||
| master->rc_ = rc; | |||
| master->caught_severe_exception_ = true; | |||
| @@ -277,7 +276,7 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio | |||
| return Status::OK(); | |||
| } | |||
| void TaskGroup::interrupt_all() noexcept { (void)intrp_svc_->InterruptAll(); } | |||
| void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); } | |||
| Status TaskGroup::join_all() { | |||
| Status rc; | |||
| @@ -299,8 +298,7 @@ Status TaskGroup::DoServiceStop() { | |||
| } | |||
| TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) { | |||
| std::shared_ptr<MemoryPool> mp = Services::GetInstance().GetServiceMemPool(); | |||
| Allocator<IntrpService> alloc(mp); | |||
| auto alloc = Services::GetAllocator<IntrpService>(); | |||
| intrp_svc_ = std::allocate_shared<IntrpService>(alloc); | |||
| (void)Service::ServiceStart(); | |||
| } | |||
| @@ -154,37 +154,27 @@ inline bool is_interrupted() { | |||
| return true; | |||
| } | |||
| Task *my_task = TaskManager::FindMe(); | |||
| return (my_task != nullptr) ? my_task->Interrupted() : false; | |||
| return my_task->Interrupted(); | |||
| } | |||
| inline bool is_master_thread() { | |||
| Task *my_task = TaskManager::FindMe(); | |||
| return my_task->IsMasterThread(); | |||
| } | |||
| inline Status GetInterruptStatus() { | |||
| Task *my_task = TaskManager::FindMe(); | |||
| return my_task->GetInterruptStatus(); | |||
| } | |||
| } // namespace this_thread | |||
| #define RETURN_IF_INTERRUPTED() \ | |||
| do { \ | |||
| if (mindspore::dataset::this_thread::is_interrupted()) { \ | |||
| Task *myTask = TaskManager::FindMe(); \ | |||
| if (myTask->IsMasterThread() && myTask->CaughtSevereException()) { \ | |||
| return TaskManager::GetMasterThreadRc(); \ | |||
| } else { \ | |||
| return Status(StatusCode::kInterrupted); \ | |||
| } \ | |||
| } \ | |||
| #define RETURN_IF_INTERRUPTED() \ | |||
| do { \ | |||
| if (mindspore::dataset::this_thread::is_interrupted()) { \ | |||
| return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \ | |||
| } \ | |||
| } while (false) | |||
| inline Status interruptible_wait(std::condition_variable *cv, std::unique_lock<std::mutex> *lk, | |||
| const std::function<bool()> &pred) noexcept { | |||
| if (!pred()) { | |||
| do { | |||
| RETURN_IF_INTERRUPTED(); | |||
| try { | |||
| (void)cv->wait_for(*lk, std::chrono::milliseconds(1)); | |||
| } catch (std::exception &e) { | |||
| // Anything thrown by wait_for is considered system error. | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| } while (!pred()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -139,6 +139,9 @@ Status MindDataTestConnector::Run_test_0() { | |||
| 10); // capacity of each queue | |||
| DS_ASSERT(my_conn != nullptr); | |||
| rc = my_conn->Register(tg_.get()); | |||
| RETURN_IF_NOT_OK(rc); | |||
| // Spawn a thread to read input_ vector and put it in my_conn | |||
| rc = tg_->CreateAsyncTask("Worker Push", | |||
| std::bind(&MindDataTestConnector::FirstWorkerPush, | |||
| @@ -184,6 +187,11 @@ Status MindDataTestConnector::Run_test_1() { | |||
| l3_threads, | |||
| conn2_qcap); | |||
| rc = conn1->Register(tg_.get()); | |||
| RETURN_IF_NOT_OK(rc); | |||
| rc = conn2->Register(tg_.get()); | |||
| RETURN_IF_NOT_OK(rc); | |||
| // Instantiating the threads in the first layer | |||
| for (int i = 0; i < l1_threads; i++) { | |||
| rc = tg_->CreateAsyncTask("First Worker Push", | |||