You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task.cc 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/util/task.h"
  17. #include "common/utils.h"
  18. #include "dataset/util/task_manager.h"
  19. #include "dataset/util/de_error.h"
  20. #include "utils/log_adapter.h"
  21. namespace mindspore {
  22. namespace dataset {
  23. thread_local Task *gMyTask = nullptr;
  24. void Task::operator()() {
  25. gMyTask = this;
  26. std::stringstream ss;
  27. ss << this_thread::get_id();
  28. MS_LOG(INFO) << my_name_ << " Thread ID " << ss.str() << " Started.";
  29. try {
  30. // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set
  31. // the TaskGroup pointer and register. We move the registration logic to here (after we spawn) so we can
  32. // get the thread id.
  33. TaskGroup *vg = MyTaskGroup();
  34. rc_ = vg->GetIntrpService()->Register(ss.str(), this);
  35. if (rc_.IsOk()) {
  36. // Now we can run the given task.
  37. rc_ = fnc_obj_();
  38. }
  39. // Some error codes are ignored, e.g. interrupt. Others we just shutdown the group.
  40. if (rc_.IsError() && !rc_.IsInterrupted()) {
  41. ShutdownGroup();
  42. }
  43. } catch (const std::bad_alloc &e) {
  44. rc_ = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, e.what());
  45. ShutdownGroup();
  46. } catch (const std::exception &e) {
  47. rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what());
  48. ShutdownGroup();
  49. }
  50. }
  51. void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine.
  52. {
  53. std::lock_guard<std::mutex> lk(mux_);
  54. caught_severe_exception_ = true;
  55. }
  56. TaskGroup *vg = MyTaskGroup();
  57. // If multiple threads hit severe errors in the same group. Keep the first one and
  58. // discard the rest.
  59. if (vg->rc_.IsOk()) {
  60. std::unique_lock<std::mutex> rcLock(vg->rc_mux_);
  61. // Check again after we get the lock
  62. if (vg->rc_.IsOk()) {
  63. vg->rc_ = rc_;
  64. rcLock.unlock();
  65. TaskManager::InterruptMaster(rc_);
  66. TaskManager::InterruptGroup(*gMyTask);
  67. }
  68. }
  69. }
  70. Status Task::GetTaskErrorIfAny() {
  71. std::lock_guard<std::mutex> lk(mux_);
  72. if (caught_severe_exception_) {
  73. return rc_;
  74. } else {
  75. return Status::OK();
  76. }
  77. }
  78. Task::Task(const std::string &myName, const std::function<Status()> &f)
  79. : my_name_(myName),
  80. rc_(),
  81. fnc_obj_(f),
  82. task_group_(nullptr),
  83. is_master_(false),
  84. running_(false),
  85. caught_severe_exception_(false) {
  86. IntrpResource::ResetIntrpState();
  87. wp_.ResetIntrpState();
  88. wp_.Clear();
  89. }
  90. Status Task::Run() {
  91. Status rc;
  92. if (running_ == false) {
  93. try {
  94. thrd_ = std::thread(std::ref(*this));
  95. id_ = thrd_.get_id();
  96. running_ = true;
  97. caught_severe_exception_ = false;
  98. } catch (const std::exception &e) {
  99. rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what());
  100. }
  101. }
  102. return rc;
  103. }
  104. Status Task::Join() {
  105. if (running_) {
  106. try {
  107. thrd_.join();
  108. std::stringstream ss;
  109. ss << get_id();
  110. MS_LOG(INFO) << MyName() << " Thread ID " << ss.str() << " Stopped.";
  111. running_ = false;
  112. RETURN_IF_NOT_OK(wp_.Deregister());
  113. if (MyTaskGroup()) {
  114. RETURN_IF_NOT_OK(MyTaskGroup()->GetIntrpService()->Deregister(ss.str()));
  115. }
  116. } catch (const std::exception &e) {
  117. RETURN_STATUS_UNEXPECTED(e.what());
  118. }
  119. }
  120. return Status::OK();
  121. }
  122. TaskGroup *Task::MyTaskGroup() { return task_group_; }
  123. void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; }
  124. Task::~Task() { task_group_ = nullptr; }
  125. } // namespace dataset
  126. } // namespace mindspore