You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/executor.h"
  17. #include <algorithm>
  18. #include <exception>
  19. #include "runtime/device/kernel_runtime_manager.h"
  20. #include "backend/session/executor_manager.h"
  21. #include "utils/comm_manager.h"
  22. #include "utils/scoped_long_running.h"
  23. namespace mindspore {
  24. namespace session {
  25. namespace {
  26. void UpdateOutputTensors(const VectorRef *outputs,
  27. const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
  28. MS_EXCEPTION_IF_NULL(outputs);
  29. for (auto item : *outputs) {
  30. if (utils::isa<VectorRefPtr>(item)) {
  31. auto vector_ref = utils::cast<VectorRef>(item);
  32. UpdateOutputTensors(&vector_ref, tensor_to_node);
  33. } else if (utils::isa<tensor::TensorPtr>(item)) {
  34. auto tensor = utils::cast<tensor::TensorPtr>(item);
  35. MS_EXCEPTION_IF_NULL(tensor);
  36. auto iter = tensor_to_node.find(tensor);
  37. if (iter != tensor_to_node.end()) {
  38. auto &node = iter->second.first;
  39. auto &output_index = iter->second.second;
  40. auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
  41. tensor->set_device_address(address);
  42. if (AnfAlgo::IsDynamicShape(node)) {
  43. auto updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
  44. ShapeVector int_shape;
  45. std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
  46. tensor->set_shape(int_shape);
  47. }
  48. }
  49. if (tensor->NeedSyncDeviceToHostImmediately()) {
  50. tensor->data_sync(false);
  51. tensor->set_device_address(nullptr);
  52. tensor->set_sync_status(kNeedSyncHostToDevice);
  53. }
  54. }
  55. }
  56. }
  57. void NotifyOutputTensors(const VectorRef *outputs) {
  58. MS_EXCEPTION_IF_NULL(outputs);
  59. for (auto item : *outputs) {
  60. if (utils::isa<VectorRefPtr>(item)) {
  61. auto vector_ref = utils::cast<VectorRef>(item);
  62. NotifyOutputTensors(&vector_ref);
  63. } else if (utils::isa<tensor::TensorPtr>(item)) {
  64. auto tensor = utils::cast<tensor::TensorPtr>(item);
  65. MS_EXCEPTION_IF_NULL(tensor);
  66. tensor->SetNeedWait(false);
  67. }
  68. }
  69. }
  70. bool TensorInVector(const VectorRef *outputs) {
  71. MS_EXCEPTION_IF_NULL(outputs);
  72. for (auto item : *outputs) {
  73. if (utils::isa<VectorRefPtr>(item)) {
  74. auto vector_ref = utils::cast<VectorRef>(item);
  75. if (TensorInVector(&vector_ref)) {
  76. return true;
  77. }
  78. } else if (utils::isa<tensor::TensorPtr>(item)) {
  79. return true;
  80. }
  81. }
  82. return false;
  83. }
  84. } // namespace
  85. void CompileNodesTask::Run() {
  86. MS_EXCEPTION_IF_NULL(session_);
  87. MS_EXCEPTION_IF_NULL(segment_);
  88. graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_);
  89. }
  90. void CompileGraphTask::Run() {
  91. MS_EXCEPTION_IF_NULL(session_);
  92. graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_));
  93. }
  94. void BuildGraphTask::Run() {
  95. MS_EXCEPTION_IF_NULL(session_);
  96. session_->BuildGraphImpl(graph_id_);
  97. }
  98. void RunGraphTask::Run() {
  99. MS_EXCEPTION_IF_NULL(session_);
  100. try {
  101. MS_LOG(INFO) << "Start run graph " << graph_id_;
  102. auto graph = session_->GetGraph(graph_id_);
  103. MS_EXCEPTION_IF_NULL(graph);
  104. graph->ResetGraphRunningStatus();
  105. session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
  106. graph->OnRunGraphFinished();
  107. UpdateOutputTensors(&outputs_, tensor_to_node_);
  108. MS_LOG(INFO) << "End run graph " << graph_id_;
  109. } catch (const std::exception &e) {
  110. MsException::GetInstance().SetException();
  111. }
  112. for (auto &tensor : input_need_lock_tensors_) {
  113. tensor->SetNeedWait(false);
  114. }
  115. NotifyOutputTensors(&outputs_);
  116. ExecutorManager::Instance().OnRunGraphFinished();
  117. }
  118. void BuildOpTask::Run() {
  119. MS_EXCEPTION_IF_NULL(session_);
  120. session_->BuildOpImpl(*op_run_info_, graph_info_, input_tensors_, tensors_mask_);
  121. }
  122. void RunOpTask::Run() {
  123. MS_EXCEPTION_IF_NULL(session_);
  124. session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_);
  125. }
  126. void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
  127. void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
  128. Executor::Executor(const std::string &device_name, uint32_t device_id) {
  129. device_name_ = device_name;
  130. device_id_ = device_id;
  131. worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
  132. }
  133. Executor::~Executor() { WorkerJoin(); }
  134. void Executor::WorkerJoin() {
  135. // Avoid worker thread join itself which will cause deadlock
  136. if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
  137. {
  138. std::unique_lock<std::mutex> lock(task_mutex_);
  139. auto task = std::make_shared<ExitTask>();
  140. ready_tasks_.push(task);
  141. task_cond_var_.notify_all();
  142. }
  143. worker_->join();
  144. }
  145. }
  146. void Executor::WorkerLoop() {
  147. while (true) {
  148. std::shared_ptr<Task> task;
  149. {
  150. std::unique_lock<std::mutex> lock(task_mutex_);
  151. task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
  152. task = ready_tasks_.front();
  153. ready_tasks_.pop();
  154. }
  155. if (task->type_ == kExit) {
  156. OnWorkerExit();
  157. return;
  158. }
  159. try {
  160. task->Run();
  161. } catch (const std::exception &e) {
  162. MsException::GetInstance().SetException();
  163. }
  164. {
  165. std::unique_lock<std::mutex> lock(task_mutex_);
  166. done_tasks_.emplace_back(task);
  167. }
  168. if (task->type_ != kRunGraph || task->sync_run_) {
  169. sync_cond_var_.notify_all();
  170. }
  171. }
  172. }
  173. std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
  174. std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
  175. std::unique_lock<std::mutex> lock(pending_task_mutex_);
  176. for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
  177. auto task = *iter;
  178. if (IsTaskReady(task)) {
  179. new_ready_tasks.emplace_back(task);
  180. pending_tasks_.erase(iter++);
  181. } else {
  182. iter++;
  183. }
  184. }
  185. return new_ready_tasks;
  186. }
  187. void Executor::OnRunGraphFinished() {
  188. auto new_ready_tasks = GetNewReadyTasks();
  189. std::unique_lock<std::mutex> lock(task_mutex_);
  190. for (auto &task : new_ready_tasks) {
  191. ready_tasks_.push(task);
  192. }
  193. if (new_ready_tasks.size() > 0) {
  194. task_cond_var_.notify_all();
  195. }
  196. reenter_cond_var_.notify_all();
  197. }
  198. bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
  199. MS_EXCEPTION_IF_NULL(task);
  200. for (auto &input : task->input_need_wait_tensors_) {
  201. MS_EXCEPTION_IF_NULL(input);
  202. if (input->NeedWait()) {
  203. return false;
  204. }
  205. }
  206. auto session = task->session_;
  207. MS_EXCEPTION_IF_NULL(session);
  208. auto graph = session->GetGraph(task->graph_id_);
  209. if (graph != nullptr) {
  210. return graph->IsPreGraphFinished();
  211. }
  212. return true;
  213. }
  214. void Executor::SyncRunTask(const std::shared_ptr<Task> &task) {
  215. std::unique_lock<std::mutex> lock(task_mutex_);
  216. ready_tasks_.push(task);
  217. done_tasks_.clear();
  218. task_cond_var_.notify_all();
  219. sync_cond_var_.wait(lock);
  220. MsException::GetInstance().CheckException();
  221. }
  222. GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
  223. const AnfNodePtrList &outputs) {
  224. auto task = std::make_shared<CompileNodesTask>();
  225. task->session_ = session;
  226. task->segment_ = segment;
  227. task->output_nodes_ = outputs;
  228. SyncRunTask(task);
  229. return task->graph_id_;
  230. }
  231. GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
  232. auto task = std::make_shared<CompileGraphTask>();
  233. task->session_ = session;
  234. task->func_graph_ = func_graph;
  235. SyncRunTask(task);
  236. return task->graph_id_;
  237. }
  238. void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) {
  239. auto task = std::make_shared<BuildGraphTask>();
  240. task->session_ = session;
  241. task->graph_id_ = graphId;
  242. SyncRunTask(task);
  243. }
  244. void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
  245. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  246. MS_EXCEPTION_IF_NULL(session);
  247. MS_EXCEPTION_IF_NULL(outputs);
  248. auto task = std::make_shared<RunGraphTask>();
  249. task->session_ = session;
  250. task->graph_id_ = graph_id;
  251. task->input_tensors_ = inputs;
  252. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  253. task->outputs_ = *outputs;
  254. task->sync_run_ = true;
  255. mindspore::ScopedLongRunning long_running;
  256. SyncRunTask(task);
  257. }
  258. void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
  259. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  260. MS_EXCEPTION_IF_NULL(session);
  261. MS_EXCEPTION_IF_NULL(outputs);
  262. auto task = std::make_shared<RunGraphTask>();
  263. task->session_ = session;
  264. task->graph_id_ = graph_id;
  265. task->input_tensors_ = inputs;
  266. task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs);
  267. for (auto &tensor : inputs) {
  268. if (tensor->NeedWait()) {
  269. if (tensor->IsGraphOutput()) {
  270. task->input_need_wait_tensors_.emplace_back(tensor);
  271. } else {
  272. mindspore::ScopedLongRunning long_running;
  273. tensor->Wait();
  274. }
  275. }
  276. }
  277. for (auto &tensor : task->input_need_lock_tensors_) {
  278. tensor->SetNeedWait(true);
  279. }
  280. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  281. // maintain a copy of output vector
  282. task->outputs_ = *outputs;
  283. // sync run graph without output tensor(int dataset graph)
  284. if (!TensorInVector(outputs)) {
  285. task->sync_run_ = true;
  286. mindspore::ScopedLongRunning long_running;
  287. SyncRunTask(task);
  288. return;
  289. }
  290. auto graph = session->GetGraph(task->graph_id_);
  291. if (graph != nullptr) {
  292. if (!graph->IsPostGraphFinished()) {
  293. mindspore::ScopedLongRunning long_running;
  294. std::unique_lock<std::mutex> lock(reenter_mutex_);
  295. reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); });
  296. }
  297. }
  298. bool ready = IsTaskReady(task);
  299. if (!ready) {
  300. std::unique_lock<std::mutex> lock(pending_task_mutex_);
  301. pending_tasks_.push_back(task);
  302. return;
  303. }
  304. std::unique_lock<std::mutex> lock(task_mutex_);
  305. ready_tasks_.push(task);
  306. done_tasks_.clear();
  307. task_cond_var_.notify_all();
  308. }
  309. void Executor::BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  310. const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask) {
  311. auto task = std::make_shared<BuildOpTask>();
  312. task->session_ = session;
  313. task->op_run_info_ = op_run_info;
  314. task->graph_info_ = graph_info;
  315. task->input_tensors_ = input_tensors;
  316. task->tensors_mask_ = tensors_mask;
  317. SyncRunTask(task);
  318. }
  319. void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  320. const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
  321. auto task = std::make_shared<RunOpTask>();
  322. task->session_ = session;
  323. task->op_run_info_ = op_run_info;
  324. task->graph_info_ = graph_info;
  325. task->input_tensors_ = input_tensors;
  326. for (auto &tensor : input_tensors) {
  327. if (tensor->NeedWait()) {
  328. tensor->Wait();
  329. }
  330. }
  331. SyncRunTask(task);
  332. *outputs = task->outputs_;
  333. }
  334. bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
  335. auto task = std::make_shared<CreateCommGroupTask>();
  336. task->group_name_ = group_name;
  337. task->ranks_ = ranks;
  338. SyncRunTask(task);
  339. return task->result_;
  340. }
  341. bool Executor::DestroyCommGroup(const std::string &group_name) {
  342. auto task = std::make_shared<DestroyCommGroupTask>();
  343. task->group_name_ = group_name;
  344. SyncRunTask(task);
  345. return task->result_;
  346. }
  347. void Executor::OnWorkerExit() {
  348. if (device_name_ == kAscendDevice) {
  349. device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
  350. }
  351. }
  352. } // namespace session
  353. } // namespace mindspore