You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger.cc 41 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <stdio.h>
  18. #include <fstream>
  19. #include <tuple>
  20. #include <vector>
  21. #include <algorithm>
  22. #include <iostream>
  23. #include <cstring>
  24. #include <utility>
  25. #include <map>
  26. #include <regex>
  27. #include "debug/debugger/debugger.h"
  28. #include "debug/data_dump/dump_json_parser.h"
  29. #include "pipeline/jit/pipeline.h"
  30. #include "backend/session/anf_runtime_algorithm.h"
  31. #include "runtime/device/kernel_runtime_manager.h"
  32. #include "runtime/device/kernel_runtime.h"
  33. #include "debug/data_dump/e2e_dump_util.h"
  34. #include "utils/config_manager.h"
  35. using debugger::Chunk;
  36. using debugger::EventReply;
  37. using debugger::GraphProto;
  38. using debugger::ModelProto;
  39. using debugger::TensorProto;
  40. using debugger::WatchCondition;
  41. using debugger::WatchCondition_Condition_inf;
  42. using debugger::WatchCondition_Condition_nan;
  43. using debugger::WatchCondition_Parameter;
  44. using debugger::WatchNode;
  45. using debugger::WatchpointHit;
  46. #define CHUNK_SIZE 1024 * 1024 * 3
  47. namespace mindspore {
  48. DebuggerPtr Debugger::debugger_ = nullptr;
  49. std::mutex Debugger::instance_lock_;
  50. static const size_t PARAMETER_OUTPUT_INDEX = 0;
  51. static const size_t VALUE_NODE_OUTPUT_INDEX = 0;
  52. Debugger::Debugger()
  53. : grpc_client_(nullptr),
  54. debug_services_(nullptr),
  55. device_id_(0),
  56. device_target_(""),
  57. num_step_(0),
  58. debugger_enabled_(false),
  59. run_level_(""),
  60. node_name_(""),
  61. cur_name_(""),
  62. training_done_(false),
  63. is_dataset_graph_(false),
  64. partial_memory_(false),
  65. last_overflow_bin_(0),
  66. overflow_bin_path_(""),
  67. initial_suspend_(true),
  68. not_dataset_graph_sum_(0) {
  69. if (CheckDebuggerEnabled()) {
  70. // configure partial memory reuse
  71. partial_memory_ = CheckDebuggerPartialMemoryEnabled();
  72. auto context_ptr = MsContext::GetInstance();
  73. MS_EXCEPTION_IF_NULL(context_ptr);
  74. // switch memory reuse on or off
  75. context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, partial_memory_);
  76. // print some message about memory reuse to user
  77. if (partial_memory_) {
  78. MS_LOG(WARNING)
  79. << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
  80. "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
  81. } else {
  82. MS_LOG(INFO) << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
  83. "usage for large models.";
  84. }
  85. }
  86. }
  87. void Debugger::Init(const uint32_t device_id, const std::string device_target) {
  88. // access lock for public method
  89. std::lock_guard<std::mutex> a_lock(access_lock_);
  90. // save device_id
  91. MS_LOG(INFO) << "Debugger got device_id: " << device_id;
  92. device_id_ = device_id;
  93. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  94. device_target_ = device_target;
  95. }
  96. void Debugger::EnableDebugger() {
  97. // reset some of the class members
  98. num_step_ = 0;
  99. debugger_enabled_ = false;
  100. partial_memory_ = false;
  101. grpc_client_ = nullptr;
  102. debug_services_ = nullptr;
  103. // see if dump using debugger backend is enabled
  104. bool dump_enabled = CheckDebuggerDumpEnabled();
  105. MS_LOG(INFO) << "dump using debugger backend = " << dump_enabled;
  106. // check if debugger enabled
  107. debugger_enabled_ = CheckDebuggerEnabled();
  108. MS_LOG(INFO) << "debugger_enabled_ = " << debugger_enabled_;
  109. if (!debugger_enabled_ && !dump_enabled) {
  110. MS_LOG(INFO) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
  111. return;
  112. }
  113. // configure grpc host
  114. const char *env_host_str = std::getenv("MS_DEBUGGER_HOST");
  115. std::string host;
  116. if (env_host_str != nullptr) {
  117. std::regex reg_ip(
  118. "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
  119. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  120. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  121. "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
  122. std::smatch smat;
  123. std::string host_str = std::string(env_host_str);
  124. if (std::regex_match(host_str, smat, reg_ip)) {
  125. MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
  126. host = std::string(env_host_str);
  127. } else {
  128. MS_LOG(ERROR) << "Environment variable MS_DEBUGGER_HOST isn't a valid IP address. "
  129. "Please set environment variable MS_DEBUGGER_HOST=x.x.x.x to a valid IP";
  130. debugger_enabled_ = false;
  131. }
  132. } else {
  133. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
  134. host = "localhost";
  135. }
  136. // configure grpc port
  137. const char *env_port_str = std::getenv("MS_DEBUGGER_PORT");
  138. std::string port;
  139. if (env_port_str != nullptr) {
  140. if (CheckPort(env_port_str)) {
  141. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
  142. port = std::string(env_port_str);
  143. } else {
  144. MS_LOG(ERROR) << "Environment variable MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to 65535";
  145. debugger_enabled_ = false;
  146. }
  147. } else {
  148. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
  149. port = "50051";
  150. }
  151. #ifdef ENABLE_D
  152. // set operation overflow info
  153. overflow_bin_path_ = DumpJsonParser::GetInstance().GetOpOverflowBinPath(graph_ptr_->graph_id(), device_id_);
  154. // new overflow dump files will have a timestamp greater than last_overflow_bin_
  155. last_overflow_bin_ = 0;
  156. DIR *d;
  157. d = opendir(overflow_bin_path_.c_str());
  158. if (d != nullptr) {
  159. struct dirent *dir;
  160. while ((dir = readdir(d)) != NULL) {
  161. if (dir->d_type == DT_REG) {
  162. std::string file_path = overflow_bin_path_;
  163. file_path.append(dir->d_name);
  164. std::size_t found = file_path.find_last_of(".");
  165. if (found == std::string::npos) {
  166. continue;
  167. }
  168. std::string overflow_time = file_path.substr(found + 1);
  169. if (stod(overflow_time) <= last_overflow_bin_) {
  170. MS_LOG(INFO) << "Old op overflow bin folder" << file_path;
  171. continue;
  172. }
  173. last_overflow_bin_ = stod(overflow_time);
  174. }
  175. }
  176. MS_LOG(INFO) << "last op overflow bin folder" << last_overflow_bin_;
  177. closedir(d);
  178. }
  179. #endif
  180. // initialize grpc client
  181. if (debugger_enabled_) {
  182. grpc_client_ = std::make_unique<GrpcClient>(host, port);
  183. }
  184. debug_services_ = std::make_unique<DebugServices>();
  185. }
  186. void Debugger::CheckDatasetSinkMode() {
  187. if (CheckDebuggerDumpEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  188. MS_EXCEPTION(NotSupportError)
  189. << "e2e_dump not supported on GPU with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  190. }
  191. if (CheckDebuggerEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  192. MS_EXCEPTION(NotSupportError)
  193. << "Debugger is not supported with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  194. }
  195. }
  196. bool Debugger::CheckDebuggerDumpEnabled() {
  197. // see if dump is enabled
  198. if (device_target_ == kGPUDevice) {
  199. return device::KernelRuntime::DumpDataEnabled();
  200. }
  201. return false;
  202. }
  203. bool Debugger::CheckDebuggerEnabled() {
  204. // get env variables to configure debugger
  205. const char *env_enable_str = std::getenv("ENABLE_MS_DEBUGGER");
  206. if (env_enable_str != nullptr) {
  207. if (std::strcmp(env_enable_str, "1") == 0) {
  208. return true;
  209. }
  210. }
  211. return false;
  212. }
  213. bool Debugger::CheckDebuggerPartialMemoryEnabled() {
  214. const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM");
  215. if (env_partial_mem_str != nullptr) {
  216. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
  217. if (std::strcmp(env_partial_mem_str, "1") == 0) {
  218. return true;
  219. }
  220. }
  221. return false;
  222. }
  223. bool Debugger::DebuggerBackendEnabled() { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
  224. void Debugger::Reset() {
  225. // access lock for public method
  226. std::lock_guard<std::mutex> a_lock(access_lock_);
  227. // reset components
  228. device_id_ = 0;
  229. device_target_ = "";
  230. num_step_ = 0;
  231. debugger_enabled_ = false;
  232. is_dataset_graph_ = false;
  233. partial_memory_ = false;
  234. graph_ptr_ = nullptr;
  235. grpc_client_ = nullptr;
  236. debug_services_ = nullptr;
  237. last_overflow_bin_ = 0;
  238. overflow_bin_path_ = "";
  239. stream_task_to_opname_.clear();
  240. }
  241. void Debugger::PreExecute(const KernelGraphPtr &graph_ptr, uint32_t graph_sum) {
  242. // access lock for public method
  243. std::lock_guard<std::mutex> a_lock(access_lock_);
  244. CheckDatasetSinkMode();
  245. auto graph_id = graph_ptr->graph_id();
  246. // collect rungrap_ids to update step number in multigraph case
  247. if (!rungraph_id_list_.size()) {
  248. rungraph_id_list_.push_back(graph_id);
  249. } else {
  250. if (std::find(rungraph_id_list_.begin(), rungraph_id_list_.end(), graph_id) == rungraph_id_list_.end()) {
  251. rungraph_id_list_.push_back(graph_id);
  252. }
  253. }
  254. // check and save graph_ptr, suspend if graph is new
  255. MS_LOG(INFO) << "total number graph: " << graph_sum;
  256. // multiple graphs
  257. if (graph_sum > 1) {
  258. // there are more than one graphs are not dataset_graph
  259. if (not_dataset_graph_sum_ > 0) {
  260. // only try to enable debugger if they are not all dataset graphs
  261. if (!debugger_enabled_) {
  262. EnableDebugger();
  263. }
  264. if (debugger_enabled_) {
  265. if (graph_proto_list_.size()) {
  266. // only send compiled graphs once.
  267. SendMultiGraphsAndSuspend(graph_proto_list_, graph_sum);
  268. graph_proto_list_.clear();
  269. } else if (graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
  270. // stop only when receive the first sub run graph for each step
  271. CommandLoop();
  272. }
  273. }
  274. }
  275. } else if (graph_proto_list_.size() == 1) {
  276. // In single graph case, reset graph_ptr_ to be nullptr for the initial step
  277. if (num_step_ == 0) {
  278. graph_ptr_ = nullptr;
  279. }
  280. CheckGraphPtr(graph_ptr);
  281. }
  282. }
  283. void Debugger::PostExecute() {
  284. // access lock for public method
  285. std::lock_guard<std::mutex> a_lock(access_lock_);
  286. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  287. return;
  288. }
  289. if (debugger_->DebuggerBackendEnabled()) {
  290. // analyze tensor data and send the watchpoints been hit
  291. if (run_level_ == "node") {
  292. MS_LOG(INFO) << "Debugger is in node level mode ";
  293. return;
  294. }
  295. if (debugger_enabled_ && !is_dataset_graph_) {
  296. if (device_target_ != kGPUDevice) {
  297. num_step_++;
  298. MS_LOG(INFO) << "Debugger suspend at end of step; number of steps executed: " << num_step_;
  299. SendWatchpoints(CheckWatchpoints());
  300. CommandLoop();
  301. } else {
  302. CommandLoop();
  303. }
  304. }
  305. }
  306. }
  307. bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) {
  308. if (debugger_enabled_ && !is_dataset_graph_) {
  309. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  310. // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
  311. if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
  312. return true;
  313. }
  314. }
  315. return false;
  316. }
  317. void Debugger::PostExecuteNode(const CNodePtr &kernel) {
  318. // access lock for public method
  319. std::lock_guard<std::mutex> a_lock(access_lock_);
  320. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  321. return;
  322. }
  323. if (debugger_enabled_ && !is_dataset_graph_) {
  324. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  325. // if kernel is watchpoint,and get hit. suspend.
  326. bool hit_empty_flag = true;
  327. if (is_watchpoint) {
  328. auto hits = CheckWatchpoints(cur_name_, kernel);
  329. if (!hits.empty()) {
  330. SendWatchpoints(hits);
  331. CommandLoop();
  332. hit_empty_flag = false;
  333. }
  334. }
  335. if (hit_empty_flag && run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) {
  336. // if kernel is not watchpoint and is next_to or continue_to node, suspend
  337. CommandLoop();
  338. }
  339. return;
  340. }
  341. }
  342. void Debugger::PostDebugOp() {
  343. // access lock for public method
  344. std::lock_guard<std::mutex> a_lock(access_lock_);
  345. // suspend if debugger is enabled
  346. if (debugger_enabled_ && !is_dataset_graph_) {
  347. MS_LOG(INFO) << "Debugger suspend at debug_op";
  348. CommandLoop();
  349. }
  350. }
  351. void Debugger::SetStreamTaskToOpnameMap(const std::map<std::pair<uint32_t, uint32_t>, std::string> &mapping) {
  352. stream_task_to_opname_ = mapping;
  353. }
  354. void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) {
  355. if (graph_ptr_ != graph_ptr) {
  356. MS_LOG(INFO) << "LoadGraphs Debugger got new graph: " << graph_ptr->graph_id();
  357. // save new graph_ptr
  358. graph_ptr_ = graph_ptr;
  359. CheckDatasetGraph();
  360. if (!is_dataset_graph_) {
  361. // get proto for new graph_ptr
  362. auto graph_proto = GetGraphProto(graph_ptr);
  363. // add new graph proto to graph_proto_list_
  364. graph_proto_list_.push_back(graph_proto);
  365. graph_ptr_list_.push_back(graph_ptr);
  366. not_dataset_graph_sum_++;
  367. }
  368. // reset is_dataset_graph to be false
  369. is_dataset_graph_ = false;
  370. }
  371. }
  372. // In single graph cases, check single graph ptr
  373. void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
  374. if (graph_ptr_ != graph_ptr) {
  375. MS_LOG(INFO) << "CheckGraphPtr Debugger got new graph: " << graph_ptr->graph_id();
  376. // save new graph_ptr
  377. graph_ptr_ = graph_ptr;
  378. if (!is_dataset_graph_) {
  379. // only try to enable debugger if it is not a dataset graph
  380. EnableDebugger();
  381. if (debugger_enabled_) {
  382. LoadParametersAndConst();
  383. // get graph proto and send to mindinsight
  384. auto graph_proto = graph_proto_list_.front();
  385. SendGraphAndSuspend(graph_proto);
  386. }
  387. }
  388. }
  389. }
  390. void Debugger::CheckDatasetGraph() {
  391. // print parameter node names
  392. const auto &params = graph_ptr_->inputs();
  393. for (const auto &param : params) {
  394. MS_LOG(INFO) << "param: " << param->fullname_with_scope();
  395. }
  396. // check if there is GetNext or InitDataSetQueue node
  397. const auto &nodes = graph_ptr_->execution_order();
  398. for (const auto &node : nodes) {
  399. auto node_name = AnfAlgo::GetCNodeName(node);
  400. MS_LOG(INFO) << "node: " << node->fullname_with_scope();
  401. if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
  402. MS_LOG(INFO) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
  403. << node_name;
  404. is_dataset_graph_ = true;
  405. return;
  406. }
  407. }
  408. is_dataset_graph_ = false;
  409. }
  410. GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
  411. // convert kernel graph to debugger modelproto
  412. ModelProto model = GetDebuggerFuncGraphProto(graph_ptr_);
  413. return model.graph();
  414. }
  415. void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
  416. SendMetadata();
  417. // send graph to mindinght server
  418. EventReply reply = grpc_client_->SendGraph(graph_proto);
  419. if (reply.status() != reply.OK) {
  420. MS_LOG(ERROR) << "Error: SendGraph failed";
  421. }
  422. // enter command loop, wait and process commands
  423. CommandLoop();
  424. }
  425. void Debugger::SendMetadata() {
  426. // prepare metadata
  427. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  428. Metadata metadata;
  429. metadata.set_device_name(device_name);
  430. metadata.set_cur_step(num_step_);
  431. metadata.set_backend(device_target_);
  432. metadata.set_cur_node(cur_name_);
  433. metadata.set_training_done(training_done_);
  434. MS_LOG(INFO) << "Is training done?" << training_done_;
  435. // set graph munber to not_dataset_graph_sum_
  436. metadata.set_graph_num(not_dataset_graph_sum_);
  437. EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
  438. if (reply_metadata.status() != reply_metadata.OK) {
  439. MS_LOG(ERROR) << "Error: SendMetadata failed";
  440. }
  441. }
  442. void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list, uint32_t graph_sum) {
  443. SendMetadata();
  444. // send multiple graphs to mindinght server
  445. // split graph into chunks if one graph is larger than chunk size
  446. std::list<Chunk> chunked_graph_proto_list;
  447. Chunk chunk;
  448. for (auto graph : graph_proto_list) {
  449. std::string str = graph.SerializeAsString();
  450. auto graph_size = graph.ByteSize();
  451. if (graph_size > CHUNK_SIZE) {
  452. auto sub_graph_str = grpc_client_->ChunkString(str, graph_size);
  453. for (unsigned int i = 0; i < sub_graph_str.size(); i++) {
  454. chunk.set_buffer(sub_graph_str[i]);
  455. chunked_graph_proto_list.push_back(chunk);
  456. if (i < sub_graph_str.size() - 1) {
  457. chunk.set_finished(false);
  458. } else {
  459. chunk.set_finished(true);
  460. chunked_graph_proto_list.push_back(chunk);
  461. }
  462. }
  463. } else {
  464. chunk.set_buffer(str);
  465. chunk.set_finished(true);
  466. chunked_graph_proto_list.push_back(chunk);
  467. }
  468. }
  469. EventReply reply = grpc_client_->SendMultiGraphs(chunked_graph_proto_list);
  470. if (reply.status() != reply.OK) {
  471. MS_LOG(ERROR) << "Error: SendGraph failed";
  472. }
  473. // enter command loop, wait and process commands
  474. CommandLoop();
  475. }
  476. void Debugger::CommandLoop() {
  477. // prepare metadata
  478. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  479. Metadata metadata;
  480. metadata.set_device_name(device_name);
  481. metadata.set_cur_step(num_step_);
  482. metadata.set_backend(device_target_);
  483. metadata.set_cur_node(cur_name_);
  484. metadata.set_training_done(training_done_);
  485. // loop exit flag
  486. bool run = false;
  487. int num_wait_fail = 0;
  488. const int max_num_wait_fail = 5;
  489. while (!run) {
  490. // wait for command
  491. EventReply reply = grpc_client_->WaitForCommand(metadata);
  492. if (reply.status() != reply.OK) {
  493. MS_LOG(ERROR) << "Error: WaitForCommand failed";
  494. num_wait_fail++;
  495. if (num_wait_fail > max_num_wait_fail) {
  496. MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session.";
  497. MS_LOG(ERROR) << "Failed to connect to MindInsight debugger server. Please check the config "
  498. "of debugger host and port.";
  499. Exit();
  500. run = true;
  501. } else {
  502. MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
  503. << num_wait_fail << "s";
  504. std::this_thread::sleep_for(std::chrono::milliseconds(1000 * num_wait_fail));
  505. }
  506. continue;
  507. }
  508. // get type of the command in reply
  509. DebuggerCommand cmd = GetCommand(reply);
  510. if (cmd == DebuggerCommand::kUnknownCMD) {
  511. MS_LOG(DEBUG) << "Debug: debugger received unknown command";
  512. continue;
  513. }
  514. MS_LOG(INFO) << "received command: ";
  515. switch (cmd) {
  516. case DebuggerCommand::kUnknownCMD:
  517. MS_LOG(INFO) << "UnknownCMD";
  518. break;
  519. case DebuggerCommand::kExitCMD:
  520. MS_LOG(INFO) << "ExitCMD";
  521. Exit();
  522. // Used for debugger termination
  523. run = true;
  524. break;
  525. case DebuggerCommand::kRunCMD:
  526. MS_LOG(INFO) << "RunCMD";
  527. if (GetRunLevel(reply) == "recheck") {
  528. MS_LOG(INFO) << "rechecking all watchpoints";
  529. SendWatchpoints(CheckWatchpoints());
  530. } else {
  531. // no longer the initial suspension.
  532. initial_suspend_ = false;
  533. // print run cmd content
  534. // get run_level and node_name
  535. run_level_ = GetRunLevel(reply);
  536. node_name_ = GetNodeName(reply);
  537. MS_LOG(INFO) << "run_level: " << run_level_;
  538. MS_LOG(INFO) << "node_name_: " << node_name_;
  539. // exit loop
  540. run = true;
  541. }
  542. break;
  543. case DebuggerCommand::kSetCMD:
  544. MS_LOG(INFO) << "SetCMD";
  545. MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
  546. MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
  547. if (GetWatchpointDelete(reply)) {
  548. MS_LOG(INFO) << "Deleting watchpoint";
  549. RemoveWatchpoint(GetWatchpointID(reply));
  550. } else {
  551. MS_LOG(INFO) << "Setting watchpoint";
  552. MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
  553. ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
  554. for (const auto &node : recieved_nodes) {
  555. MS_LOG(INFO) << "node name: " << node.node_name();
  556. MS_LOG(INFO) << "node type: " << node.node_type();
  557. }
  558. ProtoVector<WatchCondition_Parameter> parameters = GetParameters(reply);
  559. for (const auto &parameter : parameters) {
  560. MS_LOG(INFO) << "parameter name: " << parameter.name();
  561. MS_LOG(INFO) << "parameter is disabled: " << parameter.disabled();
  562. MS_LOG(INFO) << "parameter value: " << parameter.value();
  563. }
  564. SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply), GetParameters(reply));
  565. }
  566. break;
  567. case DebuggerCommand::kViewCMD:
  568. MS_LOG(INFO) << "ViewCMD";
  569. {
  570. // print view cmd content
  571. ProtoVector<TensorProto> received_tensors = GetTensors(reply);
  572. for (auto tensor : received_tensors) {
  573. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  574. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  575. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  576. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  577. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  578. }
  579. }
  580. MS_LOG(INFO) << "Sending tensors";
  581. std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
  582. {
  583. // print view cmd reply
  584. for (auto tensor : tensors) {
  585. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  586. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  587. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  588. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  589. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  590. MS_LOG(INFO) << "tensor dims: ";
  591. for (auto dim : tensor.dims()) {
  592. MS_LOG(INFO) << dim << ",";
  593. }
  594. MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
  595. }
  596. }
  597. EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
  598. if (send_tensors_reply.status() != send_tensors_reply.OK) {
  599. MS_LOG(ERROR) << "Error: SendTensors failed";
  600. }
  601. break;
  602. }
  603. }
  604. }
  605. void AddTensorProtoInfo(TensorProto *tensor_item, TensorProto tensor) {
  606. tensor_item->set_node_name(tensor.node_name());
  607. tensor_item->set_slot(tensor.slot());
  608. tensor_item->set_iter(tensor.iter());
  609. tensor_item->set_truncate(tensor.truncate());
  610. tensor_item->clear_tensor_content();
  611. tensor_item->clear_data_type();
  612. tensor_item->clear_dims();
  613. }
  614. void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id,
  615. const ProtoVector<WatchCondition_Parameter> &parameters) {
  616. std::vector<std::tuple<std::string, bool>> check_node_list;
  617. std::vector<DebugServices::parameter_t> parameter_list;
  618. std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
  619. [](const WatchNode &node) -> std::tuple<std::string, bool> {
  620. return make_tuple(node.node_name(), node.node_type() == "scope");
  621. });
  622. std::transform(
  623. parameters.begin(), parameters.end(), std::back_inserter(parameter_list),
  624. [](const WatchCondition_Parameter &parameter) -> DebugServices::parameter_t {
  625. return DebugServices::parameter_t{parameter.name(), parameter.disabled(), parameter.value(), parameter.hit()};
  626. });
  627. debug_services_->AddWatchpoint(id, condition.condition(), condition.value(), check_node_list, parameter_list);
  628. if (initial_suspend_ &&
  629. static_cast<DebugServices::CONDITION_TYPE>(condition.condition()) == DebugServices::CONDITION_TYPE::INIT)
  630. SendWatchpoints(CheckWatchpoints());
  631. }
  632. void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
  633. std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
  634. std::vector<std::string> name;
  635. std::vector<std::string> ret_name;
  636. std::vector<char *> data_ptr;
  637. std::vector<unsigned int> data_size;
  638. std::vector<TypePtr> dtype;
  639. std::vector<std::vector<int64_t>> shape;
  640. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  641. // ret_name will contain tensor names that are found in TensorLoader
  642. // items in ret_name will be in the same order with tensors if found
  643. debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
  644. std::list<TensorProto> tensor_list;
  645. unsigned int result_index = 0;
  646. for (auto tensor : tensors) {
  647. int size_iter = 0;
  648. if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
  649. TensorProto tensor_item;
  650. tensor_item.set_finished(true);
  651. AddTensorProtoInfo(&tensor_item, tensor);
  652. tensor_list.push_back(tensor_item);
  653. continue;
  654. }
  655. int tensor_size = data_size[result_index];
  656. while (size_iter < tensor_size) {
  657. int chunk_size = CHUNK_SIZE;
  658. TensorProto tensor_item;
  659. tensor_item.set_finished(false);
  660. if (tensor_size - size_iter <= CHUNK_SIZE) {
  661. chunk_size = tensor_size - size_iter;
  662. tensor_item.set_finished(true);
  663. }
  664. AddTensorProtoInfo(&tensor_item, tensor);
  665. // return empty tensor if didn't find the requested tensor
  666. tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
  667. tensor_item.set_data_type(GetDebuggerNumberDataType(dtype[result_index]));
  668. for (auto &elem : shape[result_index]) {
  669. tensor_item.add_dims(elem);
  670. }
  671. // add tensor to result list and increment result_index to check next item in ret_name
  672. tensor_list.push_back(tensor_item);
  673. size_iter += CHUNK_SIZE;
  674. }
  675. result_index++;
  676. }
  677. return tensor_list;
  678. }
  679. void Debugger::Exit() {
  680. // clear resource before exit
  681. // For node level, debugger has to exit itself because main thread can only exit in step bundary;
  682. // For step level, debugger will notify main thread to exit;
  683. if (run_level_ == "node") {
  684. pipeline::ClearResAtexit();
  685. exit(1);
  686. } else if (run_level_ == "step" || device_target_ == kAscendDevice) {
  687. // Notify main thread to terminate
  688. pipeline::ExecutorPy::DebugTerminate(true);
  689. } else {
  690. pipeline::ClearResAtexit();
  691. exit(1);
  692. }
  693. }
  694. std::list<WatchpointHit> Debugger::CheckWatchpoints(const std::string &watchnode, const CNodePtr &kernel) {
  695. std::vector<std::string> name;
  696. std::vector<std::string> slot;
  697. std::vector<int> condition;
  698. std::vector<unsigned int> watchpoint_id;
  699. std::vector<std::string> overflow_ops;
  700. std::vector<std::vector<DebugServices::parameter_t>> parameters;
  701. #ifdef ENABLE_D
  702. overflow_ops = CheckOpOverflow();
  703. #endif
  704. auto tensor_loader = debug_services_->tensor_loader();
  705. std::vector<std::shared_ptr<TensorData>> tensor_list;
  706. if (watchnode.empty()) {
  707. tensor_list = tensor_loader->GetTensor();
  708. } else {
  709. tensor_list = tensor_loader->GetNodeTensorMap(watchnode);
  710. debug_services_->AddWeightsBiasInputs(&tensor_list, kernel);
  711. }
  712. debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, &parameters, overflow_ops, tensor_list,
  713. initial_suspend_);
  714. std::list<WatchpointHit> hits;
  715. for (unsigned int i = 0; i < name.size(); i++) {
  716. WatchpointHit hit;
  717. std::vector<DebugServices::parameter_t> &parameter = parameters[i];
  718. hit.set_id(watchpoint_id[i]);
  719. // here TensorProto act as a tensor indicator, not sending tensor content
  720. TensorProto *tensor_item = hit.mutable_tensor();
  721. tensor_item->set_node_name(name[i]);
  722. tensor_item->set_slot(slot[i]);
  723. tensor_item->set_finished(true);
  724. WatchCondition *condition_item = hit.mutable_watch_condition();
  725. condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
  726. for (const auto &p : parameter) {
  727. auto x = condition_item->mutable_params()->Add();
  728. x->set_name(p.name);
  729. x->set_disabled(p.disabled);
  730. x->set_value(p.value);
  731. x->set_hit(p.hit);
  732. }
  733. hits.push_back(hit);
  734. }
  735. return hits;
  736. }
  737. void Debugger::SendWatchpoints(const std::list<WatchpointHit> &points) {
  738. // send info about watchpoint
  739. if (!points.empty()) {
  740. EventReply reply = grpc_client_->SendWatchpointHits(points);
  741. if (reply.status() != reply.OK) {
  742. MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
  743. }
  744. }
  745. }
  746. DebugServices *Debugger::debug_services() const { return debug_services_.get(); }
  747. bool Debugger::debugger_enabled() const { return debugger_enabled_; }
  748. DebuggerCommand GetCommand(const EventReply &reply) {
  749. DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
  750. switch (reply.cmd_case()) {
  751. case debugger::EventReply::CmdCase::kExit:
  752. cmd = DebuggerCommand::kExitCMD;
  753. break;
  754. case debugger::EventReply::CmdCase::kRunCmd:
  755. cmd = DebuggerCommand::kRunCMD;
  756. break;
  757. case debugger::EventReply::CmdCase::kSetCmd:
  758. cmd = DebuggerCommand::kSetCMD;
  759. break;
  760. case debugger::EventReply::CmdCase::kViewCmd:
  761. cmd = DebuggerCommand::kViewCMD;
  762. break;
  763. default:
  764. MS_LOG(DEBUG) << "Debug: UnknownCMD";
  765. break;
  766. }
  767. return cmd;
  768. }
  769. ProtoVector<WatchCondition_Parameter> GetParameters(const EventReply &reply) {
  770. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  771. MS_LOG(ERROR) << "Error: Can not get Parameters from command. Returning default value: ProtoVector<Parameter>().";
  772. return ProtoVector<WatchCondition_Parameter>();
  773. }
  774. return reply.set_cmd().watch_condition().params();
  775. }
  776. ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
  777. if (!reply.has_set_cmd()) {
  778. MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
  779. return ProtoVector<WatchNode>();
  780. }
  781. return reply.set_cmd().watch_nodes();
  782. }
  783. std::string GetRunLevel(const EventReply &reply) {
  784. if (!reply.has_run_cmd()) {
  785. MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
  786. "";
  787. return "";
  788. }
  789. return reply.run_cmd().run_level();
  790. }
  791. std::string GetNodeName(const EventReply &reply) {
  792. if (!reply.has_run_cmd()) {
  793. MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
  794. "";
  795. return "";
  796. }
  797. return reply.run_cmd().node_name();
  798. }
  799. WatchCondition GetWatchcondition(const EventReply &reply) {
  800. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  801. MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
  802. return WatchCondition();
  803. }
  804. return reply.set_cmd().watch_condition();
  805. }
  806. int32_t GetWatchpointID(const EventReply &reply) {
  807. if (!reply.has_set_cmd()) {
  808. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
  809. return 0;
  810. }
  811. return reply.set_cmd().id();
  812. }
  813. bool GetWatchpointDelete(const EventReply &reply) {
  814. if (!reply.has_set_cmd()) {
  815. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
  816. return false;
  817. }
  818. return reply.set_cmd().delete_();
  819. }
  820. ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
  821. if (!reply.has_view_cmd()) {
  822. MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
  823. return ProtoVector<TensorProto>();
  824. }
  825. return reply.view_cmd().tensors();
  826. }
  827. std::string GetTensorFullName(const TensorProto &tensor) {
  828. string node_name = tensor.node_name();
  829. if (tensor.truncate()) {
  830. // scopes in node name are seperated by '/'
  831. // use the name without scope if truncate is true
  832. std::size_t found = node_name.find_last_of("/");
  833. node_name = node_name.substr(found + 1);
  834. }
  835. return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
  836. }
  837. bool Debugger::partial_memory() { return partial_memory_; }
  838. void Debugger::SetCurNode(std::string cur_name) {
  839. // access lock for public method
  840. std::lock_guard<std::mutex> a_lock(access_lock_);
  841. cur_name_ = cur_name;
  842. }
  843. std::string Debugger::run_level() const { return run_level_; }
  844. void Debugger::SetStepNum(int32_t cur_num_step) {
  845. // access lock for public method
  846. std::lock_guard<std::mutex> a_lock(access_lock_);
  847. num_step_ = cur_num_step;
  848. }
  849. int32_t Debugger::step_num() const { return num_step_; }
  850. uint64_t BytestoInt64(const std::vector<char> &buffer) {
  851. uint64_t ret;
  852. ret = ((uint64_t)buffer[7] << 56) | ((uint64_t)buffer[6] << 48) | ((uint64_t)buffer[5] << 40) |
  853. ((uint64_t)buffer[4] << 32) | ((uint64_t)buffer[3] << 24) | ((uint64_t)buffer[2] << 16) |
  854. ((uint64_t)buffer[1] << 8) | ((uint64_t)buffer[0]);
  855. return ret;
  856. }
  857. #define BUF_SIZ 256
  858. std::vector<std::string> Debugger::CheckOpOverflow() {
  859. std::vector<double> bin_list;
  860. std::vector<std::string> op_names;
  861. DIR *d;
  862. struct dirent *dir = nullptr;
  863. d = opendir(overflow_bin_path_.c_str());
  864. if (d != nullptr) {
  865. while ((dir = readdir(d)) != NULL) {
  866. if (dir->d_type == DT_REG) {
  867. std::string file_path = overflow_bin_path_;
  868. file_path.append(dir->d_name);
  869. std::string file_name = dir->d_name;
  870. std::size_t found = file_name.find_last_of(".");
  871. if (found == std::string::npos) {
  872. continue;
  873. }
  874. std::string overflow_time = file_name.substr(found + 1);
  875. if (stod(overflow_time) <= last_overflow_bin_) {
  876. MS_LOG(INFO) << "File already processed " << file_name;
  877. continue;
  878. }
  879. bin_list.push_back(stod(overflow_time));
  880. std::fstream infile;
  881. infile.open(file_path.c_str(), std::ios::binary | std::ios::in);
  882. if (!infile.is_open()) {
  883. MS_LOG(ERROR) << "Failed to open overflow bin file " << file_name;
  884. continue;
  885. }
  886. infile.seekg(313, std::ios::beg);
  887. std::vector<char> buffer;
  888. buffer.resize(BUF_SIZ);
  889. infile.read(buffer.data(), BUF_SIZ);
  890. uint64_t stream_id = BytestoInt64(std::vector<char>(buffer.begin() + 8, buffer.end()));
  891. uint64_t task_id = BytestoInt64(std::vector<char>(buffer.begin() + 16, buffer.end()));
  892. MS_LOG(INFO) << "Overflow stream_id " << stream_id << ", task_id " << task_id << ".";
  893. auto op = debugger_->stream_task_to_opname_.find(std::make_pair(stream_id, task_id));
  894. if (op != debugger_->stream_task_to_opname_.end()) {
  895. MS_LOG(ERROR) << "Overflow detected on node " << op->second << std::endl;
  896. op_names.push_back(op->second);
  897. } else {
  898. MS_LOG(INFO) << "No overflow is detected " << std::endl;
  899. }
  900. infile.close();
  901. }
  902. }
  903. } else {
  904. MS_LOG(INFO) << "OverFlow bin directory does not exist!";
  905. }
  906. closedir(d);
  907. if (op_names.size()) {
  908. MS_LOG(ERROR) << "These operation overflows are detected " << op_names;
  909. }
  910. for (auto &i : bin_list) {
  911. if (i > last_overflow_bin_) {
  912. last_overflow_bin_ = i;
  913. }
  914. }
  915. return op_names;
  916. }
  917. void Debugger::SetTrainingDone(bool training_done) { training_done_ = training_done; }
  918. bool Debugger::CheckPort(const char *port) {
  919. char *p = const_cast<char *>(port);
  920. int num = 0;
  921. if (*p == '0' && *(p + 1) != '\0') return false;
  922. while (*p != '\0') {
  923. if (*p < '0' || *p > '9') return false;
  924. num = num * 10 + (*p) - '0';
  925. if (num < 1 || num > 65535) return false;
  926. p++;
  927. }
  928. return true;
  929. }
  930. uint32_t Debugger::GetFirstRunGraphId() { return rungraph_id_list_.front(); }
  931. void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
  932. MS_EXCEPTION_IF_NULL(anf_node);
  933. if (!anf_node->isa<Parameter>() && !anf_node->isa<ValueNode>()) {
  934. return;
  935. }
  936. bool keep_prev;
  937. if (anf_node->isa<Parameter>()) {
  938. keep_prev = true;
  939. } else {
  940. keep_prev = false;
  941. }
  942. // for parameters and value nodes, set its execution order to be 0;
  943. int exec_order = 0;
  944. std::string node_name = anf_node->fullname_with_scope();
  945. E2eDumpUtil::GetFileKernelName(NOT_NULL(&node_name));
  946. // check if output adde exists, if not, return;
  947. if (!AnfAlgo::OutputAddrExist(anf_node, output_index)) {
  948. return;
  949. }
  950. auto addr = AnfAlgo::GetOutputAddr(anf_node, output_index);
  951. MS_EXCEPTION_IF_NULL(addr);
  952. auto type = AnfAlgo::GetOutputInferDataType(anf_node, output_index);
  953. auto format = kOpFormat_DEFAULT;
  954. string tensor_name = node_name + ':' + "0";
  955. ShapeVector int_shapes;
  956. auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index);
  957. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  958. [](size_t inner_item) { return SizeToInt(inner_item); });
  959. bool ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, keep_prev);
  960. if (!ret) {
  961. MS_LOG(ERROR) << "LoadMemToHost:"
  962. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  963. }
  964. }
  965. void Debugger::LoadParametersAndConst() {
  966. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  967. if (!(num_step_ == 0 || device_target_ == kAscendDevice ||
  968. (device_target_ == kGPUDevice && device::KernelRuntime::DumpDataEnabledIteration())))
  969. return;
  970. MS_EXCEPTION_IF_NULL(graph_ptr_);
  971. // load parameters
  972. MS_LOG(INFO) << "Start to load Parameters!";
  973. const auto &parameters = graph_ptr_->inputs();
  974. for (auto &item : parameters) {
  975. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
  976. }
  977. // load value nodes
  978. // get all constant avlues from the graph
  979. MS_LOG(INFO) << "Start to load value nodes!";
  980. const auto value_nodes = graph_ptr_->graph_value_nodes();
  981. for (auto &item : value_nodes) {
  982. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
  983. }
  984. }
  985. void Debugger::LoadGraphOutputs() {
  986. if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
  987. MS_EXCEPTION_IF_NULL(graph_ptr_);
  988. const auto &apply_kernels = graph_ptr_->execution_order();
  989. // for kernels, execution order starts from 1
  990. int exec_order = 1;
  991. for (const auto &node : apply_kernels) {
  992. MS_EXCEPTION_IF_NULL(node);
  993. auto node_name = AnfAlgo::GetCNodeName(node);
  994. std::string kernel_name = node->fullname_with_scope();
  995. auto output_size = AnfAlgo::GetOutputTensorNum(node);
  996. if (partial_memory_) {
  997. if (!debug_services_->IsWatchPoint(kernel_name, node)) {
  998. continue;
  999. }
  1000. }
  1001. for (size_t j = 0; j < output_size; ++j) {
  1002. auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
  1003. MS_EXCEPTION_IF_NULL(kernel_info);
  1004. auto addr_test = kernel_info->GetOutputAddr(j);
  1005. if (addr_test == nullptr) {
  1006. MS_LOG(INFO) << "Cannot find output addr for slot " << j << " for " << kernel_name;
  1007. continue;
  1008. }
  1009. auto addr = AnfAlgo::GetOutputAddr(node, j);
  1010. MS_EXCEPTION_IF_NULL(addr);
  1011. auto type = AnfAlgo::GetOutputInferDataType(node, j);
  1012. auto format = kOpFormat_DEFAULT;
  1013. string tensor_name = kernel_name + ':' + std::to_string(j);
  1014. ShapeVector int_shapes;
  1015. auto shape = AnfAlgo::GetOutputDeviceShape(node, j);
  1016. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  1017. [](size_t inner_item) { return SizeToInt(inner_item); });
  1018. auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false);
  1019. if (!ret) {
  1020. MS_LOG(ERROR) << "LoadMemToHost:"
  1021. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1022. }
  1023. }
  1024. exec_order = exec_order + 1;
  1025. }
  1026. }
  1027. void Debugger::UpdateStepNum(const session::KernelGraph *graph) {
  1028. // update step number if we are processing the first graph (to support multigraph)
  1029. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()) &&
  1030. (graph->graph_id() == debugger_->GetFirstRunGraphId())) {
  1031. // access lock for public method
  1032. std::lock_guard<std::mutex> a_lock(access_lock_);
  1033. ++num_step_;
  1034. }
  1035. }
  1036. void Debugger::ClearCurrentData() {
  1037. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()))
  1038. debug_services_->tensor_loader()->EmptyCurrentTensor();
  1039. }
  1040. } // namespace mindspore