Browse Source

!1406 Simplify CondVar class

Merge pull request !1406 from JesseKLee/CondVar
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
e8980ed298
12 changed files with 88 additions and 82 deletions
  1. +20
    -22
      mindspore/ccsrc/dataset/util/cond_var.cc
  2. +1
    -1
      mindspore/ccsrc/dataset/util/cond_var.h
  3. +8
    -4
      mindspore/ccsrc/dataset/util/intrp_resource.h
  4. +6
    -13
      mindspore/ccsrc/dataset/util/intrp_service.cc
  5. +1
    -1
      mindspore/ccsrc/dataset/util/intrp_service.h
  6. +4
    -4
      mindspore/ccsrc/dataset/util/queue.h
  7. +6
    -0
      mindspore/ccsrc/dataset/util/services.h
  8. +9
    -1
      mindspore/ccsrc/dataset/util/task.cc
  9. +4
    -3
      mindspore/ccsrc/dataset/util/task.h
  10. +5
    -7
      mindspore/ccsrc/dataset/util/task_manager.cc
  11. +16
    -26
      mindspore/ccsrc/dataset/util/task_manager.h
  12. +8
    -0
      tests/ut/cpp/dataset/connector_test.cc

+ 20
- 22
mindspore/ccsrc/dataset/util/cond_var.cc View File

@@ -14,35 +14,34 @@
* limitations under the License.
*/
#include "dataset/util/cond_var.h"
#include <exception>
#include <utility>
#include "dataset/util/services.h"
#include "dataset/util/task_manager.h"

namespace mindspore {
namespace dataset {
CondVar::CondVar() : svc_(nullptr), my_name_(std::move(Services::GetUniqueID())) {}
CondVar::CondVar() : svc_(nullptr), my_name_(Services::GetUniqueID()) {}

Status CondVar::Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred) {
// Append an additional condition on top of the given predicate.
// We will also bail out if this cv got interrupted.
auto f = [this, &pred]() -> bool { return (pred() || (CurState() == State::kInterrupted)); };
// If we have interrupt service, just wait on the cv unconditionally.
// Otherwise fall back to the old way of checking interrupt.
if (svc_) {
cv_.wait(*lck, f);
if (CurState() == State::kInterrupted) {
Task *my_task = TaskManager::FindMe();
if (my_task->IsMasterThread() && my_task->CaughtSevereException()) {
return TaskManager::GetMasterThreadRc();
} else {
return Status(StatusCode::kInterrupted);
try {
if (svc_ != nullptr) {
// If this cv registers with a global resource tracking, then wait unconditionally.
auto f = [this, &pred]() -> bool { return (pred() || this->Interrupted()); };
cv_.wait(*lck, f);
// If we are interrupted, override the return value if this is the master thread.
// Master thread is being interrupted mostly because of some thread is reporting error.
RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus()));
} else {
// Otherwise we wake up once a while to check for interrupt (for this thread).
auto f = [&pred]() -> bool { return (pred() || this_thread::is_interrupted()); };
while (!f()) {
(void)cv_.wait_for(*lck, std::chrono::milliseconds(1));
}
RETURN_IF_INTERRUPTED();
}
} else {
RETURN_IF_NOT_OK(interruptible_wait(&cv_, lck, pred));
if (CurState() == State::kInterrupted) {
return Status(StatusCode::kInterrupted);
}
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
@@ -66,10 +65,9 @@ Status CondVar::Register(std::shared_ptr<IntrpService> svc) {
return rc;
}

Status CondVar::Interrupt() {
RETURN_IF_NOT_OK(IntrpResource::Interrupt());
void CondVar::Interrupt() {
IntrpResource::Interrupt();
cv_.notify_all();
return Status::OK();
}

std::string CondVar::my_name() const { return my_name_; }


+ 1
- 1
mindspore/ccsrc/dataset/util/cond_var.h View File

@@ -35,7 +35,7 @@ class CondVar : public IntrpResource {

Status Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred);

Status Interrupt() override;
void Interrupt() override;

void NotifyOne() noexcept;



+ 8
- 4
mindspore/ccsrc/dataset/util/intrp_resource.h View File

@@ -29,10 +29,7 @@ class IntrpResource {

virtual ~IntrpResource() = default;

virtual Status Interrupt() {
st_ = State::kInterrupted;
return Status::OK();
}
virtual void Interrupt() { st_ = State::kInterrupted; }

virtual void ResetIntrpState() { st_ = State::kRunning; }

@@ -40,6 +37,13 @@ class IntrpResource {

bool Interrupted() const { return CurState() == State::kInterrupted; }

virtual Status GetInterruptStatus() const {
if (Interrupted()) {
return Status(StatusCode::kInterrupted);
}
return Status::OK();
}

protected:
std::atomic<State> st_;
};


+ 6
- 13
mindspore/ccsrc/dataset/util/intrp_service.cc View File

@@ -27,7 +27,7 @@ IntrpService::~IntrpService() noexcept {
MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << ".";
if (!all_intrp_resources_.empty()) {
try {
(void)InterruptAll();
InterruptAll();
} catch (const std::exception &e) {
// Ignore all error as we can't throw in the destructor.
}
@@ -64,11 +64,9 @@ Status IntrpService::Deregister(const std::string &name) noexcept {
std::ostringstream ss;
ss << this_thread::get_id();
MS_LOG(DEBUG) << "De-register resource with name " << name << ". Thread ID is " << ss.str() << ".";
auto it = all_intrp_resources_.find(name);
if (it != all_intrp_resources_.end()) {
(void)all_intrp_resources_.erase(it);
} else {
MS_LOG(DEBUG) << "Key " << name << " not found.";
auto n = all_intrp_resources_.erase(name);
if (n == 0) {
MS_LOG(INFO) << "Key " << name << " not found.";
}
} catch (std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
@@ -76,21 +74,16 @@ Status IntrpService::Deregister(const std::string &name) noexcept {
return Status::OK();
}

Status IntrpService::InterruptAll() noexcept {
void IntrpService::InterruptAll() noexcept {
std::lock_guard<std::mutex> lck(mutex_);
Status rc;
for (auto const &it : all_intrp_resources_) {
std::string kName = it.first;
try {
Status rc2 = it.second->Interrupt();
if (rc2.IsError()) {
rc = rc2;
}
it.second->Interrupt();
} catch (const std::exception &e) {
// continue the clean up.
}
}
return rc;
}
} // namespace dataset
} // namespace mindspore

+ 1
- 1
mindspore/ccsrc/dataset/util/intrp_service.h View File

@@ -47,7 +47,7 @@ class IntrpService : public Service {

Status Deregister(const std::string &name) noexcept;

Status InterruptAll() noexcept;
void InterruptAll() noexcept;

Status DoServiceStart() override { return Status::OK(); }



+ 4
- 4
mindspore/ccsrc/dataset/util/queue.h View File

@@ -110,7 +110,7 @@ class Queue {
empty_cv_.NotifyAll();
_lock.unlock();
} else {
(void)empty_cv_.Interrupt();
empty_cv_.Interrupt();
}
return rc;
}
@@ -125,7 +125,7 @@ class Queue {
empty_cv_.NotifyAll();
_lock.unlock();
} else {
(void)empty_cv_.Interrupt();
empty_cv_.Interrupt();
}
return rc;
}
@@ -141,7 +141,7 @@ class Queue {
empty_cv_.NotifyAll();
_lock.unlock();
} else {
(void)empty_cv_.Interrupt();
empty_cv_.Interrupt();
}
return rc;
}
@@ -160,7 +160,7 @@ class Queue {
full_cv_.NotifyAll();
_lock.unlock();
} else {
(void)full_cv_.Interrupt();
full_cv_.Interrupt();
}
return rc;
}


+ 6
- 0
mindspore/ccsrc/dataset/util/services.h View File

@@ -20,6 +20,7 @@
#include <mutex>
#include <string>
#include "dataset/util/memory_pool.h"
#include "dataset/util/allocator.h"
#include "dataset/util/service.h"

#define UNIQUEID_LEN 36
@@ -72,6 +73,11 @@ class Services {

static std::string GetUniqueID();

template <typename T>
static Allocator<T> GetAllocator() {
return Allocator<T>(Services::GetInstance().GetServiceMemPool());
}

private:
static std::once_flag init_instance_flag_;
static std::unique_ptr<Services> instance_;


+ 9
- 1
mindspore/ccsrc/dataset/util/task.cc View File

@@ -72,7 +72,7 @@ void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine.
}
}

Status Task::GetTaskErrorIfAny() {
Status Task::GetTaskErrorIfAny() const {
std::lock_guard<std::mutex> lk(mux_);
if (caught_severe_exception_) {
return rc_;
@@ -141,5 +141,13 @@ TaskGroup *Task::MyTaskGroup() { return task_group_; }
void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; }

Task::~Task() { task_group_ = nullptr; }
Status Task::OverrideInterruptRc(const Status &rc) {
if (rc.IsInterrupted() && this_thread::is_master_thread()) {
// If we are interrupted, override the return value if this is the master thread.
// Master thread is being interrupted mostly because of some thread is reporting error.
return TaskManager::GetMasterThreadRc();
}
return rc;
}
} // namespace dataset
} // namespace mindspore

+ 4
- 3
mindspore/ccsrc/dataset/util/task.h View File

@@ -60,7 +60,7 @@ class Task : public IntrpResource {

Task &operator=(Task &&) = delete;

Status GetTaskErrorIfAny();
Status GetTaskErrorIfAny() const;

void ChangeName(const std::string &newName) { my_name_ = newName; }

@@ -95,10 +95,10 @@ class Task : public IntrpResource {

Status Wait() { return (wp_.Wait()); }

void set_task_group(TaskGroup *vg);
static Status OverrideInterruptRc(const Status &rc);

private:
std::mutex mux_;
mutable std::mutex mux_;
std::string my_name_;
Status rc_;
WaitPost wp_;
@@ -115,6 +115,7 @@ class Task : public IntrpResource {

void ShutdownGroup();
TaskGroup *MyTaskGroup();
void set_task_group(TaskGroup *vg);
};

extern thread_local Task *gMyTask;


+ 5
- 7
mindspore/ccsrc/dataset/util/task_manager.cc View File

@@ -84,7 +84,7 @@ void TaskManager::interrupt_all() noexcept {
svc->InterruptAll();
}
}
(void)master_->Interrupt();
master_->Interrupt();
}

Task *TaskManager::FindMe() { return gMyTask; }
@@ -94,8 +94,7 @@ TaskManager::TaskManager() try : global_interrupt_(0),
free_lst_(&Task::free),
watchdog_grp_(nullptr),
watchdog_(nullptr) {
std::shared_ptr<MemoryPool> mp = Services::GetInstance().GetServiceMemPool();
Allocator<Task> alloc(mp);
auto alloc = Services::GetAllocator<Task>();
// Create a dummy Task for the master thread (this thread)
master_ = std::allocate_shared<Task>(alloc, "master", []() -> Status { return Status::OK(); });
master_->id_ = this_thread::get_id();
@@ -185,7 +184,7 @@ void TaskManager::InterruptMaster(const Status &rc) {
TaskManager &tm = TaskManager::GetInstance();
std::shared_ptr<Task> master = tm.master_;
std::lock_guard<std::mutex> lck(master->mux_);
(void)master->Interrupt();
master->Interrupt();
if (rc.IsError() && master->rc_.IsOk()) {
master->rc_ = rc;
master->caught_severe_exception_ = true;
@@ -277,7 +276,7 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio
return Status::OK();
}

void TaskGroup::interrupt_all() noexcept { (void)intrp_svc_->InterruptAll(); }
void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); }

Status TaskGroup::join_all() {
Status rc;
@@ -299,8 +298,7 @@ Status TaskGroup::DoServiceStop() {
}

TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) {
std::shared_ptr<MemoryPool> mp = Services::GetInstance().GetServiceMemPool();
Allocator<IntrpService> alloc(mp);
auto alloc = Services::GetAllocator<IntrpService>();
intrp_svc_ = std::allocate_shared<IntrpService>(alloc);
(void)Service::ServiceStart();
}


+ 16
- 26
mindspore/ccsrc/dataset/util/task_manager.h View File

@@ -154,37 +154,27 @@ inline bool is_interrupted() {
return true;
}
Task *my_task = TaskManager::FindMe();
return (my_task != nullptr) ? my_task->Interrupted() : false;
return my_task->Interrupted();
}

inline bool is_master_thread() {
Task *my_task = TaskManager::FindMe();
return my_task->IsMasterThread();
}

inline Status GetInterruptStatus() {
Task *my_task = TaskManager::FindMe();
return my_task->GetInterruptStatus();
}
} // namespace this_thread

#define RETURN_IF_INTERRUPTED() \
do { \
if (mindspore::dataset::this_thread::is_interrupted()) { \
Task *myTask = TaskManager::FindMe(); \
if (myTask->IsMasterThread() && myTask->CaughtSevereException()) { \
return TaskManager::GetMasterThreadRc(); \
} else { \
return Status(StatusCode::kInterrupted); \
} \
} \
#define RETURN_IF_INTERRUPTED() \
do { \
if (mindspore::dataset::this_thread::is_interrupted()) { \
return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \
} \
} while (false)

inline Status interruptible_wait(std::condition_variable *cv, std::unique_lock<std::mutex> *lk,
const std::function<bool()> &pred) noexcept {
if (!pred()) {
do {
RETURN_IF_INTERRUPTED();
try {
(void)cv->wait_for(*lk, std::chrono::milliseconds(1));
} catch (std::exception &e) {
// Anything thrown by wait_for is considered system error.
RETURN_STATUS_UNEXPECTED(e.what());
}
} while (!pred());
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore



+ 8
- 0
tests/ut/cpp/dataset/connector_test.cc View File

@@ -139,6 +139,9 @@ Status MindDataTestConnector::Run_test_0() {
10); // capacity of each queue
DS_ASSERT(my_conn != nullptr);

rc = my_conn->Register(tg_.get());
RETURN_IF_NOT_OK(rc);

// Spawn a thread to read input_ vector and put it in my_conn
rc = tg_->CreateAsyncTask("Worker Push",
std::bind(&MindDataTestConnector::FirstWorkerPush,
@@ -184,6 +187,11 @@ Status MindDataTestConnector::Run_test_1() {
l3_threads,
conn2_qcap);

rc = conn1->Register(tg_.get());
RETURN_IF_NOT_OK(rc);
rc = conn2->Register(tg_.get());
RETURN_IF_NOT_OK(rc);

// Instantiating the threads in the first layer
for (int i = 0; i < l1_threads; i++) {
rc = tg_->CreateAsyncTask("First Worker Push",


Loading…
Cancel
Save