/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_ #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) #include #include // for sig_atomic_t #endif #include #include #include #include #include #include "minddata/dataset/util/allocator.h" #include "minddata/dataset/util/intrp_service.h" #include "minddata/dataset/util/lock.h" #include "minddata/dataset/util/services.h" #include "minddata/dataset/util/status.h" #include "minddata/dataset/util/task.h" namespace mindspore { namespace dataset { namespace thread { using id = std::thread::id; } // namespace thread namespace this_thread { inline thread::id get_id() { return std::this_thread::get_id(); } } // namespace this_thread class TaskManager : public Service { public: friend class Services; friend class TaskGroup; ~TaskManager() override; TaskManager(const TaskManager &) = delete; TaskManager &operator=(const TaskManager &) = delete; static Status CreateInstance() { std::call_once(init_instance_flag_, [&]() -> Status { auto &svcManager = Services::GetInstance(); RETURN_IF_NOT_OK(svcManager.AddHook(&instance_)); return Status::OK(); }); return Status::OK(); } static TaskManager &GetInstance() noexcept { return *instance_; } Status DoServiceStart() override; Status DoServiceStop() override; // A public global interrupt flag for signal handlers volatile sig_atomic_t global_interrupt_; // API // This takes the same parameter as Task constructor. Take a look // of the test-thread.cc for usage. Status CreateAsyncTask(const std::string &my_name, const std::function &f, TaskGroup *vg, Task **); // Same usage as boot thread group Status join_all(); void interrupt_all() noexcept; // Locate a particular Task. static Task *FindMe(); static void InterruptGroup(Task &); static Status GetMasterThreadRc(); static void InterruptMaster(const Status &rc = Status::OK()); static void WakeUpWatchDog() { #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) TaskManager &tm = TaskManager::GetInstance(); (void)sem_post(&tm.sem_); #endif } void ReturnFreeTask(Task *p) noexcept; Status GetFreeTask(const std::string &my_name, const std::function &f, Task **p); Status WatchDog(); private: static std::once_flag init_instance_flag_; static TaskManager *instance_; RWLock lru_lock_; SpinLock free_lock_; SpinLock tg_lock_; std::shared_ptr master_; List lru_; List free_lst_; #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) sem_t sem_; #endif TaskGroup *watchdog_grp_; std::set grp_list_; Task *watchdog_; TaskManager(); }; // A group of related tasks. class TaskGroup : public Service { public: friend class Task; friend class TaskManager; Status CreateAsyncTask(const std::string &my_name, const std::function &f, Task **pTask = nullptr); void interrupt_all() noexcept; Status join_all(Task::WaitFlag wf = Task::WaitFlag::kBlocking); int size() const noexcept { return grp_list_.count; } Status DoServiceStart() override { return Status::OK(); } Status DoServiceStop() override; TaskGroup(); ~TaskGroup() override; Status GetTaskErrorIfAny(); std::shared_ptr GetIntrpService(); private: Status rc_; // Can't use rw_lock_ as we will lead to deadlatch. Create another mutex to serialize access to rc_. std::mutex rc_mux_; RWLock rw_lock_; List grp_list_; std::shared_ptr intrp_svc_; }; namespace this_thread { inline bool is_interrupted() { TaskManager &tm = TaskManager::GetInstance(); if (tm.global_interrupt_ == 1) { return true; } Task *my_task = TaskManager::FindMe(); return my_task->Interrupted(); } inline bool is_master_thread() { Task *my_task = TaskManager::FindMe(); return my_task->IsMasterThread(); } inline Status GetInterruptStatus() { Task *my_task = TaskManager::FindMe(); return my_task->GetInterruptStatus(); } } // namespace this_thread #define RETURN_IF_INTERRUPTED() \ do { \ if (mindspore::dataset::this_thread::is_interrupted()) { \ return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \ } \ } while (false) } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_