You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/executor.h"
  17. #include <exception>
  18. #include "backend/session/executor_manager.h"
  19. #include "runtime/device/kernel_runtime_manager.h"
  20. #include "utils/comm_manager.h"
  21. #include "utils/scoped_long_running.h"
  22. namespace mindspore {
  23. namespace session {
  24. namespace {
  25. void UpdateOutputTensors(const VectorRef *outputs,
  26. const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
  27. MS_EXCEPTION_IF_NULL(outputs);
  28. for (auto item : *outputs) {
  29. if (utils::isa<VectorRefPtr>(item)) {
  30. auto vector_ref = utils::cast<VectorRef>(item);
  31. UpdateOutputTensors(&vector_ref, tensor_to_node);
  32. } else if (utils::isa<tensor::TensorPtr>(item)) {
  33. auto tensor = utils::cast<tensor::TensorPtr>(item);
  34. MS_EXCEPTION_IF_NULL(tensor);
  35. auto iter = tensor_to_node.find(tensor);
  36. if (iter != tensor_to_node.end()) {
  37. auto &node = iter->second.first;
  38. auto &output_index = iter->second.second;
  39. auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
  40. tensor->set_device_address(address);
  41. if (AnfAlgo::IsDynamicShape(node)) {
  42. auto updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
  43. ShapeVector int_shape;
  44. std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
  45. tensor->set_shape(int_shape);
  46. }
  47. }
  48. if (tensor->NeedSyncDeviceToHostImmediately()) {
  49. tensor->data_sync(false);
  50. tensor->set_device_address(nullptr);
  51. tensor->set_sync_status(kNeedSyncHostToDevice);
  52. }
  53. }
  54. }
  55. }
  56. void NotifyOutputTensors(const VectorRef *outputs) {
  57. MS_EXCEPTION_IF_NULL(outputs);
  58. for (auto item : *outputs) {
  59. if (utils::isa<VectorRefPtr>(item)) {
  60. auto vector_ref = utils::cast<VectorRef>(item);
  61. NotifyOutputTensors(&vector_ref);
  62. } else if (utils::isa<tensor::TensorPtr>(item)) {
  63. auto tensor = utils::cast<tensor::TensorPtr>(item);
  64. MS_EXCEPTION_IF_NULL(tensor);
  65. tensor->SetNeedWait(false);
  66. }
  67. }
  68. }
  69. bool TensorInVector(const VectorRef *outputs) {
  70. MS_EXCEPTION_IF_NULL(outputs);
  71. for (auto item : *outputs) {
  72. if (utils::isa<VectorRefPtr>(item)) {
  73. auto vector_ref = utils::cast<VectorRef>(item);
  74. if (TensorInVector(&vector_ref)) {
  75. return true;
  76. }
  77. } else if (utils::isa<tensor::TensorPtr>(item)) {
  78. return true;
  79. }
  80. }
  81. return false;
  82. }
  83. } // namespace
  84. void CompileNodesTask::Run() {
  85. MS_EXCEPTION_IF_NULL(session_);
  86. graph_id_ = session_->CompileGraphImpl(nodes_, output_nodes_);
  87. }
  88. void CompileGraphTask::Run() {
  89. MS_EXCEPTION_IF_NULL(session_);
  90. graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_));
  91. }
  92. void BuildGraphTask::Run() {
  93. MS_EXCEPTION_IF_NULL(session_);
  94. session_->BuildGraphImpl(graph_id_);
  95. }
  96. void RunGraphTask::Run() {
  97. MS_EXCEPTION_IF_NULL(session_);
  98. try {
  99. session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
  100. UpdateOutputTensors(&outputs_, tensor_to_node_);
  101. } catch (const std::exception &e) {
  102. MsException::GetInstance().SetException();
  103. }
  104. NotifyOutputTensors(&outputs_);
  105. for (auto &tensor : input_need_lock_tensors_) {
  106. tensor->SetNeedWait(false);
  107. }
  108. ExecutorManager::Instance().OnRunGraphFinished();
  109. }
  110. void BuildOpTask::Run() {
  111. MS_EXCEPTION_IF_NULL(session_);
  112. session_->BuildOpImpl(*op_run_info_, graph_info_, input_tensors_, tensors_mask_);
  113. }
  114. void RunOpTask::Run() {
  115. MS_EXCEPTION_IF_NULL(session_);
  116. session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_);
  117. }
  118. void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
  119. void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
  120. Executor::Executor(const std::string &device_name, uint32_t device_id) {
  121. device_name_ = device_name;
  122. device_id_ = device_id;
  123. worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
  124. }
  125. Executor::~Executor() { WorkerJoin(); }
  126. void Executor::WorkerJoin() {
  127. // Avoid worker thread join itself which will cause deadlock
  128. if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
  129. {
  130. std::unique_lock<std::mutex> lock(task_mutex_);
  131. auto task = std::make_shared<ExitTask>();
  132. ready_tasks_.push(task);
  133. task_cond_var_.notify_all();
  134. }
  135. worker_->join();
  136. }
  137. }
  138. void Executor::WorkerLoop() {
  139. while (true) {
  140. std::shared_ptr<Task> task;
  141. {
  142. std::unique_lock<std::mutex> lock(task_mutex_);
  143. task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
  144. task = ready_tasks_.front();
  145. ready_tasks_.pop();
  146. }
  147. if (task->type_ == kExit) {
  148. OnWorkerExit();
  149. return;
  150. }
  151. try {
  152. task->Run();
  153. } catch (const std::exception &e) {
  154. MsException::GetInstance().SetException();
  155. }
  156. if (task->type_ != kRunGraph || task->sync_run_) {
  157. task = nullptr;
  158. sync_cond_var_.notify_all();
  159. } else {
  160. task = nullptr;
  161. }
  162. }
  163. }
  164. std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
  165. std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
  166. std::unique_lock<std::mutex> lock(pending_task_mutex_);
  167. for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
  168. auto task = *iter;
  169. if (IsTaskReady(task)) {
  170. new_ready_tasks.emplace_back(task);
  171. pending_tasks_.erase(iter++);
  172. } else {
  173. iter++;
  174. }
  175. }
  176. return new_ready_tasks;
  177. }
  178. void Executor::OnRunGraphFinished() {
  179. auto new_ready_tasks = GetNewReadyTasks();
  180. std::unique_lock<std::mutex> lock(task_mutex_);
  181. for (auto &task : new_ready_tasks) {
  182. ready_tasks_.push(task);
  183. }
  184. if (new_ready_tasks.size() > 0) {
  185. task_cond_var_.notify_all();
  186. }
  187. }
  188. bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
  189. MS_EXCEPTION_IF_NULL(task);
  190. for (auto &input : task->input_need_wait_tensors_) {
  191. MS_EXCEPTION_IF_NULL(input);
  192. if (input->NeedWait()) {
  193. return false;
  194. }
  195. }
  196. return true;
  197. }
  198. void Executor::SyncRunTask(const std::shared_ptr<Task> &task) {
  199. std::unique_lock<std::mutex> lock(task_mutex_);
  200. ready_tasks_.push(task);
  201. task_cond_var_.notify_all();
  202. sync_cond_var_.wait(lock);
  203. MsException::GetInstance().CheckException();
  204. }
  205. GraphId Executor::CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  206. auto task = std::make_shared<CompileNodesTask>();
  207. task->session_ = session;
  208. task->nodes_ = lst;
  209. task->output_nodes_ = outputs;
  210. SyncRunTask(task);
  211. return task->graph_id_;
  212. }
  213. GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
  214. auto task = std::make_shared<CompileGraphTask>();
  215. task->session_ = session;
  216. task->func_graph_ = func_graph;
  217. SyncRunTask(task);
  218. return task->graph_id_;
  219. }
  220. void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) {
  221. auto task = std::make_shared<BuildGraphTask>();
  222. task->session_ = session;
  223. task->graph_id_ = graphId;
  224. SyncRunTask(task);
  225. }
  226. void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
  227. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  228. MS_EXCEPTION_IF_NULL(session);
  229. MS_EXCEPTION_IF_NULL(outputs);
  230. auto task = std::make_shared<RunGraphTask>();
  231. task->session_ = session;
  232. task->graph_id_ = graph_id;
  233. task->input_tensors_ = inputs;
  234. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  235. task->outputs_ = *outputs;
  236. task->sync_run_ = true;
  237. mindspore::ScopedLongRunning long_running;
  238. SyncRunTask(task);
  239. }
  240. void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
  241. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  242. MS_EXCEPTION_IF_NULL(session);
  243. MS_EXCEPTION_IF_NULL(outputs);
  244. auto task = std::make_shared<RunGraphTask>();
  245. task->session_ = session;
  246. task->graph_id_ = graph_id;
  247. task->input_tensors_ = inputs;
  248. task->input_need_lock_tensors_ = session->GetNeedLockInputTensors(graph_id, inputs);
  249. for (auto &tensor : inputs) {
  250. if (tensor->NeedWait()) {
  251. if (tensor->IsGraphOutput()) {
  252. task->input_need_wait_tensors_.emplace_back(tensor);
  253. } else {
  254. mindspore::ScopedLongRunning long_running;
  255. tensor->Wait();
  256. }
  257. }
  258. }
  259. for (auto &tensor : task->input_need_lock_tensors_) {
  260. tensor->SetNeedWait(true);
  261. }
  262. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  263. // maintain a copy of output vector
  264. task->outputs_ = *outputs;
  265. // sync run graph without output tensor(int dataset graph)
  266. if (!TensorInVector(outputs)) {
  267. task->sync_run_ = true;
  268. mindspore::ScopedLongRunning long_running;
  269. SyncRunTask(task);
  270. return;
  271. }
  272. bool ready = IsTaskReady(task);
  273. if (!ready) {
  274. std::unique_lock<std::mutex> lock(pending_task_mutex_);
  275. pending_tasks_.push_back(task);
  276. return;
  277. }
  278. std::unique_lock<std::mutex> lock(task_mutex_);
  279. ready_tasks_.push(task);
  280. task_cond_var_.notify_all();
  281. }
  282. void Executor::BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  283. const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) {
  284. auto task = std::make_shared<BuildOpTask>();
  285. task->session_ = session;
  286. task->op_run_info_ = op_run_info;
  287. task->graph_info_ = graph_info;
  288. task->input_tensors_ = input_tensors;
  289. task->tensors_mask_ = tensors_mask;
  290. SyncRunTask(task);
  291. }
  292. void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  293. const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
  294. auto task = std::make_shared<RunOpTask>();
  295. task->session_ = session;
  296. task->op_run_info_ = op_run_info;
  297. task->graph_info_ = graph_info;
  298. task->input_tensors_ = input_tensors;
  299. SyncRunTask(task);
  300. *outputs = task->outputs_;
  301. }
  302. bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
  303. auto task = std::make_shared<CreateCommGroupTask>();
  304. task->group_name_ = group_name;
  305. task->ranks_ = ranks;
  306. SyncRunTask(task);
  307. return task->result_;
  308. }
  309. bool Executor::DestroyCommGroup(const std::string &group_name) {
  310. auto task = std::make_shared<DestroyCommGroupTask>();
  311. task->group_name_ = group_name;
  312. SyncRunTask(task);
  313. return task->result_;
  314. }
  315. void Executor::OnWorkerExit() {
  316. if (device_name_ == kAscendDevice) {
  317. device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
  318. }
  319. }
  320. } // namespace session
  321. } // namespace mindspore