You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/executor.h"
  17. #include <exception>
  18. #include "runtime/device/kernel_runtime_manager.h"
  19. #include "backend/session/executor_manager.h"
  20. #include "utils/comm_manager.h"
  21. #include "utils/scoped_long_running.h"
  22. namespace mindspore {
  23. namespace session {
  24. namespace {
  25. void UpdateOutputTensors(const VectorRef *outputs,
  26. const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
  27. MS_EXCEPTION_IF_NULL(outputs);
  28. for (auto item : *outputs) {
  29. if (utils::isa<VectorRefPtr>(item)) {
  30. auto vector_ref = utils::cast<VectorRef>(item);
  31. UpdateOutputTensors(&vector_ref, tensor_to_node);
  32. } else if (utils::isa<tensor::TensorPtr>(item)) {
  33. auto tensor = utils::cast<tensor::TensorPtr>(item);
  34. MS_EXCEPTION_IF_NULL(tensor);
  35. auto iter = tensor_to_node.find(tensor);
  36. if (iter != tensor_to_node.end()) {
  37. auto &node = iter->second.first;
  38. auto &output_index = iter->second.second;
  39. auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
  40. tensor->set_device_address(address);
  41. if (AnfAlgo::IsDynamicShape(node)) {
  42. auto updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
  43. ShapeVector int_shape;
  44. std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
  45. tensor->set_shape(int_shape);
  46. }
  47. }
  48. if (tensor->NeedSyncDeviceToHostImmediately()) {
  49. tensor->data_sync(false);
  50. tensor->set_device_address(nullptr);
  51. tensor->set_sync_status(kNeedSyncHostToDevice);
  52. }
  53. tensor->SetNeedWait(false);
  54. }
  55. }
  56. }
  57. bool TensorInVector(const VectorRef *outputs) {
  58. MS_EXCEPTION_IF_NULL(outputs);
  59. for (auto item : *outputs) {
  60. if (utils::isa<VectorRefPtr>(item)) {
  61. auto vector_ref = utils::cast<VectorRef>(item);
  62. if (TensorInVector(&vector_ref)) {
  63. return true;
  64. }
  65. } else if (utils::isa<tensor::TensorPtr>(item)) {
  66. return true;
  67. }
  68. }
  69. return false;
  70. }
  71. } // namespace
  72. void CompileNodesTask::Run() {
  73. MS_EXCEPTION_IF_NULL(session_);
  74. graph_id_ = session_->CompileGraphImpl(nodes_, output_nodes_);
  75. }
  76. void CompileGraphTask::Run() {
  77. MS_EXCEPTION_IF_NULL(session_);
  78. graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_));
  79. }
  80. void BuildGraphTask::Run() {
  81. MS_EXCEPTION_IF_NULL(session_);
  82. session_->BuildGraphImpl(graph_id_);
  83. }
  84. void RunGraphTask::Run() {
  85. MS_EXCEPTION_IF_NULL(session_);
  86. try {
  87. session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
  88. } catch (const std::exception &e) {
  89. MsException::GetInstance().SetException();
  90. }
  91. UpdateOutputTensors(&outputs_, tensor_to_node_);
  92. for (auto &tensor : input_need_lock_tensors_) {
  93. tensor->SetNeedWait(false);
  94. }
  95. ExecutorManager::Instance().OnRunGraphFinished();
  96. }
  97. void BuildOpTask::Run() {
  98. MS_EXCEPTION_IF_NULL(session_);
  99. session_->BuildOpImpl(*op_run_info_, graph_info_, input_tensors_, tensors_mask_);
  100. }
  101. void RunOpTask::Run() {
  102. MS_EXCEPTION_IF_NULL(session_);
  103. session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_);
  104. }
  105. void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
  106. void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
  107. Executor::Executor(const std::string &device_name, uint32_t device_id) {
  108. device_name_ = device_name;
  109. device_id_ = device_id;
  110. worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
  111. }
  112. Executor::~Executor() { WorkerJoin(); }
  113. void Executor::WorkerJoin() {
  114. // Avoid worker thread join itself which will cause deadlock
  115. if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
  116. {
  117. std::unique_lock<std::mutex> lock(task_mutex_);
  118. auto task = std::make_shared<ExitTask>();
  119. ready_tasks_.push(task);
  120. task_cond_var_.notify_all();
  121. }
  122. worker_->join();
  123. }
  124. }
  125. void Executor::WorkerLoop() {
  126. while (true) {
  127. std::shared_ptr<Task> task;
  128. {
  129. std::unique_lock<std::mutex> lock(task_mutex_);
  130. task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
  131. task = ready_tasks_.front();
  132. ready_tasks_.pop();
  133. }
  134. if (task->type_ == kExit) {
  135. OnWorkerExit();
  136. return;
  137. }
  138. try {
  139. task->Run();
  140. } catch (const std::exception &e) {
  141. MsException::GetInstance().SetException();
  142. }
  143. if (task->type_ != kRunGraph || task->sync_run_) {
  144. task = nullptr;
  145. sync_cond_var_.notify_all();
  146. } else {
  147. task = nullptr;
  148. }
  149. }
  150. }
  151. std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
  152. std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
  153. std::unique_lock<std::mutex> lock(pending_task_mutex_);
  154. for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
  155. auto task = *iter;
  156. if (IsTaskReady(task)) {
  157. new_ready_tasks.emplace_back(task);
  158. pending_tasks_.erase(iter++);
  159. } else {
  160. iter++;
  161. }
  162. }
  163. return new_ready_tasks;
  164. }
  165. void Executor::OnRunGraphFinished() {
  166. auto new_ready_tasks = GetNewReadyTasks();
  167. std::unique_lock<std::mutex> lock(task_mutex_);
  168. for (auto &task : new_ready_tasks) {
  169. ready_tasks_.push(task);
  170. }
  171. if (new_ready_tasks.size() > 0) {
  172. task_cond_var_.notify_all();
  173. }
  174. }
  175. bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
  176. MS_EXCEPTION_IF_NULL(task);
  177. for (auto &input : task->input_need_wait_tensors_) {
  178. MS_EXCEPTION_IF_NULL(input);
  179. if (input->NeedWait()) {
  180. return false;
  181. }
  182. }
  183. return true;
  184. }
  185. void Executor::SyncRunTask(const std::shared_ptr<Task> &task) {
  186. std::unique_lock<std::mutex> lock(task_mutex_);
  187. ready_tasks_.push(task);
  188. task_cond_var_.notify_all();
  189. sync_cond_var_.wait(lock);
  190. MsException::GetInstance().CheckException();
  191. }
  192. GraphId Executor::CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  193. auto task = std::make_shared<CompileNodesTask>();
  194. task->session_ = session;
  195. task->nodes_ = lst;
  196. task->output_nodes_ = outputs;
  197. SyncRunTask(task);
  198. return task->graph_id_;
  199. }
  200. GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
  201. auto task = std::make_shared<CompileGraphTask>();
  202. task->session_ = session;
  203. task->func_graph_ = func_graph;
  204. SyncRunTask(task);
  205. return task->graph_id_;
  206. }
  207. void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) {
  208. auto task = std::make_shared<BuildGraphTask>();
  209. task->session_ = session;
  210. task->graph_id_ = graphId;
  211. SyncRunTask(task);
  212. }
  213. void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
  214. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  215. MS_EXCEPTION_IF_NULL(session);
  216. MS_EXCEPTION_IF_NULL(outputs);
  217. auto task = std::make_shared<RunGraphTask>();
  218. task->session_ = session;
  219. task->graph_id_ = graph_id;
  220. task->input_tensors_ = inputs;
  221. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  222. task->outputs_ = *outputs;
  223. task->sync_run_ = true;
  224. mindspore::ScopedLongRunning long_running;
  225. SyncRunTask(task);
  226. }
  227. void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
  228. const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  229. MS_EXCEPTION_IF_NULL(session);
  230. MS_EXCEPTION_IF_NULL(outputs);
  231. if (session != nullptr) {
  232. RunGraph(session, graph_id, inputs, outputs);
  233. return;
  234. }
  235. auto task = std::make_shared<RunGraphTask>();
  236. task->session_ = session;
  237. task->graph_id_ = graph_id;
  238. task->input_tensors_ = inputs;
  239. task->input_need_lock_tensors_ = session->GetNeedLockInputTensors(graph_id, inputs);
  240. // lock inputs
  241. for (auto &tensor : inputs) {
  242. if (tensor->NeedWait()) {
  243. task->input_need_wait_tensors_.emplace_back(tensor);
  244. }
  245. }
  246. for (auto &tensor : task->input_need_lock_tensors_) {
  247. tensor->SetNeedWait(true);
  248. }
  249. session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
  250. // maintain a copy of output vector
  251. task->outputs_ = *outputs;
  252. // sync run graph without output tensor(int dataset graph)
  253. if (!TensorInVector(outputs)) {
  254. task->sync_run_ = true;
  255. mindspore::ScopedLongRunning long_running;
  256. SyncRunTask(task);
  257. return;
  258. }
  259. bool ready = IsTaskReady(task);
  260. if (!ready) {
  261. std::unique_lock<std::mutex> lock(pending_task_mutex_);
  262. pending_tasks_.push_back(task);
  263. return;
  264. }
  265. std::unique_lock<std::mutex> lock(task_mutex_);
  266. ready_tasks_.push(task);
  267. task_cond_var_.notify_all();
  268. }
  269. void Executor::BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  270. const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) {
  271. auto task = std::make_shared<BuildOpTask>();
  272. task->session_ = session;
  273. task->op_run_info_ = op_run_info;
  274. task->graph_info_ = graph_info;
  275. task->input_tensors_ = input_tensors;
  276. task->tensors_mask_ = tensors_mask;
  277. SyncRunTask(task);
  278. }
  279. void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  280. const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
  281. auto task = std::make_shared<RunOpTask>();
  282. task->session_ = session;
  283. task->op_run_info_ = op_run_info;
  284. task->graph_info_ = graph_info;
  285. task->input_tensors_ = input_tensors;
  286. SyncRunTask(task);
  287. *outputs = task->outputs_;
  288. }
  289. bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
  290. auto task = std::make_shared<CreateCommGroupTask>();
  291. task->group_name_ = group_name;
  292. task->ranks_ = ranks;
  293. SyncRunTask(task);
  294. return task->result_;
  295. }
  296. bool Executor::DestroyCommGroup(const std::string &group_name) {
  297. auto task = std::make_shared<DestroyCommGroupTask>();
  298. task->group_name_ = group_name;
  299. SyncRunTask(task);
  300. return task->result_;
  301. }
  302. void Executor::OnWorkerExit() {
  303. if (device_name_ == kAscendDevice) {
  304. device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
  305. }
  306. }
  307. } // namespace session
  308. } // namespace mindspore