You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger.cc 44 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <stdio.h>
  18. #include <fstream>
  19. #include <tuple>
  20. #include <vector>
  21. #include <algorithm>
  22. #include <iostream>
  23. #include <cstring>
  24. #include <utility>
  25. #include <map>
  26. #include <regex>
  27. #include "debug/debugger/debugger.h"
  28. #include "debug/data_dump/dump_json_parser.h"
  29. #include "pipeline/jit/pipeline.h"
  30. #include "backend/session/anf_runtime_algorithm.h"
  31. #include "runtime/device/kernel_runtime_manager.h"
  32. #include "runtime/device/kernel_runtime.h"
  33. #include "debug/data_dump/e2e_dump_util.h"
  34. #include "utils/config_manager.h"
  35. using debugger::Chunk;
  36. using debugger::EventReply;
  37. using debugger::GraphProto;
  38. using debugger::ModelProto;
  39. using debugger::TensorProto;
  40. using debugger::WatchCondition;
  41. using debugger::WatchCondition_Condition_inf;
  42. using debugger::WatchCondition_Condition_nan;
  43. using debugger::WatchCondition_Parameter;
  44. using debugger::WatchNode;
  45. using debugger::WatchpointHit;
  46. #define CHUNK_SIZE 1024 * 1024 * 3
  47. namespace mindspore {
  48. DebuggerPtr Debugger::debugger_ = nullptr;
  49. std::mutex Debugger::instance_lock_;
  50. static const size_t PARAMETER_OUTPUT_INDEX = 0;
  51. static const size_t VALUE_NODE_OUTPUT_INDEX = 0;
  52. Debugger::Debugger()
  53. : grpc_client_(nullptr),
  54. debug_services_(nullptr),
  55. device_id_(0),
  56. device_target_(""),
  57. num_step_(0),
  58. debugger_enabled_(false),
  59. run_level_(""),
  60. node_name_(""),
  61. cur_name_(""),
  62. training_done_(false),
  63. is_dataset_graph_(false),
  64. partial_memory_(false),
  65. last_overflow_bin_(0),
  66. initial_suspend_(true),
  67. not_dataset_graph_sum_(0),
  68. version_("") {
  69. if (CheckDebuggerEnabled()) {
  70. // configure partial memory reuse
  71. partial_memory_ = CheckDebuggerPartialMemoryEnabled();
  72. auto context_ptr = MsContext::GetInstance();
  73. MS_EXCEPTION_IF_NULL(context_ptr);
  74. // switch memory reuse on or off
  75. context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, partial_memory_);
  76. // print some message about memory reuse to user
  77. if (partial_memory_) {
  78. MS_LOG(WARNING)
  79. << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
  80. "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
  81. } else {
  82. MS_LOG(INFO) << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
  83. "usage for large models.";
  84. }
  85. }
  86. }
  87. void Debugger::Init(const uint32_t device_id, const std::string device_target) {
  88. // access lock for public method
  89. std::lock_guard<std::mutex> a_lock(access_lock_);
  90. // save device_id
  91. MS_LOG(INFO) << "Debugger got device_id: " << device_id;
  92. device_id_ = device_id;
  93. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  94. device_target_ = device_target;
  95. version_ = "1.1.0";
  96. }
  97. void Debugger::EnableDebugger() {
  98. // reset some of the class members
  99. num_step_ = 0;
  100. debugger_enabled_ = false;
  101. partial_memory_ = false;
  102. grpc_client_ = nullptr;
  103. debug_services_ = nullptr;
  104. // see if dump using debugger backend is enabled
  105. bool dump_enabled = CheckDebuggerDumpEnabled();
  106. MS_LOG(INFO) << "dump using debugger backend = " << dump_enabled;
  107. // check if debugger enabled
  108. debugger_enabled_ = CheckDebuggerEnabled();
  109. MS_LOG(INFO) << "debugger_enabled_ = " << debugger_enabled_;
  110. if (!debugger_enabled_ && !dump_enabled) {
  111. MS_LOG(INFO) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
  112. return;
  113. }
  114. if (debugger_enabled_) {
  115. // configure grpc host
  116. const char *env_host_str = std::getenv("MS_DEBUGGER_HOST");
  117. std::string host;
  118. if (env_host_str != nullptr) {
  119. if (CheckIp(env_host_str)) {
  120. MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
  121. host = std::string(env_host_str);
  122. } else {
  123. debugger_enabled_ = false;
  124. MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_HOST isn't a valid IP address. "
  125. "Please set environment variable MS_DEBUGGER_HOST=x.x.x.x to a valid IP";
  126. }
  127. } else {
  128. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
  129. host = "localhost";
  130. }
  131. // configure grpc port
  132. const char *env_port_str = std::getenv("MS_DEBUGGER_PORT");
  133. std::string port;
  134. if (env_port_str != nullptr) {
  135. if (CheckPort(env_port_str)) {
  136. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
  137. port = std::string(env_port_str);
  138. } else {
  139. debugger_enabled_ = false;
  140. MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to "
  141. "65535";
  142. }
  143. } else {
  144. port = "50051";
  145. if (!CheckPort(port.c_str())) {
  146. MS_EXCEPTION(ValueError) << "Default MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to 65535";
  147. }
  148. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
  149. }
  150. // initialize grpc client
  151. grpc_client_ = std::make_unique<GrpcClient>(host, port);
  152. }
  153. debug_services_ = std::make_unique<DebugServices>();
  154. }
  155. void Debugger::SetOpOverflowBinPath(uint32_t graph_id) {
  156. #ifdef ENABLE_D
  157. // set operation overflow info
  158. overflow_bin_path_.insert(std::pair<uint32_t, std::string>(
  159. graph_id, DumpJsonParser::GetInstance().GetOpOverflowBinPath(graph_id, device_id_)));
  160. // new overflow dump files will have a timestamp greater than last_overflow_bin_
  161. auto overflow_bin_path = overflow_bin_path_.find(graph_id)->second;
  162. DIR *d;
  163. d = opendir(overflow_bin_path.c_str());
  164. if (d != nullptr) {
  165. struct dirent *dir;
  166. while ((dir = readdir(d)) != NULL) {
  167. if (dir->d_type == DT_REG) {
  168. std::string file_path = overflow_bin_path;
  169. file_path.append(dir->d_name);
  170. std::size_t found = file_path.find_last_of(".");
  171. if (found == std::string::npos) {
  172. continue;
  173. }
  174. std::string overflow_time = file_path.substr(found + 1);
  175. if (stod(overflow_time) <= last_overflow_bin_) {
  176. MS_LOG(INFO) << "Old op overflow bin folder" << file_path;
  177. continue;
  178. }
  179. last_overflow_bin_ = stod(overflow_time);
  180. }
  181. }
  182. MS_LOG(INFO) << "last op overflow bin folder" << last_overflow_bin_;
  183. closedir(d);
  184. }
  185. #endif
  186. }
  187. void Debugger::CheckDatasetSinkMode() {
  188. if (CheckDebuggerDumpEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  189. MS_EXCEPTION(NotSupportError)
  190. << "e2e_dump not supported on GPU with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  191. }
  192. if (CheckDebuggerEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  193. MS_EXCEPTION(NotSupportError)
  194. << "Debugger is not supported with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  195. }
  196. }
  197. bool Debugger::CheckDebuggerDumpEnabled() {
  198. // see if dump is enabled
  199. if (device_target_ == kGPUDevice) {
  200. return device::KernelRuntime::DumpDataEnabled();
  201. }
  202. return false;
  203. }
  204. bool Debugger::CheckDebuggerEnabled() {
  205. // get env variables to configure debugger
  206. const char *env_enable_str = std::getenv("ENABLE_MS_DEBUGGER");
  207. if (env_enable_str != nullptr) {
  208. if (std::strcmp(env_enable_str, "1") == 0) {
  209. return true;
  210. }
  211. }
  212. return false;
  213. }
  214. bool Debugger::CheckDebuggerPartialMemoryEnabled() {
  215. const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM");
  216. if (env_partial_mem_str != nullptr) {
  217. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
  218. if (std::strcmp(env_partial_mem_str, "1") == 0) {
  219. return true;
  220. }
  221. }
  222. return false;
  223. }
  224. bool Debugger::DebuggerBackendEnabled() { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
  225. void Debugger::Reset() {
  226. // access lock for public method
  227. std::lock_guard<std::mutex> a_lock(access_lock_);
  228. // reset components
  229. device_id_ = 0;
  230. device_target_ = "";
  231. num_step_ = 0;
  232. debugger_enabled_ = false;
  233. is_dataset_graph_ = false;
  234. partial_memory_ = false;
  235. graph_ptr_ = nullptr;
  236. grpc_client_ = nullptr;
  237. debug_services_ = nullptr;
  238. last_overflow_bin_ = 0;
  239. overflow_bin_path_.clear();
  240. stream_task_to_opname_.clear();
  241. }
  242. void Debugger::PreExecute(const KernelGraphPtr &graph_ptr, uint32_t graph_sum) {
  243. // access lock for public method
  244. std::lock_guard<std::mutex> a_lock(access_lock_);
  245. CheckDatasetSinkMode();
  246. auto graph_id = graph_ptr->graph_id();
  247. // collect rungrap_ids to update step number in multigraph case
  248. if (!rungraph_id_list_.size()) {
  249. rungraph_id_list_.push_back(graph_id);
  250. } else {
  251. if (std::find(rungraph_id_list_.begin(), rungraph_id_list_.end(), graph_id) == rungraph_id_list_.end()) {
  252. rungraph_id_list_.push_back(graph_id);
  253. }
  254. }
  255. // check and save graph_ptr, suspend if graph is new
  256. MS_LOG(INFO) << "total number graph: " << graph_sum;
  257. // multiple graphs
  258. if (graph_sum > 1) {
  259. // there are more than one graphs are not dataset_graph
  260. if (not_dataset_graph_sum_ > 0) {
  261. // only try to enable debugger if they are not all dataset graphs
  262. if (!debugger_enabled_) {
  263. EnableDebugger();
  264. }
  265. if (debugger_enabled_) {
  266. if (graph_proto_list_.size()) {
  267. // only send compiled graphs once.
  268. SendMultiGraphsAndSuspend(graph_proto_list_, graph_sum);
  269. graph_proto_list_.clear();
  270. } else if (graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
  271. // stop only when receive the first sub run graph for each step
  272. CommandLoop();
  273. }
  274. }
  275. }
  276. } else if (graph_proto_list_.size() == 1) {
  277. // In single graph case, reset graph_ptr_ to be nullptr for the initial step
  278. if (num_step_ == 0) {
  279. graph_ptr_ = nullptr;
  280. }
  281. CheckGraphPtr(graph_ptr);
  282. }
  283. }
  284. void Debugger::PostExecute() {
  285. // access lock for public method
  286. std::lock_guard<std::mutex> a_lock(access_lock_);
  287. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  288. return;
  289. }
  290. if (debugger_->DebuggerBackendEnabled()) {
  291. // analyze tensor data and send the watchpoints been hit
  292. if (debugger_enabled_ && !is_dataset_graph_) {
  293. if (device_target_ != kGPUDevice) {
  294. num_step_++;
  295. }
  296. MS_LOG(INFO) << "Debugger suspend at end of step; number of steps executed: " << num_step_;
  297. SendWatchpoints(CheckWatchpoints());
  298. CommandLoop();
  299. }
  300. // Only keep parameters in the current map
  301. debug_services_->ResetLoadedTensors();
  302. }
  303. }
  304. bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) {
  305. if (debugger_enabled_ && !is_dataset_graph_) {
  306. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  307. // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
  308. if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
  309. return true;
  310. }
  311. }
  312. return false;
  313. }
  314. void Debugger::PostExecuteNode(const CNodePtr &kernel) {
  315. // access lock for public method
  316. std::lock_guard<std::mutex> a_lock(access_lock_);
  317. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  318. return;
  319. }
  320. if (debugger_enabled_ && !is_dataset_graph_) {
  321. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  322. // if kernel is watchpoint,and get hit. suspend.
  323. bool hit_empty_flag = true;
  324. if (is_watchpoint) {
  325. auto hits = CheckWatchpoints(cur_name_, kernel);
  326. if (!hits.empty()) {
  327. SendWatchpoints(hits);
  328. CommandLoop();
  329. hit_empty_flag = false;
  330. }
  331. }
  332. if (hit_empty_flag && run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) {
  333. // if kernel is not watchpoint and is next_to or continue_to node, suspend
  334. CommandLoop();
  335. }
  336. return;
  337. }
  338. }
  339. void Debugger::PostDebugOp() {
  340. // access lock for public method
  341. std::lock_guard<std::mutex> a_lock(access_lock_);
  342. // suspend if debugger is enabled
  343. if (debugger_enabled_ && !is_dataset_graph_) {
  344. MS_LOG(INFO) << "Debugger suspend at debug_op";
  345. CommandLoop();
  346. }
  347. }
  348. void Debugger::SetStreamTaskToOpnameMap(const std::map<std::pair<uint32_t, uint32_t>, std::string> &mapping) {
  349. stream_task_to_opname_ = mapping;
  350. }
  351. void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) {
  352. if (graph_ptr_ != graph_ptr) {
  353. MS_LOG(INFO) << "LoadGraphs Debugger got new graph: " << graph_ptr->graph_id();
  354. // save new graph_ptr
  355. graph_ptr_ = graph_ptr;
  356. CheckDatasetGraph();
  357. if (!is_dataset_graph_) {
  358. // get proto for new graph_ptr
  359. auto graph_proto = GetGraphProto(graph_ptr);
  360. // add new graph proto to graph_proto_list_
  361. graph_proto_list_.push_back(graph_proto);
  362. graph_ptr_list_.push_back(graph_ptr);
  363. #ifdef ENABLE_D
  364. SetOpOverflowBinPath(graph_ptr->graph_id());
  365. #endif
  366. not_dataset_graph_sum_++;
  367. }
  368. // reset is_dataset_graph to be false
  369. is_dataset_graph_ = false;
  370. }
  371. }
  372. // In single graph cases, check single graph ptr
  373. void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
  374. if (graph_ptr_ != graph_ptr) {
  375. MS_LOG(INFO) << "CheckGraphPtr Debugger got new graph: " << graph_ptr->graph_id();
  376. // save new graph_ptr
  377. graph_ptr_ = graph_ptr;
  378. if (!is_dataset_graph_) {
  379. // only try to enable debugger if it is not a dataset graph
  380. EnableDebugger();
  381. if (debugger_enabled_) {
  382. LoadParametersAndConst();
  383. // get graph proto and send to Mindinsight
  384. auto graph_proto = graph_proto_list_.front();
  385. SendGraphAndSuspend(graph_proto);
  386. }
  387. }
  388. }
  389. }
  390. void Debugger::CheckDatasetGraph() {
  391. // print parameter node names
  392. const auto &params = graph_ptr_->inputs();
  393. for (const auto &param : params) {
  394. MS_LOG(INFO) << "param: " << param->fullname_with_scope();
  395. }
  396. // check if there is GetNext or InitDataSetQueue node
  397. const auto &nodes = graph_ptr_->execution_order();
  398. for (const auto &node : nodes) {
  399. auto node_name = AnfAlgo::GetCNodeName(node);
  400. MS_LOG(INFO) << "node: " << node->fullname_with_scope();
  401. if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
  402. MS_LOG(INFO) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
  403. << node_name;
  404. is_dataset_graph_ = true;
  405. return;
  406. }
  407. }
  408. is_dataset_graph_ = false;
  409. }
  410. GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
  411. // convert kernel graph to debugger modelproto
  412. ModelProto model = GetDebuggerFuncGraphProto(graph_ptr_);
  413. return model.graph();
  414. }
  415. void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
  416. if (SendMetadata(true)) {
  417. // send graph to Mindinsight server
  418. EventReply reply = grpc_client_->SendGraph(graph_proto);
  419. if (reply.status() != reply.OK) {
  420. MS_LOG(ERROR) << "Error: SendGraph failed";
  421. }
  422. // enter command loop, wait and process commands
  423. CommandLoop();
  424. }
  425. }
  426. bool Debugger::SendMetadata(bool version_check) {
  427. // prepare metadata
  428. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  429. Metadata metadata;
  430. metadata.set_device_name(device_name);
  431. metadata.set_cur_step(num_step_);
  432. metadata.set_backend(device_target_);
  433. metadata.set_cur_node(cur_name_);
  434. metadata.set_training_done(training_done_);
  435. metadata.set_ms_version(version_);
  436. MS_LOG(INFO) << "Is training done?" << training_done_;
  437. // set graph munber to not_dataset_graph_sum_
  438. metadata.set_graph_num(not_dataset_graph_sum_);
  439. EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
  440. bool ret = false;
  441. if (reply_metadata.status() == reply_metadata.OK) {
  442. if (version_check) {
  443. // get type of the command in meta data reply, it should be version matched
  444. DebuggerCommand cmd = GetCommand(reply_metadata);
  445. if (cmd != DebuggerCommand::kVersionMatchedCMD) {
  446. MS_LOG(ERROR) << "MindInsight version is too old, Mindspore version is " << version_;
  447. Exit();
  448. } else {
  449. if (GetMiVersionMatched(reply_metadata)) {
  450. MS_LOG(INFO) << "MindSpore version is " << version_ << " matches MindInsight version.";
  451. ret = true;
  452. } else {
  453. MS_LOG(ERROR) << "MindSpore version " << version_ << ", did not match MindInsight version.";
  454. CommandLoop();
  455. }
  456. }
  457. } else {
  458. // version check is done before so we can just return true here
  459. ret = true;
  460. }
  461. } else {
  462. MS_LOG(ERROR) << "Error: SendMetadata failed";
  463. }
  464. return ret;
  465. }
  466. void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list, uint32_t graph_sum) {
  467. if (!SendMetadata(true)) {
  468. return;
  469. }
  470. // send multiple graphs to mindinght server
  471. // split graph into chunks if one graph is larger than chunk size
  472. std::list<Chunk> chunked_graph_proto_list;
  473. Chunk chunk;
  474. for (auto graph : graph_proto_list) {
  475. std::string str = graph.SerializeAsString();
  476. auto graph_size = graph.ByteSize();
  477. if (graph_size > CHUNK_SIZE) {
  478. auto sub_graph_str = grpc_client_->ChunkString(str, graph_size);
  479. for (unsigned int i = 0; i < sub_graph_str.size(); i++) {
  480. chunk.set_buffer(sub_graph_str[i]);
  481. chunked_graph_proto_list.push_back(chunk);
  482. if (i < sub_graph_str.size() - 1) {
  483. chunk.set_finished(false);
  484. } else {
  485. chunk.set_finished(true);
  486. chunked_graph_proto_list.push_back(chunk);
  487. }
  488. }
  489. } else {
  490. chunk.set_buffer(str);
  491. chunk.set_finished(true);
  492. chunked_graph_proto_list.push_back(chunk);
  493. }
  494. }
  495. EventReply reply = grpc_client_->SendMultiGraphs(chunked_graph_proto_list);
  496. if (reply.status() != reply.OK) {
  497. MS_LOG(ERROR) << "Error: SendGraph failed";
  498. }
  499. // enter command loop, wait and process commands
  500. CommandLoop();
  501. }
  502. void Debugger::CommandLoop() {
  503. // prepare metadata
  504. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  505. Metadata metadata;
  506. metadata.set_device_name(device_name);
  507. metadata.set_cur_step(num_step_);
  508. metadata.set_backend(device_target_);
  509. metadata.set_cur_node(cur_name_);
  510. metadata.set_training_done(training_done_);
  511. // loop exit flag
  512. bool run = false;
  513. int num_wait_fail = 0;
  514. const int max_num_wait_fail = 5;
  515. while (!run) {
  516. // wait for command
  517. EventReply reply = grpc_client_->WaitForCommand(metadata);
  518. if (reply.status() != reply.OK) {
  519. MS_LOG(ERROR) << "Error: WaitForCommand failed";
  520. num_wait_fail++;
  521. if (num_wait_fail > max_num_wait_fail) {
  522. MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session.";
  523. MS_LOG(ERROR) << "Failed to connect to MindInsight debugger server. Please check the config "
  524. "of debugger host and port.";
  525. Exit();
  526. run = true;
  527. } else {
  528. MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
  529. << num_wait_fail << "s";
  530. std::this_thread::sleep_for(std::chrono::milliseconds(1000 * num_wait_fail));
  531. }
  532. continue;
  533. }
  534. // get type of the command in reply
  535. DebuggerCommand cmd = GetCommand(reply);
  536. if (cmd == DebuggerCommand::kUnknownCMD) {
  537. MS_LOG(DEBUG) << "Debug: debugger received unknown command";
  538. continue;
  539. }
  540. MS_LOG(INFO) << "received command: ";
  541. switch (cmd) {
  542. case DebuggerCommand::kUnknownCMD:
  543. MS_LOG(INFO) << "UnknownCMD";
  544. break;
  545. case DebuggerCommand::kExitCMD:
  546. MS_LOG(INFO) << "ExitCMD";
  547. Exit();
  548. // Used for debugger termination
  549. run = true;
  550. break;
  551. case DebuggerCommand::kRunCMD:
  552. MS_LOG(INFO) << "RunCMD";
  553. if (GetRunLevel(reply) == "recheck") {
  554. MS_LOG(INFO) << "rechecking all watchpoints";
  555. SendWatchpoints(CheckWatchpoints("", nullptr, true));
  556. } else {
  557. // no longer the initial suspension.
  558. initial_suspend_ = false;
  559. // print run cmd content
  560. // get run_level and node_name
  561. run_level_ = GetRunLevel(reply);
  562. node_name_ = GetNodeName(reply);
  563. MS_LOG(INFO) << "run_level: " << run_level_;
  564. MS_LOG(INFO) << "node_name_: " << node_name_;
  565. // exit loop
  566. run = true;
  567. }
  568. break;
  569. case DebuggerCommand::kSetCMD:
  570. MS_LOG(INFO) << "SetCMD";
  571. MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
  572. MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
  573. if (GetWatchpointDelete(reply)) {
  574. MS_LOG(INFO) << "Deleting watchpoint";
  575. RemoveWatchpoint(GetWatchpointID(reply));
  576. } else {
  577. MS_LOG(INFO) << "Setting watchpoint";
  578. MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
  579. ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
  580. for (const auto &node : recieved_nodes) {
  581. MS_LOG(INFO) << "node name: " << node.node_name();
  582. MS_LOG(INFO) << "node type: " << node.node_type();
  583. }
  584. ProtoVector<WatchCondition_Parameter> parameters = GetParameters(reply);
  585. for (const auto &parameter : parameters) {
  586. MS_LOG(INFO) << "parameter name: " << parameter.name();
  587. MS_LOG(INFO) << "parameter is disabled: " << parameter.disabled();
  588. MS_LOG(INFO) << "parameter value: " << parameter.value();
  589. }
  590. SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply), GetParameters(reply));
  591. }
  592. break;
  593. case DebuggerCommand::kViewCMD: {
  594. MS_LOG(INFO) << "ViewCMD";
  595. // print view cmd content
  596. ProtoVector<TensorProto> received_tensors = GetTensors(reply);
  597. for (auto received_tensor : received_tensors) {
  598. MS_LOG(INFO) << "tensor node name: " << received_tensor.node_name();
  599. MS_LOG(INFO) << "tensor slot: " << received_tensor.slot();
  600. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << received_tensor.finished() << std::noboolalpha;
  601. MS_LOG(INFO) << "tensor iter: " << received_tensor.iter();
  602. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << received_tensor.truncate() << std::noboolalpha;
  603. }
  604. MS_LOG(INFO) << "Sending tensors";
  605. std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
  606. // print view cmd reply
  607. for (auto tensor : tensors) {
  608. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  609. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  610. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  611. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  612. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  613. MS_LOG(INFO) << "tensor dims: ";
  614. for (auto dim : tensor.dims()) {
  615. MS_LOG(INFO) << dim << ",";
  616. }
  617. MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
  618. }
  619. EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
  620. if (send_tensors_reply.status() != send_tensors_reply.OK) {
  621. MS_LOG(ERROR) << "Error: SendTensors failed";
  622. }
  623. } break;
  624. case DebuggerCommand::kVersionMatchedCMD:
  625. MS_LOG(ERROR) << "Received unexpected Version Matched CMD from Mindinsight.";
  626. Exit();
  627. break;
  628. default:
  629. MS_LOG(ERROR) << "Received unknown CMD from Mindinsight";
  630. Exit();
  631. break;
  632. }
  633. }
  634. }
  635. void AddTensorProtoInfo(TensorProto *tensor_item, TensorProto tensor) {
  636. tensor_item->set_node_name(tensor.node_name());
  637. tensor_item->set_slot(tensor.slot());
  638. tensor_item->set_iter(tensor.iter());
  639. tensor_item->set_truncate(tensor.truncate());
  640. tensor_item->clear_tensor_content();
  641. tensor_item->clear_data_type();
  642. tensor_item->clear_dims();
  643. }
  644. void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id,
  645. const ProtoVector<WatchCondition_Parameter> &parameters) {
  646. std::vector<std::tuple<std::string, bool>> check_node_list;
  647. std::vector<DebugServices::parameter_t> parameter_list;
  648. std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
  649. [](const WatchNode &node) -> std::tuple<std::string, bool> {
  650. return make_tuple(node.node_name(), node.node_type() == "scope");
  651. });
  652. std::transform(
  653. parameters.begin(), parameters.end(), std::back_inserter(parameter_list),
  654. [](const WatchCondition_Parameter &parameter) -> DebugServices::parameter_t {
  655. return DebugServices::parameter_t{parameter.name(), parameter.disabled(), parameter.value(), parameter.hit()};
  656. });
  657. debug_services_->AddWatchpoint(id, condition.condition(), condition.value(), check_node_list, parameter_list);
  658. }
  659. void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
  660. std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
  661. std::vector<std::string> name;
  662. std::vector<std::string> ret_name;
  663. std::vector<char *> data_ptr;
  664. std::vector<unsigned int> data_size;
  665. std::vector<TypePtr> dtype;
  666. std::vector<std::vector<int64_t>> shape;
  667. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  668. // ret_name will contain tensor names that are found in TensorLoader
  669. // items in ret_name will be in the same order with tensors if found
  670. debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
  671. std::list<TensorProto> tensor_list;
  672. unsigned int result_index = 0;
  673. for (auto tensor : tensors) {
  674. int size_iter = 0;
  675. if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
  676. TensorProto tensor_item;
  677. tensor_item.set_finished(true);
  678. AddTensorProtoInfo(&tensor_item, tensor);
  679. tensor_list.push_back(tensor_item);
  680. continue;
  681. }
  682. int tensor_size = data_size[result_index];
  683. while (size_iter < tensor_size) {
  684. int chunk_size = CHUNK_SIZE;
  685. TensorProto tensor_item;
  686. tensor_item.set_finished(false);
  687. if (tensor_size - size_iter <= CHUNK_SIZE) {
  688. chunk_size = tensor_size - size_iter;
  689. tensor_item.set_finished(true);
  690. }
  691. AddTensorProtoInfo(&tensor_item, tensor);
  692. // return empty tensor if didn't find the requested tensor
  693. tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
  694. tensor_item.set_data_type(GetDebuggerNumberDataType(dtype[result_index]));
  695. for (auto &elem : shape[result_index]) {
  696. tensor_item.add_dims(elem);
  697. }
  698. // add tensor to result list and increment result_index to check next item in ret_name
  699. tensor_list.push_back(tensor_item);
  700. size_iter += CHUNK_SIZE;
  701. }
  702. result_index++;
  703. }
  704. return tensor_list;
  705. }
  706. void Debugger::Exit() {
  707. // clear resource before exit
  708. // debugger will notify main thread to exit because main thread can only exit at step boundary
  709. pipeline::ExecutorPy::DebugTerminate(true);
  710. }
  711. std::list<WatchpointHit> Debugger::CheckWatchpoints(const std::string &watchnode, const CNodePtr &kernel,
  712. bool recheck) {
  713. std::vector<std::string> name;
  714. std::vector<std::string> slot;
  715. std::vector<int> condition;
  716. std::vector<unsigned int> watchpoint_id;
  717. std::vector<std::string> overflow_ops;
  718. std::vector<std::vector<DebugServices::parameter_t>> parameters;
  719. std::vector<int32_t> error_codes;
  720. #ifdef ENABLE_D
  721. overflow_ops = CheckOpOverflow();
  722. #endif
  723. std::vector<std::shared_ptr<TensorData>> tensor_list;
  724. if (watchnode.empty()) {
  725. tensor_list = debug_services_->GetTensor();
  726. } else {
  727. tensor_list = debug_services_->GetNodeTensor(kernel);
  728. }
  729. debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, &parameters, &error_codes, overflow_ops,
  730. tensor_list, initial_suspend_, watchnode.empty(), recheck);
  731. std::list<WatchpointHit> hits;
  732. for (unsigned int i = 0; i < name.size(); i++) {
  733. WatchpointHit hit;
  734. std::vector<DebugServices::parameter_t> &parameter = parameters[i];
  735. hit.set_id(watchpoint_id[i]);
  736. hit.set_error_code(error_codes[i]);
  737. // here TensorProto act as a tensor indicator, not sending tensor content
  738. TensorProto *tensor_item = hit.mutable_tensor();
  739. tensor_item->set_node_name(name[i]);
  740. tensor_item->set_slot(slot[i]);
  741. tensor_item->set_finished(true);
  742. WatchCondition *condition_item = hit.mutable_watch_condition();
  743. condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
  744. for (const auto &p : parameter) {
  745. auto x = condition_item->mutable_params()->Add();
  746. x->set_name(p.name);
  747. x->set_disabled(p.disabled);
  748. x->set_value(p.value);
  749. x->set_hit(p.hit);
  750. x->set_actual_value(p.actual_value);
  751. }
  752. hits.push_back(hit);
  753. }
  754. return hits;
  755. }
  756. void Debugger::SendWatchpoints(const std::list<WatchpointHit> &points) {
  757. // send info about watchpoint
  758. if (!points.empty()) {
  759. EventReply reply = grpc_client_->SendWatchpointHits(points);
  760. if (reply.status() != reply.OK) {
  761. MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
  762. }
  763. }
  764. }
  765. bool Debugger::DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath,
  766. const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type,
  767. TypeId addr_type_id, const std::string &addr_format, size_t slot) const {
  768. return debug_services_.get()->DumpTensorToFile(tensor_name, trans_flag, filepath, host_fmt, host_shape, host_type,
  769. addr_type_id, addr_format, slot);
  770. }
  771. bool Debugger::DebugServicesIsWatchPoint(const std::string &kernel_name, const CNodePtr &kernel) const {
  772. return debug_services_.get()->IsWatchPoint(kernel_name, kernel);
  773. }
  774. void Debugger::EmptyTensor() { debug_services_.get()->EmptyTensor(); }
  775. void Debugger::SetTensorLoaderIterNum(uint32_t iter_num) { debug_services_.get()->SetTensorLoaderIterNum(iter_num); }
  776. void Debugger::EmptyPrevTensor() { debug_services_.get()->EmptyPrevTensor(); }
  777. uint32_t Debugger::GetTensorLoaderIterNum() const { return debug_services_.get()->GetTensorLoaderIterNum(); }
  778. bool Debugger::LoadNewTensor(const std::shared_ptr<TensorData> &tensor, bool keep_prev) {
  779. return debug_services_.get()->LoadNewTensor(tensor, keep_prev);
  780. }
  781. bool Debugger::debugger_enabled() const { return debugger_enabled_; }
  782. DebuggerCommand GetCommand(const EventReply &reply) {
  783. DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
  784. switch (reply.cmd_case()) {
  785. case debugger::EventReply::CmdCase::kExit:
  786. cmd = DebuggerCommand::kExitCMD;
  787. break;
  788. case debugger::EventReply::CmdCase::kRunCmd:
  789. cmd = DebuggerCommand::kRunCMD;
  790. break;
  791. case debugger::EventReply::CmdCase::kSetCmd:
  792. cmd = DebuggerCommand::kSetCMD;
  793. break;
  794. case debugger::EventReply::CmdCase::kViewCmd:
  795. cmd = DebuggerCommand::kViewCMD;
  796. break;
  797. case debugger::EventReply::CmdCase::kVersionMatched:
  798. cmd = DebuggerCommand::kVersionMatchedCMD;
  799. break;
  800. default:
  801. MS_LOG(DEBUG) << "Debug: UnknownCMD";
  802. break;
  803. }
  804. return cmd;
  805. }
  806. ProtoVector<WatchCondition_Parameter> GetParameters(const EventReply &reply) {
  807. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  808. MS_LOG(ERROR) << "Error: Can not get Parameters from command. Returning default value: ProtoVector<Parameter>().";
  809. return ProtoVector<WatchCondition_Parameter>();
  810. }
  811. return reply.set_cmd().watch_condition().params();
  812. }
  813. ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
  814. if (!reply.has_set_cmd()) {
  815. MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
  816. return ProtoVector<WatchNode>();
  817. }
  818. return reply.set_cmd().watch_nodes();
  819. }
  820. std::string GetRunLevel(const EventReply &reply) {
  821. if (!reply.has_run_cmd()) {
  822. MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
  823. "";
  824. return "";
  825. }
  826. return reply.run_cmd().run_level();
  827. }
  828. std::string GetNodeName(const EventReply &reply) {
  829. if (!reply.has_run_cmd()) {
  830. MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
  831. "";
  832. return "";
  833. }
  834. return reply.run_cmd().node_name();
  835. }
  836. WatchCondition GetWatchcondition(const EventReply &reply) {
  837. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  838. MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
  839. return WatchCondition();
  840. }
  841. return reply.set_cmd().watch_condition();
  842. }
  843. int32_t GetWatchpointID(const EventReply &reply) {
  844. if (!reply.has_set_cmd()) {
  845. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
  846. return 0;
  847. }
  848. return reply.set_cmd().id();
  849. }
  850. bool GetWatchpointDelete(const EventReply &reply) {
  851. if (!reply.has_set_cmd()) {
  852. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
  853. return false;
  854. }
  855. return reply.set_cmd().delete_();
  856. }
  857. ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
  858. if (!reply.has_view_cmd()) {
  859. MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
  860. return ProtoVector<TensorProto>();
  861. }
  862. return reply.view_cmd().tensors();
  863. }
  864. std::string GetTensorFullName(const TensorProto &tensor) {
  865. string node_name = tensor.node_name();
  866. if (tensor.truncate()) {
  867. // scopes in node name are seperated by '/'
  868. // use the name without scope if truncate is true
  869. std::size_t found = node_name.find_last_of("/");
  870. node_name = node_name.substr(found + 1);
  871. }
  872. return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
  873. }
  874. bool GetMiVersionMatched(const EventReply &reply) { return reply.version_matched(); }
  875. bool Debugger::partial_memory() { return partial_memory_; }
  876. void Debugger::SetCurNode(std::string cur_name) {
  877. // access lock for public method
  878. std::lock_guard<std::mutex> a_lock(access_lock_);
  879. cur_name_ = cur_name;
  880. }
  881. std::string Debugger::run_level() const { return run_level_; }
  882. void Debugger::SetStepNum(int32_t cur_num_step) {
  883. // access lock for public method
  884. std::lock_guard<std::mutex> a_lock(access_lock_);
  885. num_step_ = cur_num_step;
  886. }
  887. int32_t Debugger::step_num() const { return num_step_; }
  888. uint64_t BytestoInt64(const std::vector<char> &buffer) {
  889. uint64_t ret;
  890. ret = ((uint64_t)buffer[7] << 56) | ((uint64_t)buffer[6] << 48) | ((uint64_t)buffer[5] << 40) |
  891. ((uint64_t)buffer[4] << 32) | ((uint64_t)buffer[3] << 24) | ((uint64_t)buffer[2] << 16) |
  892. ((uint64_t)buffer[1] << 8) | ((uint64_t)buffer[0]);
  893. return ret;
  894. }
  895. #define BUF_SIZ 256
  896. std::vector<std::string> Debugger::CheckOpOverflow() {
  897. std::vector<double> bin_list;
  898. std::vector<std::string> op_names;
  899. for (const auto &[graph_id, overflow_bin_path] : overflow_bin_path_) {
  900. DIR *d;
  901. d = opendir(overflow_bin_path.c_str());
  902. MS_LOG(INFO) << "processing bin file path " << overflow_bin_path << ", graph id " << graph_id;
  903. if (d != nullptr) {
  904. struct dirent *dir = nullptr;
  905. while ((dir = readdir(d)) != NULL) {
  906. if (dir->d_type == DT_REG) {
  907. std::string file_path = overflow_bin_path;
  908. file_path.append(dir->d_name);
  909. std::string file_name = dir->d_name;
  910. std::size_t found = file_name.find_last_of(".");
  911. if (found == std::string::npos) {
  912. continue;
  913. }
  914. std::string overflow_time = file_name.substr(found + 1);
  915. if (stod(overflow_time) <= last_overflow_bin_) {
  916. MS_LOG(INFO) << "File already processed " << file_name;
  917. continue;
  918. }
  919. bin_list.push_back(stod(overflow_time));
  920. std::fstream infile;
  921. infile.open(file_path.c_str(), std::ios::binary | std::ios::in);
  922. if (!infile.is_open()) {
  923. MS_LOG(ERROR) << "Failed to open overflow bin file " << file_name;
  924. continue;
  925. }
  926. infile.seekg(313, std::ios::beg);
  927. std::vector<char> buffer;
  928. buffer.resize(BUF_SIZ);
  929. infile.read(buffer.data(), BUF_SIZ);
  930. uint64_t stream_id = BytestoInt64(std::vector<char>(buffer.begin() + 8, buffer.end()));
  931. uint64_t task_id = BytestoInt64(std::vector<char>(buffer.begin() + 16, buffer.end()));
  932. MS_LOG(INFO) << "Overflow stream_id " << stream_id << ", task_id " << task_id << ".";
  933. auto op = debugger_->stream_task_to_opname_.find(std::make_pair(stream_id, task_id));
  934. if (op != debugger_->stream_task_to_opname_.end()) {
  935. MS_LOG(ERROR) << "Overflow detected on node " << op->second << std::endl;
  936. op_names.push_back(op->second);
  937. } else {
  938. MS_LOG(INFO) << "No overflow is detected " << std::endl;
  939. }
  940. infile.close();
  941. }
  942. }
  943. } else {
  944. MS_LOG(INFO) << "OverFlow bin directory does not exist!";
  945. }
  946. closedir(d);
  947. }
  948. if (!op_names.empty()) {
  949. MS_LOG(ERROR) << "These operation overflows are detected " << op_names;
  950. }
  951. for (auto &i : bin_list) {
  952. if (i > last_overflow_bin_) {
  953. last_overflow_bin_ = i;
  954. }
  955. }
  956. auto iter_op_names = overflow_ops_.find(num_step_);
  957. if (iter_op_names == overflow_ops_.end()) {
  958. overflow_ops_.insert(std::pair<uint32_t, std::vector<std::string>>(num_step_, op_names));
  959. return op_names;
  960. }
  961. iter_op_names->second.insert(std::end(iter_op_names->second), std::begin(op_names), std::end(op_names));
  962. return iter_op_names->second;
  963. }
  964. void Debugger::SetTrainingDone(bool training_done) { training_done_ = training_done; }
  965. bool Debugger::CheckPort(const char *port) {
  966. char *p = const_cast<char *>(port);
  967. int num = 0;
  968. if (*p == '0' && *(p + 1) != '\0') return false;
  969. while (*p != '\0') {
  970. if (*p < '0' || *p > '9') return false;
  971. num = num * 10 + (*p) - '0';
  972. if (num < 1 || num > 65535) return false;
  973. p++;
  974. }
  975. return true;
  976. }
  977. bool Debugger::CheckIp(const char *host) {
  978. std::regex reg_ip(
  979. "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
  980. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  981. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  982. "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
  983. std::smatch smat;
  984. std::string host_str = std::string(host);
  985. return std::regex_match(host_str, smat, reg_ip);
  986. }
  987. uint32_t Debugger::GetFirstRunGraphId() { return rungraph_id_list_.front(); }
  988. void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
  989. MS_EXCEPTION_IF_NULL(anf_node);
  990. if (!anf_node->isa<Parameter>() && !anf_node->isa<ValueNode>()) {
  991. return;
  992. }
  993. // for parameters and value nodes, set its execution order to be 0;
  994. int exec_order = 0;
  995. std::string node_name = anf_node->fullname_with_scope();
  996. E2eDumpUtil::GetFileKernelName(NOT_NULL(&node_name));
  997. // check if output adde exists, if not, return;
  998. if (!AnfAlgo::OutputAddrExist(anf_node, output_index)) {
  999. return;
  1000. }
  1001. auto addr = AnfAlgo::GetOutputAddr(anf_node, output_index);
  1002. MS_EXCEPTION_IF_NULL(addr);
  1003. auto type = AnfAlgo::GetOutputInferDataType(anf_node, output_index);
  1004. auto format = kOpFormat_DEFAULT;
  1005. string tensor_name = node_name + ':' + "0";
  1006. ShapeVector int_shapes;
  1007. auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index);
  1008. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  1009. [](size_t inner_item) { return SizeToInt(inner_item); });
  1010. bool keep_prev;
  1011. if (anf_node->isa<Parameter>()) {
  1012. keep_prev = true;
  1013. debug_services_->MoveTensorCurrentToPrev(tensor_name);
  1014. } else {
  1015. keep_prev = false;
  1016. }
  1017. bool ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, keep_prev);
  1018. if (!ret) {
  1019. MS_LOG(ERROR) << "LoadMemToHost:"
  1020. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1021. }
  1022. }
  1023. void Debugger::LoadParametersAndConst() {
  1024. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  1025. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1026. // load parameters
  1027. MS_LOG(INFO) << "Start to load Parameters!";
  1028. const auto &parameters = graph_ptr_->inputs();
  1029. for (auto &item : parameters) {
  1030. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
  1031. }
  1032. // load value nodes
  1033. // get all constant avlues from the graph
  1034. MS_LOG(INFO) << "Start to load value nodes!";
  1035. const auto value_nodes = graph_ptr_->graph_value_nodes();
  1036. for (auto &item : value_nodes) {
  1037. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
  1038. }
  1039. }
  1040. void Debugger::LoadGraphOutputs() {
  1041. if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
  1042. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1043. const auto &apply_kernels = graph_ptr_->execution_order();
  1044. // for kernels, execution order starts from 1
  1045. int exec_order = 1;
  1046. for (const auto &node : apply_kernels) {
  1047. MS_EXCEPTION_IF_NULL(node);
  1048. auto node_name = AnfAlgo::GetCNodeName(node);
  1049. std::string kernel_name = node->fullname_with_scope();
  1050. auto output_size = AnfAlgo::GetOutputTensorNum(node);
  1051. if (partial_memory_) {
  1052. if (!debug_services_->IsWatchPoint(kernel_name, node)) {
  1053. continue;
  1054. }
  1055. }
  1056. for (size_t j = 0; j < output_size; ++j) {
  1057. auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
  1058. MS_EXCEPTION_IF_NULL(kernel_info);
  1059. auto addr_test = kernel_info->GetOutputAddr(j);
  1060. if (addr_test == nullptr) {
  1061. MS_LOG(INFO) << "Cannot find output addr for slot " << j << " for " << kernel_name;
  1062. continue;
  1063. }
  1064. auto addr = AnfAlgo::GetOutputAddr(node, j);
  1065. MS_EXCEPTION_IF_NULL(addr);
  1066. auto type = AnfAlgo::GetOutputInferDataType(node, j);
  1067. auto format = kOpFormat_DEFAULT;
  1068. string tensor_name = kernel_name + ':' + std::to_string(j);
  1069. ShapeVector int_shapes;
  1070. auto shape = AnfAlgo::GetOutputDeviceShape(node, j);
  1071. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  1072. [](size_t inner_item) { return SizeToInt(inner_item); });
  1073. auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false);
  1074. if (!ret) {
  1075. MS_LOG(ERROR) << "LoadMemToHost:"
  1076. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1077. }
  1078. }
  1079. exec_order = exec_order + 1;
  1080. }
  1081. }
  1082. void Debugger::UpdateStepNum(const session::KernelGraph *graph) {
  1083. // update step number if we are processing the first graph (to support multigraph)
  1084. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()) &&
  1085. (graph->graph_id() == debugger_->GetFirstRunGraphId())) {
  1086. // access lock for public method
  1087. std::lock_guard<std::mutex> a_lock(access_lock_);
  1088. ++num_step_;
  1089. }
  1090. }
  1091. void Debugger::ClearCurrentData() {
  1092. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()))
  1093. debug_services_->EmptyCurrentTensor();
  1094. }
  1095. bool Debugger::TensorExistsInCurrent(std::string tensor_name) {
  1096. return debug_services_->TensorExistsInCurrent(tensor_name);
  1097. }
  1098. } // namespace mindspore