You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger.cc 46 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <stdio.h>
  18. #include <fstream>
  19. #include <tuple>
  20. #include <vector>
  21. #include <algorithm>
  22. #include <iostream>
  23. #include <cstring>
  24. #include <utility>
  25. #include <map>
  26. #include <regex>
  27. #include "debug/debugger/debugger.h"
  28. #include "debug/data_dump/dump_json_parser.h"
  29. #include "pipeline/jit/pipeline.h"
  30. #include "backend/session/anf_runtime_algorithm.h"
  31. #include "runtime/device/kernel_runtime_manager.h"
  32. #include "runtime/device/kernel_runtime.h"
  33. #include "debug/data_dump/e2e_dump_util.h"
  34. #include "utils/config_manager.h"
  35. using debugger::Chunk;
  36. using debugger::EventReply;
  37. using debugger::GraphProto;
  38. using debugger::ModelProto;
  39. using debugger::TensorProto;
  40. using debugger::WatchCondition;
  41. using debugger::WatchCondition_Condition_inf;
  42. using debugger::WatchCondition_Condition_nan;
  43. using debugger::WatchCondition_Parameter;
  44. using debugger::WatchNode;
  45. using debugger::WatchpointHit;
  46. #define CHUNK_SIZE 1024 * 1024 * 3
  47. namespace mindspore {
  48. DebuggerPtr Debugger::debugger_ = nullptr;
  49. std::mutex Debugger::instance_lock_;
  50. static const size_t PARAMETER_OUTPUT_INDEX = 0;
  51. static const size_t VALUE_NODE_OUTPUT_INDEX = 0;
  52. Debugger::Debugger()
  53. : grpc_client_(nullptr),
  54. debug_services_(nullptr),
  55. device_id_(0),
  56. device_target_(""),
  57. num_step_(0),
  58. debugger_enabled_(false),
  59. run_level_(""),
  60. node_name_(""),
  61. cur_name_(""),
  62. training_done_(false),
  63. is_dataset_graph_(false),
  64. partial_memory_(false),
  65. last_overflow_bin_(0),
  66. initial_suspend_(true),
  67. not_dataset_graph_sum_(0),
  68. version_("") {
  69. CheckDebuggerEnabledParam();
  70. auto ms_context = MsContext::GetInstance();
  71. MS_EXCEPTION_IF_NULL(ms_context);
  72. device_target_ = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  73. MS_LOG(INFO) << "Debugger got device_target: " << device_target_;
  74. CheckDebuggerEnabledParam();
  75. if (device_target_ == kCPUDevice) {
  76. MS_LOG(WARNING) << "Not enabling debugger. Debugger does not support CPU.";
  77. } else if (CheckDebuggerEnabled()) {
  78. // configure partial memory reuse
  79. partial_memory_ = CheckDebuggerPartialMemoryEnabled();
  80. auto context_ptr = MsContext::GetInstance();
  81. MS_EXCEPTION_IF_NULL(context_ptr);
  82. // switch memory reuse on or off
  83. context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, partial_memory_);
  84. // print some message about memory reuse to user
  85. if (partial_memory_) {
  86. MS_LOG(WARNING)
  87. << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
  88. "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
  89. } else {
  90. MS_LOG(WARNING)
  91. << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
  92. "usage for large models.";
  93. }
  94. }
  95. }
  96. void Debugger::Init(const uint32_t device_id) {
  97. // access lock for public method
  98. std::lock_guard<std::mutex> a_lock(access_lock_);
  99. // save device_id
  100. MS_LOG(INFO) << "Debugger got device_id: " << device_id;
  101. device_id_ = device_id;
  102. version_ = "1.2.0";
  103. }
  104. void Debugger::EnableDebugger() {
  105. // reset some of the class members
  106. num_step_ = 0;
  107. debugger_enabled_ = false;
  108. partial_memory_ = false;
  109. grpc_client_ = nullptr;
  110. debug_services_ = nullptr;
  111. // see if dump using debugger backend is enabled
  112. bool dump_enabled = CheckDebuggerDumpEnabled();
  113. MS_LOG(INFO) << "dump using debugger backend = " << dump_enabled;
  114. // check if debugger enabled
  115. debugger_enabled_ = CheckDebuggerEnabled();
  116. MS_LOG(INFO) << "debugger_enabled_ = " << debugger_enabled_;
  117. if (!debugger_enabled_ && !dump_enabled) {
  118. MS_LOG(INFO) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
  119. return;
  120. }
  121. if (debugger_enabled_) {
  122. // configure grpc host
  123. const char *env_host_str = std::getenv("MS_DEBUGGER_HOST");
  124. std::string host;
  125. if (env_host_str != nullptr) {
  126. if (CheckIp(env_host_str)) {
  127. MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
  128. host = std::string(env_host_str);
  129. } else {
  130. debugger_enabled_ = false;
  131. MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_HOST isn't a valid IP address. "
  132. "Please set environment variable MS_DEBUGGER_HOST=x.x.x.x to a valid IP";
  133. }
  134. } else {
  135. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
  136. host = "localhost";
  137. }
  138. // configure grpc port
  139. const char *env_port_str = std::getenv("MS_DEBUGGER_PORT");
  140. std::string port;
  141. if (env_port_str != nullptr) {
  142. if (CheckPort(env_port_str)) {
  143. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
  144. port = std::string(env_port_str);
  145. } else {
  146. debugger_enabled_ = false;
  147. MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to "
  148. "65535";
  149. }
  150. } else {
  151. port = "50051";
  152. if (!CheckPort(port.c_str())) {
  153. MS_EXCEPTION(ValueError) << "Default MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to 65535";
  154. }
  155. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
  156. }
  157. // initialize grpc client
  158. grpc_client_ = std::make_unique<GrpcClient>(host, port);
  159. }
  160. debug_services_ = std::make_unique<DebugServices>();
  161. }
  162. void Debugger::SetOpOverflowBinPath(uint32_t graph_id) {
  163. #ifdef ENABLE_D
  164. // set operation overflow info
  165. overflow_bin_path_.insert(std::pair<uint32_t, std::string>(
  166. graph_id, DumpJsonParser::GetInstance().GetOpOverflowBinPath(graph_id, device_id_)));
  167. // new overflow dump files will have a timestamp greater than last_overflow_bin_
  168. auto overflow_bin_path = overflow_bin_path_.find(graph_id)->second;
  169. DIR *d;
  170. d = opendir(overflow_bin_path.c_str());
  171. if (d != nullptr) {
  172. struct dirent *dir;
  173. while ((dir = readdir(d)) != NULL) {
  174. if (dir->d_type == DT_REG) {
  175. std::string file_path = overflow_bin_path;
  176. file_path.append(dir->d_name);
  177. std::size_t found = file_path.find_last_of(".");
  178. if (found == std::string::npos) {
  179. continue;
  180. }
  181. std::string overflow_time = file_path.substr(found + 1);
  182. if (stod(overflow_time) <= last_overflow_bin_) {
  183. MS_LOG(INFO) << "Old op overflow bin folder" << file_path;
  184. continue;
  185. }
  186. last_overflow_bin_ = stod(overflow_time);
  187. }
  188. }
  189. MS_LOG(INFO) << "last op overflow bin folder" << last_overflow_bin_;
  190. closedir(d);
  191. }
  192. #endif
  193. }
  194. void Debugger::CheckDatasetSinkMode() {
  195. if (CheckDebuggerDumpEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  196. MS_EXCEPTION(NotSupportError)
  197. << "e2e_dump not supported on GPU with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  198. }
  199. if (CheckDebuggerEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  200. MS_EXCEPTION(NotSupportError)
  201. << "Debugger is not supported with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  202. }
  203. }
  204. bool Debugger::CheckDebuggerDumpEnabled() {
  205. // see if dump is enabled
  206. if (device_target_ == kGPUDevice) {
  207. return device::KernelRuntime::DumpDataEnabled();
  208. }
  209. return false;
  210. }
  211. bool Debugger::CheckDebuggerEnabled() {
  212. // get env variables to configure debugger
  213. const char *env_enable_char = std::getenv("ENABLE_MS_DEBUGGER");
  214. if (env_enable_char != nullptr) {
  215. std::string env_enable_str = env_enable_char;
  216. (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
  217. if ((env_enable_str == "1" || env_enable_str == "true") && device_target_ != kCPUDevice) {
  218. return true;
  219. }
  220. }
  221. return false;
  222. }
  223. void Debugger::CheckDebuggerEnabledParam() {
  224. // check the value of env variable ENABLE_MS_DEBUGGER
  225. const char *env_enable_char = std::getenv("ENABLE_MS_DEBUGGER");
  226. if (env_enable_char != nullptr) {
  227. std::string env_enable_str = env_enable_char;
  228. (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
  229. if (env_enable_str != "0" && env_enable_str != "1" && env_enable_str != "false" && env_enable_str != "true") {
  230. MS_LOG(WARNING) << "Env variable ENABLE_MS_DEBUGGER should be True/False/1/0 (case insensitive), but get: "
  231. << env_enable_str;
  232. }
  233. }
  234. }
  235. bool Debugger::CheckDebuggerPartialMemoryEnabled() {
  236. const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM");
  237. if (env_partial_mem_str != nullptr) {
  238. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
  239. if (std::strcmp(env_partial_mem_str, "1") == 0) {
  240. return true;
  241. }
  242. }
  243. return false;
  244. }
  245. bool Debugger::DebuggerBackendEnabled() { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
  246. void Debugger::Reset() {
  247. // access lock for public method
  248. std::lock_guard<std::mutex> a_lock(access_lock_);
  249. // reset components
  250. device_id_ = 0;
  251. device_target_ = "";
  252. num_step_ = 0;
  253. debugger_enabled_ = false;
  254. is_dataset_graph_ = false;
  255. partial_memory_ = false;
  256. graph_ptr_ = nullptr;
  257. grpc_client_ = nullptr;
  258. debug_services_ = nullptr;
  259. last_overflow_bin_ = 0;
  260. overflow_bin_path_.clear();
  261. stream_task_to_opname_.clear();
  262. }
  263. void Debugger::PreExecute(const KernelGraphPtr &graph_ptr, uint32_t graph_sum) {
  264. // access lock for public method
  265. std::lock_guard<std::mutex> a_lock(access_lock_);
  266. CheckDatasetSinkMode();
  267. auto graph_id = graph_ptr->graph_id();
  268. // collect rungrap_ids to update step number in multigraph case
  269. if (!rungraph_id_list_.size()) {
  270. rungraph_id_list_.push_back(graph_id);
  271. } else {
  272. if (std::find(rungraph_id_list_.begin(), rungraph_id_list_.end(), graph_id) == rungraph_id_list_.end()) {
  273. rungraph_id_list_.push_back(graph_id);
  274. }
  275. }
  276. // check and save graph_ptr, suspend if graph is new
  277. MS_LOG(INFO) << "total number graph: " << graph_sum;
  278. // multiple graphs
  279. if (graph_sum > 1) {
  280. // there are more than one graphs are not dataset_graph
  281. if (not_dataset_graph_sum_ > 0) {
  282. // only try to enable debugger if they are not all dataset graphs
  283. if (!debugger_enabled_) {
  284. EnableDebugger();
  285. }
  286. if (debugger_enabled_) {
  287. if (graph_proto_list_.size()) {
  288. // only send compiled graphs once.
  289. auto dbg_graph_ptr = graph_ptr_;
  290. // use current graph ptr to load parameters
  291. graph_ptr_ = graph_ptr;
  292. LoadParametersAndConst();
  293. // revert graph ptr to original value
  294. graph_ptr_ = dbg_graph_ptr;
  295. SendMultiGraphsAndSuspend(graph_proto_list_, graph_sum);
  296. graph_proto_list_.clear();
  297. } else if (graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
  298. // stop only when receive the first sub run graph for each step
  299. CommandLoop();
  300. }
  301. }
  302. }
  303. } else if (graph_proto_list_.size() == 1) {
  304. // In single graph case, reset graph_ptr_ to be nullptr for the initial step
  305. if (num_step_ == 0) {
  306. graph_ptr_ = nullptr;
  307. }
  308. CheckGraphPtr(graph_ptr);
  309. }
  310. }
  311. void Debugger::PostExecute() {
  312. // access lock for public method
  313. std::lock_guard<std::mutex> a_lock(access_lock_);
  314. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  315. return;
  316. }
  317. if (debugger_->DebuggerBackendEnabled()) {
  318. // analyze tensor data and send the watchpoints been hit
  319. if (debugger_enabled_ && !is_dataset_graph_) {
  320. if (device_target_ != kGPUDevice) {
  321. num_step_++;
  322. }
  323. MS_LOG(INFO) << "Debugger suspend at end of step; number of steps executed: " << num_step_;
  324. SendWatchpoints(CheckWatchpoints());
  325. CommandLoop();
  326. }
  327. // Only keep parameters in the current map
  328. debug_services_->ResetLoadedTensors();
  329. }
  330. }
  331. bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) {
  332. if (debugger_enabled_ && !is_dataset_graph_) {
  333. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  334. // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
  335. if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
  336. return true;
  337. }
  338. }
  339. return false;
  340. }
  341. void Debugger::PostExecuteNode(const CNodePtr &kernel, bool last_kernel) {
  342. // access lock for public method
  343. std::lock_guard<std::mutex> a_lock(access_lock_);
  344. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  345. return;
  346. }
  347. if (debugger_enabled_ && !is_dataset_graph_) {
  348. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  349. // if kernel is watchpoint,and get hit. suspend.
  350. bool hit_empty_flag = true;
  351. if (is_watchpoint) {
  352. auto hits = CheckWatchpoints(cur_name_, kernel);
  353. if (!hits.empty()) {
  354. SendWatchpoints(hits);
  355. CommandLoop();
  356. hit_empty_flag = false;
  357. }
  358. }
  359. if (hit_empty_flag && run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_) && !last_kernel) {
  360. // if kernel is not watchpoint and is next_to or continue_to node, suspend
  361. // No need to suspend if this is the last node in graph since PostExecute suspends at the end of graph
  362. CommandLoop();
  363. }
  364. return;
  365. }
  366. }
  367. void Debugger::PostDebugOp() {
  368. // access lock for public method
  369. std::lock_guard<std::mutex> a_lock(access_lock_);
  370. // suspend if debugger is enabled
  371. if (debugger_enabled_ && !is_dataset_graph_) {
  372. MS_LOG(INFO) << "Debugger suspend at debug_op";
  373. CommandLoop();
  374. }
  375. }
  376. void Debugger::SetStreamTaskToOpnameMap(const std::map<std::pair<uint32_t, uint32_t>, std::string> &mapping) {
  377. stream_task_to_opname_ = mapping;
  378. }
  379. void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) {
  380. if (graph_ptr_ != graph_ptr) {
  381. MS_LOG(INFO) << "LoadGraphs Debugger got new graph: " << graph_ptr->graph_id();
  382. // save new graph_ptr
  383. graph_ptr_ = graph_ptr;
  384. CheckDatasetGraph();
  385. if (!is_dataset_graph_) {
  386. // get proto for new graph_ptr
  387. auto graph_proto = GetGraphProto(graph_ptr);
  388. // add new graph proto to graph_proto_list_
  389. graph_proto_list_.push_back(graph_proto);
  390. graph_ptr_list_.push_back(graph_ptr);
  391. #ifdef ENABLE_D
  392. SetOpOverflowBinPath(graph_ptr->graph_id());
  393. #endif
  394. not_dataset_graph_sum_++;
  395. }
  396. // reset is_dataset_graph to be false
  397. is_dataset_graph_ = false;
  398. }
  399. }
  400. // In single graph cases, check single graph ptr
  401. void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
  402. if (graph_ptr_ != graph_ptr) {
  403. MS_LOG(INFO) << "CheckGraphPtr Debugger got new graph: " << graph_ptr->graph_id();
  404. // save new graph_ptr
  405. graph_ptr_ = graph_ptr;
  406. if (!is_dataset_graph_) {
  407. // only try to enable debugger if it is not a dataset graph
  408. EnableDebugger();
  409. if (debugger_enabled_) {
  410. LoadParametersAndConst();
  411. // get graph proto and send to Mindinsight
  412. auto graph_proto = graph_proto_list_.front();
  413. SendGraphAndSuspend(graph_proto);
  414. }
  415. }
  416. }
  417. }
  418. void Debugger::CheckDatasetGraph() {
  419. // print parameter node names
  420. const auto &params = graph_ptr_->inputs();
  421. for (const auto &param : params) {
  422. MS_LOG(INFO) << "param: " << param->fullname_with_scope();
  423. }
  424. // check if there is GetNext or InitDataSetQueue node
  425. const auto &nodes = graph_ptr_->execution_order();
  426. for (const auto &node : nodes) {
  427. auto node_name = AnfAlgo::GetCNodeName(node);
  428. MS_LOG(INFO) << "node: " << node->fullname_with_scope();
  429. if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
  430. MS_LOG(INFO) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
  431. << node_name;
  432. is_dataset_graph_ = true;
  433. return;
  434. }
  435. }
  436. is_dataset_graph_ = false;
  437. }
  438. GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
  439. // convert kernel graph to debugger modelproto
  440. ModelProto model = GetDebuggerFuncGraphProto(graph_ptr_);
  441. return model.graph();
  442. }
  443. void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
  444. if (SendMetadata(true)) {
  445. // send graph to Mindinsight server
  446. EventReply reply = grpc_client_->SendGraph(graph_proto);
  447. if (reply.status() != reply.OK) {
  448. MS_LOG(ERROR) << "Error: SendGraph failed";
  449. }
  450. // enter command loop, wait and process commands
  451. CommandLoop();
  452. }
  453. }
  454. bool Debugger::SendMetadata(bool version_check) {
  455. // prepare metadata
  456. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  457. Metadata metadata;
  458. metadata.set_device_name(device_name);
  459. metadata.set_cur_step(num_step_);
  460. metadata.set_backend(device_target_);
  461. metadata.set_cur_node(cur_name_);
  462. metadata.set_training_done(training_done_);
  463. metadata.set_ms_version(version_);
  464. MS_LOG(INFO) << "Is training done?" << training_done_;
  465. // set graph munber to not_dataset_graph_sum_
  466. metadata.set_graph_num(not_dataset_graph_sum_);
  467. EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
  468. bool ret = false;
  469. if (reply_metadata.status() == reply_metadata.OK) {
  470. if (version_check) {
  471. // get type of the command in meta data reply, it should be version matched
  472. DebuggerCommand cmd = GetCommand(reply_metadata);
  473. if (cmd != DebuggerCommand::kVersionMatchedCMD) {
  474. MS_LOG(ERROR) << "MindInsight version is too old, Mindspore version is " << version_;
  475. Exit();
  476. } else {
  477. if (GetMiVersionMatched(reply_metadata)) {
  478. MS_LOG(INFO) << "MindSpore version is " << version_ << " matches MindInsight version.";
  479. ret = true;
  480. } else {
  481. MS_LOG(ERROR) << "MindSpore version " << version_ << ", did not match MindInsight version.";
  482. CommandLoop();
  483. }
  484. }
  485. } else {
  486. // version check is done before so we can just return true here
  487. ret = true;
  488. }
  489. } else {
  490. MS_LOG(ERROR) << "Error: SendMetadata failed";
  491. }
  492. return ret;
  493. }
  494. void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list, uint32_t graph_sum) {
  495. if (!SendMetadata(true)) {
  496. return;
  497. }
  498. // send multiple graphs to mindinght server
  499. // split graph into chunks if one graph is larger than chunk size
  500. std::list<Chunk> chunked_graph_proto_list;
  501. Chunk chunk;
  502. for (auto graph : graph_proto_list) {
  503. std::string str = graph.SerializeAsString();
  504. auto graph_size = graph.ByteSize();
  505. if (graph_size > CHUNK_SIZE) {
  506. auto sub_graph_str = grpc_client_->ChunkString(str, graph_size);
  507. for (unsigned int i = 0; i < sub_graph_str.size(); i++) {
  508. chunk.set_buffer(sub_graph_str[i]);
  509. chunked_graph_proto_list.push_back(chunk);
  510. if (i < sub_graph_str.size() - 1) {
  511. chunk.set_finished(false);
  512. } else {
  513. chunk.set_finished(true);
  514. chunked_graph_proto_list.push_back(chunk);
  515. }
  516. }
  517. } else {
  518. chunk.set_buffer(str);
  519. chunk.set_finished(true);
  520. chunked_graph_proto_list.push_back(chunk);
  521. }
  522. }
  523. EventReply reply = grpc_client_->SendMultiGraphs(chunked_graph_proto_list);
  524. if (reply.status() != reply.OK) {
  525. MS_LOG(ERROR) << "Error: SendGraph failed";
  526. }
  527. // enter command loop, wait and process commands
  528. CommandLoop();
  529. }
  530. void Debugger::CommandLoop() {
  531. // prepare metadata
  532. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  533. Metadata metadata;
  534. metadata.set_device_name(device_name);
  535. metadata.set_cur_step(num_step_);
  536. metadata.set_backend(device_target_);
  537. metadata.set_cur_node(cur_name_);
  538. metadata.set_training_done(training_done_);
  539. // loop exit flag
  540. bool run = false;
  541. int num_wait_fail = 0;
  542. const int max_num_wait_fail = 5;
  543. while (!run) {
  544. // wait for command
  545. EventReply reply = grpc_client_->WaitForCommand(metadata);
  546. if (reply.status() != reply.OK) {
  547. MS_LOG(ERROR) << "Error: WaitForCommand failed";
  548. num_wait_fail++;
  549. if (num_wait_fail > max_num_wait_fail) {
  550. MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session.";
  551. MS_LOG(ERROR) << "Failed to connect to MindInsight debugger server. Please check the config "
  552. "of debugger host and port.";
  553. Exit();
  554. run = true;
  555. } else {
  556. MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
  557. << num_wait_fail << "s";
  558. std::this_thread::sleep_for(std::chrono::milliseconds(1000 * num_wait_fail));
  559. }
  560. continue;
  561. }
  562. // get type of the command in reply
  563. DebuggerCommand cmd = GetCommand(reply);
  564. if (cmd == DebuggerCommand::kUnknownCMD) {
  565. MS_LOG(DEBUG) << "Debug: debugger received unknown command";
  566. continue;
  567. }
  568. MS_LOG(INFO) << "received command: ";
  569. switch (cmd) {
  570. case DebuggerCommand::kUnknownCMD:
  571. MS_LOG(INFO) << "UnknownCMD";
  572. break;
  573. case DebuggerCommand::kExitCMD:
  574. MS_LOG(INFO) << "ExitCMD";
  575. Exit();
  576. // Used for debugger termination
  577. run = true;
  578. break;
  579. case DebuggerCommand::kRunCMD:
  580. ProcessRunCMD(reply);
  581. // exit loop
  582. run = true;
  583. break;
  584. case DebuggerCommand::kSetCMD:
  585. ProcessKSetCMD(reply);
  586. break;
  587. case DebuggerCommand::kViewCMD:
  588. ProcessKViewCMD(reply);
  589. break;
  590. case DebuggerCommand::kVersionMatchedCMD:
  591. MS_LOG(ERROR) << "Received unexpected Version Matched CMD from Mindinsight.";
  592. Exit();
  593. break;
  594. default:
  595. MS_LOG(ERROR) << "Received unknown CMD from Mindinsight";
  596. Exit();
  597. break;
  598. }
  599. }
  600. }
  601. void Debugger::ProcessRunCMD(const EventReply &reply) {
  602. MS_LOG(INFO) << "RunCMD";
  603. if (GetRunLevel(reply) == "recheck") {
  604. MS_LOG(INFO) << "rechecking all watchpoints";
  605. SendWatchpoints(CheckWatchpoints("", nullptr, true));
  606. } else {
  607. // no longer the initial suspension.
  608. initial_suspend_ = false;
  609. // print run cmd content
  610. // get run_level and node_name
  611. run_level_ = GetRunLevel(reply);
  612. node_name_ = GetNodeName(reply);
  613. MS_LOG(INFO) << "run_level: " << run_level_;
  614. MS_LOG(INFO) << "node_name_: " << node_name_;
  615. }
  616. }
  617. void Debugger::ProcessKSetCMD(const EventReply &reply) {
  618. MS_LOG(INFO) << "SetCMD";
  619. MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
  620. MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
  621. if (GetWatchpointDelete(reply)) {
  622. MS_LOG(INFO) << "Deleting watchpoint";
  623. RemoveWatchpoint(GetWatchpointID(reply));
  624. } else {
  625. MS_LOG(INFO) << "Setting watchpoint";
  626. MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
  627. ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
  628. for (const auto &node : recieved_nodes) {
  629. MS_LOG(INFO) << "node name: " << node.node_name();
  630. MS_LOG(INFO) << "node type: " << node.node_type();
  631. }
  632. ProtoVector<WatchCondition_Parameter> parameters = GetParameters(reply);
  633. for (const auto &parameter : parameters) {
  634. MS_LOG(INFO) << "parameter name: " << parameter.name();
  635. MS_LOG(INFO) << "parameter is disabled: " << parameter.disabled();
  636. MS_LOG(INFO) << "parameter value: " << parameter.value();
  637. }
  638. SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply), GetParameters(reply));
  639. }
  640. }
  641. void Debugger::ProcessKViewCMD(const EventReply &reply) {
  642. MS_LOG(INFO) << "ViewCMD";
  643. // print view cmd content
  644. ProtoVector<TensorProto> received_tensors = GetTensors(reply);
  645. for (auto received_tensor : received_tensors) {
  646. MS_LOG(INFO) << "tensor node name: " << received_tensor.node_name();
  647. MS_LOG(INFO) << "tensor slot: " << received_tensor.slot();
  648. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << received_tensor.finished() << std::noboolalpha;
  649. MS_LOG(INFO) << "tensor iter: " << received_tensor.iter();
  650. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << received_tensor.truncate() << std::noboolalpha;
  651. }
  652. MS_LOG(INFO) << "Sending tensors";
  653. std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
  654. // print view cmd reply
  655. for (auto tensor : tensors) {
  656. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  657. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  658. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  659. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  660. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  661. MS_LOG(INFO) << "tensor dims: ";
  662. for (auto dim : tensor.dims()) {
  663. MS_LOG(INFO) << dim << ",";
  664. }
  665. MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
  666. }
  667. EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
  668. if (send_tensors_reply.status() != send_tensors_reply.OK) {
  669. MS_LOG(ERROR) << "Error: SendTensors failed";
  670. }
  671. }
  672. void AddTensorProtoInfo(TensorProto *tensor_item, TensorProto tensor) {
  673. tensor_item->set_node_name(tensor.node_name());
  674. tensor_item->set_slot(tensor.slot());
  675. tensor_item->set_iter(tensor.iter());
  676. tensor_item->set_truncate(tensor.truncate());
  677. tensor_item->clear_tensor_content();
  678. tensor_item->clear_data_type();
  679. tensor_item->clear_dims();
  680. }
  681. void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id,
  682. const ProtoVector<WatchCondition_Parameter> &parameters) {
  683. std::vector<std::tuple<std::string, bool>> check_node_list;
  684. std::vector<DebugServices::parameter_t> parameter_list;
  685. std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
  686. [](const WatchNode &node) -> std::tuple<std::string, bool> {
  687. return make_tuple(node.node_name(), node.node_type() == "scope");
  688. });
  689. std::transform(
  690. parameters.begin(), parameters.end(), std::back_inserter(parameter_list),
  691. [](const WatchCondition_Parameter &parameter) -> DebugServices::parameter_t {
  692. return DebugServices::parameter_t{parameter.name(), parameter.disabled(), parameter.value(), parameter.hit()};
  693. });
  694. debug_services_->AddWatchpoint(id, condition.condition(), condition.value(), check_node_list, parameter_list);
  695. }
  696. void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
  697. std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
  698. std::vector<std::string> name;
  699. std::vector<std::string> ret_name;
  700. std::vector<char *> data_ptr;
  701. std::vector<ssize_t> data_size;
  702. std::vector<TypePtr> dtype;
  703. std::vector<std::vector<int64_t>> shape;
  704. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  705. // ret_name will contain tensor names that are found in TensorLoader
  706. // items in ret_name will be in the same order with tensors if found
  707. debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
  708. std::list<TensorProto> tensor_list;
  709. unsigned int result_index = 0;
  710. for (auto tensor : tensors) {
  711. ssize_t size_iter = 0;
  712. if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
  713. TensorProto tensor_item;
  714. tensor_item.set_finished(true);
  715. AddTensorProtoInfo(&tensor_item, tensor);
  716. tensor_list.push_back(tensor_item);
  717. continue;
  718. }
  719. ssize_t tensor_size = data_size[result_index];
  720. while (size_iter < tensor_size) {
  721. ssize_t chunk_size = CHUNK_SIZE;
  722. TensorProto tensor_item;
  723. tensor_item.set_finished(false);
  724. if (tensor_size - size_iter <= CHUNK_SIZE) {
  725. chunk_size = tensor_size - size_iter;
  726. tensor_item.set_finished(true);
  727. }
  728. AddTensorProtoInfo(&tensor_item, tensor);
  729. // return empty tensor if didn't find the requested tensor
  730. tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
  731. tensor_item.set_data_type(GetDebuggerNumberDataType(dtype[result_index]));
  732. for (auto &elem : shape[result_index]) {
  733. tensor_item.add_dims(elem);
  734. }
  735. // add tensor to result list and increment result_index to check next item in ret_name
  736. tensor_list.push_back(tensor_item);
  737. size_iter += CHUNK_SIZE;
  738. }
  739. result_index++;
  740. }
  741. return tensor_list;
  742. }
  743. void Debugger::Exit() {
  744. // clear resource before exit
  745. // debugger will notify main thread to exit because main thread can only exit at step boundary
  746. pipeline::ExecutorPy::DebugTerminate(true);
  747. }
  748. std::list<WatchpointHit> Debugger::CheckWatchpoints(const std::string &watchnode, const CNodePtr &kernel,
  749. bool recheck) {
  750. std::vector<std::string> name;
  751. std::vector<std::string> slot;
  752. std::vector<int> condition;
  753. std::vector<unsigned int> watchpoint_id;
  754. std::vector<std::string> overflow_ops;
  755. std::vector<std::vector<DebugServices::parameter_t>> parameters;
  756. std::vector<int32_t> error_codes;
  757. #ifdef ENABLE_D
  758. overflow_ops = CheckOpOverflow();
  759. #endif
  760. std::vector<std::shared_ptr<TensorData>> tensor_list;
  761. if (watchnode.empty()) {
  762. tensor_list = debug_services_->GetTensor();
  763. } else {
  764. tensor_list = debug_services_->GetNodeTensor(kernel);
  765. }
  766. debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, &parameters, &error_codes, overflow_ops,
  767. tensor_list, initial_suspend_, watchnode.empty(), recheck);
  768. std::list<WatchpointHit> hits;
  769. for (unsigned int i = 0; i < name.size(); i++) {
  770. WatchpointHit hit;
  771. std::vector<DebugServices::parameter_t> &parameter = parameters[i];
  772. hit.set_id(watchpoint_id[i]);
  773. hit.set_error_code(error_codes[i]);
  774. // here TensorProto act as a tensor indicator, not sending tensor content
  775. TensorProto *tensor_item = hit.mutable_tensor();
  776. tensor_item->set_node_name(name[i]);
  777. tensor_item->set_slot(slot[i]);
  778. tensor_item->set_finished(true);
  779. WatchCondition *condition_item = hit.mutable_watch_condition();
  780. condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
  781. for (const auto &p : parameter) {
  782. auto x = condition_item->mutable_params()->Add();
  783. x->set_name(p.name);
  784. x->set_disabled(p.disabled);
  785. x->set_value(p.value);
  786. x->set_hit(p.hit);
  787. x->set_actual_value(p.actual_value);
  788. }
  789. hits.push_back(hit);
  790. }
  791. return hits;
  792. }
  793. void Debugger::SendWatchpoints(const std::list<WatchpointHit> &points) {
  794. // send info about watchpoint
  795. if (!points.empty()) {
  796. EventReply reply = grpc_client_->SendWatchpointHits(points);
  797. if (reply.status() != reply.OK) {
  798. MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
  799. }
  800. }
  801. }
  802. bool Debugger::DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath,
  803. const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type,
  804. TypeId addr_type_id, const std::string &addr_format, size_t slot) const {
  805. return debug_services_.get()->DumpTensorToFile(tensor_name, trans_flag, filepath, host_fmt, host_shape, host_type,
  806. addr_type_id, addr_format, slot);
  807. }
  808. bool Debugger::DebugServicesIsWatchPoint(const std::string &kernel_name, const CNodePtr &kernel) const {
  809. return debug_services_.get()->IsWatchPoint(kernel_name, kernel);
  810. }
  811. void Debugger::EmptyTensor() { debug_services_.get()->EmptyTensor(); }
  812. void Debugger::SetTensorLoaderIterNum(uint32_t iter_num) { debug_services_.get()->SetTensorLoaderIterNum(iter_num); }
  813. void Debugger::EmptyPrevTensor() { debug_services_.get()->EmptyPrevTensor(); }
  814. uint32_t Debugger::GetTensorLoaderIterNum() const { return debug_services_.get()->GetTensorLoaderIterNum(); }
  815. bool Debugger::LoadNewTensor(const std::shared_ptr<TensorData> &tensor, bool keep_prev) {
  816. return debug_services_.get()->LoadNewTensor(tensor, keep_prev);
  817. }
  818. bool Debugger::debugger_enabled() const { return debugger_enabled_; }
  819. DebuggerCommand GetCommand(const EventReply &reply) {
  820. DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
  821. switch (reply.cmd_case()) {
  822. case debugger::EventReply::CmdCase::kExit:
  823. cmd = DebuggerCommand::kExitCMD;
  824. break;
  825. case debugger::EventReply::CmdCase::kRunCmd:
  826. cmd = DebuggerCommand::kRunCMD;
  827. break;
  828. case debugger::EventReply::CmdCase::kSetCmd:
  829. cmd = DebuggerCommand::kSetCMD;
  830. break;
  831. case debugger::EventReply::CmdCase::kViewCmd:
  832. cmd = DebuggerCommand::kViewCMD;
  833. break;
  834. case debugger::EventReply::CmdCase::kVersionMatched:
  835. cmd = DebuggerCommand::kVersionMatchedCMD;
  836. break;
  837. default:
  838. MS_LOG(DEBUG) << "Debug: UnknownCMD";
  839. break;
  840. }
  841. return cmd;
  842. }
  843. ProtoVector<WatchCondition_Parameter> GetParameters(const EventReply &reply) {
  844. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  845. MS_LOG(ERROR) << "Error: Can not get Parameters from command. Returning default value: ProtoVector<Parameter>().";
  846. return ProtoVector<WatchCondition_Parameter>();
  847. }
  848. return reply.set_cmd().watch_condition().params();
  849. }
  850. ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
  851. if (!reply.has_set_cmd()) {
  852. MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
  853. return ProtoVector<WatchNode>();
  854. }
  855. return reply.set_cmd().watch_nodes();
  856. }
  857. std::string GetRunLevel(const EventReply &reply) {
  858. if (!reply.has_run_cmd()) {
  859. MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
  860. "";
  861. return "";
  862. }
  863. return reply.run_cmd().run_level();
  864. }
  865. std::string GetNodeName(const EventReply &reply) {
  866. if (!reply.has_run_cmd()) {
  867. MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
  868. "";
  869. return "";
  870. }
  871. return reply.run_cmd().node_name();
  872. }
  873. WatchCondition GetWatchcondition(const EventReply &reply) {
  874. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  875. MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
  876. return WatchCondition();
  877. }
  878. return reply.set_cmd().watch_condition();
  879. }
  880. int32_t GetWatchpointID(const EventReply &reply) {
  881. if (!reply.has_set_cmd()) {
  882. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
  883. return 0;
  884. }
  885. return reply.set_cmd().id();
  886. }
  887. bool GetWatchpointDelete(const EventReply &reply) {
  888. if (!reply.has_set_cmd()) {
  889. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
  890. return false;
  891. }
  892. return reply.set_cmd().delete_();
  893. }
  894. ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
  895. if (!reply.has_view_cmd()) {
  896. MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
  897. return ProtoVector<TensorProto>();
  898. }
  899. return reply.view_cmd().tensors();
  900. }
  901. std::string GetTensorFullName(const TensorProto &tensor) {
  902. string node_name = tensor.node_name();
  903. if (tensor.truncate()) {
  904. // scopes in node name are separated by '/'
  905. // use the name without scope if truncate is true
  906. std::size_t found = node_name.find_last_of("/");
  907. node_name = node_name.substr(found + 1);
  908. }
  909. return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
  910. }
  911. bool GetMiVersionMatched(const EventReply &reply) { return reply.version_matched(); }
  912. bool Debugger::partial_memory() { return partial_memory_; }
  913. void Debugger::SetCurNode(std::string cur_name) {
  914. // access lock for public method
  915. std::lock_guard<std::mutex> a_lock(access_lock_);
  916. cur_name_ = cur_name;
  917. }
  918. std::string Debugger::run_level() const { return run_level_; }
  919. void Debugger::SetStepNum(int32_t cur_num_step) {
  920. // access lock for public method
  921. std::lock_guard<std::mutex> a_lock(access_lock_);
  922. num_step_ = cur_num_step;
  923. }
  924. int32_t Debugger::step_num() const { return num_step_; }
  925. uint64_t BytestoInt64(const std::vector<char> &buffer) {
  926. uint64_t ret;
  927. ret = ((uint64_t)buffer[7] << 56) | ((uint64_t)buffer[6] << 48) | ((uint64_t)buffer[5] << 40) |
  928. ((uint64_t)buffer[4] << 32) | ((uint64_t)buffer[3] << 24) | ((uint64_t)buffer[2] << 16) |
  929. ((uint64_t)buffer[1] << 8) | ((uint64_t)buffer[0]);
  930. return ret;
  931. }
  932. #define BUF_SIZ 256
  933. std::vector<std::string> Debugger::CheckOpOverflow() {
  934. std::vector<double> bin_list;
  935. std::vector<std::string> op_names;
  936. for (const auto &[graph_id, overflow_bin_path] : overflow_bin_path_) {
  937. DIR *d;
  938. d = opendir(overflow_bin_path.c_str());
  939. MS_LOG(INFO) << "processing bin file path " << overflow_bin_path << ", graph id " << graph_id;
  940. if (d != nullptr) {
  941. struct dirent *dir = nullptr;
  942. while ((dir = readdir(d)) != NULL) {
  943. if (dir->d_type == DT_REG) {
  944. std::string file_path = overflow_bin_path;
  945. file_path.append(dir->d_name);
  946. std::string file_name = dir->d_name;
  947. std::size_t found = file_name.find_last_of(".");
  948. if (found == std::string::npos) {
  949. continue;
  950. }
  951. std::string overflow_time = file_name.substr(found + 1);
  952. if (stod(overflow_time) <= last_overflow_bin_) {
  953. MS_LOG(INFO) << "File already processed " << file_name;
  954. continue;
  955. }
  956. bin_list.push_back(stod(overflow_time));
  957. std::fstream infile;
  958. infile.open(file_path.c_str(), std::ios::binary | std::ios::in);
  959. if (!infile.is_open()) {
  960. MS_LOG(ERROR) << "Failed to open overflow bin file " << file_name;
  961. continue;
  962. }
  963. infile.seekg(313, std::ios::beg);
  964. std::vector<char> buffer;
  965. buffer.resize(BUF_SIZ);
  966. infile.read(buffer.data(), BUF_SIZ);
  967. uint64_t stream_id = BytestoInt64(std::vector<char>(buffer.begin() + 8, buffer.end()));
  968. uint64_t task_id = BytestoInt64(std::vector<char>(buffer.begin() + 16, buffer.end()));
  969. MS_LOG(INFO) << "Overflow stream_id " << stream_id << ", task_id " << task_id << ".";
  970. auto op = debugger_->stream_task_to_opname_.find(std::make_pair(stream_id, task_id));
  971. if (op != debugger_->stream_task_to_opname_.end()) {
  972. MS_LOG(ERROR) << "Overflow detected on node " << op->second << std::endl;
  973. op_names.push_back(op->second);
  974. } else {
  975. MS_LOG(INFO) << "No overflow is detected " << std::endl;
  976. }
  977. infile.close();
  978. }
  979. }
  980. } else {
  981. MS_LOG(INFO) << "OverFlow bin directory does not exist!";
  982. }
  983. closedir(d);
  984. }
  985. if (!op_names.empty()) {
  986. MS_LOG(ERROR) << "These operation overflows are detected " << op_names;
  987. }
  988. for (auto &i : bin_list) {
  989. if (i > last_overflow_bin_) {
  990. last_overflow_bin_ = i;
  991. }
  992. }
  993. auto iter_op_names = overflow_ops_.find(num_step_);
  994. if (iter_op_names == overflow_ops_.end()) {
  995. overflow_ops_.insert(std::pair<uint32_t, std::vector<std::string>>(num_step_, op_names));
  996. return op_names;
  997. }
  998. iter_op_names->second.insert(std::end(iter_op_names->second), std::begin(op_names), std::end(op_names));
  999. return iter_op_names->second;
  1000. }
  1001. void Debugger::SetTrainingDone(bool training_done) { training_done_ = training_done; }
  1002. bool Debugger::CheckPort(const char *port) {
  1003. char *p = const_cast<char *>(port);
  1004. int num = 0;
  1005. if (*p == '0' && *(p + 1) != '\0') return false;
  1006. while (*p != '\0') {
  1007. if (*p < '0' || *p > '9') return false;
  1008. num = num * 10 + (*p) - '0';
  1009. if (num < 1 || num > 65535) return false;
  1010. p++;
  1011. }
  1012. return true;
  1013. }
  1014. bool Debugger::CheckIp(const char *host) {
  1015. std::regex reg_ip(
  1016. "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
  1017. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  1018. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  1019. "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
  1020. std::smatch smat;
  1021. std::string host_str = std::string(host);
  1022. return std::regex_match(host_str, smat, reg_ip);
  1023. }
  1024. uint32_t Debugger::GetFirstRunGraphId() { return rungraph_id_list_.front(); }
  1025. void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
  1026. MS_EXCEPTION_IF_NULL(anf_node);
  1027. if (!anf_node->isa<Parameter>() && !anf_node->isa<ValueNode>()) {
  1028. return;
  1029. }
  1030. // for parameters and value nodes, set its execution order to be 0;
  1031. int exec_order = 0;
  1032. std::string node_name = anf_node->fullname_with_scope();
  1033. E2eDumpUtil::GetFileKernelName(NOT_NULL(&node_name));
  1034. // check if output adde exists, if not, return;
  1035. if (!AnfAlgo::OutputAddrExist(anf_node, output_index)) {
  1036. return;
  1037. }
  1038. auto addr = AnfAlgo::GetOutputAddr(anf_node, output_index);
  1039. MS_EXCEPTION_IF_NULL(addr);
  1040. auto type = AnfAlgo::GetOutputInferDataType(anf_node, output_index);
  1041. if (type == kObjectTypeUMonad) {
  1042. return;
  1043. }
  1044. auto format = kOpFormat_DEFAULT;
  1045. string tensor_name = node_name + ':' + "0";
  1046. ShapeVector int_shapes;
  1047. auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index);
  1048. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  1049. [](size_t inner_item) { return SizeToInt(inner_item); });
  1050. bool keep_prev;
  1051. if (anf_node->isa<Parameter>()) {
  1052. keep_prev = true;
  1053. debug_services_->MoveTensorCurrentToPrev(tensor_name);
  1054. } else {
  1055. keep_prev = false;
  1056. }
  1057. bool ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, keep_prev);
  1058. if (!ret) {
  1059. MS_LOG(ERROR) << "LoadMemToHost:"
  1060. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1061. }
  1062. }
  1063. void Debugger::LoadParametersAndConst() {
  1064. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  1065. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1066. // load parameters
  1067. MS_LOG(INFO) << "Start to load Parameters!";
  1068. const auto &parameters = graph_ptr_->inputs();
  1069. for (auto &item : parameters) {
  1070. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
  1071. }
  1072. // load value nodes
  1073. // get all constant avlues from the graph
  1074. MS_LOG(INFO) << "Start to load value nodes!";
  1075. const auto value_nodes = graph_ptr_->graph_value_nodes();
  1076. for (auto &item : value_nodes) {
  1077. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
  1078. }
  1079. }
  1080. void Debugger::LoadGraphOutputs() {
  1081. if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
  1082. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1083. const auto &apply_kernels = graph_ptr_->execution_order();
  1084. // for kernels, execution order starts from 1
  1085. int exec_order = 1;
  1086. for (const auto &node : apply_kernels) {
  1087. MS_EXCEPTION_IF_NULL(node);
  1088. auto node_name = AnfAlgo::GetCNodeName(node);
  1089. std::string kernel_name = node->fullname_with_scope();
  1090. auto output_size = AnfAlgo::GetOutputTensorNum(node);
  1091. if (partial_memory_) {
  1092. if (!debug_services_->IsWatchPoint(kernel_name, node)) {
  1093. continue;
  1094. }
  1095. }
  1096. for (size_t j = 0; j < output_size; ++j) {
  1097. if (!AnfAlgo::OutputAddrExist(node, j)) {
  1098. MS_LOG(INFO) << "Cannot find output addr for slot " << j << " for " << node->fullname_with_scope();
  1099. continue;
  1100. }
  1101. auto addr = AnfAlgo::GetOutputAddr(node, j);
  1102. MS_EXCEPTION_IF_NULL(addr);
  1103. auto type = AnfAlgo::GetOutputInferDataType(node, j);
  1104. if (type == kObjectTypeUMonad) {
  1105. continue;
  1106. }
  1107. auto format = kOpFormat_DEFAULT;
  1108. string tensor_name = kernel_name + ':' + std::to_string(j);
  1109. ShapeVector int_shapes;
  1110. auto shape = AnfAlgo::GetOutputDeviceShape(node, j);
  1111. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  1112. [](size_t inner_item) { return SizeToInt(inner_item); });
  1113. auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false);
  1114. if (!ret) {
  1115. MS_LOG(ERROR) << "LoadMemToHost:"
  1116. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1117. }
  1118. }
  1119. exec_order = exec_order + 1;
  1120. }
  1121. }
  1122. void Debugger::UpdateStepNum(const session::KernelGraph *graph) {
  1123. // update step number if we are processing the first graph (to support multigraph)
  1124. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()) &&
  1125. (graph->graph_id() == debugger_->GetFirstRunGraphId())) {
  1126. // access lock for public method
  1127. std::lock_guard<std::mutex> a_lock(access_lock_);
  1128. ++num_step_;
  1129. }
  1130. }
  1131. void Debugger::ClearCurrentData() {
  1132. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()))
  1133. debug_services_->EmptyCurrentTensor();
  1134. }
  1135. bool Debugger::TensorExistsInCurrent(std::string tensor_name) {
  1136. return debug_services_->TensorExistsInCurrent(tensor_name);
  1137. }
  1138. } // namespace mindspore