You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 16 kB

5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/executor.h"
  17. #include "backend/session/executor_manager.h"
  18. #include <algorithm>
  19. #include <exception>
  20. #include <set>
  21. #include "runtime/device/kernel_runtime_manager.h"
  22. #include "utils/comm_manager.h"
  23. #include "utils/scoped_long_running.h"
  24. #include "pybind_api/ir/tensor_py.h"
  25. #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
  26. #include "ps/ps_cache/ps_cache_manager.h"
  27. #endif
  28. using mindspore::tensor::TensorPy;
  29. namespace mindspore {
  30. namespace session {
  31. namespace {
  32. void GetNeedNotifyTensors(const VectorRef *outputs, std::set<TensorPtr> *result) {
  33. MS_EXCEPTION_IF_NULL(outputs);
  34. MS_EXCEPTION_IF_NULL(result);
  35. for (auto &item : *outputs) {
  36. if (utils::isa<VectorRefPtr>(item)) {
  37. auto vector_ref = utils::cast<VectorRef>(item);
  38. GetNeedNotifyTensors(&vector_ref, result);
  39. } else if (utils::isa<tensor::TensorPtr>(item)) {
  40. auto tensor = utils::cast<tensor::TensorPtr>(item);
  41. result->emplace(tensor);
  42. }
  43. }
  44. }
  45. bool TensorInVector(const VectorRef *outputs) {
  46. MS_EXCEPTION_IF_NULL(outputs);
  47. for (auto &item : *outputs) {
  48. if (utils::isa<VectorRefPtr>(item)) {
  49. auto vector_ref = utils::cast<VectorRef>(item);
  50. if (TensorInVector(&vector_ref)) {
  51. return true;
  52. }
  53. } else if (utils::isa<tensor::TensorPtr>(item)) {
  54. return true;
  55. }
  56. }
  57. return false;
  58. }
  59. bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
  60. MS_EXCEPTION_IF_NULL(task);
  61. for (auto &input : task->input_need_wait_tensors_) {
  62. MS_EXCEPTION_IF_NULL(input);
  63. if (input->NeedWait()) {
  64. return false;
  65. }
  66. }
  67. auto session = task->session_;
  68. MS_EXCEPTION_IF_NULL(session);
  69. auto graph = session->GetGraph(task->graph_id_);
  70. if (graph != nullptr) {
  71. return graph->IsPreGraphFinished();
  72. }
  73. return true;
  74. }
  75. void WaitLockedInputs(const std::shared_ptr<RunGraphTask> &task) {
  76. bool need_lock = false;
  77. for (auto &tensor : task->input_tensors_) {
  78. if (tensor->NeedWait()) {
  79. if (tensor->IsGraphOutput()) {
  80. task->input_need_wait_tensors_.emplace_back(tensor);
  81. } else {
  82. need_lock = true;
  83. }
  84. }
  85. }
  86. if (need_lock) {
  87. mindspore::ScopedLongRunning long_running;
  88. for (auto &input_tensor : task->input_tensors_) {
  89. if (input_tensor->NeedWait() && !input_tensor->IsGraphOutput()) {
  90. MsException::Instance().CheckException();
  91. input_tensor->Wait();
  92. }
  93. }
  94. MsException::Instance().CheckException();
  95. }
  96. // need lock input parameters for optimizer
  97. for (auto &need_lock_tensor : task->input_need_lock_tensors_) {
  98. need_lock_tensor->SetNeedWait(true);
  99. }
  100. }
  101. } // namespace
  102. void CompileNodesTask::Run() {
  103. MS_EXCEPTION_IF_NULL(session_);
  104. MS_EXCEPTION_IF_NULL(segment_);
  105. graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_);
  106. }
  107. void CompileGraphTask::Run() {
  108. MS_EXCEPTION_IF_NULL(session_);
  109. graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_));
  110. }
  111. void BuildGraphTask::Run() {
  112. MS_EXCEPTION_IF_NULL(session_);
  113. session_->BuildGraphImpl(graph_id_);
  114. }
  115. void RunGraphTask::Run() {
  116. MS_EXCEPTION_IF_NULL(session_);
  117. MS_LOG(INFO) << "Start run graph " << graph_id_;
  118. auto graph = session_->GetGraph(graph_id_);
  119. if (graph == nullptr) {
  120. MS_LOG(ERROR) << "Invalid graph id " << graph_id_;
  121. return;
  122. }
  123. graph->ResetGraphRunningStatus();
  124. if (device::KernelRuntime::UseMemScheduler()) {
  125. graph->SetOutputNodeToTensor(node_to_tensor_);
  126. }
  127. try {
  128. session_->LoadInputs(graph_id_, input_tensors_);
  129. session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
  130. std::map<DeviceAddressPtr, DeviceAddressPtr> new_to_old_device_address;
  131. session_->UpdateOutputTensors(&outputs_, tensor_to_node_, &new_to_old_device_address);
  132. } catch (const std::exception &e) {
  133. session_->ReportErrorMessage();
  134. ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
  135. MsException::Instance().SetException();
  136. }
  137. MS_LOG(INFO) << "End run graph " << graph_id_;
  138. graph->OnRunGraphFinished();
  139. std::set<TensorPtr> need_notify_tensors(input_need_lock_tensors_.begin(), input_need_lock_tensors_.end());
  140. GetNeedNotifyTensors(&outputs_, &need_notify_tensors);
  141. for (auto &tensor : need_notify_tensors) {
  142. if (tensor != nullptr) {
  143. tensor->SetNeedWait(false);
  144. }
  145. }
  146. ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
  147. }
  148. void RunOpTask::Run() {
  149. MS_EXCEPTION_IF_NULL(session_);
  150. session_->RunOpImpl(graph_info_, op_run_info_, input_tensors_, &outputs_, tensors_mask_);
  151. }
  152. void RunOpsInGraphTask::Run() {
  153. MS_EXCEPTION_IF_NULL(session_);
  154. session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_);
  155. }
  156. void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
  157. void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
  158. Executor::Executor(const std::string &device_name, uint32_t device_id) {
  159. device_name_ = device_name;
  160. device_id_ = device_id;
  161. worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
  162. }
  163. Executor::~Executor() {
  164. try {
  165. WorkerJoin();
  166. } catch (const std::exception &e) {
  167. MS_LOG(ERROR) << "Executor call destructor failed: " << e.what();
  168. } catch (...) {
  169. MS_LOG(ERROR) << "KernelGraph call destructor failed";
  170. }
  171. }
  172. void Executor::WorkerJoin() {
  173. // Avoid worker thread join itself which will cause deadlock
  174. if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
  175. {
  176. std::lock_guard<std::mutex> lock(task_mutex_);
  177. auto task = std::make_shared<ExitTask>();
  178. ready_tasks_.push(task);
  179. task_cond_var_.notify_all();
  180. }
  181. worker_->join();
  182. }
  183. }
  184. void Executor::WorkerLoop() {
  185. while (true) {
  186. std::shared_ptr<Task> task;
  187. {
  188. std::unique_lock<std::mutex> lock(task_mutex_);
  189. task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
  190. task = ready_tasks_.front();
  191. ready_tasks_.pop();
  192. }
  193. MS_EXCEPTION_IF_NULL(task);
  194. enum TaskType task_type = task->type_;
  195. bool task_sync_flag = task->sync_run_;
  196. if (task_type == kExit) {
  197. OnWorkerExit();
  198. return;
  199. }
  200. try {
  201. if (task->session_ != nullptr) {
  202. task->session_->SetThreadContext();
  203. }
  204. task->Run();
  205. if (task->session_ != nullptr) {
  206. task->session_->ReportWarningMessage();
  207. }
  208. } catch (const std::exception &e) {
  209. if (task->session_ != nullptr) {
  210. task->session_->ReportErrorMessage();
  211. }
  212. ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
  213. MsException::Instance().SetException();
  214. }
  215. {
  216. std::lock_guard<std::mutex> lock(done_task_mutex_);
  217. done_tasks_.emplace_back(std::move(task));
  218. }
  219. if (task_type != kRunGraph || task_sync_flag) {
  220. std::lock_guard<std::mutex> lock(task_mutex_);
  221. sync_run_task_finished_ = true;
  222. sync_cond_var_.notify_all();
  223. }
  224. }
  225. }
  226. std::vector<std::shared_ptr<RunGraphTask>> Executor::GetReadyTasksFromPendingList() {
  227. std::vector<std::shared_ptr<RunGraphTask>> ready_tasks;
  228. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  229. for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
  230. auto task = *iter;
  231. if (IsTaskReady(task)) {
  232. (void)ready_tasks.emplace_back(task);
  233. iter = pending_tasks_.erase(iter);
  234. } else {
  235. ++iter;
  236. }
  237. }
  238. return ready_tasks;
  239. }
  240. void Executor::OnEvent(const ExecutorEvent &event) {
  241. if (event == ExecutorEvent::kRunGraphFinished) {
  242. OnRunGraphFinished();
  243. } else if (event == ExecutorEvent::kClear) {
  244. OnClear();
  245. } else if (event == ExecutorEvent::kException) {
  246. OnException();
  247. }
  248. }
  249. void Executor::OnClear() {
  250. {
  251. mindspore::ScopedLongRunning long_running;
  252. WorkerJoin();
  253. }
  254. ClearDoneTasks();
  255. }
  256. void Executor::OnException() {
  257. std::vector<std::shared_ptr<Task>> done_tasks;
  258. {
  259. std::lock_guard<std::mutex> lock(task_mutex_);
  260. while (!ready_tasks_.empty()) {
  261. (void)done_tasks.emplace_back(ready_tasks_.front());
  262. ready_tasks_.pop();
  263. }
  264. }
  265. {
  266. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  267. (void)std::copy(pending_tasks_.begin(), pending_tasks_.end(), std::back_inserter(done_tasks));
  268. pending_tasks_.clear();
  269. }
  270. {
  271. std::lock_guard<std::mutex> lock(done_task_mutex_);
  272. (void)done_tasks_.insert(done_tasks_.end(), done_tasks.begin(), done_tasks.end());
  273. }
  274. }
  275. void Executor::OnRunGraphFinished() {
  276. auto ready_tasks = GetReadyTasksFromPendingList();
  277. std::lock_guard<std::mutex> lock(task_mutex_);
  278. for (auto &task : ready_tasks) {
  279. ready_tasks_.push(task);
  280. }
  281. if (!ready_tasks.empty()) {
  282. task_cond_var_.notify_all();
  283. }
  284. reenter_cond_var_.notify_all();
  285. }
  286. void Executor::ClearDoneTasks() {
  287. std::lock_guard<std::mutex> lock(done_task_mutex_);
  288. done_tasks_.clear();
  289. }
  290. void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run) {
  291. if (sync) {
  292. ClearDoneTasks();
  293. }
  294. {
  295. std::lock_guard<std::mutex> lock(task_mutex_);
  296. sync_run_task_finished_ = false;
  297. ready_tasks_.push(task);
  298. }
  299. task_cond_var_.notify_all();
  300. if (sync && !sync_run_task_finished_) {
  301. std::unique_lock<std::mutex> lock(task_mutex_);
  302. if (sync && long_run) {
  303. mindspore::ScopedLongRunning long_running;
  304. sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
  305. } else {
  306. sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
  307. }
  308. }
  309. ClearDoneTasks();
  310. MsException::Instance().CheckException();
  311. }
  312. GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
  313. const AnfNodePtrList &outputs) {
  314. auto task = std::make_shared<CompileNodesTask>();
  315. task->session_ = session;
  316. task->segment_ = segment;
  317. task->output_nodes_ = outputs;
  318. RunTask(task, true);
  319. return task->graph_id_;
  320. }
  321. GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
  322. auto task = std::make_shared<CompileGraphTask>();
  323. task->session_ = session;
  324. task->func_graph_ = func_graph.get();
  325. RunTask(task, true);
  326. return task->graph_id_;
  327. }
  328. void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) {
  329. auto task = std::make_shared<BuildGraphTask>();
  330. task->session_ = session;
  331. task->graph_id_ = graphId;
  332. RunTask(task, true);
  333. }
  334. void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
  335. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  336. MS_EXCEPTION_IF_NULL(session);
  337. MS_EXCEPTION_IF_NULL(outputs);
  338. auto task = std::make_shared<RunGraphTask>();
  339. task->session_ = session;
  340. task->graph_id_ = graph_id;
  341. task->input_tensors_ = inputs;
  342. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_, &task->node_to_tensor_);
  343. task->outputs_ = *outputs;
  344. task->sync_run_ = true;
  345. RunTask(task, true, true);
  346. }
  347. void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
  348. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  349. MS_EXCEPTION_IF_NULL(session);
  350. MS_EXCEPTION_IF_NULL(outputs);
  351. auto task = std::make_shared<RunGraphTask>();
  352. task->session_ = session;
  353. task->graph_id_ = graph_id;
  354. task->input_tensors_ = inputs;
  355. task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs);
  356. auto graph = session->GetGraph(task->graph_id_);
  357. if (graph != nullptr && !graph->IsPostGraphFinished()) {
  358. mindspore::ScopedLongRunning long_running;
  359. std::unique_lock<std::mutex> lock(reenter_mutex_);
  360. reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); });
  361. MsException::Instance().CheckException();
  362. }
  363. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_, &task->node_to_tensor_);
  364. // maintain a copy of output vector
  365. task->outputs_ = *outputs;
  366. // Run graph synchronously when the graph require gil.
  367. if (graph != nullptr && graph->is_need_gil()) {
  368. std::unique_lock<std::mutex> lock(reenter_mutex_);
  369. reenter_cond_var_.wait(lock, [&graph] { return graph->IsPreGraphFinished(); });
  370. MsException::Instance().CheckException();
  371. task->sync_run_ = true;
  372. RunTask(task, true, true);
  373. return;
  374. }
  375. // sync run graph without output tensor(int dataset graph)
  376. if ((!TensorInVector(outputs) && !graph->HasPostGraph())) {
  377. task->sync_run_ = true;
  378. RunTask(task, true, true);
  379. return;
  380. }
  381. WaitLockedInputs(task);
  382. for (auto &tensor_node : task->tensor_to_node_) {
  383. tensor_node.first->SetNeedWait(true);
  384. }
  385. {
  386. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  387. if (!IsTaskReady(task)) {
  388. ClearDoneTasks();
  389. pending_tasks_.push_back(task);
  390. return;
  391. }
  392. }
  393. RunTask(task, false);
  394. }
  395. void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  396. std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
  397. const std::vector<int64_t> &tensors_mask) {
  398. MS_EXCEPTION_IF_NULL(session);
  399. MS_EXCEPTION_IF_NULL(input_tensors);
  400. MS_EXCEPTION_IF_NULL(outputs);
  401. MS_EXCEPTION_IF_NULL(op_run_info);
  402. auto ms_context = MsContext::GetInstance();
  403. auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  404. if (target == kGPUDevice) {
  405. for (auto &tensor : *input_tensors) {
  406. if (tensor->NeedWait()) {
  407. tensor->Wait();
  408. }
  409. }
  410. {
  411. // Release GIL before calling into (potentially long-running) C++ code
  412. if (Py_IsInitialized()) {
  413. py::gil_scoped_release release;
  414. session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
  415. } else {
  416. session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
  417. }
  418. }
  419. } else {
  420. auto task = std::make_shared<RunOpTask>();
  421. task->session_ = session;
  422. task->op_run_info_ = op_run_info;
  423. task->graph_info_ = graph_info;
  424. task->input_tensors_ = input_tensors;
  425. task->tensors_mask_ = tensors_mask;
  426. for (auto &tensor : *input_tensors) {
  427. if (tensor->NeedWait()) {
  428. tensor->Wait();
  429. }
  430. }
  431. RunTask(task, true, true);
  432. *outputs = task->outputs_;
  433. }
  434. }
  435. void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id,
  436. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  437. MS_EXCEPTION_IF_NULL(session);
  438. MS_EXCEPTION_IF_NULL(outputs);
  439. auto task = std::make_shared<RunOpsInGraphTask>();
  440. task->session_ = session;
  441. task->graph_id_ = graph_id;
  442. task->input_tensors_ = inputs;
  443. RunTask(task, true, true);
  444. *outputs = task->outputs_;
  445. }
  446. bool Executor::CreateCommGroup(const std::string &group_name, const std::vector<uint32_t> &ranks) {
  447. auto task = std::make_shared<CreateCommGroupTask>();
  448. task->group_name_ = group_name;
  449. task->ranks_ = ranks;
  450. RunTask(task, true);
  451. return task->result_;
  452. }
  453. bool Executor::DestroyCommGroup(const std::string &group_name) {
  454. auto task = std::make_shared<DestroyCommGroupTask>();
  455. task->group_name_ = group_name;
  456. RunTask(task, true);
  457. return task->result_;
  458. }
  459. void Executor::OnWorkerExit() {
  460. if (device_name_ == kAscendDevice) {
  461. device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
  462. }
  463. }
  464. } // namespace session
  465. } // namespace mindspore