You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/executor.h"
  17. #include "backend/session/executor_manager.h"
  18. #include <algorithm>
  19. #include <exception>
  20. #include "runtime/device/kernel_runtime_manager.h"
  21. #include "utils/comm_manager.h"
  22. #include "utils/scoped_long_running.h"
  23. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  24. #include "ps/ps_cache/ps_cache_manager.h"
  25. #endif
  26. namespace mindspore {
  27. namespace session {
  28. namespace {
  29. void UpdateOutputTensors(const VectorRef *outputs,
  30. const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
  31. MS_EXCEPTION_IF_NULL(outputs);
  32. for (auto item : *outputs) {
  33. if (utils::isa<VectorRefPtr>(item)) {
  34. auto vector_ref = utils::cast<VectorRef>(item);
  35. UpdateOutputTensors(&vector_ref, tensor_to_node);
  36. } else if (utils::isa<tensor::TensorPtr>(item)) {
  37. auto tensor = utils::cast<tensor::TensorPtr>(item);
  38. MS_EXCEPTION_IF_NULL(tensor);
  39. auto iter = tensor_to_node.find(tensor);
  40. if (iter != tensor_to_node.end()) {
  41. auto &node = iter->second.first;
  42. auto &output_index = iter->second.second;
  43. auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
  44. tensor->set_device_address(address);
  45. if (AnfAlgo::IsDynamicShape(node)) {
  46. auto updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
  47. ShapeVector int_shape;
  48. std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
  49. tensor->set_shape(int_shape);
  50. }
  51. }
  52. if (tensor->NeedSyncDeviceToHostImmediately()) {
  53. tensor->data_sync(false);
  54. tensor->set_device_address(nullptr);
  55. tensor->set_sync_status(kNeedSyncHostToDevice);
  56. }
  57. }
  58. }
  59. }
  60. void NotifyOutputTensors(const VectorRef *outputs) {
  61. MS_EXCEPTION_IF_NULL(outputs);
  62. for (auto item : *outputs) {
  63. if (utils::isa<VectorRefPtr>(item)) {
  64. auto vector_ref = utils::cast<VectorRef>(item);
  65. NotifyOutputTensors(&vector_ref);
  66. } else if (utils::isa<tensor::TensorPtr>(item)) {
  67. auto tensor = utils::cast<tensor::TensorPtr>(item);
  68. MS_EXCEPTION_IF_NULL(tensor);
  69. tensor->SetNeedWait(false);
  70. }
  71. }
  72. }
  73. bool TensorInVector(const VectorRef *outputs) {
  74. MS_EXCEPTION_IF_NULL(outputs);
  75. for (auto item : *outputs) {
  76. if (utils::isa<VectorRefPtr>(item)) {
  77. auto vector_ref = utils::cast<VectorRef>(item);
  78. if (TensorInVector(&vector_ref)) {
  79. return true;
  80. }
  81. } else if (utils::isa<tensor::TensorPtr>(item)) {
  82. return true;
  83. }
  84. }
  85. return false;
  86. }
  87. } // namespace
  88. void CompileNodesTask::Run() {
  89. MS_EXCEPTION_IF_NULL(session_);
  90. MS_EXCEPTION_IF_NULL(segment_);
  91. graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_);
  92. }
  93. void CompileGraphTask::Run() {
  94. MS_EXCEPTION_IF_NULL(session_);
  95. graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_));
  96. }
  97. void BuildGraphTask::Run() {
  98. MS_EXCEPTION_IF_NULL(session_);
  99. session_->BuildGraphImpl(graph_id_);
  100. }
  101. void RunGraphTask::Run() {
  102. MS_EXCEPTION_IF_NULL(session_);
  103. MS_LOG(INFO) << "Start run graph " << graph_id_;
  104. auto graph = session_->GetGraph(graph_id_);
  105. if (graph == nullptr) {
  106. MS_LOG(ERROR) << "Invalid graph id " << graph_id_;
  107. return;
  108. }
  109. graph->ResetGraphRunningStatus();
  110. try {
  111. session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
  112. UpdateOutputTensors(&outputs_, tensor_to_node_);
  113. } catch (const std::exception &e) {
  114. ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
  115. MsException::Instance().SetException();
  116. }
  117. MS_LOG(INFO) << "End run graph " << graph_id_;
  118. graph->OnRunGraphFinished();
  119. for (auto &tensor : input_need_lock_tensors_) {
  120. tensor->SetNeedWait(false);
  121. }
  122. NotifyOutputTensors(&outputs_);
  123. ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
  124. }
  125. void RunOpTask::Run() {
  126. MS_EXCEPTION_IF_NULL(session_);
  127. session_->RunOpImpl(graph_info_, op_run_info_, input_tensors_, &outputs_, tensors_mask_);
  128. }
  129. void RunOpsInGraphTask::Run() {
  130. MS_EXCEPTION_IF_NULL(session_);
  131. session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_);
  132. }
  133. void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
  134. void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
  135. Executor::Executor(const std::string &device_name, uint32_t device_id) {
  136. device_name_ = device_name;
  137. device_id_ = device_id;
  138. worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
  139. }
  140. Executor::~Executor() { WorkerJoin(); }
  141. void Executor::WorkerJoin() {
  142. // Avoid worker thread join itself which will cause deadlock
  143. if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
  144. {
  145. std::lock_guard<std::mutex> lock(task_mutex_);
  146. auto task = std::make_shared<ExitTask>();
  147. ready_tasks_.push(task);
  148. task_cond_var_.notify_all();
  149. }
  150. worker_->join();
  151. }
  152. }
  153. void Executor::WorkerLoop() {
  154. while (true) {
  155. std::shared_ptr<Task> task;
  156. {
  157. std::unique_lock<std::mutex> lock(task_mutex_);
  158. task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
  159. task = ready_tasks_.front();
  160. ready_tasks_.pop();
  161. }
  162. if (task->type_ == kExit) {
  163. OnWorkerExit();
  164. return;
  165. }
  166. try {
  167. task->Run();
  168. } catch (const std::exception &e) {
  169. ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
  170. MsException::Instance().SetException();
  171. }
  172. {
  173. std::lock_guard<std::mutex> lock(done_task_mutex_);
  174. done_tasks_.emplace_back(task);
  175. }
  176. if (task->type_ != kRunGraph || task->sync_run_) {
  177. sync_run_task_finished_ = true;
  178. sync_cond_var_.notify_all();
  179. }
  180. }
  181. }
  182. std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
  183. std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
  184. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  185. for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
  186. auto task = *iter;
  187. if (IsTaskReady(task)) {
  188. new_ready_tasks.emplace_back(task);
  189. pending_tasks_.erase(iter++);
  190. } else {
  191. iter++;
  192. }
  193. }
  194. return new_ready_tasks;
  195. }
  196. void Executor::OnEvent(const ExecutorEvent &event) {
  197. if (event == ExecutorEvent::kRunGraphFinished) {
  198. OnRunGraphFinished();
  199. } else if (event == ExecutorEvent::kClear) {
  200. WorkerJoin();
  201. } else if (event == ExecutorEvent::kException) {
  202. OnException();
  203. }
  204. }
  205. void Executor::OnException() {
  206. std::vector<std::shared_ptr<Task>> new_done_tasks;
  207. {
  208. std::lock_guard<std::mutex> lock(task_mutex_);
  209. while (!ready_tasks_.empty()) {
  210. new_done_tasks.emplace_back(ready_tasks_.front());
  211. ready_tasks_.pop();
  212. }
  213. }
  214. {
  215. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  216. std::copy(pending_tasks_.begin(), pending_tasks_.end(), std::back_inserter(new_done_tasks));
  217. pending_tasks_.clear();
  218. }
  219. {
  220. std::lock_guard<std::mutex> lock(done_task_mutex_);
  221. (void)done_tasks_.insert(done_tasks_.end(), new_done_tasks.begin(), new_done_tasks.end());
  222. }
  223. }
  224. void Executor::OnRunGraphFinished() {
  225. auto new_ready_tasks = GetNewReadyTasks();
  226. std::lock_guard<std::mutex> lock(task_mutex_);
  227. for (auto &task : new_ready_tasks) {
  228. ready_tasks_.push(task);
  229. }
  230. if (!new_ready_tasks.empty()) {
  231. task_cond_var_.notify_all();
  232. }
  233. reenter_cond_var_.notify_all();
  234. }
  235. bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
  236. MS_EXCEPTION_IF_NULL(task);
  237. for (auto &input : task->input_need_wait_tensors_) {
  238. MS_EXCEPTION_IF_NULL(input);
  239. if (input->NeedWait()) {
  240. return false;
  241. }
  242. }
  243. auto session = task->session_;
  244. MS_EXCEPTION_IF_NULL(session);
  245. auto graph = session->GetGraph(task->graph_id_);
  246. if (graph != nullptr) {
  247. return graph->IsPreGraphFinished();
  248. }
  249. return true;
  250. }
  251. void Executor::ClearDoneTasks() {
  252. std::lock_guard<std::mutex> lock(done_task_mutex_);
  253. done_tasks_.clear();
  254. }
  255. void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run) {
  256. {
  257. std::lock_guard<std::mutex> lock(task_mutex_);
  258. ready_tasks_.push(task);
  259. }
  260. sync_run_task_finished_ = false;
  261. task_cond_var_.notify_all();
  262. if (sync && !sync_run_task_finished_) {
  263. std::unique_lock<std::mutex> lock(task_mutex_);
  264. if (long_run) {
  265. mindspore::ScopedLongRunning long_running;
  266. sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
  267. } else {
  268. sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
  269. }
  270. }
  271. ClearDoneTasks();
  272. MsException::Instance().CheckException();
  273. }
  274. GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
  275. const AnfNodePtrList &outputs) {
  276. auto task = std::make_shared<CompileNodesTask>();
  277. task->session_ = session;
  278. task->segment_ = segment;
  279. task->output_nodes_ = outputs;
  280. RunTask(task, true);
  281. return task->graph_id_;
  282. }
  283. GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
  284. auto task = std::make_shared<CompileGraphTask>();
  285. task->session_ = session;
  286. task->func_graph_ = func_graph.get();
  287. RunTask(task, true);
  288. return task->graph_id_;
  289. }
  290. void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) {
  291. auto task = std::make_shared<BuildGraphTask>();
  292. task->session_ = session;
  293. task->graph_id_ = graphId;
  294. RunTask(task, true);
  295. }
  296. void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
  297. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  298. MS_EXCEPTION_IF_NULL(session);
  299. MS_EXCEPTION_IF_NULL(outputs);
  300. auto task = std::make_shared<RunGraphTask>();
  301. task->session_ = session;
  302. task->graph_id_ = graph_id;
  303. task->input_tensors_ = inputs;
  304. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  305. task->outputs_ = *outputs;
  306. task->sync_run_ = true;
  307. RunTask(task, true, true);
  308. }
  309. void Executor::WaitTaskGraphAvailable(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task) {
  310. bool need_lock = false;
  311. for (auto &tensor : task->input_tensors_) {
  312. if (tensor->NeedWait()) {
  313. if (tensor->IsGraphOutput()) {
  314. task->input_need_wait_tensors_.emplace_back(tensor);
  315. } else {
  316. need_lock = true;
  317. }
  318. }
  319. }
  320. if (need_lock) {
  321. mindspore::ScopedLongRunning long_running;
  322. for (auto &tensor : task->input_tensors_) {
  323. if (tensor->NeedWait() && !tensor->IsGraphOutput()) {
  324. MsException::Instance().CheckException();
  325. tensor->Wait();
  326. }
  327. }
  328. MsException::Instance().CheckException();
  329. }
  330. // need lock input parameters for optimizer
  331. for (auto &tensor : task->input_need_lock_tensors_) {
  332. tensor->SetNeedWait(true);
  333. }
  334. }
  335. void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
  336. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  337. MS_EXCEPTION_IF_NULL(session);
  338. MS_EXCEPTION_IF_NULL(outputs);
  339. auto task = std::make_shared<RunGraphTask>();
  340. task->session_ = session;
  341. task->graph_id_ = graph_id;
  342. task->input_tensors_ = inputs;
  343. task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs);
  344. auto graph = session->GetGraph(task->graph_id_);
  345. if (graph != nullptr && !graph->IsPostGraphFinished()) {
  346. mindspore::ScopedLongRunning long_running;
  347. std::unique_lock<std::mutex> lock(reenter_mutex_);
  348. reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); });
  349. MsException::Instance().CheckException();
  350. }
  351. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  352. // maintain a copy of output vector
  353. task->outputs_ = *outputs;
  354. // sync run graph without output tensor(int dataset graph)
  355. if (!TensorInVector(outputs)) {
  356. task->sync_run_ = true;
  357. RunTask(task, true, true);
  358. return;
  359. }
  360. WaitTaskGraphAvailable(session, task);
  361. if (!IsTaskReady(task)) {
  362. std::lock_guard<std::mutex> lock(pending_task_mutex_);
  363. pending_tasks_.push_back(task);
  364. return;
  365. }
  366. RunTask(task, false);
  367. }
  368. void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  369. std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
  370. const std::vector<int64_t> &tensors_mask) {
  371. MS_EXCEPTION_IF_NULL(session);
  372. for (auto &tensor : *input_tensors) {
  373. if (tensor->NeedWait()) {
  374. tensor->Wait();
  375. }
  376. }
  377. session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
  378. }
  379. void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id,
  380. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  381. MS_EXCEPTION_IF_NULL(session);
  382. MS_EXCEPTION_IF_NULL(outputs);
  383. auto task = std::make_shared<RunOpsInGraphTask>();
  384. task->session_ = session;
  385. task->graph_id_ = graph_id;
  386. task->input_tensors_ = inputs;
  387. RunTask(task, true, true);
  388. *outputs = task->outputs_;
  389. }
  390. bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
  391. auto task = std::make_shared<CreateCommGroupTask>();
  392. task->group_name_ = group_name;
  393. task->ranks_ = ranks;
  394. RunTask(task, true);
  395. return task->result_;
  396. }
  397. bool Executor::DestroyCommGroup(const std::string &group_name) {
  398. auto task = std::make_shared<DestroyCommGroupTask>();
  399. task->group_name_ = group_name;
  400. RunTask(task, true);
  401. return task->result_;
  402. }
  403. void Executor::OnWorkerExit() {
  404. if (device_name_ == kAscendDevice) {
  405. device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
  406. }
  407. }
  408. } // namespace session
  409. } // namespace mindspore