You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger.cc 43 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <stdio.h>
  18. #include <fstream>
  19. #include <tuple>
  20. #include <vector>
  21. #include <algorithm>
  22. #include <iostream>
  23. #include <cstring>
  24. #include <utility>
  25. #include <map>
  26. #include <regex>
  27. #include "debug/debugger/debugger.h"
  28. #include "debug/data_dump/dump_json_parser.h"
  29. #include "pipeline/jit/pipeline.h"
  30. #include "backend/session/anf_runtime_algorithm.h"
  31. #include "runtime/device/kernel_runtime_manager.h"
  32. #include "runtime/device/kernel_runtime.h"
  33. #include "debug/data_dump/e2e_dump_util.h"
  34. #include "utils/config_manager.h"
  35. using debugger::Chunk;
  36. using debugger::EventReply;
  37. using debugger::GraphProto;
  38. using debugger::ModelProto;
  39. using debugger::TensorProto;
  40. using debugger::WatchCondition;
  41. using debugger::WatchCondition_Condition_inf;
  42. using debugger::WatchCondition_Condition_nan;
  43. using debugger::WatchCondition_Parameter;
  44. using debugger::WatchNode;
  45. using debugger::WatchpointHit;
  46. #define CHUNK_SIZE 1024 * 1024 * 3
  47. namespace mindspore {
  48. DebuggerPtr Debugger::debugger_ = nullptr;
  49. std::mutex Debugger::instance_lock_;
  50. static const size_t PARAMETER_OUTPUT_INDEX = 0;
  51. static const size_t VALUE_NODE_OUTPUT_INDEX = 0;
  52. Debugger::Debugger()
  53. : grpc_client_(nullptr),
  54. debug_services_(nullptr),
  55. device_id_(0),
  56. device_target_(""),
  57. num_step_(0),
  58. debugger_enabled_(false),
  59. run_level_(""),
  60. node_name_(""),
  61. cur_name_(""),
  62. training_done_(false),
  63. is_dataset_graph_(false),
  64. partial_memory_(false),
  65. last_overflow_bin_(0),
  66. overflow_bin_path_(""),
  67. initial_suspend_(true),
  68. not_dataset_graph_sum_(0),
  69. version_("") {
  70. if (CheckDebuggerEnabled()) {
  71. // configure partial memory reuse
  72. partial_memory_ = CheckDebuggerPartialMemoryEnabled();
  73. auto context_ptr = MsContext::GetInstance();
  74. MS_EXCEPTION_IF_NULL(context_ptr);
  75. // switch memory reuse on or off
  76. context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, partial_memory_);
  77. // print some message about memory reuse to user
  78. if (partial_memory_) {
  79. MS_LOG(WARNING)
  80. << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
  81. "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
  82. } else {
  83. MS_LOG(INFO) << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
  84. "usage for large models.";
  85. }
  86. }
  87. }
  88. void Debugger::Init(const uint32_t device_id, const std::string device_target) {
  89. // access lock for public method
  90. std::lock_guard<std::mutex> a_lock(access_lock_);
  91. // save device_id
  92. MS_LOG(INFO) << "Debugger got device_id: " << device_id;
  93. device_id_ = device_id;
  94. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  95. device_target_ = device_target;
  96. version_ = "1.1.0";
  97. }
  98. void Debugger::EnableDebugger() {
  99. // reset some of the class members
  100. num_step_ = 0;
  101. debugger_enabled_ = false;
  102. partial_memory_ = false;
  103. grpc_client_ = nullptr;
  104. debug_services_ = nullptr;
  105. // see if dump using debugger backend is enabled
  106. bool dump_enabled = CheckDebuggerDumpEnabled();
  107. MS_LOG(INFO) << "dump using debugger backend = " << dump_enabled;
  108. // check if debugger enabled
  109. debugger_enabled_ = CheckDebuggerEnabled();
  110. MS_LOG(INFO) << "debugger_enabled_ = " << debugger_enabled_;
  111. if (!debugger_enabled_ && !dump_enabled) {
  112. MS_LOG(INFO) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
  113. return;
  114. }
  115. // configure grpc host
  116. const char *env_host_str = std::getenv("MS_DEBUGGER_HOST");
  117. std::string host;
  118. if (env_host_str != nullptr) {
  119. std::regex reg_ip(
  120. "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
  121. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  122. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  123. "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
  124. std::smatch smat;
  125. std::string host_str = std::string(env_host_str);
  126. if (std::regex_match(host_str, smat, reg_ip)) {
  127. MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
  128. host = std::string(env_host_str);
  129. } else {
  130. MS_LOG(ERROR) << "Environment variable MS_DEBUGGER_HOST isn't a valid IP address. "
  131. "Please set environment variable MS_DEBUGGER_HOST=x.x.x.x to a valid IP";
  132. debugger_enabled_ = false;
  133. }
  134. } else {
  135. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
  136. host = "localhost";
  137. }
  138. // configure grpc port
  139. const char *env_port_str = std::getenv("MS_DEBUGGER_PORT");
  140. std::string port;
  141. if (env_port_str != nullptr) {
  142. if (CheckPort(env_port_str)) {
  143. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
  144. port = std::string(env_port_str);
  145. } else {
  146. MS_LOG(ERROR) << "Environment variable MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to 65535";
  147. debugger_enabled_ = false;
  148. }
  149. } else {
  150. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
  151. port = "50051";
  152. }
  153. #ifdef ENABLE_D
  154. // set operation overflow info
  155. overflow_bin_path_ = DumpJsonParser::GetInstance().GetOpOverflowBinPath(graph_ptr_->graph_id(), device_id_);
  156. // new overflow dump files will have a timestamp greater than last_overflow_bin_
  157. last_overflow_bin_ = 0;
  158. DIR *d;
  159. d = opendir(overflow_bin_path_.c_str());
  160. if (d != nullptr) {
  161. struct dirent *dir;
  162. while ((dir = readdir(d)) != NULL) {
  163. if (dir->d_type == DT_REG) {
  164. std::string file_path = overflow_bin_path_;
  165. file_path.append(dir->d_name);
  166. std::size_t found = file_path.find_last_of(".");
  167. if (found == std::string::npos) {
  168. continue;
  169. }
  170. std::string overflow_time = file_path.substr(found + 1);
  171. if (stod(overflow_time) <= last_overflow_bin_) {
  172. MS_LOG(INFO) << "Old op overflow bin folder" << file_path;
  173. continue;
  174. }
  175. last_overflow_bin_ = stod(overflow_time);
  176. }
  177. }
  178. MS_LOG(INFO) << "last op overflow bin folder" << last_overflow_bin_;
  179. closedir(d);
  180. }
  181. #endif
  182. // initialize grpc client
  183. if (debugger_enabled_) {
  184. grpc_client_ = std::make_unique<GrpcClient>(host, port);
  185. }
  186. debug_services_ = std::make_unique<DebugServices>();
  187. }
  188. void Debugger::CheckDatasetSinkMode() {
  189. if (CheckDebuggerDumpEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  190. MS_EXCEPTION(NotSupportError)
  191. << "e2e_dump not supported on GPU with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  192. }
  193. if (CheckDebuggerEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  194. MS_EXCEPTION(NotSupportError)
  195. << "Debugger is not supported with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  196. }
  197. }
  198. bool Debugger::CheckDebuggerDumpEnabled() {
  199. // see if dump is enabled
  200. if (device_target_ == kGPUDevice) {
  201. return device::KernelRuntime::DumpDataEnabled();
  202. }
  203. return false;
  204. }
  205. bool Debugger::CheckDebuggerEnabled() {
  206. // get env variables to configure debugger
  207. const char *env_enable_str = std::getenv("ENABLE_MS_DEBUGGER");
  208. if (env_enable_str != nullptr) {
  209. if (std::strcmp(env_enable_str, "1") == 0) {
  210. return true;
  211. }
  212. }
  213. return false;
  214. }
  215. bool Debugger::CheckDebuggerPartialMemoryEnabled() {
  216. const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM");
  217. if (env_partial_mem_str != nullptr) {
  218. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
  219. if (std::strcmp(env_partial_mem_str, "1") == 0) {
  220. return true;
  221. }
  222. }
  223. return false;
  224. }
  225. bool Debugger::DebuggerBackendEnabled() { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
  226. void Debugger::Reset() {
  227. // access lock for public method
  228. std::lock_guard<std::mutex> a_lock(access_lock_);
  229. // reset components
  230. device_id_ = 0;
  231. device_target_ = "";
  232. num_step_ = 0;
  233. debugger_enabled_ = false;
  234. is_dataset_graph_ = false;
  235. partial_memory_ = false;
  236. graph_ptr_ = nullptr;
  237. grpc_client_ = nullptr;
  238. debug_services_ = nullptr;
  239. last_overflow_bin_ = 0;
  240. overflow_bin_path_ = "";
  241. stream_task_to_opname_.clear();
  242. }
  243. void Debugger::PreExecute(const KernelGraphPtr &graph_ptr, uint32_t graph_sum) {
  244. // access lock for public method
  245. std::lock_guard<std::mutex> a_lock(access_lock_);
  246. CheckDatasetSinkMode();
  247. auto graph_id = graph_ptr->graph_id();
  248. // collect rungrap_ids to update step number in multigraph case
  249. if (!rungraph_id_list_.size()) {
  250. rungraph_id_list_.push_back(graph_id);
  251. } else {
  252. if (std::find(rungraph_id_list_.begin(), rungraph_id_list_.end(), graph_id) == rungraph_id_list_.end()) {
  253. rungraph_id_list_.push_back(graph_id);
  254. }
  255. }
  256. // check and save graph_ptr, suspend if graph is new
  257. MS_LOG(INFO) << "total number graph: " << graph_sum;
  258. // multiple graphs
  259. if (graph_sum > 1) {
  260. // there are more than one graphs are not dataset_graph
  261. if (not_dataset_graph_sum_ > 0) {
  262. // only try to enable debugger if they are not all dataset graphs
  263. if (!debugger_enabled_) {
  264. EnableDebugger();
  265. }
  266. if (debugger_enabled_) {
  267. if (graph_proto_list_.size()) {
  268. // only send compiled graphs once.
  269. SendMultiGraphsAndSuspend(graph_proto_list_, graph_sum);
  270. graph_proto_list_.clear();
  271. } else if (graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
  272. // stop only when receive the first sub run graph for each step
  273. CommandLoop();
  274. }
  275. }
  276. }
  277. } else if (graph_proto_list_.size() == 1) {
  278. // In single graph case, reset graph_ptr_ to be nullptr for the initial step
  279. if (num_step_ == 0) {
  280. graph_ptr_ = nullptr;
  281. }
  282. CheckGraphPtr(graph_ptr);
  283. }
  284. }
  285. void Debugger::PostExecute() {
  286. // access lock for public method
  287. std::lock_guard<std::mutex> a_lock(access_lock_);
  288. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  289. return;
  290. }
  291. if (debugger_->DebuggerBackendEnabled()) {
  292. // analyze tensor data and send the watchpoints been hit
  293. if (run_level_ == "node") {
  294. MS_LOG(INFO) << "Debugger is in node level mode ";
  295. return;
  296. }
  297. if (debugger_enabled_ && !is_dataset_graph_) {
  298. if (device_target_ != kGPUDevice) {
  299. num_step_++;
  300. MS_LOG(INFO) << "Debugger suspend at end of step; number of steps executed: " << num_step_;
  301. SendWatchpoints(CheckWatchpoints());
  302. CommandLoop();
  303. } else {
  304. CommandLoop();
  305. }
  306. }
  307. }
  308. }
  309. bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) {
  310. if (debugger_enabled_ && !is_dataset_graph_) {
  311. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  312. // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
  313. if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
  314. return true;
  315. }
  316. }
  317. return false;
  318. }
  319. void Debugger::PostExecuteNode(const CNodePtr &kernel) {
  320. // access lock for public method
  321. std::lock_guard<std::mutex> a_lock(access_lock_);
  322. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  323. return;
  324. }
  325. if (debugger_enabled_ && !is_dataset_graph_) {
  326. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  327. // if kernel is watchpoint,and get hit. suspend.
  328. bool hit_empty_flag = true;
  329. if (is_watchpoint) {
  330. auto hits = CheckWatchpoints(cur_name_, kernel);
  331. if (!hits.empty()) {
  332. SendWatchpoints(hits);
  333. CommandLoop();
  334. hit_empty_flag = false;
  335. }
  336. }
  337. if (hit_empty_flag && run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) {
  338. // if kernel is not watchpoint and is next_to or continue_to node, suspend
  339. CommandLoop();
  340. }
  341. return;
  342. }
  343. }
  344. void Debugger::PostDebugOp() {
  345. // access lock for public method
  346. std::lock_guard<std::mutex> a_lock(access_lock_);
  347. // suspend if debugger is enabled
  348. if (debugger_enabled_ && !is_dataset_graph_) {
  349. MS_LOG(INFO) << "Debugger suspend at debug_op";
  350. CommandLoop();
  351. }
  352. }
  353. void Debugger::SetStreamTaskToOpnameMap(const std::map<std::pair<uint32_t, uint32_t>, std::string> &mapping) {
  354. stream_task_to_opname_ = mapping;
  355. }
  356. void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) {
  357. if (graph_ptr_ != graph_ptr) {
  358. MS_LOG(INFO) << "LoadGraphs Debugger got new graph: " << graph_ptr->graph_id();
  359. // save new graph_ptr
  360. graph_ptr_ = graph_ptr;
  361. CheckDatasetGraph();
  362. if (!is_dataset_graph_) {
  363. // get proto for new graph_ptr
  364. auto graph_proto = GetGraphProto(graph_ptr);
  365. // add new graph proto to graph_proto_list_
  366. graph_proto_list_.push_back(graph_proto);
  367. graph_ptr_list_.push_back(graph_ptr);
  368. not_dataset_graph_sum_++;
  369. }
  370. // reset is_dataset_graph to be false
  371. is_dataset_graph_ = false;
  372. }
  373. }
  374. // In single graph cases, check single graph ptr
  375. void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
  376. if (graph_ptr_ != graph_ptr) {
  377. MS_LOG(INFO) << "CheckGraphPtr Debugger got new graph: " << graph_ptr->graph_id();
  378. // save new graph_ptr
  379. graph_ptr_ = graph_ptr;
  380. if (!is_dataset_graph_) {
  381. // only try to enable debugger if it is not a dataset graph
  382. EnableDebugger();
  383. if (debugger_enabled_) {
  384. LoadParametersAndConst();
  385. // get graph proto and send to Mindinsight
  386. auto graph_proto = graph_proto_list_.front();
  387. SendGraphAndSuspend(graph_proto);
  388. }
  389. }
  390. }
  391. }
  392. void Debugger::CheckDatasetGraph() {
  393. // print parameter node names
  394. const auto &params = graph_ptr_->inputs();
  395. for (const auto &param : params) {
  396. MS_LOG(INFO) << "param: " << param->fullname_with_scope();
  397. }
  398. // check if there is GetNext or InitDataSetQueue node
  399. const auto &nodes = graph_ptr_->execution_order();
  400. for (const auto &node : nodes) {
  401. auto node_name = AnfAlgo::GetCNodeName(node);
  402. MS_LOG(INFO) << "node: " << node->fullname_with_scope();
  403. if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
  404. MS_LOG(INFO) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
  405. << node_name;
  406. is_dataset_graph_ = true;
  407. return;
  408. }
  409. }
  410. is_dataset_graph_ = false;
  411. }
  412. GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
  413. // convert kernel graph to debugger modelproto
  414. ModelProto model = GetDebuggerFuncGraphProto(graph_ptr_);
  415. return model.graph();
  416. }
  417. void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
  418. if (SendMetadata(true)) {
  419. // send graph to Mindinsight server
  420. EventReply reply = grpc_client_->SendGraph(graph_proto);
  421. if (reply.status() != reply.OK) {
  422. MS_LOG(ERROR) << "Error: SendGraph failed";
  423. }
  424. // enter command loop, wait and process commands
  425. CommandLoop();
  426. }
  427. }
  428. bool Debugger::SendMetadata(bool version_check) {
  429. // prepare metadata
  430. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  431. Metadata metadata;
  432. metadata.set_device_name(device_name);
  433. metadata.set_cur_step(num_step_);
  434. metadata.set_backend(device_target_);
  435. metadata.set_cur_node(cur_name_);
  436. metadata.set_training_done(training_done_);
  437. metadata.set_ms_version(version_);
  438. MS_LOG(INFO) << "Is training done?" << training_done_;
  439. // set graph munber to not_dataset_graph_sum_
  440. metadata.set_graph_num(not_dataset_graph_sum_);
  441. EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
  442. bool ret = false;
  443. if (reply_metadata.status() == reply_metadata.OK) {
  444. if (version_check) {
  445. // get type of the command in meta data reply, it should be version matched
  446. DebuggerCommand cmd = GetCommand(reply_metadata);
  447. if (cmd != DebuggerCommand::kVersionMatchedCMD) {
  448. MS_LOG(ERROR) << "MindInsight version is too old, Mindspore version is " << version_;
  449. Exit();
  450. } else {
  451. if (GetMiVersionMatched(reply_metadata)) {
  452. MS_LOG(INFO) << "MindSpore version is " << version_ << " matches MindInsight version.";
  453. ret = true;
  454. } else {
  455. MS_LOG(ERROR) << "MindSpore version " << version_ << ", did not match MindInsight version.";
  456. CommandLoop();
  457. }
  458. }
  459. } else {
  460. // version check is done before so we can just return true here
  461. ret = true;
  462. }
  463. } else {
  464. MS_LOG(ERROR) << "Error: SendMetadata failed";
  465. }
  466. return ret;
  467. }
  468. void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list, uint32_t graph_sum) {
  469. if (!SendMetadata(true)) {
  470. return;
  471. }
  472. // send multiple graphs to mindinght server
  473. // split graph into chunks if one graph is larger than chunk size
  474. std::list<Chunk> chunked_graph_proto_list;
  475. Chunk chunk;
  476. for (auto graph : graph_proto_list) {
  477. std::string str = graph.SerializeAsString();
  478. auto graph_size = graph.ByteSize();
  479. if (graph_size > CHUNK_SIZE) {
  480. auto sub_graph_str = grpc_client_->ChunkString(str, graph_size);
  481. for (unsigned int i = 0; i < sub_graph_str.size(); i++) {
  482. chunk.set_buffer(sub_graph_str[i]);
  483. chunked_graph_proto_list.push_back(chunk);
  484. if (i < sub_graph_str.size() - 1) {
  485. chunk.set_finished(false);
  486. } else {
  487. chunk.set_finished(true);
  488. chunked_graph_proto_list.push_back(chunk);
  489. }
  490. }
  491. } else {
  492. chunk.set_buffer(str);
  493. chunk.set_finished(true);
  494. chunked_graph_proto_list.push_back(chunk);
  495. }
  496. }
  497. EventReply reply = grpc_client_->SendMultiGraphs(chunked_graph_proto_list);
  498. if (reply.status() != reply.OK) {
  499. MS_LOG(ERROR) << "Error: SendGraph failed";
  500. }
  501. // enter command loop, wait and process commands
  502. CommandLoop();
  503. }
  504. void Debugger::CommandLoop() {
  505. // prepare metadata
  506. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  507. Metadata metadata;
  508. metadata.set_device_name(device_name);
  509. metadata.set_cur_step(num_step_);
  510. metadata.set_backend(device_target_);
  511. metadata.set_cur_node(cur_name_);
  512. metadata.set_training_done(training_done_);
  513. // loop exit flag
  514. bool run = false;
  515. int num_wait_fail = 0;
  516. const int max_num_wait_fail = 5;
  517. while (!run) {
  518. // wait for command
  519. EventReply reply = grpc_client_->WaitForCommand(metadata);
  520. if (reply.status() != reply.OK) {
  521. MS_LOG(ERROR) << "Error: WaitForCommand failed";
  522. num_wait_fail++;
  523. if (num_wait_fail > max_num_wait_fail) {
  524. MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session.";
  525. MS_LOG(ERROR) << "Failed to connect to MindInsight debugger server. Please check the config "
  526. "of debugger host and port.";
  527. Exit();
  528. run = true;
  529. } else {
  530. MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
  531. << num_wait_fail << "s";
  532. std::this_thread::sleep_for(std::chrono::milliseconds(1000 * num_wait_fail));
  533. }
  534. continue;
  535. }
  536. // get type of the command in reply
  537. DebuggerCommand cmd = GetCommand(reply);
  538. if (cmd == DebuggerCommand::kUnknownCMD) {
  539. MS_LOG(DEBUG) << "Debug: debugger received unknown command";
  540. continue;
  541. }
  542. MS_LOG(INFO) << "received command: ";
  543. switch (cmd) {
  544. case DebuggerCommand::kUnknownCMD:
  545. MS_LOG(INFO) << "UnknownCMD";
  546. break;
  547. case DebuggerCommand::kExitCMD:
  548. MS_LOG(INFO) << "ExitCMD";
  549. Exit();
  550. // Used for debugger termination
  551. run = true;
  552. break;
  553. case DebuggerCommand::kRunCMD:
  554. MS_LOG(INFO) << "RunCMD";
  555. if (GetRunLevel(reply) == "recheck") {
  556. MS_LOG(INFO) << "rechecking all watchpoints";
  557. SendWatchpoints(CheckWatchpoints());
  558. } else {
  559. // no longer the initial suspension.
  560. initial_suspend_ = false;
  561. // print run cmd content
  562. // get run_level and node_name
  563. run_level_ = GetRunLevel(reply);
  564. node_name_ = GetNodeName(reply);
  565. MS_LOG(INFO) << "run_level: " << run_level_;
  566. MS_LOG(INFO) << "node_name_: " << node_name_;
  567. // exit loop
  568. run = true;
  569. }
  570. break;
  571. case DebuggerCommand::kSetCMD:
  572. MS_LOG(INFO) << "SetCMD";
  573. MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
  574. MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
  575. if (GetWatchpointDelete(reply)) {
  576. MS_LOG(INFO) << "Deleting watchpoint";
  577. RemoveWatchpoint(GetWatchpointID(reply));
  578. } else {
  579. MS_LOG(INFO) << "Setting watchpoint";
  580. MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
  581. ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
  582. for (const auto &node : recieved_nodes) {
  583. MS_LOG(INFO) << "node name: " << node.node_name();
  584. MS_LOG(INFO) << "node type: " << node.node_type();
  585. }
  586. ProtoVector<WatchCondition_Parameter> parameters = GetParameters(reply);
  587. for (const auto &parameter : parameters) {
  588. MS_LOG(INFO) << "parameter name: " << parameter.name();
  589. MS_LOG(INFO) << "parameter is disabled: " << parameter.disabled();
  590. MS_LOG(INFO) << "parameter value: " << parameter.value();
  591. }
  592. SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply), GetParameters(reply));
  593. }
  594. break;
  595. case DebuggerCommand::kViewCMD: {
  596. MS_LOG(INFO) << "ViewCMD";
  597. // print view cmd content
  598. ProtoVector<TensorProto> received_tensors = GetTensors(reply);
  599. for (auto received_tensor : received_tensors) {
  600. MS_LOG(INFO) << "tensor node name: " << received_tensor.node_name();
  601. MS_LOG(INFO) << "tensor slot: " << received_tensor.slot();
  602. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << received_tensor.finished() << std::noboolalpha;
  603. MS_LOG(INFO) << "tensor iter: " << received_tensor.iter();
  604. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << received_tensor.truncate() << std::noboolalpha;
  605. }
  606. MS_LOG(INFO) << "Sending tensors";
  607. std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
  608. // print view cmd reply
  609. for (auto tensor : tensors) {
  610. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  611. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  612. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  613. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  614. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  615. MS_LOG(INFO) << "tensor dims: ";
  616. for (auto dim : tensor.dims()) {
  617. MS_LOG(INFO) << dim << ",";
  618. }
  619. MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
  620. }
  621. EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
  622. if (send_tensors_reply.status() != send_tensors_reply.OK) {
  623. MS_LOG(ERROR) << "Error: SendTensors failed";
  624. }
  625. } break;
  626. case DebuggerCommand::kVersionMatchedCMD:
  627. MS_LOG(ERROR) << "Received unexpected Version Matched CMD from Mindinsight.";
  628. Exit();
  629. break;
  630. default:
  631. MS_LOG(ERROR) << "Received unknown CMD from Mindinsight";
  632. Exit();
  633. break;
  634. }
  635. }
  636. }
  637. void AddTensorProtoInfo(TensorProto *tensor_item, TensorProto tensor) {
  638. tensor_item->set_node_name(tensor.node_name());
  639. tensor_item->set_slot(tensor.slot());
  640. tensor_item->set_iter(tensor.iter());
  641. tensor_item->set_truncate(tensor.truncate());
  642. tensor_item->clear_tensor_content();
  643. tensor_item->clear_data_type();
  644. tensor_item->clear_dims();
  645. }
  646. void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id,
  647. const ProtoVector<WatchCondition_Parameter> &parameters) {
  648. std::vector<std::tuple<std::string, bool>> check_node_list;
  649. std::vector<DebugServices::parameter_t> parameter_list;
  650. std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
  651. [](const WatchNode &node) -> std::tuple<std::string, bool> {
  652. return make_tuple(node.node_name(), node.node_type() == "scope");
  653. });
  654. std::transform(
  655. parameters.begin(), parameters.end(), std::back_inserter(parameter_list),
  656. [](const WatchCondition_Parameter &parameter) -> DebugServices::parameter_t {
  657. return DebugServices::parameter_t{parameter.name(), parameter.disabled(), parameter.value(), parameter.hit()};
  658. });
  659. debug_services_->AddWatchpoint(id, condition.condition(), condition.value(), check_node_list, parameter_list);
  660. if (initial_suspend_ &&
  661. static_cast<DebugServices::CONDITION_TYPE>(condition.condition()) == DebugServices::CONDITION_TYPE::INIT)
  662. SendWatchpoints(CheckWatchpoints());
  663. }
  664. void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
  665. std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
  666. std::vector<std::string> name;
  667. std::vector<std::string> ret_name;
  668. std::vector<char *> data_ptr;
  669. std::vector<unsigned int> data_size;
  670. std::vector<TypePtr> dtype;
  671. std::vector<std::vector<int64_t>> shape;
  672. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  673. // ret_name will contain tensor names that are found in TensorLoader
  674. // items in ret_name will be in the same order with tensors if found
  675. debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
  676. std::list<TensorProto> tensor_list;
  677. unsigned int result_index = 0;
  678. for (auto tensor : tensors) {
  679. int size_iter = 0;
  680. if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
  681. TensorProto tensor_item;
  682. tensor_item.set_finished(true);
  683. AddTensorProtoInfo(&tensor_item, tensor);
  684. tensor_list.push_back(tensor_item);
  685. continue;
  686. }
  687. int tensor_size = data_size[result_index];
  688. while (size_iter < tensor_size) {
  689. int chunk_size = CHUNK_SIZE;
  690. TensorProto tensor_item;
  691. tensor_item.set_finished(false);
  692. if (tensor_size - size_iter <= CHUNK_SIZE) {
  693. chunk_size = tensor_size - size_iter;
  694. tensor_item.set_finished(true);
  695. }
  696. AddTensorProtoInfo(&tensor_item, tensor);
  697. // return empty tensor if didn't find the requested tensor
  698. tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
  699. tensor_item.set_data_type(GetDebuggerNumberDataType(dtype[result_index]));
  700. for (auto &elem : shape[result_index]) {
  701. tensor_item.add_dims(elem);
  702. }
  703. // add tensor to result list and increment result_index to check next item in ret_name
  704. tensor_list.push_back(tensor_item);
  705. size_iter += CHUNK_SIZE;
  706. }
  707. result_index++;
  708. }
  709. return tensor_list;
  710. }
  711. void Debugger::Exit() {
  712. // clear resource before exit
  713. // For node level, debugger has to exit itself because main thread can only exit in step bundary;
  714. // For step level, debugger will notify main thread to exit;
  715. if (run_level_ == "node") {
  716. pipeline::ClearResAtexit();
  717. exit(1);
  718. } else if (run_level_ == "step" || device_target_ == kAscendDevice) {
  719. // Notify main thread to terminate
  720. pipeline::ExecutorPy::DebugTerminate(true);
  721. } else {
  722. pipeline::ClearResAtexit();
  723. exit(1);
  724. }
  725. }
  726. std::list<WatchpointHit> Debugger::CheckWatchpoints(const std::string &watchnode, const CNodePtr &kernel) {
  727. std::vector<std::string> name;
  728. std::vector<std::string> slot;
  729. std::vector<int> condition;
  730. std::vector<unsigned int> watchpoint_id;
  731. std::vector<std::string> overflow_ops;
  732. std::vector<std::vector<DebugServices::parameter_t>> parameters;
  733. std::vector<int32_t> error_codes;
  734. #ifdef ENABLE_D
  735. overflow_ops = CheckOpOverflow();
  736. #endif
  737. auto tensor_loader = debug_services_->tensor_loader();
  738. std::vector<std::shared_ptr<TensorData>> tensor_list;
  739. if (watchnode.empty()) {
  740. tensor_list = tensor_loader->GetTensor();
  741. } else {
  742. tensor_list = tensor_loader->GetNodeTensorMap(watchnode);
  743. debug_services_->AddWeightsBiasInputs(&tensor_list, kernel);
  744. }
  745. debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, &parameters, &error_codes, overflow_ops,
  746. tensor_list, initial_suspend_);
  747. std::list<WatchpointHit> hits;
  748. for (unsigned int i = 0; i < name.size(); i++) {
  749. WatchpointHit hit;
  750. std::vector<DebugServices::parameter_t> &parameter = parameters[i];
  751. hit.set_id(watchpoint_id[i]);
  752. hit.set_error_code(error_codes[i]);
  753. // here TensorProto act as a tensor indicator, not sending tensor content
  754. TensorProto *tensor_item = hit.mutable_tensor();
  755. tensor_item->set_node_name(name[i]);
  756. tensor_item->set_slot(slot[i]);
  757. tensor_item->set_finished(true);
  758. WatchCondition *condition_item = hit.mutable_watch_condition();
  759. condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
  760. for (const auto &p : parameter) {
  761. auto x = condition_item->mutable_params()->Add();
  762. x->set_name(p.name);
  763. x->set_disabled(p.disabled);
  764. x->set_value(p.value);
  765. x->set_hit(p.hit);
  766. x->set_actual_value(p.actual_value);
  767. }
  768. hits.push_back(hit);
  769. }
  770. return hits;
  771. }
  772. void Debugger::SendWatchpoints(const std::list<WatchpointHit> &points) {
  773. // send info about watchpoint
  774. if (!points.empty()) {
  775. EventReply reply = grpc_client_->SendWatchpointHits(points);
  776. if (reply.status() != reply.OK) {
  777. MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
  778. }
  779. }
  780. }
  781. DebugServices *Debugger::debug_services() const { return debug_services_.get(); }
  782. bool Debugger::debugger_enabled() const { return debugger_enabled_; }
  783. DebuggerCommand GetCommand(const EventReply &reply) {
  784. DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
  785. switch (reply.cmd_case()) {
  786. case debugger::EventReply::CmdCase::kExit:
  787. cmd = DebuggerCommand::kExitCMD;
  788. break;
  789. case debugger::EventReply::CmdCase::kRunCmd:
  790. cmd = DebuggerCommand::kRunCMD;
  791. break;
  792. case debugger::EventReply::CmdCase::kSetCmd:
  793. cmd = DebuggerCommand::kSetCMD;
  794. break;
  795. case debugger::EventReply::CmdCase::kViewCmd:
  796. cmd = DebuggerCommand::kViewCMD;
  797. break;
  798. case debugger::EventReply::CmdCase::kVersionMatched:
  799. cmd = DebuggerCommand::kVersionMatchedCMD;
  800. break;
  801. default:
  802. MS_LOG(DEBUG) << "Debug: UnknownCMD";
  803. break;
  804. }
  805. return cmd;
  806. }
  807. ProtoVector<WatchCondition_Parameter> GetParameters(const EventReply &reply) {
  808. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  809. MS_LOG(ERROR) << "Error: Can not get Parameters from command. Returning default value: ProtoVector<Parameter>().";
  810. return ProtoVector<WatchCondition_Parameter>();
  811. }
  812. return reply.set_cmd().watch_condition().params();
  813. }
  814. ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
  815. if (!reply.has_set_cmd()) {
  816. MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
  817. return ProtoVector<WatchNode>();
  818. }
  819. return reply.set_cmd().watch_nodes();
  820. }
  821. std::string GetRunLevel(const EventReply &reply) {
  822. if (!reply.has_run_cmd()) {
  823. MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
  824. "";
  825. return "";
  826. }
  827. return reply.run_cmd().run_level();
  828. }
  829. std::string GetNodeName(const EventReply &reply) {
  830. if (!reply.has_run_cmd()) {
  831. MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
  832. "";
  833. return "";
  834. }
  835. return reply.run_cmd().node_name();
  836. }
  837. WatchCondition GetWatchcondition(const EventReply &reply) {
  838. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  839. MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
  840. return WatchCondition();
  841. }
  842. return reply.set_cmd().watch_condition();
  843. }
  844. int32_t GetWatchpointID(const EventReply &reply) {
  845. if (!reply.has_set_cmd()) {
  846. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
  847. return 0;
  848. }
  849. return reply.set_cmd().id();
  850. }
  851. bool GetWatchpointDelete(const EventReply &reply) {
  852. if (!reply.has_set_cmd()) {
  853. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
  854. return false;
  855. }
  856. return reply.set_cmd().delete_();
  857. }
  858. ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
  859. if (!reply.has_view_cmd()) {
  860. MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
  861. return ProtoVector<TensorProto>();
  862. }
  863. return reply.view_cmd().tensors();
  864. }
  865. std::string GetTensorFullName(const TensorProto &tensor) {
  866. string node_name = tensor.node_name();
  867. if (tensor.truncate()) {
  868. // scopes in node name are seperated by '/'
  869. // use the name without scope if truncate is true
  870. std::size_t found = node_name.find_last_of("/");
  871. node_name = node_name.substr(found + 1);
  872. }
  873. return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
  874. }
  875. bool GetMiVersionMatched(const EventReply &reply) { return reply.version_matched(); }
  876. bool Debugger::partial_memory() { return partial_memory_; }
  877. void Debugger::SetCurNode(std::string cur_name) {
  878. // access lock for public method
  879. std::lock_guard<std::mutex> a_lock(access_lock_);
  880. cur_name_ = cur_name;
  881. }
  882. std::string Debugger::run_level() const { return run_level_; }
  883. void Debugger::SetStepNum(int32_t cur_num_step) {
  884. // access lock for public method
  885. std::lock_guard<std::mutex> a_lock(access_lock_);
  886. num_step_ = cur_num_step;
  887. }
  888. int32_t Debugger::step_num() const { return num_step_; }
  889. uint64_t BytestoInt64(const std::vector<char> &buffer) {
  890. uint64_t ret;
  891. ret = ((uint64_t)buffer[7] << 56) | ((uint64_t)buffer[6] << 48) | ((uint64_t)buffer[5] << 40) |
  892. ((uint64_t)buffer[4] << 32) | ((uint64_t)buffer[3] << 24) | ((uint64_t)buffer[2] << 16) |
  893. ((uint64_t)buffer[1] << 8) | ((uint64_t)buffer[0]);
  894. return ret;
  895. }
  896. #define BUF_SIZ 256
  897. std::vector<std::string> Debugger::CheckOpOverflow() {
  898. std::vector<double> bin_list;
  899. std::vector<std::string> op_names;
  900. DIR *d;
  901. struct dirent *dir = nullptr;
  902. d = opendir(overflow_bin_path_.c_str());
  903. if (d != nullptr) {
  904. while ((dir = readdir(d)) != NULL) {
  905. if (dir->d_type == DT_REG) {
  906. std::string file_path = overflow_bin_path_;
  907. file_path.append(dir->d_name);
  908. std::string file_name = dir->d_name;
  909. std::size_t found = file_name.find_last_of(".");
  910. if (found == std::string::npos) {
  911. continue;
  912. }
  913. std::string overflow_time = file_name.substr(found + 1);
  914. if (stod(overflow_time) <= last_overflow_bin_) {
  915. MS_LOG(INFO) << "File already processed " << file_name;
  916. continue;
  917. }
  918. bin_list.push_back(stod(overflow_time));
  919. std::fstream infile;
  920. infile.open(file_path.c_str(), std::ios::binary | std::ios::in);
  921. if (!infile.is_open()) {
  922. MS_LOG(ERROR) << "Failed to open overflow bin file " << file_name;
  923. continue;
  924. }
  925. infile.seekg(313, std::ios::beg);
  926. std::vector<char> buffer;
  927. buffer.resize(BUF_SIZ);
  928. infile.read(buffer.data(), BUF_SIZ);
  929. uint64_t stream_id = BytestoInt64(std::vector<char>(buffer.begin() + 8, buffer.end()));
  930. uint64_t task_id = BytestoInt64(std::vector<char>(buffer.begin() + 16, buffer.end()));
  931. MS_LOG(INFO) << "Overflow stream_id " << stream_id << ", task_id " << task_id << ".";
  932. auto op = debugger_->stream_task_to_opname_.find(std::make_pair(stream_id, task_id));
  933. if (op != debugger_->stream_task_to_opname_.end()) {
  934. MS_LOG(ERROR) << "Overflow detected on node " << op->second << std::endl;
  935. op_names.push_back(op->second);
  936. } else {
  937. MS_LOG(INFO) << "No overflow is detected " << std::endl;
  938. }
  939. infile.close();
  940. }
  941. }
  942. } else {
  943. MS_LOG(INFO) << "OverFlow bin directory does not exist!";
  944. }
  945. closedir(d);
  946. if (op_names.size()) {
  947. MS_LOG(ERROR) << "These operation overflows are detected " << op_names;
  948. }
  949. for (auto &i : bin_list) {
  950. if (i > last_overflow_bin_) {
  951. last_overflow_bin_ = i;
  952. }
  953. }
  954. return op_names;
  955. }
  956. void Debugger::SetTrainingDone(bool training_done) { training_done_ = training_done; }
  957. bool Debugger::CheckPort(const char *port) {
  958. char *p = const_cast<char *>(port);
  959. int num = 0;
  960. if (*p == '0' && *(p + 1) != '\0') return false;
  961. while (*p != '\0') {
  962. if (*p < '0' || *p > '9') return false;
  963. num = num * 10 + (*p) - '0';
  964. if (num < 1 || num > 65535) return false;
  965. p++;
  966. }
  967. return true;
  968. }
  969. uint32_t Debugger::GetFirstRunGraphId() { return rungraph_id_list_.front(); }
  970. void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
  971. MS_EXCEPTION_IF_NULL(anf_node);
  972. if (!anf_node->isa<Parameter>() && !anf_node->isa<ValueNode>()) {
  973. return;
  974. }
  975. bool keep_prev;
  976. if (anf_node->isa<Parameter>()) {
  977. keep_prev = true;
  978. } else {
  979. keep_prev = false;
  980. }
  981. // for parameters and value nodes, set its execution order to be 0;
  982. int exec_order = 0;
  983. std::string node_name = anf_node->fullname_with_scope();
  984. E2eDumpUtil::GetFileKernelName(NOT_NULL(&node_name));
  985. // check if output adde exists, if not, return;
  986. if (!AnfAlgo::OutputAddrExist(anf_node, output_index)) {
  987. return;
  988. }
  989. auto addr = AnfAlgo::GetOutputAddr(anf_node, output_index);
  990. MS_EXCEPTION_IF_NULL(addr);
  991. auto type = AnfAlgo::GetOutputInferDataType(anf_node, output_index);
  992. auto format = kOpFormat_DEFAULT;
  993. string tensor_name = node_name + ':' + "0";
  994. ShapeVector int_shapes;
  995. auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index);
  996. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  997. [](size_t inner_item) { return SizeToInt(inner_item); });
  998. bool ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, keep_prev);
  999. if (!ret) {
  1000. MS_LOG(ERROR) << "LoadMemToHost:"
  1001. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1002. }
  1003. }
  1004. void Debugger::LoadParametersAndConst() {
  1005. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  1006. if (!(num_step_ == 0 || device_target_ == kAscendDevice ||
  1007. (device_target_ == kGPUDevice && device::KernelRuntime::DumpDataEnabledIteration())))
  1008. return;
  1009. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1010. // load parameters
  1011. MS_LOG(INFO) << "Start to load Parameters!";
  1012. const auto &parameters = graph_ptr_->inputs();
  1013. for (auto &item : parameters) {
  1014. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
  1015. }
  1016. // load value nodes
  1017. // get all constant avlues from the graph
  1018. MS_LOG(INFO) << "Start to load value nodes!";
  1019. const auto value_nodes = graph_ptr_->graph_value_nodes();
  1020. for (auto &item : value_nodes) {
  1021. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
  1022. }
  1023. }
  1024. void Debugger::LoadGraphOutputs() {
  1025. if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
  1026. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1027. const auto &apply_kernels = graph_ptr_->execution_order();
  1028. // for kernels, execution order starts from 1
  1029. int exec_order = 1;
  1030. for (const auto &node : apply_kernels) {
  1031. MS_EXCEPTION_IF_NULL(node);
  1032. auto node_name = AnfAlgo::GetCNodeName(node);
  1033. std::string kernel_name = node->fullname_with_scope();
  1034. auto output_size = AnfAlgo::GetOutputTensorNum(node);
  1035. if (partial_memory_) {
  1036. if (!debug_services_->IsWatchPoint(kernel_name, node)) {
  1037. continue;
  1038. }
  1039. }
  1040. for (size_t j = 0; j < output_size; ++j) {
  1041. auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
  1042. MS_EXCEPTION_IF_NULL(kernel_info);
  1043. auto addr_test = kernel_info->GetOutputAddr(j);
  1044. if (addr_test == nullptr) {
  1045. MS_LOG(INFO) << "Cannot find output addr for slot " << j << " for " << kernel_name;
  1046. continue;
  1047. }
  1048. auto addr = AnfAlgo::GetOutputAddr(node, j);
  1049. MS_EXCEPTION_IF_NULL(addr);
  1050. auto type = AnfAlgo::GetOutputInferDataType(node, j);
  1051. auto format = kOpFormat_DEFAULT;
  1052. string tensor_name = kernel_name + ':' + std::to_string(j);
  1053. ShapeVector int_shapes;
  1054. auto shape = AnfAlgo::GetOutputDeviceShape(node, j);
  1055. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  1056. [](size_t inner_item) { return SizeToInt(inner_item); });
  1057. auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false);
  1058. if (!ret) {
  1059. MS_LOG(ERROR) << "LoadMemToHost:"
  1060. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1061. }
  1062. }
  1063. exec_order = exec_order + 1;
  1064. }
  1065. }
  1066. void Debugger::UpdateStepNum(const session::KernelGraph *graph) {
  1067. // update step number if we are processing the first graph (to support multigraph)
  1068. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()) &&
  1069. (graph->graph_id() == debugger_->GetFirstRunGraphId())) {
  1070. // access lock for public method
  1071. std::lock_guard<std::mutex> a_lock(access_lock_);
  1072. ++num_step_;
  1073. }
  1074. }
  1075. void Debugger::ClearCurrentData() {
  1076. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()))
  1077. debug_services_->tensor_loader()->EmptyCurrentTensor();
  1078. }
  1079. } // namespace mindspore