You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/executor.h"
  17. #include "backend/session/executor_manager.h"
  18. #include <algorithm>
  19. #include <exception>
  20. #include "runtime/device/kernel_runtime_manager.h"
  21. #include "utils/comm_manager.h"
  22. #include "utils/scoped_long_running.h"
  23. #include "pybind_api/ir/tensor_py.h"
  24. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  25. #include "ps/ps_cache/ps_cache_manager.h"
  26. #endif
  27. using mindspore::tensor::TensorPy;
  28. namespace mindspore {
  29. namespace session {
  30. namespace {
  31. void UpdateOutputTensors(const VectorRef *outputs,
  32. const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
  33. MS_EXCEPTION_IF_NULL(outputs);
  34. for (auto item : *outputs) {
  35. if (utils::isa<VectorRefPtr>(item)) {
  36. auto vector_ref = utils::cast<VectorRef>(item);
  37. UpdateOutputTensors(&vector_ref, tensor_to_node);
  38. } else if (utils::isa<tensor::TensorPtr>(item)) {
  39. auto tensor = utils::cast<tensor::TensorPtr>(item);
  40. MS_EXCEPTION_IF_NULL(tensor);
  41. auto iter = tensor_to_node.find(tensor);
  42. if (iter != tensor_to_node.end()) {
  43. auto &node = iter->second.first;
  44. auto &output_index = iter->second.second;
  45. auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
  46. tensor->set_device_address(address);
  47. if (AnfAlgo::IsDynamicShape(node)) {
  48. auto updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
  49. ShapeVector int_shape;
  50. std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
  51. tensor->set_shape(int_shape);
  52. }
  53. }
  54. if (tensor->NeedSyncDeviceToHostImmediately()) {
  55. tensor->data_sync(false);
  56. tensor->set_device_address(nullptr);
  57. tensor->set_sync_status(kNeedSyncHostToDevice);
  58. }
  59. }
  60. }
  61. }
  62. void NotifyOutputTensors(const VectorRef *outputs) {
  63. MS_EXCEPTION_IF_NULL(outputs);
  64. for (auto item : *outputs) {
  65. if (utils::isa<VectorRefPtr>(item)) {
  66. auto vector_ref = utils::cast<VectorRef>(item);
  67. NotifyOutputTensors(&vector_ref);
  68. } else if (utils::isa<tensor::TensorPtr>(item)) {
  69. auto tensor = utils::cast<tensor::TensorPtr>(item);
  70. MS_EXCEPTION_IF_NULL(tensor);
  71. tensor->SetNeedWait(false);
  72. }
  73. }
  74. }
  75. bool TensorInVector(const VectorRef *outputs) {
  76. MS_EXCEPTION_IF_NULL(outputs);
  77. for (auto item : *outputs) {
  78. if (utils::isa<VectorRefPtr>(item)) {
  79. auto vector_ref = utils::cast<VectorRef>(item);
  80. if (TensorInVector(&vector_ref)) {
  81. return true;
  82. }
  83. } else if (utils::isa<tensor::TensorPtr>(item)) {
  84. return true;
  85. }
  86. }
  87. return false;
  88. }
  89. } // namespace
  90. void CompileNodesTask::Run() {
  91. MS_EXCEPTION_IF_NULL(session_);
  92. MS_EXCEPTION_IF_NULL(segment_);
  93. graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_);
  94. }
  95. void CompileGraphTask::Run() {
  96. MS_EXCEPTION_IF_NULL(session_);
  97. graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_));
  98. }
  99. void BuildGraphTask::Run() {
  100. MS_EXCEPTION_IF_NULL(session_);
  101. session_->BuildGraphImpl(graph_id_);
  102. }
  103. void RunGraphTask::Run() {
  104. MS_EXCEPTION_IF_NULL(session_);
  105. MS_LOG(INFO) << "Start run graph " << graph_id_;
  106. auto graph = session_->GetGraph(graph_id_);
  107. if (graph == nullptr) {
  108. MS_LOG(ERROR) << "Invalid graph id " << graph_id_;
  109. return;
  110. }
  111. graph->ResetGraphRunningStatus();
  112. try {
  113. session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
  114. UpdateOutputTensors(&outputs_, tensor_to_node_);
  115. } catch (const std::exception &e) {
  116. ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
  117. MsException::Instance().SetException();
  118. }
  119. MS_LOG(INFO) << "End run graph " << graph_id_;
  120. graph->OnRunGraphFinished();
  121. for (auto &tensor : input_need_lock_tensors_) {
  122. tensor->SetNeedWait(false);
  123. }
  124. NotifyOutputTensors(&outputs_);
  125. ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
  126. }
  127. void RunOpTask::Run() {
  128. MS_EXCEPTION_IF_NULL(session_);
  129. session_->RunOpImpl(graph_info_, op_run_info_, input_tensors_, &outputs_, tensors_mask_);
  130. }
  131. void RunOpsInGraphTask::Run() {
  132. MS_EXCEPTION_IF_NULL(session_);
  133. session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_);
  134. }
  135. void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
  136. void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
  137. Executor::Executor(const std::string &device_name, uint32_t device_id) {
  138. device_name_ = device_name;
  139. device_id_ = device_id;
  140. worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
  141. }
  142. Executor::~Executor() { WorkerJoin(); }
  143. void Executor::WorkerJoin() {
  144. // Avoid worker thread join itself which will cause deadlock
  145. if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
  146. {
  147. std::lock_guard<std::mutex> lock(task_mutex_);
  148. auto task = std::make_shared<ExitTask>();
  149. ready_tasks_.push(task);
  150. task_cond_var_.notify_all();
  151. }
  152. worker_->join();
  153. }
  154. }
  155. void Executor::WorkerLoop() {
  156. while (true) {
  157. std::shared_ptr<Task> task;
  158. {
  159. std::unique_lock<std::mutex> lock(task_mutex_);
  160. task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
  161. task = ready_tasks_.front();
  162. ready_tasks_.pop();
  163. }
  164. if (task->type_ == kExit) {
  165. OnWorkerExit();
  166. return;
  167. }
  168. try {
  169. task->Run();
  170. } catch (const std::exception &e) {
  171. ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
  172. MsException::Instance().SetException();
  173. }
  174. {
  175. std::lock_guard<std::mutex> lock(done_task_mutex_);
  176. done_tasks_.emplace_back(task);
  177. }
  178. if (task->type_ != kRunGraph || task->sync_run_) {
  179. sync_run_task_finished_ = true;
  180. sync_cond_var_.notify_all();
  181. }
  182. }
  183. }
  184. std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
  185. std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
  186. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  187. for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
  188. auto task = *iter;
  189. if (IsTaskReady(task)) {
  190. new_ready_tasks.emplace_back(task);
  191. pending_tasks_.erase(iter++);
  192. } else {
  193. iter++;
  194. }
  195. }
  196. return new_ready_tasks;
  197. }
  198. void Executor::OnEvent(const ExecutorEvent &event) {
  199. if (event == ExecutorEvent::kRunGraphFinished) {
  200. OnRunGraphFinished();
  201. } else if (event == ExecutorEvent::kClear) {
  202. WorkerJoin();
  203. } else if (event == ExecutorEvent::kException) {
  204. OnException();
  205. }
  206. }
  207. void Executor::OnException() {
  208. std::vector<std::shared_ptr<Task>> new_done_tasks;
  209. {
  210. std::lock_guard<std::mutex> lock(task_mutex_);
  211. while (!ready_tasks_.empty()) {
  212. new_done_tasks.emplace_back(ready_tasks_.front());
  213. ready_tasks_.pop();
  214. }
  215. }
  216. {
  217. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  218. std::copy(pending_tasks_.begin(), pending_tasks_.end(), std::back_inserter(new_done_tasks));
  219. pending_tasks_.clear();
  220. }
  221. {
  222. std::lock_guard<std::mutex> lock(done_task_mutex_);
  223. (void)done_tasks_.insert(done_tasks_.end(), new_done_tasks.begin(), new_done_tasks.end());
  224. }
  225. }
  226. void Executor::OnRunGraphFinished() {
  227. auto new_ready_tasks = GetNewReadyTasks();
  228. std::lock_guard<std::mutex> lock(task_mutex_);
  229. for (auto &task : new_ready_tasks) {
  230. ready_tasks_.push(task);
  231. }
  232. if (!new_ready_tasks.empty()) {
  233. task_cond_var_.notify_all();
  234. }
  235. reenter_cond_var_.notify_all();
  236. }
  237. bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
  238. MS_EXCEPTION_IF_NULL(task);
  239. for (auto &input : task->input_need_wait_tensors_) {
  240. MS_EXCEPTION_IF_NULL(input);
  241. if (input->NeedWait()) {
  242. return false;
  243. }
  244. }
  245. auto session = task->session_;
  246. MS_EXCEPTION_IF_NULL(session);
  247. auto graph = session->GetGraph(task->graph_id_);
  248. if (graph != nullptr) {
  249. return graph->IsPreGraphFinished();
  250. }
  251. return true;
  252. }
  253. void Executor::ClearDoneTasks() {
  254. std::lock_guard<std::mutex> lock(done_task_mutex_);
  255. done_tasks_.clear();
  256. }
  257. void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run) {
  258. {
  259. std::lock_guard<std::mutex> lock(task_mutex_);
  260. ready_tasks_.push(task);
  261. }
  262. sync_run_task_finished_ = false;
  263. task_cond_var_.notify_all();
  264. if (sync && !sync_run_task_finished_) {
  265. std::unique_lock<std::mutex> lock(task_mutex_);
  266. if (long_run) {
  267. mindspore::ScopedLongRunning long_running;
  268. sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
  269. } else {
  270. sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
  271. }
  272. }
  273. ClearDoneTasks();
  274. MsException::Instance().CheckException();
  275. }
  276. GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
  277. const AnfNodePtrList &outputs) {
  278. auto task = std::make_shared<CompileNodesTask>();
  279. task->session_ = session;
  280. task->segment_ = segment;
  281. task->output_nodes_ = outputs;
  282. RunTask(task, true);
  283. return task->graph_id_;
  284. }
  285. GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
  286. auto task = std::make_shared<CompileGraphTask>();
  287. task->session_ = session;
  288. task->func_graph_ = func_graph.get();
  289. RunTask(task, true);
  290. return task->graph_id_;
  291. }
  292. void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) {
  293. auto task = std::make_shared<BuildGraphTask>();
  294. task->session_ = session;
  295. task->graph_id_ = graphId;
  296. RunTask(task, true);
  297. }
  298. void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
  299. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  300. MS_EXCEPTION_IF_NULL(session);
  301. MS_EXCEPTION_IF_NULL(outputs);
  302. auto task = std::make_shared<RunGraphTask>();
  303. task->session_ = session;
  304. task->graph_id_ = graph_id;
  305. task->input_tensors_ = inputs;
  306. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  307. task->outputs_ = *outputs;
  308. task->sync_run_ = true;
  309. RunTask(task, true, true);
  310. }
  311. void Executor::WaitLockedInputs(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task) {
  312. bool need_lock = false;
  313. for (auto &tensor : task->input_tensors_) {
  314. if (tensor->NeedWait()) {
  315. if (tensor->IsGraphOutput()) {
  316. task->input_need_wait_tensors_.emplace_back(tensor);
  317. } else {
  318. need_lock = true;
  319. }
  320. }
  321. }
  322. if (need_lock) {
  323. mindspore::ScopedLongRunning long_running;
  324. for (auto &tensor : task->input_tensors_) {
  325. if (tensor->NeedWait() && !tensor->IsGraphOutput()) {
  326. MsException::Instance().CheckException();
  327. tensor->Wait();
  328. }
  329. }
  330. MsException::Instance().CheckException();
  331. }
  332. // need lock input parameters for optimizer
  333. for (auto &tensor : task->input_need_lock_tensors_) {
  334. tensor->SetNeedWait(true);
  335. }
  336. }
  337. void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
  338. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  339. MS_EXCEPTION_IF_NULL(session);
  340. MS_EXCEPTION_IF_NULL(outputs);
  341. auto task = std::make_shared<RunGraphTask>();
  342. task->session_ = session;
  343. task->graph_id_ = graph_id;
  344. task->input_tensors_ = inputs;
  345. task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs);
  346. auto graph = session->GetGraph(task->graph_id_);
  347. if (graph != nullptr && !graph->IsPostGraphFinished()) {
  348. mindspore::ScopedLongRunning long_running;
  349. std::unique_lock<std::mutex> lock(reenter_mutex_);
  350. reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); });
  351. MsException::Instance().CheckException();
  352. }
  353. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  354. // maintain a copy of output vector
  355. task->outputs_ = *outputs;
  356. // sync run graph without output tensor(int dataset graph)
  357. if (!TensorInVector(outputs)) {
  358. task->sync_run_ = true;
  359. RunTask(task, true, true);
  360. return;
  361. }
  362. WaitLockedInputs(session, task);
  363. {
  364. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  365. if (!IsTaskReady(task)) {
  366. pending_tasks_.push_back(task);
  367. return;
  368. }
  369. }
  370. RunTask(task, false);
  371. }
  372. void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  373. std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
  374. const std::vector<int64_t> &tensors_mask) {
  375. MS_EXCEPTION_IF_NULL(session);
  376. auto ms_context = MsContext::GetInstance();
  377. auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  378. if (target == kGPUDevice) {
  379. for (auto &tensor : *input_tensors) {
  380. if (tensor->NeedWait()) {
  381. tensor->Wait();
  382. }
  383. }
  384. {
  385. // Release GIL before calling into (potentially long-running) C++ code
  386. if (Py_IsInitialized()) {
  387. py::gil_scoped_release release;
  388. session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
  389. } else {
  390. session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
  391. }
  392. }
  393. } else {
  394. auto task = std::make_shared<RunOpTask>();
  395. task->session_ = session;
  396. task->op_run_info_ = op_run_info;
  397. task->graph_info_ = graph_info;
  398. task->input_tensors_ = input_tensors;
  399. task->tensors_mask_ = tensors_mask;
  400. for (auto &tensor : *input_tensors) {
  401. if (tensor->NeedWait()) {
  402. tensor->Wait();
  403. }
  404. }
  405. RunTask(task, true, true);
  406. *outputs = task->outputs_;
  407. }
  408. }
  409. void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id,
  410. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  411. MS_EXCEPTION_IF_NULL(session);
  412. MS_EXCEPTION_IF_NULL(outputs);
  413. auto task = std::make_shared<RunOpsInGraphTask>();
  414. task->session_ = session;
  415. task->graph_id_ = graph_id;
  416. task->input_tensors_ = inputs;
  417. RunTask(task, true, true);
  418. *outputs = task->outputs_;
  419. }
  420. bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
  421. auto task = std::make_shared<CreateCommGroupTask>();
  422. task->group_name_ = group_name;
  423. task->ranks_ = ranks;
  424. RunTask(task, true);
  425. return task->result_;
  426. }
  427. bool Executor::DestroyCommGroup(const std::string &group_name) {
  428. auto task = std::make_shared<DestroyCommGroupTask>();
  429. task->group_name_ = group_name;
  430. RunTask(task, true);
  431. return task->result_;
  432. }
  433. void Executor::OnWorkerExit() {
  434. if (device_name_ == kAscendDevice) {
  435. device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
  436. }
  437. }
  438. } // namespace session
  439. } // namespace mindspore