You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger.cc 34 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <stdio.h>
  18. #include <fstream>
  19. #include <tuple>
  20. #include <vector>
  21. #include <algorithm>
  22. #include <iostream>
  23. #include <cstring>
  24. #include <utility>
  25. #include <map>
  26. #include <regex>
  27. #include "debug/debugger/debugger.h"
  28. #include "debug/data_dump/dump_json_parser.h"
  29. #include "pipeline/jit/pipeline.h"
  30. #include "backend/session/anf_runtime_algorithm.h"
  31. #include "runtime/device/kernel_runtime_manager.h"
  32. #include "runtime/device/kernel_runtime.h"
  33. #include "debug/data_dump/e2e_dump_util.h"
  34. using debugger::EventReply;
  35. using debugger::GraphProto;
  36. using debugger::ModelProto;
  37. using debugger::TensorProto;
  38. using debugger::WatchCondition;
  39. using debugger::WatchCondition_Condition_inf;
  40. using debugger::WatchCondition_Condition_nan;
  41. using debugger::WatchNode;
  42. using debugger::WatchpointHit;
  43. #define CHUNK_SIZE 1024 * 1024 * 3
  44. namespace mindspore {
  45. DebuggerPtr Debugger::debugger_ = nullptr;
  46. std::mutex Debugger::instance_lock_;
  47. static const size_t PARAMETER_OUTPUT_INDEX = 0;
  48. static const size_t VALUE_NODE_OUTPUT_INDEX = 0;
  49. Debugger::Debugger()
  50. : grpc_client_(nullptr),
  51. debug_services_(nullptr),
  52. device_id_(0),
  53. device_target_(""),
  54. num_step_(0),
  55. debugger_enabled_(false),
  56. run_level_(""),
  57. node_name_(""),
  58. cur_name_(""),
  59. training_done_(false),
  60. is_dataset_graph_(false),
  61. partial_memory_(false),
  62. last_overflow_bin_(0),
  63. overflow_bin_path_("") {
  64. if (CheckDebuggerEnabled()) {
  65. // configure partial memory reuse
  66. partial_memory_ = CheckDebuggerPartialMemoryEnabled();
  67. // switch memory reuse on or off
  68. auto context_ptr = MsContext::GetInstance();
  69. MS_EXCEPTION_IF_NULL(context_ptr);
  70. context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, partial_memory_);
  71. // print some message about memory reuse to user
  72. if (partial_memory_) {
  73. MS_LOG(WARNING)
  74. << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
  75. "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
  76. } else {
  77. MS_LOG(INFO) << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
  78. "usage for large models.";
  79. }
  80. }
  81. }
  82. void Debugger::Init(const uint32_t device_id, const std::string device_target) {
  83. // access lock for public method
  84. std::lock_guard<std::mutex> a_lock(access_lock_);
  85. // save device_id
  86. MS_LOG(INFO) << "Debugger got device_id: " << device_id;
  87. device_id_ = device_id;
  88. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  89. device_target_ = device_target;
  90. }
  91. void Debugger::EnableDebugger() {
  92. // reset some of the class members
  93. num_step_ = 0;
  94. debugger_enabled_ = false;
  95. partial_memory_ = false;
  96. grpc_client_ = nullptr;
  97. debug_services_ = nullptr;
  98. // see if dump using debugger backend is enabled
  99. bool dump_enabled = CheckDebuggerDumpEnabled();
  100. MS_LOG(INFO) << "dump using debugger backend = " << dump_enabled;
  101. // check if debugger enabled
  102. debugger_enabled_ = CheckDebuggerEnabled();
  103. MS_LOG(INFO) << "debugger_enabled_ = " << debugger_enabled_;
  104. if (!debugger_enabled_ && !dump_enabled) {
  105. MS_LOG(INFO) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
  106. return;
  107. }
  108. // configure grpc host
  109. const char *env_host_str = std::getenv("MS_DEBUGGER_HOST");
  110. std::string host;
  111. if (env_host_str != nullptr) {
  112. std::regex reg_ip(
  113. "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
  114. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  115. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  116. "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
  117. std::smatch smat;
  118. std::string host_str = std::string(env_host_str);
  119. if (std::regex_match(host_str, smat, reg_ip)) {
  120. MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
  121. host = std::string(env_host_str);
  122. } else {
  123. MS_LOG(ERROR) << "Environment variable MS_DEBUGGER_HOST isn't a valid IP address. "
  124. "Please set environment variable MS_DEBUGGER_HOST=x.x.x.x to a valid IP";
  125. debugger_enabled_ = false;
  126. }
  127. } else {
  128. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
  129. host = "localhost";
  130. }
  131. // configure grpc port
  132. const char *env_port_str = std::getenv("MS_DEBUGGER_PORT");
  133. std::string port;
  134. if (env_port_str != nullptr) {
  135. if (CheckPort(env_port_str)) {
  136. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
  137. port = std::string(env_port_str);
  138. } else {
  139. MS_LOG(ERROR) << "Environment variable MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to 65535";
  140. debugger_enabled_ = false;
  141. }
  142. } else {
  143. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
  144. port = "50051";
  145. }
  146. #ifdef ENABLE_D
  147. // set operation overflow info
  148. overflow_bin_path_ = DumpJsonParser::GetInstance().GetOpOverflowBinPath(graph_ptr_->graph_id(), device_id_);
  149. // new overflow dump files will have a timestamp greater than last_overflow_bin_
  150. last_overflow_bin_ = 0;
  151. DIR *d;
  152. d = opendir(overflow_bin_path_.c_str());
  153. if (d != nullptr) {
  154. struct dirent *dir;
  155. while ((dir = readdir(d)) != NULL) {
  156. if (dir->d_type == DT_REG) {
  157. std::string file_path = overflow_bin_path_;
  158. file_path.append(dir->d_name);
  159. std::size_t found = file_path.find_last_of(".");
  160. if (found == std::string::npos) {
  161. continue;
  162. }
  163. std::string overflow_time = file_path.substr(found + 1);
  164. if (stod(overflow_time) <= last_overflow_bin_) {
  165. MS_LOG(INFO) << "Old op overflow bin folder" << file_path;
  166. continue;
  167. }
  168. last_overflow_bin_ = stod(overflow_time);
  169. }
  170. }
  171. MS_LOG(INFO) << "last op overflow bin folder" << last_overflow_bin_;
  172. closedir(d);
  173. }
  174. #endif
  175. // initialize grpc client
  176. if (debugger_enabled_) {
  177. grpc_client_ = std::make_unique<GrpcClient>(host, port);
  178. }
  179. debug_services_ = std::make_unique<DebugServices>();
  180. }
  181. bool Debugger::CheckDebuggerDumpEnabled() {
  182. // see if dump is enabled
  183. if (device_target_ == kGPUDevice) {
  184. return device::KernelRuntime::DumpDataEnabled();
  185. }
  186. return false;
  187. }
  188. bool Debugger::CheckDebuggerEnabled() {
  189. // get env variables to configure debugger
  190. const char *env_enable_str = std::getenv("ENABLE_MS_DEBUGGER");
  191. if (env_enable_str != nullptr) {
  192. if (std::strcmp(env_enable_str, "1") == 0) {
  193. return true;
  194. }
  195. }
  196. return false;
  197. }
  198. bool Debugger::CheckDebuggerPartialMemoryEnabled() {
  199. const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM");
  200. if (env_partial_mem_str != nullptr) {
  201. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
  202. if (std::strcmp(env_partial_mem_str, "1") == 0) {
  203. return true;
  204. }
  205. }
  206. return false;
  207. }
  208. bool Debugger::DebuggerBackendEnabled() { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
  209. void Debugger::Reset() {
  210. // access lock for public method
  211. std::lock_guard<std::mutex> a_lock(access_lock_);
  212. // reset components
  213. device_id_ = 0;
  214. device_target_ = "";
  215. num_step_ = 0;
  216. debugger_enabled_ = false;
  217. is_dataset_graph_ = false;
  218. partial_memory_ = false;
  219. graph_ptr_ = nullptr;
  220. grpc_client_ = nullptr;
  221. debug_services_ = nullptr;
  222. last_overflow_bin_ = 0;
  223. overflow_bin_path_ = "";
  224. stream_task_to_opname_.clear();
  225. }
  226. void Debugger::PreExecute(const KernelGraphPtr &graph_ptr) {
  227. // access lock for public method
  228. std::lock_guard<std::mutex> a_lock(access_lock_);
  229. if (debugger_->DebuggerBackendEnabled()) {
  230. // check and save graph_ptr, suspend if graph is new
  231. CheckGraphPtr(graph_ptr);
  232. }
  233. }
  234. void Debugger::PostExecute() {
  235. // access lock for public method
  236. std::lock_guard<std::mutex> a_lock(access_lock_);
  237. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  238. return;
  239. }
  240. if (debugger_->DebuggerBackendEnabled()) {
  241. // analyze tensor data and send the watchpoints been hit
  242. if (run_level_ == "node") {
  243. MS_LOG(INFO) << "Debugger is in node level mode ";
  244. return;
  245. }
  246. if (debugger_enabled_ && !is_dataset_graph_) {
  247. if (device_target_ != kGPUDevice) {
  248. num_step_++;
  249. MS_LOG(INFO) << "Debugger suspend at end of step; number of steps executed: " << num_step_;
  250. SendWatchpoints(CheckWatchpoints());
  251. CommandLoop();
  252. } else {
  253. CommandLoop();
  254. }
  255. }
  256. }
  257. }
  258. bool Debugger::ReadNodeDataRequired() {
  259. if (debugger_enabled_ && !is_dataset_graph_) {
  260. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_);
  261. // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
  262. if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
  263. return true;
  264. }
  265. }
  266. return false;
  267. }
  268. void Debugger::PostExecuteNode() {
  269. // access lock for public method
  270. std::lock_guard<std::mutex> a_lock(access_lock_);
  271. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  272. return;
  273. }
  274. if (debugger_enabled_ && !is_dataset_graph_) {
  275. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_);
  276. // if kernel is watchpoint,and get hit. suspend.
  277. bool hit_empty_flag = true;
  278. if (is_watchpoint) {
  279. auto hits = CheckWatchpoints(cur_name_);
  280. if (!hits.empty()) {
  281. SendWatchpoints(hits);
  282. CommandLoop();
  283. hit_empty_flag = false;
  284. }
  285. }
  286. if (hit_empty_flag && run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) {
  287. // if kernel is not watchpoint and is next_to or continue_to node, suspend
  288. CommandLoop();
  289. }
  290. return;
  291. }
  292. }
  293. void Debugger::PostDebugOp() {
  294. // access lock for public method
  295. std::lock_guard<std::mutex> a_lock(access_lock_);
  296. // suspend if debugger is enabled
  297. if (debugger_enabled_ && !is_dataset_graph_) {
  298. MS_LOG(INFO) << "Debugger suspend at debug_op";
  299. CommandLoop();
  300. }
  301. }
  302. void Debugger::SetStreamTaskToOpnameMap(const std::map<std::pair<uint32_t, uint32_t>, std::string> &mapping) {
  303. stream_task_to_opname_ = mapping;
  304. }
  305. void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
  306. if (graph_ptr_ != graph_ptr) {
  307. MS_LOG(INFO) << "Debugger got new graph: " << graph_ptr->graph_id();
  308. // save new graph_ptr
  309. graph_ptr_ = graph_ptr;
  310. // check if it is dataset graph
  311. CheckDatasetGraph();
  312. if (!is_dataset_graph_) {
  313. // only try to enable debugger if it is not a dataset graph
  314. EnableDebugger();
  315. if (debugger_enabled_) {
  316. LoadParametersAndConst();
  317. // get graph proto and send to mindinsight
  318. SendGraphAndSuspend(GetGraphProto());
  319. }
  320. }
  321. }
  322. }
  323. void Debugger::CheckDatasetGraph() {
  324. // print parameter node names
  325. const auto &params = graph_ptr_->inputs();
  326. for (const auto &param : params) {
  327. MS_LOG(INFO) << "param: " << param->fullname_with_scope();
  328. }
  329. // check if there is GetNext or InitDataSetQueue node
  330. const auto &nodes = graph_ptr_->execution_order();
  331. for (const auto &node : nodes) {
  332. auto node_name = AnfAlgo::GetCNodeName(node);
  333. MS_LOG(INFO) << "node: " << node->fullname_with_scope();
  334. if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
  335. MS_LOG(INFO) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
  336. << node_name;
  337. is_dataset_graph_ = true;
  338. return;
  339. }
  340. }
  341. is_dataset_graph_ = false;
  342. }
  343. GraphProto Debugger::GetGraphProto() const {
  344. // convert kernel graph to debugger modelproto
  345. ModelProto model = GetDebuggerFuncGraphProto(graph_ptr_);
  346. return model.graph();
  347. }
  348. void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
  349. SendMetadata();
  350. // send graph to mindinght server
  351. EventReply reply = grpc_client_->SendGraph(graph_proto);
  352. if (reply.status() != reply.OK) {
  353. MS_LOG(ERROR) << "Error: SendGraph failed";
  354. }
  355. // enter command loop, wait and process commands
  356. CommandLoop();
  357. }
  358. void Debugger::SendMetadata() {
  359. // prepare metadata
  360. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  361. Metadata metadata;
  362. metadata.set_device_name(device_name);
  363. metadata.set_cur_step(num_step_);
  364. metadata.set_backend(device_target_);
  365. metadata.set_cur_node(cur_name_);
  366. metadata.set_training_done(training_done_);
  367. MS_LOG(INFO) << "Is training done?" << training_done_;
  368. EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
  369. if (reply_metadata.status() != reply_metadata.OK) {
  370. MS_LOG(ERROR) << "Error: SendMetadata failed";
  371. }
  372. }
  373. void Debugger::CommandLoop() {
  374. // prepare metadata
  375. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  376. Metadata metadata;
  377. metadata.set_device_name(device_name);
  378. metadata.set_cur_step(num_step_);
  379. metadata.set_backend(device_target_);
  380. metadata.set_cur_node(cur_name_);
  381. metadata.set_training_done(training_done_);
  382. // loop exit flag
  383. bool run = false;
  384. int num_wait_fail = 0;
  385. const int max_num_wait_fail = 5;
  386. while (!run) {
  387. // wait for command
  388. EventReply reply = grpc_client_->WaitForCommand(metadata);
  389. if (reply.status() != reply.OK) {
  390. MS_LOG(ERROR) << "Error: WaitForCommand failed";
  391. num_wait_fail++;
  392. if (num_wait_fail > max_num_wait_fail) {
  393. MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session.";
  394. MS_LOG(ERROR) << "Failed to connect to MindInsight debugger server. Please check the config "
  395. "of debugger host and port.";
  396. Exit();
  397. run = true;
  398. } else {
  399. MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
  400. << num_wait_fail << "s";
  401. std::this_thread::sleep_for(std::chrono::milliseconds(1000 * num_wait_fail));
  402. }
  403. continue;
  404. }
  405. // get type of the command in reply
  406. DebuggerCommand cmd = GetCommand(reply);
  407. if (cmd == DebuggerCommand::kUnknownCMD) {
  408. MS_LOG(DEBUG) << "Debug: debugger received unknown command";
  409. continue;
  410. }
  411. MS_LOG(INFO) << "received command: ";
  412. switch (cmd) {
  413. case DebuggerCommand::kUnknownCMD:
  414. MS_LOG(INFO) << "UnknownCMD";
  415. break;
  416. case DebuggerCommand::kExitCMD:
  417. MS_LOG(INFO) << "ExitCMD";
  418. Exit();
  419. // Used for debugger termination
  420. run = true;
  421. break;
  422. case DebuggerCommand::kRunCMD:
  423. MS_LOG(INFO) << "RunCMD";
  424. if (GetRunLevel(reply) == "recheck") {
  425. MS_LOG(INFO) << "rechecking all watchpoints";
  426. SendWatchpoints(CheckWatchpoints());
  427. } else {
  428. // print run cmd content
  429. // get run_level and node_name
  430. run_level_ = GetRunLevel(reply);
  431. node_name_ = GetNodeName(reply);
  432. MS_LOG(INFO) << "run_level: " << run_level_;
  433. MS_LOG(INFO) << "node_name_: " << node_name_;
  434. // exit loop
  435. run = true;
  436. }
  437. break;
  438. case DebuggerCommand::kSetCMD:
  439. MS_LOG(INFO) << "SetCMD";
  440. {
  441. // print set cmd content
  442. ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
  443. for (auto node : recieved_nodes) {
  444. MS_LOG(INFO) << "node name: " << node.node_name();
  445. MS_LOG(INFO) << "node type: " << node.node_type();
  446. }
  447. MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
  448. MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
  449. MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
  450. }
  451. MS_LOG(INFO) << "Setting watchpoint";
  452. if (GetWatchpointDelete(reply)) {
  453. RemoveWatchpoint(GetWatchpointID(reply));
  454. } else {
  455. SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply));
  456. }
  457. break;
  458. case DebuggerCommand::kViewCMD:
  459. MS_LOG(INFO) << "ViewCMD";
  460. {
  461. // print view cmd content
  462. ProtoVector<TensorProto> received_tensors = GetTensors(reply);
  463. for (auto tensor : received_tensors) {
  464. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  465. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  466. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  467. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  468. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  469. }
  470. }
  471. MS_LOG(INFO) << "Sending tensors";
  472. std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
  473. {
  474. // print view cmd reply
  475. for (auto tensor : tensors) {
  476. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  477. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  478. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  479. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  480. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  481. MS_LOG(INFO) << "tensor dims: ";
  482. for (auto dim : tensor.dims()) {
  483. MS_LOG(INFO) << dim << ",";
  484. }
  485. MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
  486. }
  487. }
  488. EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
  489. if (send_tensors_reply.status() != send_tensors_reply.OK) {
  490. MS_LOG(ERROR) << "Error: SendTensors failed";
  491. }
  492. break;
  493. }
  494. }
  495. }
  496. void AddTensorProtoInfo(TensorProto *tensor_item, TensorProto tensor) {
  497. tensor_item->set_node_name(tensor.node_name());
  498. tensor_item->set_slot(tensor.slot());
  499. tensor_item->set_iter(tensor.iter());
  500. tensor_item->set_truncate(tensor.truncate());
  501. tensor_item->clear_tensor_content();
  502. tensor_item->clear_data_type();
  503. tensor_item->clear_dims();
  504. }
  505. void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id) {
  506. std::vector<std::tuple<std::string, bool>> check_node_list;
  507. std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
  508. [](WatchNode node) -> std::tuple<std::string, bool> {
  509. return make_tuple(node.node_name(), node.node_type() == "scope");
  510. });
  511. debug_services_->AddWatchpoint(id, condition.condition(), condition.value(), check_node_list);
  512. }
  513. void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
  514. std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
  515. std::vector<std::string> name;
  516. std::vector<std::string> ret_name;
  517. std::vector<char *> data_ptr;
  518. std::vector<unsigned int> data_size;
  519. std::vector<TypePtr> dtype;
  520. std::vector<std::vector<int>> shape;
  521. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  522. // ret_name will contain tensor names that are found in TensorLoader
  523. // items in ret_name will be in the same order with tensors if found
  524. debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
  525. std::list<TensorProto> tensor_list;
  526. unsigned int result_index = 0;
  527. for (auto tensor : tensors) {
  528. int size_iter = 0;
  529. if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
  530. TensorProto tensor_item;
  531. tensor_item.set_finished(true);
  532. AddTensorProtoInfo(&tensor_item, tensor);
  533. tensor_list.push_back(tensor_item);
  534. continue;
  535. }
  536. int tensor_size = data_size[result_index];
  537. while (size_iter < tensor_size) {
  538. int chunk_size = CHUNK_SIZE;
  539. TensorProto tensor_item;
  540. tensor_item.set_finished(false);
  541. if (tensor_size - size_iter <= CHUNK_SIZE) {
  542. chunk_size = tensor_size - size_iter;
  543. tensor_item.set_finished(true);
  544. }
  545. AddTensorProtoInfo(&tensor_item, tensor);
  546. // return empty tensor if didn't find the requested tensor
  547. tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
  548. tensor_item.set_data_type(GetDebuggerNumberDataType(dtype[result_index]));
  549. for (auto &elem : shape[result_index]) {
  550. tensor_item.add_dims(elem);
  551. }
  552. // add tensor to result list and increment result_index to check next item in ret_name
  553. tensor_list.push_back(tensor_item);
  554. size_iter += CHUNK_SIZE;
  555. }
  556. result_index++;
  557. }
  558. return tensor_list;
  559. }
  560. void Debugger::Exit() {
  561. // clear resource before exit
  562. // For node level, debugger has to exit itself because main thread can only exit in step bundary;
  563. // For step level, debugger will notify main thread to exit;
  564. if (run_level_ == "node") {
  565. pipeline::ClearResAtexit();
  566. exit(1);
  567. } else if (run_level_ == "step" || device_target_ == kAscendDevice) {
  568. // Notify main thread to terminate
  569. pipeline::ExecutorPy::DebugTerminate(true);
  570. } else {
  571. pipeline::ClearResAtexit();
  572. exit(1);
  573. }
  574. }
  575. std::list<WatchpointHit> Debugger::CheckWatchpoints(const std::string &watchnode) {
  576. std::vector<std::string> name;
  577. std::vector<std::string> slot;
  578. std::vector<int> condition;
  579. std::vector<unsigned int> watchpoint_id;
  580. std::vector<std::string> overflow_ops;
  581. #ifdef ENABLE_D
  582. overflow_ops = CheckOpOverflow();
  583. #endif
  584. auto tensor_loader = debug_services_->tensor_loader();
  585. std::vector<std::shared_ptr<TensorData>> tensor_list;
  586. if (watchnode.empty()) {
  587. tensor_list = tensor_loader->GetTensor();
  588. } else {
  589. tensor_list = tensor_loader->GetNodeTensorMap(watchnode);
  590. }
  591. debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, overflow_ops, tensor_list);
  592. std::list<WatchpointHit> hits;
  593. for (unsigned int i = 0; i < name.size(); i++) {
  594. WatchpointHit hit;
  595. hit.set_id(watchpoint_id[i]);
  596. // here TensorProto act as a tensor indicator, not sending tensor content
  597. TensorProto *tensor_item = hit.mutable_tensor();
  598. tensor_item->set_node_name(name[i]);
  599. tensor_item->set_slot(slot[i]);
  600. tensor_item->set_finished(true);
  601. WatchCondition *condition_item = hit.mutable_watch_condition();
  602. condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
  603. hits.push_back(hit);
  604. }
  605. return hits;
  606. }
  607. void Debugger::SendWatchpoints(const std::list<WatchpointHit> &points) {
  608. // send info about watchpoint
  609. if (!points.empty()) {
  610. EventReply reply = grpc_client_->SendWatchpointHits(points);
  611. if (reply.status() != reply.OK) {
  612. MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
  613. }
  614. }
  615. }
  616. DebugServices *Debugger::debug_services() const { return debug_services_.get(); }
  617. bool Debugger::debugger_enabled() const { return debugger_enabled_; }
  618. DebuggerCommand GetCommand(const EventReply &reply) {
  619. DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
  620. switch (reply.cmd_case()) {
  621. case debugger::EventReply::CmdCase::kExit:
  622. cmd = DebuggerCommand::kExitCMD;
  623. break;
  624. case debugger::EventReply::CmdCase::kRunCmd:
  625. cmd = DebuggerCommand::kRunCMD;
  626. break;
  627. case debugger::EventReply::CmdCase::kSetCmd:
  628. cmd = DebuggerCommand::kSetCMD;
  629. break;
  630. case debugger::EventReply::CmdCase::kViewCmd:
  631. cmd = DebuggerCommand::kViewCMD;
  632. break;
  633. default:
  634. MS_LOG(DEBUG) << "Debug: UnknownCMD";
  635. break;
  636. }
  637. return cmd;
  638. }
  639. ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
  640. if (!reply.has_set_cmd()) {
  641. MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
  642. return ProtoVector<WatchNode>();
  643. }
  644. return reply.set_cmd().watch_nodes();
  645. }
  646. std::string GetRunLevel(const EventReply &reply) {
  647. if (!reply.has_run_cmd()) {
  648. MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
  649. "";
  650. return "";
  651. }
  652. return reply.run_cmd().run_level();
  653. }
  654. std::string GetNodeName(const EventReply &reply) {
  655. if (!reply.has_run_cmd()) {
  656. MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
  657. "";
  658. return "";
  659. }
  660. return reply.run_cmd().node_name();
  661. }
  662. WatchCondition GetWatchcondition(const EventReply &reply) {
  663. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  664. MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
  665. return WatchCondition();
  666. }
  667. return reply.set_cmd().watch_condition();
  668. }
  669. int32_t GetWatchpointID(const EventReply &reply) {
  670. if (!reply.has_set_cmd()) {
  671. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
  672. return 0;
  673. }
  674. return reply.set_cmd().id();
  675. }
  676. bool GetWatchpointDelete(const EventReply &reply) {
  677. if (!reply.has_set_cmd()) {
  678. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
  679. return false;
  680. }
  681. return reply.set_cmd().delete_();
  682. }
  683. ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
  684. if (!reply.has_view_cmd()) {
  685. MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
  686. return ProtoVector<TensorProto>();
  687. }
  688. return reply.view_cmd().tensors();
  689. }
  690. std::string GetTensorFullName(const TensorProto &tensor) {
  691. string node_name = tensor.node_name();
  692. if (tensor.truncate()) {
  693. // scopes in node name are seperated by '/'
  694. // use the name without scope if truncate is true
  695. std::size_t found = node_name.find_last_of("/");
  696. node_name = node_name.substr(found + 1);
  697. }
  698. return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
  699. }
  700. bool Debugger::partial_memory() { return partial_memory_; }
  701. void Debugger::SetCurNode(std::string cur_name) {
  702. // access lock for public method
  703. std::lock_guard<std::mutex> a_lock(access_lock_);
  704. cur_name_ = cur_name;
  705. }
  706. std::string Debugger::run_level() const { return run_level_; }
  707. void Debugger::SetStepNum(int32_t cur_num_step) {
  708. // access lock for public method
  709. std::lock_guard<std::mutex> a_lock(access_lock_);
  710. num_step_ = cur_num_step;
  711. }
  712. int32_t Debugger::step_num() const { return num_step_; }
  713. uint64_t BytestoInt64(const std::vector<char> &buffer) {
  714. uint64_t ret;
  715. ret = ((uint64_t)buffer[7] << 56) | ((uint64_t)buffer[6] << 48) | ((uint64_t)buffer[5] << 40) |
  716. ((uint64_t)buffer[4] << 32) | ((uint64_t)buffer[3] << 24) | ((uint64_t)buffer[2] << 16) |
  717. ((uint64_t)buffer[1] << 8) | ((uint64_t)buffer[0]);
  718. return ret;
  719. }
  720. #define BUF_SIZ 256
  721. std::vector<std::string> Debugger::CheckOpOverflow() {
  722. std::vector<double> bin_list;
  723. std::vector<std::string> op_names;
  724. DIR *d;
  725. struct dirent *dir = nullptr;
  726. d = opendir(overflow_bin_path_.c_str());
  727. if (d != nullptr) {
  728. while ((dir = readdir(d)) != NULL) {
  729. if (dir->d_type == DT_REG) {
  730. std::string file_path = overflow_bin_path_;
  731. file_path.append(dir->d_name);
  732. std::string file_name = dir->d_name;
  733. std::size_t found = file_name.find_last_of(".");
  734. if (found == std::string::npos) {
  735. continue;
  736. }
  737. std::string overflow_time = file_name.substr(found + 1);
  738. if (stod(overflow_time) <= last_overflow_bin_) {
  739. MS_LOG(INFO) << "File already processed " << file_name;
  740. continue;
  741. }
  742. bin_list.push_back(stod(overflow_time));
  743. std::fstream infile;
  744. infile.open(file_path.c_str(), std::ios::binary | std::ios::in);
  745. if (!infile.is_open()) {
  746. MS_LOG(ERROR) << "Failed to open overflow bin file " << file_name;
  747. continue;
  748. }
  749. infile.seekg(313, std::ios::beg);
  750. std::vector<char> buffer;
  751. buffer.resize(BUF_SIZ);
  752. infile.read(buffer.data(), BUF_SIZ);
  753. uint64_t stream_id = BytestoInt64(std::vector<char>(buffer.begin() + 8, buffer.end()));
  754. uint64_t task_id = BytestoInt64(std::vector<char>(buffer.begin() + 16, buffer.end()));
  755. MS_LOG(INFO) << "Overflow stream_id " << stream_id << ", task_id " << task_id << ".";
  756. auto op = debugger_->stream_task_to_opname_.find(std::make_pair(stream_id, task_id));
  757. if (op != debugger_->stream_task_to_opname_.end()) {
  758. MS_LOG(ERROR) << "Overflow detected on node " << op->second << std::endl;
  759. op_names.push_back(op->second);
  760. } else {
  761. MS_LOG(INFO) << "No overflow is detected " << std::endl;
  762. }
  763. infile.close();
  764. }
  765. }
  766. } else {
  767. MS_LOG(INFO) << "OverFlow bin directory does not exist!";
  768. }
  769. closedir(d);
  770. if (op_names.size()) {
  771. MS_LOG(ERROR) << "These operation overflows are detected " << op_names;
  772. }
  773. for (auto &i : bin_list) {
  774. if (i > last_overflow_bin_) {
  775. last_overflow_bin_ = i;
  776. }
  777. }
  778. return op_names;
  779. }
  780. void Debugger::SetTrainingDone(bool training_done) { training_done_ = training_done; }
  781. bool Debugger::CheckPort(const char *port) {
  782. char *p = const_cast<char *>(port);
  783. int num = 0;
  784. if (*p == '0' && *(p + 1) != '\0') return false;
  785. while (*p != '\0') {
  786. if (*p < '0' || *p > '9') return false;
  787. num = num * 10 + (*p) - '0';
  788. if (num < 1 || num > 65535) return false;
  789. p++;
  790. }
  791. return true;
  792. }
  793. void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
  794. MS_EXCEPTION_IF_NULL(anf_node);
  795. if (!anf_node->isa<Parameter>() && !anf_node->isa<ValueNode>()) {
  796. return;
  797. }
  798. bool keep_prev;
  799. if (anf_node->isa<Parameter>()) {
  800. keep_prev = true;
  801. } else {
  802. keep_prev = false;
  803. }
  804. // for parameters and value nodes, set its execution order to be 0;
  805. int exec_order = 0;
  806. std::string node_name = anf_node->fullname_with_scope();
  807. E2eDumpUtil::GetFileKernelName(NOT_NULL(&node_name));
  808. // check if output adde exists, if not, return;
  809. if (!AnfAlgo::OutputAddrExist(anf_node, output_index)) {
  810. return;
  811. }
  812. auto addr = AnfAlgo::GetOutputAddr(anf_node, output_index);
  813. MS_EXCEPTION_IF_NULL(addr);
  814. auto type = AnfAlgo::GetOutputInferDataType(anf_node, output_index);
  815. auto format = kOpFormat_DEFAULT;
  816. string tensor_name = node_name + "_output:" + "0";
  817. ShapeVector int_shapes;
  818. auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index);
  819. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  820. [](size_t inner_item) { return SizeToInt(inner_item); });
  821. bool ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, keep_prev);
  822. if (!ret) {
  823. MS_LOG(ERROR) << "LoadMemToHost:"
  824. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  825. }
  826. }
  827. void Debugger::LoadParametersAndConst() {
  828. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  829. if (!(num_step_ == 0 || device_target_ == kAscendDevice ||
  830. (device_target_ == kGPUDevice && device::KernelRuntime::DumpDataEnabledIteration())))
  831. return;
  832. MS_EXCEPTION_IF_NULL(graph_ptr_);
  833. // load parameters
  834. MS_LOG(INFO) << "Start to load Parameters!";
  835. const auto &parameters = graph_ptr_->inputs();
  836. for (auto &item : parameters) {
  837. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
  838. }
  839. // load value nodes
  840. // get all constant avlues from the graph
  841. MS_LOG(INFO) << "Start to load value nodes!";
  842. const auto value_nodes = graph_ptr_->graph_value_nodes();
  843. for (auto &item : value_nodes) {
  844. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
  845. }
  846. }
  847. void Debugger::LoadGraphOutputs() {
  848. if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
  849. MS_EXCEPTION_IF_NULL(graph_ptr_);
  850. const auto &apply_kernels = graph_ptr_->execution_order();
  851. // for kernels, execution order starts from 1
  852. int exec_order = 1;
  853. for (const auto &node : apply_kernels) {
  854. MS_EXCEPTION_IF_NULL(node);
  855. auto node_name = AnfAlgo::GetCNodeName(node);
  856. std::string kernel_name = node->fullname_with_scope();
  857. auto output_size = AnfAlgo::GetOutputTensorNum(node);
  858. if (partial_memory_) {
  859. if (!debug_services_->IsWatchPoint(kernel_name)) {
  860. continue;
  861. }
  862. }
  863. for (size_t j = 0; j < output_size; ++j) {
  864. auto addr = AnfAlgo::GetOutputAddr(node, j);
  865. MS_EXCEPTION_IF_NULL(addr);
  866. auto type = AnfAlgo::GetOutputInferDataType(node, j);
  867. auto format = kOpFormat_DEFAULT;
  868. string tensor_name = kernel_name + ':' + std::to_string(j);
  869. ShapeVector int_shapes;
  870. auto shape = AnfAlgo::GetOutputDeviceShape(node, j);
  871. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
  872. [](size_t inner_item) { return SizeToInt(inner_item); });
  873. auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false);
  874. if (!ret) {
  875. MS_LOG(ERROR) << "LoadMemToHost:"
  876. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  877. }
  878. }
  879. exec_order = exec_order + 1;
  880. }
  881. }
  882. void Debugger::UpdateStepNum() {
  883. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()))
  884. ++num_step_;
  885. }
  886. void Debugger::ClearCurrentData() {
  887. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()))
  888. debug_services_->tensor_loader()->EmptyCurrentTensor();
  889. }
  890. } // namespace mindspore