You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_kernel_runtime.cc 18 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/ascend/ascend_kernel_runtime.h"
  17. #include <string>
  18. #include <vector>
  19. #include <memory>
  20. #include <utility>
  21. #include <exception>
  22. #include <algorithm>
  23. #include "device/ascend/ascend_device_address.h"
  24. #include "utils/context/ms_context.h"
  25. #include "device/ascend/profiling/profiling_manager.h"
  26. #include "hccl/hcom.h"
  27. #include "common/trans.h"
  28. #include "runtime/context.h"
  29. #include "device/ascend/ascend_stream_assign.h"
  30. #include "device/ascend/ascend_memory_pool.h"
  31. #include "framework/ge_runtime/model_runner.h"
  32. #include "device/ascend/tasksink/task_generator.h"
  33. #include "session/anf_runtime_algorithm.h"
  34. #include "device/ascend/profiling/profiling_utils.h"
  35. #include "kernel/tbe/tbe_utils.h"
  36. #include "kernel/tbe/tbe_python_funcs.h"
  37. #include "pre_activate/mem_reuse/mem_reuse_checker.h"
  38. #include "device/ascend/ascend_memory_manager.h"
  39. using mindspore::device::ascend::ProfilingManager;
  40. using mindspore::device::ascend::ProfilingUtils;
  41. using mindspore::device::ascend::tasksink::TaskGenerator;
  42. using mindspore::kernel::tbe::TbeUtils;
  43. using std::vector;
  44. namespace mindspore {
  45. namespace device {
  46. namespace ascend {
  47. static const size_t PRAMATER_OUTPUT_INDEX = 0;
  48. AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); }
  49. void AscendKernelRuntime::ClearGraphModelMap() {
  50. for (auto &iter : graph_model_map_) {
  51. MS_LOG(INFO) << "Ge UnloadModel " << iter.first;
  52. auto ret = ge::model_runner::ModelRunner::Instance().UnloadModel(iter.first);
  53. if (!ret) {
  54. MS_LOG(ERROR) << "UnloadModel failed";
  55. }
  56. }
  57. }
  58. bool AscendKernelRuntime::NeedDestroyHccl() {
  59. auto context_ptr = MsContext::GetInstance();
  60. MS_EXCEPTION_IF_NULL(context_ptr);
  61. if (!context_ptr->enable_hccl()) {
  62. MS_LOG(INFO) << "hccl is not enabled";
  63. return false;
  64. }
  65. // Note: make sure hcom_connectivity_detection api never be used.
  66. return true;
  67. }
  68. void AscendKernelRuntime::ReleaseDeviceRes() {
  69. MS_LOG(INFO) << "ascend finalize start";
  70. // release ge runtime
  71. ClearGraphModelMap();
  72. auto context_ptr = MsContext::GetInstance();
  73. MS_EXCEPTION_IF_NULL(context_ptr);
  74. auto ret = rtSetDevice(context_ptr->device_id());
  75. if (ret != RT_ERROR_NONE) {
  76. MS_EXCEPTION(DeviceProcessError) << "rtSetDevice, ret[" << static_cast<int>(ret) << "]";
  77. }
  78. if (mem_manager_ != nullptr) {
  79. mem_manager_->FreeDeviceMemory();
  80. }
  81. (void)DestroyHccl();
  82. (void)ResetDevice();
  83. (void)ProfilingManager::GetInstance().StopProfiling();
  84. MS_LOG(INFO) << "ascend finalize end";
  85. }
  86. bool AscendKernelRuntime::Init() {
  87. if (initialized_) {
  88. return true;
  89. }
  90. bool ret = false;
  91. #ifdef ENABLE_DUMP_E2E
  92. ret = SetDumpConf();
  93. if (!ret) {
  94. MS_LOG(INFO) << "no dump conf to set!";
  95. }
  96. #endif
  97. ret = InitDevice();
  98. if (!ret) {
  99. return ret;
  100. }
  101. mem_manager_ = std::make_shared<AscendMemoryManager>();
  102. MS_EXCEPTION_IF_NULL(mem_manager_);
  103. mem_manager_->MallocDeviceMemory();
  104. ret = ProfilingManager::GetInstance().StartupProfiling(device_id_);
  105. if (!ret) {
  106. MS_EXCEPTION(DeviceProcessError) << "StartupProfiling failed.";
  107. }
  108. initialized_ = true;
  109. return ret;
  110. }
  111. #ifdef ENABLE_DUMP_E2E
  112. namespace {
  113. void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf) {
  114. MS_EXCEPTION_IF_NULL(graph);
  115. MS_EXCEPTION_IF_NULL(dump_conf);
  116. bool trans_flag = dump_conf->trans_flag();
  117. const auto &apply_kernels = graph->execution_order();
  118. for (const auto &node : apply_kernels) {
  119. MS_EXCEPTION_IF_NULL(node);
  120. auto node_name = AnfAlgo::GetCNodeName(node);
  121. std::string kernel_name = node->fullname_with_scope();
  122. if (!dump_conf->IsKernelNeedDump(kernel_name)) {
  123. continue;
  124. }
  125. const std::string strsrc = "/";
  126. const std::string strdst = "--";
  127. std::string::size_type pos = 0;
  128. std::string::size_type srclen = strsrc.size();
  129. std::string::size_type dstlen = strdst.size();
  130. while ((pos = kernel_name.find(strsrc, pos)) != std::string::npos) {
  131. kernel_name.replace(pos, srclen, strdst);
  132. pos += dstlen;
  133. }
  134. auto output_size = AnfAlgo::GetOutputTensorNum(node);
  135. for (size_t j = 0; j < output_size; ++j) {
  136. auto addr = AnfAlgo::GetOutputAddr(node, j);
  137. std::vector<int> int_shapes;
  138. if (trans_flag) {
  139. int_shapes = trans::GetRuntimePaddingShape(node, j);
  140. } else {
  141. auto shape = AnfAlgo::GetOutputDeviceShape(node, j);
  142. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  143. [](size_t inner_item) { return SizeToInt(inner_item); });
  144. }
  145. auto type = AnfAlgo::GetOutputInferDataType(node, j);
  146. auto format = kOpFormat_DEFAULT;
  147. string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j);
  148. auto ascend_addr = dynamic_cast<const mindspore::device::ascend::AscendDeviceAddress *>(addr);
  149. auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type);
  150. if (!ret) {
  151. MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath
  152. << ", host_format:" << format << ".!";
  153. }
  154. }
  155. }
  156. }
  157. void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf) {
  158. MS_EXCEPTION_IF_NULL(graph);
  159. MS_EXCEPTION_IF_NULL(dump_conf);
  160. bool trans_flag = dump_conf->trans_flag();
  161. const auto &parameters = graph->inputs();
  162. for (auto &item : parameters) {
  163. if (!item->isa<Parameter>()) {
  164. continue;
  165. }
  166. std::string parameter_name = item->fullname_with_scope();
  167. if (!dump_conf->IsKernelNeedDump(parameter_name)) {
  168. continue;
  169. }
  170. auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX);
  171. std::vector<int> int_shapes;
  172. if (trans_flag) {
  173. int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX);
  174. } else {
  175. auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX);
  176. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  177. [](size_t inner_item) { return SizeToInt(inner_item); });
  178. }
  179. auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX);
  180. auto format = kOpFormat_DEFAULT;
  181. string filepath = dump_path + '/' + parameter_name + '_' + "output_0";
  182. auto ascend_addr = dynamic_cast<const mindspore::device::ascend::AscendDeviceAddress *>(addr);
  183. auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type);
  184. if (!ret) {
  185. MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath
  186. << ", host_format:" << format << ".!";
  187. }
  188. }
  189. }
  190. } // namespace
  191. #endif
  192. bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) {
  193. MS_EXCEPTION_IF_NULL(graph);
  194. #ifdef ENABLE_DUMP_E2E
  195. MS_LOG(INFO) << "start dump step";
  196. DumpConfPtr dump_conf = GetDumpConf();
  197. MS_EXCEPTION_IF_NULL(dump_conf);
  198. dump_conf->UpdataCurIter();
  199. bool dump_flag = dump_conf->dump_enable();
  200. if (!dump_flag) {
  201. MS_LOG(INFO) << "dump flag is disable, pass dump step";
  202. return true;
  203. }
  204. uint32_t cur_iter = dump_conf->cur_iter();
  205. if (dump_conf->dump_iter() != 0) {
  206. if (cur_iter != dump_conf->dump_iter()) {
  207. return true;
  208. }
  209. }
  210. MS_LOG(INFO) << "cur iter is " << cur_iter;
  211. std::string net_name = dump_conf->dump_net_name();
  212. std::string iterator = to_string(cur_iter);
  213. std::string dump_path = dump_conf->dump_path();
  214. if (dump_path.back() == '/') {
  215. dump_path = dump_path + net_name + '/' + iterator;
  216. } else {
  217. dump_path = dump_path + '/' + net_name + '/' + iterator;
  218. }
  219. // dump output
  220. DumpOutput(graph, dump_path, dump_conf);
  221. // dump parameters
  222. DumpParameters(graph, dump_path, dump_conf);
  223. #endif
  224. return true;
  225. }
  226. DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
  227. TypeId type_id) {
  228. return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
  229. }
  230. bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
  231. if (graph == nullptr) {
  232. MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!";
  233. }
  234. MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id();
  235. auto context_ptr = MsContext::GetInstance();
  236. MS_EXCEPTION_IF_NULL(context_ptr);
  237. bool is_task_sink = context_ptr->enable_task_sink();
  238. if (!is_task_sink) {
  239. return true;
  240. }
  241. #ifdef MEM_REUSE_DEBUG
  242. if (!context_ptr->enable_mem_reuse()) {
  243. // Get normal graph ir for memreuse
  244. mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph);
  245. }
  246. #endif
  247. vector<std::shared_ptr<TaskInfo>> task_info_list;
  248. auto anf_node_list = graph->execution_order();
  249. TaskGenerator::GenTasks(anf_node_list, &task_info_list, graph->graph_id());
  250. // Store the task_info_list
  251. auto insert_ret = task_map_.insert(std::make_pair(graph->graph_id(), task_info_list));
  252. if (!insert_ret.second) {
  253. MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
  254. }
  255. // Graph may have no compute node, such TensorAddGrad.
  256. if (task_info_list.empty()) {
  257. MS_LOG(WARNING) << "graph " << graph->graph_id() << " have no compute node";
  258. return true;
  259. }
  260. AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance();
  261. // the streams' flag not HEAD_STREAM
  262. std::vector<uint32_t> wait_active_stream_list;
  263. assign_instance.GetWaitStreams(&wait_active_stream_list);
  264. auto force_copy_stream_list = assign_instance.hcom_streams();
  265. MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum()
  266. << ", total event num:" << assign_instance.total_event_num()
  267. << ", wait_active_stream_list size:" << wait_active_stream_list.size()
  268. << ", force_copy_stream_list size:" << force_copy_stream_list.size();
  269. std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list;
  270. std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>(
  271. task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0,
  272. 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.total_event_num(), 0);
  273. auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model));
  274. if (!ret.second) {
  275. MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
  276. }
  277. MS_LOG(INFO) << "TaskGenerator GetTaskInfo end...";
  278. return true;
  279. }
  280. bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
  281. if (graph == nullptr) {
  282. MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. ";
  283. }
  284. MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id();
  285. auto context_ptr = MsContext::GetInstance();
  286. MS_EXCEPTION_IF_NULL(context_ptr);
  287. bool is_task_sink = context_ptr->enable_task_sink();
  288. if (!is_task_sink) {
  289. return true;
  290. }
  291. if (GraphWithEmptyTaskList(graph)) {
  292. MS_LOG(WARNING) << "LoadTask end, task list is empty";
  293. return true;
  294. }
  295. auto model_iter = graph_model_map_.find(graph->graph_id());
  296. if (model_iter == graph_model_map_.end()) {
  297. MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph LoadTask without GenTask.";
  298. return false;
  299. }
  300. std::shared_ptr<ge::ModelListener> listener;
  301. MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first;
  302. bool status = ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first,
  303. model_iter->second, listener);
  304. if (!status) {
  305. MS_LOG(ERROR) << "load task failed";
  306. return false;
  307. }
  308. if (ProfilingManager::GetInstance().IsProfiling()) {
  309. std::vector<uint32_t> task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first);
  310. ProfilingUtils::ReportProfilingData(graph->graph_id(), task_ids);
  311. }
  312. return true;
  313. }
  314. bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
  315. MS_EXCEPTION_IF_NULL(graph);
  316. MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id();
  317. auto context_ptr = MsContext::GetInstance();
  318. MS_EXCEPTION_IF_NULL(context_ptr);
  319. ge::InputData input_tensors = ge::InputData();
  320. ge::OutputData *output_tensors = nullptr;
  321. if (GraphWithEmptyTaskList(graph)) {
  322. MS_LOG(WARNING) << "RunTask end, no task info found";
  323. return true;
  324. }
  325. if (!CheckGraphIdValid(graph->graph_id())) {
  326. MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph RunTask without GenTask.";
  327. return false;
  328. }
  329. bool status = ge::model_runner::ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors);
  330. if (!status) {
  331. MS_LOG(INFO) << "run task failed";
  332. return false;
  333. }
  334. return true;
  335. }
  336. bool AscendKernelRuntime::SyncStream() {
  337. if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { // o for switch stream
  338. MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
  339. return false;
  340. }
  341. return true;
  342. }
  343. bool AscendKernelRuntime::InitDevice() {
  344. int device_count = 0;
  345. auto ret = rtGetDeviceCount(&device_count);
  346. if (ret != RT_ERROR_NONE) {
  347. MS_EXCEPTION(DeviceProcessError) << "rtGetDeviceCount, ret[" << static_cast<int>(ret) << "]";
  348. }
  349. ret = rtSetDevice(device_id_);
  350. if (ret != RT_ERROR_NONE) {
  351. MS_EXCEPTION(DeviceProcessError) << "rtSetDevice, ret[" << static_cast<int>(ret) << "]";
  352. }
  353. auto context_ptr = MsContext::GetInstance();
  354. MS_EXCEPTION_IF_NULL(context_ptr);
  355. if (context_ptr == nullptr) {
  356. MS_LOG(ERROR) << "get MsContext instance failed";
  357. return false;
  358. }
  359. if (context_ptr->enable_hccl()) {
  360. if (!HcclInit()) {
  361. MS_LOG(ERROR) << "HcclInit init failed";
  362. return false;
  363. }
  364. }
  365. ret = rtCtxCreate(&rt_context_, 0, device_id_);
  366. if (ret != RT_ERROR_NONE) {
  367. MS_EXCEPTION(DeviceProcessError) << "rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
  368. }
  369. ret = rtCtxSetCurrent(rt_context_);
  370. if (ret != RT_ERROR_NONE) {
  371. MS_EXCEPTION(DeviceProcessError) << "rtCtxSetCurrent, ret[" << ret << "]";
  372. }
  373. ret = rtStreamCreate(&stream_, 0);
  374. if (ret != RT_ERROR_NONE) {
  375. MS_LOG(EXCEPTION) << "rtStreamCreate, ret[" << ret << "]";
  376. }
  377. return true;
  378. }
  379. bool AscendKernelRuntime::ResetDevice() {
  380. auto ret = rtCtxSetCurrent(rt_context_);
  381. if (ret != RT_ERROR_NONE) {
  382. MS_LOG(ERROR) << "call rtCtxSetCurrent failed";
  383. return false;
  384. }
  385. if (stream_ != nullptr) {
  386. ret = rtStreamDestroy(stream_);
  387. if (ret != RT_ERROR_NONE) {
  388. MS_LOG(EXCEPTION) << "rtStreamDestroy, ret[" << ret << "]";
  389. }
  390. stream_ = nullptr;
  391. }
  392. if (rt_context_ != nullptr) {
  393. ret = rtCtxDestroy(rt_context_);
  394. if (ret != RT_ERROR_NONE) {
  395. MS_EXCEPTION(DeviceProcessError) << "rtCtxDestroy, ret[" << ret << "]";
  396. }
  397. rt_context_ = nullptr;
  398. }
  399. return true;
  400. }
  401. bool AscendKernelRuntime::HcclInit() {
  402. auto context_ptr = MsContext::GetInstance();
  403. MS_EXCEPTION_IF_NULL(context_ptr);
  404. if (!context_ptr->IsTsdOpened()) {
  405. MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open";
  406. }
  407. MS_LOG(INFO) << "do hcom init";
  408. const char *config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH");
  409. if (config_path_str == nullptr) {
  410. MS_LOG(ERROR) << "get hccl json config failed, please set env MINDSPORE_HCCL_CONFIG_PATH";
  411. return false;
  412. }
  413. auto full_path = realpath(config_path_str, nullptr);
  414. if (full_path == nullptr) {
  415. MS_LOG(ERROR) << "file path " << config_path_str << " does not exist";
  416. return false;
  417. }
  418. const char *identify = std::getenv("RANK_ID");
  419. if (identify == nullptr) {
  420. MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID";
  421. free(full_path);
  422. return false;
  423. }
  424. MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << identify;
  425. hcclResult_t res = hcom_init(full_path, identify);
  426. free(full_path);
  427. if (res != HCCL_SUCCESS) {
  428. MS_LOG(ERROR) << "hcom init failed, res is " << static_cast<int>(res);
  429. return false;
  430. }
  431. return true;
  432. }
  433. bool AscendKernelRuntime::DestroyHccl() {
  434. auto context_ptr = MsContext::GetInstance();
  435. MS_EXCEPTION_IF_NULL(context_ptr);
  436. if (!NeedDestroyHccl()) {
  437. MS_LOG(INFO) << "hccl is not enable, no need to close.";
  438. return true;
  439. }
  440. hcclResult_t res = hcom_destroy();
  441. if (res != HCCL_SUCCESS) {
  442. MS_LOG(ERROR) << "hccl destroy failed";
  443. return false;
  444. }
  445. MS_LOG(INFO) << "hccl destroy successful, status = " << res << ".";
  446. context_ptr->set_enable_hccl(false);
  447. return true;
  448. }
  449. bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph *graph) const {
  450. auto iter = task_map_.find(graph->graph_id());
  451. if (iter == task_map_.end()) {
  452. MS_LOG(EXCEPTION) << "Unknown graph ptr";
  453. }
  454. return iter->second.empty();
  455. }
  456. bool AscendKernelRuntime::CheckGraphIdValid(GraphId graph_id) const {
  457. return task_map_.find(graph_id) != task_map_.end() && graph_model_map_.find(graph_id) != graph_model_map_.end();
  458. }
  459. } // namespace ascend
  460. } // namespace device
  461. } // namespace mindspore