You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/executor.h"
  17. #include "backend/session/executor_manager.h"
  18. #include <algorithm>
  19. #include <exception>
  20. #include "runtime/device/kernel_runtime_manager.h"
  21. #include "utils/comm_manager.h"
  22. #include "utils/scoped_long_running.h"
  23. #include "pybind_api/ir/tensor_py.h"
  24. #if (ENABLE_CPU && !_WIN32)
  25. #include "ps/ps_cache/ps_cache_manager.h"
  26. #endif
  27. using mindspore::tensor::TensorPy;
  28. namespace mindspore {
  29. namespace session {
  30. namespace {
  31. void SetOutputTensorsWaitStatus(const VectorRef *outputs) {
  32. MS_EXCEPTION_IF_NULL(outputs);
  33. for (auto &item : *outputs) {
  34. if (utils::isa<VectorRefPtr>(item)) {
  35. auto vector_ref = utils::cast<VectorRef>(item);
  36. SetOutputTensorsWaitStatus(&vector_ref);
  37. } else if (utils::isa<tensor::TensorPtr>(item)) {
  38. auto tensor = utils::cast<tensor::TensorPtr>(item);
  39. MS_EXCEPTION_IF_NULL(tensor);
  40. tensor->SetNeedWait(false);
  41. }
  42. }
  43. }
  44. bool TensorInVector(const VectorRef *outputs) {
  45. MS_EXCEPTION_IF_NULL(outputs);
  46. for (auto &item : *outputs) {
  47. if (utils::isa<VectorRefPtr>(item)) {
  48. auto vector_ref = utils::cast<VectorRef>(item);
  49. if (TensorInVector(&vector_ref)) {
  50. return true;
  51. }
  52. } else if (utils::isa<tensor::TensorPtr>(item)) {
  53. return true;
  54. }
  55. }
  56. return false;
  57. }
  58. bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
  59. MS_EXCEPTION_IF_NULL(task);
  60. for (auto &input : task->input_need_wait_tensors_) {
  61. MS_EXCEPTION_IF_NULL(input);
  62. if (input->NeedWait()) {
  63. return false;
  64. }
  65. }
  66. auto session = task->session_;
  67. MS_EXCEPTION_IF_NULL(session);
  68. auto graph = session->GetGraph(task->graph_id_);
  69. if (graph != nullptr) {
  70. return graph->IsPreGraphFinished();
  71. }
  72. return true;
  73. }
  74. void WaitLockedInputs(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task) {
  75. bool need_lock = false;
  76. for (auto &tensor : task->input_tensors_) {
  77. if (tensor->NeedWait()) {
  78. if (tensor->IsGraphOutput()) {
  79. task->input_need_wait_tensors_.emplace_back(tensor);
  80. } else {
  81. need_lock = true;
  82. }
  83. }
  84. }
  85. if (need_lock) {
  86. mindspore::ScopedLongRunning long_running;
  87. for (auto &tensor : task->input_tensors_) {
  88. if (tensor->NeedWait() && !tensor->IsGraphOutput()) {
  89. MsException::Instance().CheckException();
  90. tensor->Wait();
  91. }
  92. }
  93. MsException::Instance().CheckException();
  94. }
  95. // need lock input parameters for optimizer
  96. for (auto &tensor : task->input_need_lock_tensors_) {
  97. tensor->SetNeedWait(true);
  98. }
  99. }
  100. } // namespace
  101. void CompileNodesTask::Run() {
  102. MS_EXCEPTION_IF_NULL(session_);
  103. MS_EXCEPTION_IF_NULL(segment_);
  104. graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_);
  105. }
  106. void CompileGraphTask::Run() {
  107. MS_EXCEPTION_IF_NULL(session_);
  108. graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_));
  109. }
  110. void BuildGraphTask::Run() {
  111. MS_EXCEPTION_IF_NULL(session_);
  112. session_->BuildGraphImpl(graph_id_);
  113. }
  114. void RunGraphTask::Run() {
  115. MS_EXCEPTION_IF_NULL(session_);
  116. MS_LOG(INFO) << "Start run graph " << graph_id_;
  117. auto graph = session_->GetGraph(graph_id_);
  118. if (graph == nullptr) {
  119. MS_LOG(ERROR) << "Invalid graph id " << graph_id_;
  120. return;
  121. }
  122. graph->ResetGraphRunningStatus();
  123. try {
  124. session_->LoadInputs(graph_id_, input_tensors_);
  125. session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
  126. session_->UpdateOutputTensors(&outputs_, tensor_to_node_);
  127. } catch (const std::exception &e) {
  128. ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
  129. MsException::Instance().SetException();
  130. }
  131. MS_LOG(INFO) << "End run graph " << graph_id_;
  132. graph->OnRunGraphFinished();
  133. for (auto &tensor : input_need_lock_tensors_) {
  134. tensor->SetNeedWait(false);
  135. }
  136. SetOutputTensorsWaitStatus(&outputs_);
  137. ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
  138. }
  139. void RunOpTask::Run() {
  140. MS_EXCEPTION_IF_NULL(session_);
  141. session_->RunOpImpl(graph_info_, op_run_info_, input_tensors_, &outputs_, tensors_mask_);
  142. }
  143. void RunOpsInGraphTask::Run() {
  144. MS_EXCEPTION_IF_NULL(session_);
  145. session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_);
  146. }
  147. void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
  148. void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
  149. Executor::Executor(const std::string &device_name, uint32_t device_id) {
  150. device_name_ = device_name;
  151. device_id_ = device_id;
  152. worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
  153. }
  154. Executor::~Executor() { WorkerJoin(); }
  155. void Executor::WorkerJoin() {
  156. // Avoid worker thread join itself which will cause deadlock
  157. if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
  158. {
  159. std::lock_guard<std::mutex> lock(task_mutex_);
  160. auto task = std::make_shared<ExitTask>();
  161. ready_tasks_.push(task);
  162. task_cond_var_.notify_all();
  163. }
  164. worker_->join();
  165. }
  166. }
  167. void Executor::WorkerLoop() {
  168. while (true) {
  169. std::shared_ptr<Task> task;
  170. {
  171. std::unique_lock<std::mutex> lock(task_mutex_);
  172. task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
  173. task = ready_tasks_.front();
  174. ready_tasks_.pop();
  175. }
  176. if (task->type_ == kExit) {
  177. OnWorkerExit();
  178. return;
  179. }
  180. try {
  181. task->Run();
  182. } catch (const std::exception &e) {
  183. ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
  184. MsException::Instance().SetException();
  185. }
  186. {
  187. std::lock_guard<std::mutex> lock(done_task_mutex_);
  188. done_tasks_.emplace_back(task);
  189. }
  190. if (task->type_ != kRunGraph || task->sync_run_) {
  191. std::lock_guard<std::mutex> lock(task_mutex_);
  192. sync_run_task_finished_ = true;
  193. sync_cond_var_.notify_all();
  194. }
  195. }
  196. }
  197. std::vector<std::shared_ptr<RunGraphTask>> Executor::GetReadyTasksFromPendingList() {
  198. std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
  199. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  200. for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
  201. auto task = *iter;
  202. if (IsTaskReady(task)) {
  203. new_ready_tasks.emplace_back(task);
  204. pending_tasks_.erase(iter++);
  205. } else {
  206. iter++;
  207. }
  208. }
  209. return new_ready_tasks;
  210. }
  211. void Executor::OnEvent(const ExecutorEvent &event) {
  212. if (event == ExecutorEvent::kRunGraphFinished) {
  213. OnRunGraphFinished();
  214. } else if (event == ExecutorEvent::kClear) {
  215. OnClear();
  216. } else if (event == ExecutorEvent::kException) {
  217. OnException();
  218. }
  219. }
  220. void Executor::OnClear() {
  221. WorkerJoin();
  222. ClearDoneTasks();
  223. }
  224. void Executor::OnException() {
  225. std::vector<std::shared_ptr<Task>> new_done_tasks;
  226. {
  227. std::lock_guard<std::mutex> lock(task_mutex_);
  228. while (!ready_tasks_.empty()) {
  229. new_done_tasks.emplace_back(ready_tasks_.front());
  230. ready_tasks_.pop();
  231. }
  232. }
  233. {
  234. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  235. std::copy(pending_tasks_.begin(), pending_tasks_.end(), std::back_inserter(new_done_tasks));
  236. pending_tasks_.clear();
  237. }
  238. {
  239. std::lock_guard<std::mutex> lock(done_task_mutex_);
  240. (void)done_tasks_.insert(done_tasks_.end(), new_done_tasks.begin(), new_done_tasks.end());
  241. }
  242. }
  243. void Executor::OnRunGraphFinished() {
  244. auto new_ready_tasks = GetReadyTasksFromPendingList();
  245. std::lock_guard<std::mutex> lock(task_mutex_);
  246. for (auto &task : new_ready_tasks) {
  247. ready_tasks_.push(task);
  248. }
  249. if (!new_ready_tasks.empty()) {
  250. task_cond_var_.notify_all();
  251. }
  252. reenter_cond_var_.notify_all();
  253. }
  254. void Executor::ClearDoneTasks() {
  255. std::lock_guard<std::mutex> lock(done_task_mutex_);
  256. done_tasks_.clear();
  257. }
  258. void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run) {
  259. if (sync) {
  260. ClearDoneTasks();
  261. }
  262. {
  263. std::lock_guard<std::mutex> lock(task_mutex_);
  264. sync_run_task_finished_ = false;
  265. ready_tasks_.push(task);
  266. }
  267. task_cond_var_.notify_all();
  268. if (sync && !sync_run_task_finished_) {
  269. std::unique_lock<std::mutex> lock(task_mutex_);
  270. if (long_run) {
  271. mindspore::ScopedLongRunning long_running;
  272. sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
  273. } else {
  274. sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
  275. }
  276. }
  277. ClearDoneTasks();
  278. MsException::Instance().CheckException();
  279. }
  280. GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
  281. const AnfNodePtrList &outputs) {
  282. auto task = std::make_shared<CompileNodesTask>();
  283. task->session_ = session;
  284. task->segment_ = segment;
  285. task->output_nodes_ = outputs;
  286. RunTask(task, true);
  287. return task->graph_id_;
  288. }
  289. GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
  290. auto task = std::make_shared<CompileGraphTask>();
  291. task->session_ = session;
  292. task->func_graph_ = func_graph.get();
  293. RunTask(task, true);
  294. return task->graph_id_;
  295. }
  296. void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) {
  297. auto task = std::make_shared<BuildGraphTask>();
  298. task->session_ = session;
  299. task->graph_id_ = graphId;
  300. RunTask(task, true);
  301. }
  302. void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
  303. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  304. MS_EXCEPTION_IF_NULL(session);
  305. MS_EXCEPTION_IF_NULL(outputs);
  306. auto task = std::make_shared<RunGraphTask>();
  307. task->session_ = session;
  308. task->graph_id_ = graph_id;
  309. task->input_tensors_ = inputs;
  310. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  311. task->outputs_ = *outputs;
  312. task->sync_run_ = true;
  313. RunTask(task, true, true);
  314. }
  315. void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
  316. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  317. MS_EXCEPTION_IF_NULL(session);
  318. MS_EXCEPTION_IF_NULL(outputs);
  319. auto task = std::make_shared<RunGraphTask>();
  320. task->session_ = session;
  321. task->graph_id_ = graph_id;
  322. task->input_tensors_ = inputs;
  323. task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs);
  324. auto graph = session->GetGraph(task->graph_id_);
  325. if (graph != nullptr && !graph->IsPostGraphFinished()) {
  326. mindspore::ScopedLongRunning long_running;
  327. std::unique_lock<std::mutex> lock(reenter_mutex_);
  328. reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); });
  329. MsException::Instance().CheckException();
  330. }
  331. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  332. // maintain a copy of output vector
  333. task->outputs_ = *outputs;
  334. // sync run graph without output tensor(int dataset graph)
  335. if (!TensorInVector(outputs)) {
  336. task->sync_run_ = true;
  337. RunTask(task, true, true);
  338. return;
  339. }
  340. WaitLockedInputs(session, task);
  341. {
  342. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  343. if (!IsTaskReady(task)) {
  344. pending_tasks_.push_back(task);
  345. return;
  346. }
  347. }
  348. RunTask(task, false);
  349. }
  350. void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  351. std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
  352. const std::vector<int64_t> &tensors_mask) {
  353. MS_EXCEPTION_IF_NULL(session);
  354. auto ms_context = MsContext::GetInstance();
  355. auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  356. if (target == kGPUDevice) {
  357. for (auto &tensor : *input_tensors) {
  358. if (tensor->NeedWait()) {
  359. tensor->Wait();
  360. }
  361. }
  362. {
  363. // Release GIL before calling into (potentially long-running) C++ code
  364. if (Py_IsInitialized()) {
  365. py::gil_scoped_release release;
  366. session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
  367. } else {
  368. session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
  369. }
  370. }
  371. } else {
  372. auto task = std::make_shared<RunOpTask>();
  373. task->session_ = session;
  374. task->op_run_info_ = op_run_info;
  375. task->graph_info_ = graph_info;
  376. task->input_tensors_ = input_tensors;
  377. task->tensors_mask_ = tensors_mask;
  378. for (auto &tensor : *input_tensors) {
  379. if (tensor->NeedWait()) {
  380. tensor->Wait();
  381. }
  382. }
  383. RunTask(task, true, true);
  384. *outputs = task->outputs_;
  385. }
  386. }
  387. void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id,
  388. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  389. MS_EXCEPTION_IF_NULL(session);
  390. MS_EXCEPTION_IF_NULL(outputs);
  391. auto task = std::make_shared<RunOpsInGraphTask>();
  392. task->session_ = session;
  393. task->graph_id_ = graph_id;
  394. task->input_tensors_ = inputs;
  395. RunTask(task, true, true);
  396. *outputs = task->outputs_;
  397. }
  398. bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
  399. auto task = std::make_shared<CreateCommGroupTask>();
  400. task->group_name_ = group_name;
  401. task->ranks_ = ranks;
  402. RunTask(task, true);
  403. return task->result_;
  404. }
  405. bool Executor::DestroyCommGroup(const std::string &group_name) {
  406. auto task = std::make_shared<DestroyCommGroupTask>();
  407. task->group_name_ = group_name;
  408. RunTask(task, true);
  409. return task->result_;
  410. }
  411. void Executor::OnWorkerExit() {
  412. if (device_name_ == kAscendDevice) {
  413. device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
  414. }
  415. }
  416. } // namespace session
  417. } // namespace mindspore