You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger.cc 31 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <stdio.h>
  18. #include <fstream>
  19. #include <tuple>
  20. #include <vector>
  21. #include <algorithm>
  22. #include <iostream>
  23. #include <cstring>
  24. #include <utility>
  25. #include <map>
  26. #include "debug/debugger/debugger.h"
  27. #include "debug/data_dump_parser.h"
  28. #include "pipeline/jit/pipeline.h"
  29. #include "backend/session/anf_runtime_algorithm.h"
  30. #include "runtime/device/kernel_runtime_manager.h"
  31. #include "utils/system/file_system.h"
  32. #include "utils/system/env.h"
  33. #include <nlohmann/json.hpp>
  34. using json = nlohmann::json;
  35. using debugger::EventReply;
  36. using debugger::GraphProto;
  37. using debugger::ModelProto;
  38. using debugger::TensorProto;
  39. using debugger::WatchCondition;
  40. using debugger::WatchCondition_Condition_inf;
  41. using debugger::WatchCondition_Condition_nan;
  42. using debugger::WatchNode;
  43. using debugger::WatchpointHit;
  44. #define CHUNK_SIZE 1024 * 1024 * 3
  45. namespace mindspore {
  46. DebuggerPtr Debugger::debugger_ = nullptr;
  47. std::mutex Debugger::instance_lock_;
  48. Debugger::Debugger()
  49. : grpc_client_(nullptr),
  50. debug_services_(nullptr),
  51. device_id_(0),
  52. device_target_(""),
  53. num_step_(0),
  54. debugger_enabled_(false),
  55. run_level_(""),
  56. node_name_(""),
  57. cur_name_(""),
  58. is_dataset_graph_(false),
  59. partial_memory_(false),
  60. last_overflow_bin_(0),
  61. overflow_bin_path_(""),
  62. ssl_certificate(true),
  63. certificate_dir(""),
  64. certificate_passphrase("") {}
  65. void Debugger::Init(const uint32_t device_id, const std::string device_target) {
  66. // access lock for public method
  67. std::lock_guard<std::mutex> a_lock(access_lock_);
  68. // save device_id
  69. MS_LOG(INFO) << "Debugger got device_id: " << device_id;
  70. device_id_ = device_id;
  71. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  72. device_target_ = device_target;
  73. }
  74. void Debugger::EnableDebugger() {
  75. // reset some of the class members
  76. num_step_ = 0;
  77. debugger_enabled_ = false;
  78. partial_memory_ = false;
  79. grpc_client_ = nullptr;
  80. debug_services_ = nullptr;
  81. // see if dump is enabled
  82. bool dump_enabled = false;
  83. if (device_target_ == kGPUDevice) {
  84. auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
  85. MS_EXCEPTION_IF_NULL(runtime_instance);
  86. dump_enabled = runtime_instance->DumpDataEnabled();
  87. }
  88. // get env variables to configure debugger
  89. const char *env_enable_str = std::getenv("ENABLE_MS_DEBUGGER");
  90. if (env_enable_str != nullptr) {
  91. MS_LOG(INFO) << "Getenv ENABLE_MS_DEBUGGER: " << env_enable_str;
  92. if (std::strcmp(env_enable_str, "1") == 0) {
  93. debugger_enabled_ = true;
  94. }
  95. }
  96. if (!debugger_enabled_ && !dump_enabled) {
  97. MS_LOG(WARNING) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
  98. return;
  99. }
  100. // configure grpc host
  101. const char *env_host_str = std::getenv("MS_DEBUGGER_HOST");
  102. std::string host;
  103. if (env_host_str != nullptr) {
  104. MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
  105. host = std::string(env_host_str);
  106. } else {
  107. MS_LOG(WARNING) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
  108. host = "localhost";
  109. }
  110. // configure grpc port
  111. const char *env_port_str = std::getenv("MS_DEBUGGER_PORT");
  112. std::string port;
  113. if (env_port_str != nullptr) {
  114. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
  115. port = std::string(env_port_str);
  116. } else {
  117. MS_LOG(WARNING) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
  118. port = "50051";
  119. }
  120. // configure partial memory reuse
  121. const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM");
  122. if (env_partial_mem_str != nullptr) {
  123. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
  124. if (std::strcmp(env_partial_mem_str, "1") == 0) {
  125. partial_memory_ = true;
  126. }
  127. }
  128. // switch memory reuse on or off
  129. auto context_ptr = MsContext::GetInstance();
  130. MS_EXCEPTION_IF_NULL(context_ptr);
  131. context_ptr->set_enable_mem_reuse(partial_memory_);
  132. // print some message about memory reuse to user
  133. if (partial_memory_) {
  134. MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
  135. "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
  136. } else {
  137. MS_LOG(WARNING) << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
  138. "usage for large models.";
  139. }
  140. #ifdef ENABLE_D
  141. // set operation overflow info
  142. overflow_bin_path_ = DataDumpParser::GetInstance().GetOpOverflowBinPath(graph_ptr_->graph_id(), device_id_);
  143. // new overflow dump files will have a timestamp greater than last_overflow_bin_
  144. last_overflow_bin_ = 0;
  145. DIR *d;
  146. d = opendir(overflow_bin_path_.c_str());
  147. if (d != nullptr) {
  148. struct dirent *dir;
  149. while ((dir = readdir(d)) != NULL) {
  150. if (dir->d_type == DT_REG) {
  151. std::string file_path = overflow_bin_path_;
  152. file_path.append(dir->d_name);
  153. std::size_t found = file_path.find_last_of(".");
  154. if (found == std::string::npos) {
  155. continue;
  156. }
  157. std::string overflow_time = file_path.substr(found + 1);
  158. if (stod(overflow_time) <= last_overflow_bin_) {
  159. MS_LOG(INFO) << "Old op overflow bin folder" << file_path;
  160. continue;
  161. }
  162. last_overflow_bin_ = stod(overflow_time);
  163. }
  164. }
  165. MS_LOG(INFO) << "last op overflow bin folder" << last_overflow_bin_;
  166. }
  167. #endif
  168. // initialize grpc client
  169. if (debugger_enabled_) {
  170. SetDebuggerConfFromJsonFile();
  171. grpc_client_ = std::make_unique<GrpcClient>(host, port, ssl_certificate, certificate_dir, certificate_passphrase);
  172. }
  173. debug_services_ = std::make_unique<DebugServices>();
  174. }
  175. void Debugger::Reset() {
  176. // access lock for public method
  177. std::lock_guard<std::mutex> a_lock(access_lock_);
  178. // reset components
  179. device_id_ = 0;
  180. device_target_ = "";
  181. num_step_ = 0;
  182. debugger_enabled_ = false;
  183. is_dataset_graph_ = false;
  184. partial_memory_ = false;
  185. graph_ptr_ = nullptr;
  186. grpc_client_ = nullptr;
  187. debug_services_ = nullptr;
  188. last_overflow_bin_ = 0;
  189. overflow_bin_path_ = "";
  190. stream_task_to_opname_.clear();
  191. }
  192. void Debugger::PreExecute(const KernelGraphPtr &graph_ptr) {
  193. // access lock for public method
  194. std::lock_guard<std::mutex> a_lock(access_lock_);
  195. // check and save graph_ptr, suspend if graph is new
  196. CheckGraphPtr(graph_ptr);
  197. }
  198. void Debugger::PostExecute() {
  199. // access lock for public method
  200. std::lock_guard<std::mutex> a_lock(access_lock_);
  201. // analyze tensor data and send the watchpoints been hit
  202. if (run_level_ == "node") {
  203. MS_LOG(INFO) << "Debugger is in node level mode ";
  204. return;
  205. }
  206. if (debugger_enabled_ && !is_dataset_graph_) {
  207. if (device_target_ != kGPUDevice) {
  208. num_step_++;
  209. MS_LOG(INFO) << "Debugger suspend at end of step; number of steps executed: " << num_step_;
  210. SendWatchpointsAndSuspend(CheckWatchpoints());
  211. } else {
  212. CommandLoop();
  213. }
  214. }
  215. }
  216. bool Debugger::ReadNodeDataRequired() {
  217. if (debugger_enabled_ && !is_dataset_graph_) {
  218. auto watchpoint_table = debug_services_->GetWatchpointTable();
  219. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, watchpoint_table);
  220. // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
  221. if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
  222. return true;
  223. }
  224. }
  225. return false;
  226. }
  227. void Debugger::PostExecuteNode() {
  228. // access lock for public method
  229. std::lock_guard<std::mutex> a_lock(access_lock_);
  230. if (debugger_enabled_ && !is_dataset_graph_) {
  231. auto watchpoint_table = debug_services_->GetWatchpointTable();
  232. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, watchpoint_table);
  233. // if kernel is watchpoint,and get hit. suspend.
  234. if (is_watchpoint) {
  235. auto hits = CheckSingleWatchpoint(cur_name_);
  236. if (!hits.empty()) {
  237. SendWatchpointsAndSuspend(hits);
  238. }
  239. }
  240. // if kernel is not watchpoint and is next_to or continue_to node, suspend.
  241. if (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) {
  242. CommandLoop();
  243. }
  244. return;
  245. }
  246. }
  247. void Debugger::PostDebugOp() {
  248. // access lock for public method
  249. std::lock_guard<std::mutex> a_lock(access_lock_);
  250. // suspend if debugger is enabled
  251. if (debugger_enabled_ && !is_dataset_graph_) {
  252. MS_LOG(INFO) << "Debugger suspend at debug_op";
  253. CommandLoop();
  254. }
  255. }
  256. std::map<std::pair<uint32_t, uint32_t>, std::string> &Debugger::GetStreamTaskToOpnameMap() {
  257. return stream_task_to_opname_;
  258. }
  259. void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
  260. if (graph_ptr_ != graph_ptr) {
  261. MS_LOG(INFO) << "Debugger got new graph: " << graph_ptr->graph_id();
  262. // save new graph_ptr
  263. graph_ptr_ = graph_ptr;
  264. // check if it is dataset graph
  265. CheckDatasetGraph();
  266. if (!is_dataset_graph_) {
  267. // only try to enable debugger if it is not a dataset graph
  268. EnableDebugger();
  269. if (debugger_enabled_) {
  270. // get graph proto and send to mindinsight
  271. SendGraphAndSuspend(GetGraphProto());
  272. }
  273. }
  274. }
  275. }
  276. void Debugger::CheckDatasetGraph() {
  277. // print parameter node names
  278. const auto &params = graph_ptr_->inputs();
  279. for (const auto &param : params) {
  280. MS_LOG(INFO) << "param: " << param->fullname_with_scope();
  281. }
  282. // check if there is GetNext or InitDataSetQueue node
  283. const auto &nodes = graph_ptr_->execution_order();
  284. for (const auto &node : nodes) {
  285. auto node_name = AnfAlgo::GetCNodeName(node);
  286. MS_LOG(INFO) << "node: " << node->fullname_with_scope();
  287. if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
  288. MS_LOG(WARNING) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
  289. << node_name;
  290. is_dataset_graph_ = true;
  291. return;
  292. }
  293. }
  294. is_dataset_graph_ = false;
  295. }
  296. GraphProto Debugger::GetGraphProto() const {
  297. // convert kernel graph to debugger modelproto
  298. ModelProto model = GetDebuggerFuncGraphProto(graph_ptr_);
  299. return model.graph();
  300. }
  301. void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
  302. // prepare metadata
  303. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  304. Metadata metadata;
  305. metadata.set_device_name(device_name);
  306. metadata.set_cur_step(num_step_);
  307. metadata.set_backend(device_target_);
  308. metadata.set_cur_node(cur_name_);
  309. EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
  310. if (reply_metadata.status() != reply_metadata.OK) {
  311. MS_LOG(ERROR) << "Error: SendMetadata failed";
  312. }
  313. // send graph to mindinght server
  314. EventReply reply = grpc_client_->SendGraph(graph_proto);
  315. if (reply.status() != reply.OK) {
  316. MS_LOG(ERROR) << "Error: SendGraph failed";
  317. }
  318. // enter command loop, wait and process commands
  319. CommandLoop();
  320. }
  321. void Debugger::CommandLoop() {
  322. // prepare metadata
  323. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  324. Metadata metadata;
  325. metadata.set_device_name(device_name);
  326. metadata.set_cur_step(num_step_);
  327. metadata.set_backend(device_target_);
  328. metadata.set_cur_node(cur_name_);
  329. // loop exit flag
  330. bool run = false;
  331. int num_wait_fail = 0;
  332. const int max_num_wait_fail = 5;
  333. while (!run) {
  334. // wait for command
  335. EventReply reply = grpc_client_->WaitForCommand(metadata);
  336. if (reply.status() != reply.OK) {
  337. MS_LOG(ERROR) << "Error: WaitForCommand failed";
  338. num_wait_fail++;
  339. if (num_wait_fail > max_num_wait_fail) {
  340. MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session";
  341. Exit();
  342. }
  343. MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
  344. << num_wait_fail << "s";
  345. std::this_thread::sleep_for(std::chrono::milliseconds(1000 * num_wait_fail));
  346. continue;
  347. }
  348. // get type of the command in reply
  349. DebuggerCommand cmd = GetCommand(reply);
  350. if (cmd == DebuggerCommand::kUnknownCMD) {
  351. MS_LOG(ERROR) << "Error: debugger recieved unknown command";
  352. continue;
  353. }
  354. MS_LOG(INFO) << "recieved command: ";
  355. switch (cmd) {
  356. case DebuggerCommand::kUnknownCMD:
  357. MS_LOG(INFO) << "UnknownCMD";
  358. break;
  359. case DebuggerCommand::kExitCMD:
  360. MS_LOG(INFO) << "ExitCMD";
  361. Exit();
  362. break;
  363. case DebuggerCommand::kRunCMD:
  364. MS_LOG(INFO) << "RunCMD";
  365. {
  366. // print run cmd content
  367. // get run_level and node_name
  368. run_level_ = GetRunLevel(reply);
  369. node_name_ = GetNodeName(reply);
  370. MS_LOG(INFO) << "run_level: " << run_level_;
  371. MS_LOG(INFO) << "node_name_: " << node_name_;
  372. }
  373. // exit loop
  374. run = true;
  375. break;
  376. case DebuggerCommand::kSetCMD:
  377. MS_LOG(INFO) << "SetCMD";
  378. {
  379. // print set cmd content
  380. ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
  381. for (auto node : recieved_nodes) {
  382. MS_LOG(INFO) << "node name: " << node.node_name();
  383. MS_LOG(INFO) << "node type: " << node.node_type();
  384. }
  385. MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
  386. MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
  387. MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
  388. }
  389. MS_LOG(INFO) << "Setting watchpoint";
  390. if (GetWatchpointDelete(reply)) {
  391. RemoveWatchpoint(GetWatchpointID(reply));
  392. } else {
  393. SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply));
  394. }
  395. break;
  396. case DebuggerCommand::kViewCMD:
  397. MS_LOG(INFO) << "ViewCMD";
  398. {
  399. // print view cmd content
  400. ProtoVector<TensorProto> received_tensors = GetTensors(reply);
  401. for (auto tensor : received_tensors) {
  402. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  403. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  404. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  405. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  406. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  407. }
  408. }
  409. MS_LOG(INFO) << "Sending tensors";
  410. std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
  411. {
  412. // print view cmd reply
  413. for (auto tensor : tensors) {
  414. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  415. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  416. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  417. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  418. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  419. MS_LOG(INFO) << "tensor dims: ";
  420. for (auto dim : tensor.dims()) {
  421. MS_LOG(INFO) << dim << ",";
  422. }
  423. MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
  424. }
  425. }
  426. EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
  427. if (send_tensors_reply.status() != send_tensors_reply.OK) {
  428. MS_LOG(ERROR) << "Error: SendTensors failed";
  429. }
  430. break;
  431. }
  432. }
  433. }
  434. void AddTensorProtoInfo(TensorProto *tensor_item, TensorProto tensor) {
  435. tensor_item->set_node_name(tensor.node_name());
  436. tensor_item->set_slot(tensor.slot());
  437. tensor_item->set_iter(tensor.iter());
  438. tensor_item->set_truncate(tensor.truncate());
  439. tensor_item->clear_tensor_content();
  440. tensor_item->clear_data_type();
  441. tensor_item->clear_dims();
  442. }
  443. void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id) {
  444. std::vector<std::tuple<std::string, bool>> check_node_list;
  445. std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
  446. [](WatchNode node) -> std::tuple<std::string, bool> {
  447. return make_tuple(node.node_name(), node.node_type() == "scope");
  448. });
  449. debug_services_->AddWatchpoint(id, condition.condition(), check_node_list);
  450. }
  451. void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
  452. std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
  453. std::vector<std::string> name;
  454. std::vector<std::string> ret_name;
  455. std::vector<char *> data_ptr;
  456. std::vector<unsigned int> data_size;
  457. std::vector<TypePtr> dtype;
  458. std::vector<std::vector<int>> shape;
  459. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  460. // ret_name will contain tensor names that are found in TensorLoader
  461. // items in ret_name will be in the same order with tensors if found
  462. debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
  463. std::list<TensorProto> tensor_list;
  464. unsigned int result_index = 0;
  465. for (auto tensor : tensors) {
  466. int size_iter = 0;
  467. if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
  468. TensorProto tensor_item;
  469. tensor_item.set_finished(true);
  470. AddTensorProtoInfo(&tensor_item, tensor);
  471. tensor_list.push_back(tensor_item);
  472. continue;
  473. }
  474. int tensor_size = data_size[result_index];
  475. while (size_iter < tensor_size) {
  476. int chunk_size = CHUNK_SIZE;
  477. TensorProto tensor_item;
  478. tensor_item.set_finished(false);
  479. if (tensor_size - size_iter <= CHUNK_SIZE) {
  480. chunk_size = tensor_size - size_iter;
  481. tensor_item.set_finished(true);
  482. }
  483. AddTensorProtoInfo(&tensor_item, tensor);
  484. // return empty tensor if didn't find the requested tensor
  485. tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
  486. tensor_item.set_data_type(GetDebuggerNumberDataType(dtype[result_index]));
  487. for (auto &elem : shape[result_index]) {
  488. tensor_item.add_dims(elem);
  489. }
  490. // add tensor to result list and increment result_index to check next item in ret_name
  491. tensor_list.push_back(tensor_item);
  492. size_iter += CHUNK_SIZE;
  493. }
  494. result_index++;
  495. }
  496. return tensor_list;
  497. }
  498. void Debugger::Exit() {
  499. // clear resource before exit
  500. pipeline::ClearResAtexit();
  501. std::exit(EXIT_FAILURE);
  502. }
  503. std::list<WatchpointHit> Debugger::CheckWatchpoints() {
  504. std::vector<std::string> name;
  505. std::vector<std::string> slot;
  506. std::vector<int> condition;
  507. std::vector<unsigned int> watchpoint_id;
  508. std::vector<std::string> overflow_ops;
  509. #ifdef ENABLE_D
  510. overflow_ops = CheckOpOverflow();
  511. #endif
  512. debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, overflow_ops);
  513. std::list<WatchpointHit> hits;
  514. for (unsigned int i = 0; i < name.size(); i++) {
  515. WatchpointHit hit;
  516. hit.set_id(watchpoint_id[i]);
  517. // here TensorProto act as a tensor indicator, not sending tensor content
  518. TensorProto *tensor_item = hit.mutable_tensor();
  519. tensor_item->set_node_name(name[i]);
  520. tensor_item->set_slot(slot[i]);
  521. tensor_item->set_finished(true);
  522. WatchCondition *condition_item = hit.mutable_watch_condition();
  523. condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
  524. hits.push_back(hit);
  525. }
  526. return hits;
  527. }
  528. std::list<WatchpointHit> Debugger::CheckSingleWatchpoint(std::string watchnode) const {
  529. auto tensor_loader = debug_services_->tensor_loader();
  530. auto tensors = tensor_loader->GetNodeTensorMap(watchnode);
  531. std::list<WatchpointHit> hits;
  532. for (std::vector<std::shared_ptr<TensorData>>::iterator it = tensors.begin(); it != tensors.end(); ++it) {
  533. auto cur_tensor = *it;
  534. std::string name = "";
  535. std::string slot = "";
  536. char *data_ptr = nullptr;
  537. unsigned int data_size = 0;
  538. int condition = -1;
  539. unsigned int watchpoint_id = -1;
  540. WatchpointHit hit;
  541. debug_services_->CheckSingleWatchpoint(cur_tensor, &name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id);
  542. if (name != "") {
  543. hit.set_id(watchpoint_id);
  544. // here TensorProto act as a tensor indicator, not sending tensor content
  545. TensorProto *tensor_item = hit.mutable_tensor();
  546. tensor_item->set_node_name(name);
  547. tensor_item->set_slot(slot);
  548. tensor_item->set_finished(true);
  549. WatchCondition *condition_item = hit.mutable_watch_condition();
  550. condition_item->set_condition(debugger::WatchCondition_Condition(condition));
  551. hits.push_back(hit);
  552. }
  553. }
  554. return hits;
  555. }
  556. void Debugger::SendWatchpointsAndSuspend(const std::list<WatchpointHit> &points) {
  557. // send info about watchpoint
  558. if (!points.empty()) {
  559. EventReply reply = grpc_client_->SendWatchpointHits(points);
  560. if (reply.status() != reply.OK) {
  561. MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
  562. }
  563. }
  564. // enter command loop
  565. CommandLoop();
  566. }
  567. DebugServices *Debugger::debug_services() const { return debug_services_.get(); }
  568. bool Debugger::debugger_enabled() const { return debugger_enabled_; }
  569. DebuggerCommand GetCommand(const EventReply &reply) {
  570. DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
  571. switch (reply.cmd_case()) {
  572. case debugger::EventReply::CmdCase::kExit:
  573. cmd = DebuggerCommand::kExitCMD;
  574. break;
  575. case debugger::EventReply::CmdCase::kRunCmd:
  576. cmd = DebuggerCommand::kRunCMD;
  577. break;
  578. case debugger::EventReply::CmdCase::kSetCmd:
  579. cmd = DebuggerCommand::kSetCMD;
  580. break;
  581. case debugger::EventReply::CmdCase::kViewCmd:
  582. cmd = DebuggerCommand::kViewCMD;
  583. break;
  584. default:
  585. MS_LOG(ERROR) << "Error: UnknownCMD";
  586. break;
  587. }
  588. return cmd;
  589. }
  590. ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
  591. if (!reply.has_set_cmd()) {
  592. MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
  593. return ProtoVector<WatchNode>();
  594. }
  595. return reply.set_cmd().watch_nodes();
  596. }
  597. std::string GetRunLevel(const EventReply &reply) {
  598. if (!reply.has_run_cmd()) {
  599. MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
  600. "";
  601. return "";
  602. }
  603. return reply.run_cmd().run_level();
  604. }
  605. std::string GetNodeName(const EventReply &reply) {
  606. if (!reply.has_run_cmd()) {
  607. MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
  608. "";
  609. return "";
  610. }
  611. return reply.run_cmd().node_name();
  612. }
  613. WatchCondition GetWatchcondition(const EventReply &reply) {
  614. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  615. MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
  616. return WatchCondition();
  617. }
  618. return reply.set_cmd().watch_condition();
  619. }
  620. int32_t GetWatchpointID(const EventReply &reply) {
  621. if (!reply.has_set_cmd()) {
  622. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
  623. return 0;
  624. }
  625. return reply.set_cmd().id();
  626. }
  627. bool GetWatchpointDelete(const EventReply &reply) {
  628. if (!reply.has_set_cmd()) {
  629. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
  630. return false;
  631. }
  632. return reply.set_cmd().delete_();
  633. }
  634. ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
  635. if (!reply.has_view_cmd()) {
  636. MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
  637. return ProtoVector<TensorProto>();
  638. }
  639. return reply.view_cmd().tensors();
  640. }
  641. std::string GetTensorFullName(const TensorProto &tensor) {
  642. string node_name = tensor.node_name();
  643. if (tensor.truncate()) {
  644. // scopes in node name are seperated by '/'
  645. // use the name without scope if truncate is true
  646. std::size_t found = node_name.find_last_of("/");
  647. node_name = node_name.substr(found + 1);
  648. }
  649. return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
  650. }
  651. bool Debugger::partial_memory() { return partial_memory_; }
  652. void Debugger::SetCurNode(std::string cur_name) {
  653. // access lock for public method
  654. std::lock_guard<std::mutex> a_lock(access_lock_);
  655. cur_name_ = cur_name;
  656. }
  657. std::string Debugger::run_level() const { return run_level_; }
  658. void Debugger::SetStepNum(int32_t cur_num_step) {
  659. // access lock for public method
  660. std::lock_guard<std::mutex> a_lock(access_lock_);
  661. num_step_ = cur_num_step;
  662. }
  663. int32_t Debugger::step_num() const { return num_step_; }
  664. uint64_t BytestoInt64(const std::vector<char> &buffer) {
  665. uint64_t ret;
  666. ret = ((uint64_t)buffer[7] << 56) | ((uint64_t)buffer[6] << 48) | ((uint64_t)buffer[5] << 40) |
  667. ((uint64_t)buffer[4] << 32) | (buffer[3] << 24) | (buffer[2] << 16) | (buffer[1] << 8) | buffer[0];
  668. return ret;
  669. }
  670. #define BUF_SIZ 256
  671. std::vector<std::string> Debugger::CheckOpOverflow() {
  672. std::vector<double> bin_list;
  673. std::vector<std::string> op_names;
  674. DIR *d;
  675. struct dirent *dir = nullptr;
  676. d = opendir(overflow_bin_path_.c_str());
  677. if (d != nullptr) {
  678. while ((dir = readdir(d)) != NULL) {
  679. if (dir->d_type == DT_REG) {
  680. std::string file_path = overflow_bin_path_;
  681. file_path.append(dir->d_name);
  682. std::string file_name = dir->d_name;
  683. std::size_t found = file_name.find_last_of(".");
  684. if (found == std::string::npos) {
  685. continue;
  686. }
  687. std::string overflow_time = file_name.substr(found + 1);
  688. if (stod(overflow_time) <= last_overflow_bin_) {
  689. MS_LOG(INFO) << "File already processed " << file_name;
  690. continue;
  691. }
  692. bin_list.push_back(stod(overflow_time));
  693. std::fstream infile;
  694. infile.open(file_path.c_str(), std::ios::binary | std::ios::in);
  695. infile.seekg(313, std::ios::beg);
  696. std::vector<char> buffer;
  697. buffer.resize(BUF_SIZ);
  698. infile.read(buffer.data(), BUF_SIZ);
  699. uint64_t stream_id = BytestoInt64(std::vector<char>(buffer.begin() + 8, buffer.end()));
  700. uint64_t task_id = BytestoInt64(std::vector<char>(buffer.begin() + 16, buffer.end()));
  701. MS_LOG(INFO) << "Overflow stream_id " << stream_id << ", task_id " << task_id << ".";
  702. auto op = debugger_->stream_task_to_opname_.find(std::make_pair(stream_id, task_id));
  703. if (op != debugger_->stream_task_to_opname_.end()) {
  704. MS_LOG(ERROR) << "Overflow detected on node " << op->second << std::endl;
  705. op_names.push_back(op->second);
  706. } else {
  707. MS_LOG(INFO) << "No overflow is detected " << std::endl;
  708. }
  709. infile.close();
  710. }
  711. }
  712. } else {
  713. MS_LOG(INFO) << "OverFlow bin directory does not exist!";
  714. }
  715. closedir(d);
  716. MS_LOG(ERROR) << "These operation overflows are detected " << op_names;
  717. for (auto &i : bin_list) {
  718. if (i > last_overflow_bin_) {
  719. last_overflow_bin_ = i;
  720. }
  721. }
  722. return op_names;
  723. }
  724. bool Debugger::SetDebuggerConfFromJsonFile() {
  725. const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
  726. if (config_path_str != nullptr) {
  727. MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str;
  728. } else {
  729. MS_LOG(INFO) << "No need debugger config path. please export MINDSPORE_CONFIG_PATH eg: MINDSPORE_CONFIG_PATH=/etc";
  730. ssl_certificate = false;
  731. return false;
  732. }
  733. char real_path[4096] = {0};
  734. if (nullptr == realpath(config_path_str, real_path)) {
  735. MS_LOG(ERROR) << "Env debugger config path error, " << config_path_str;
  736. ssl_certificate = false;
  737. return false;
  738. }
  739. std::string debugger_config_file = std::string(real_path) + "/debugger_config.json";
  740. std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem();
  741. MS_EXCEPTION_IF_NULL(fs);
  742. if (!fs->FileExist(debugger_config_file)) {
  743. MS_LOG(ERROR) << debugger_config_file << " not exist.";
  744. ssl_certificate = false;
  745. return false;
  746. }
  747. return ParseDebuggerConfig(debugger_config_file);
  748. }
  749. bool Debugger::ParseDebuggerConfig(const std::string &debugger_config_file) {
  750. std::ifstream jsonFile(debugger_config_file);
  751. if (!jsonFile.is_open()) {
  752. MS_LOG(ERROR) << debugger_config_file << " open failed.";
  753. ssl_certificate = false;
  754. return false;
  755. }
  756. json j;
  757. jsonFile >> j;
  758. if (j.find("DebuggerSettings") == j.end()) {
  759. MS_LOG(ERROR) << "DebuggerSettings is not exist.";
  760. ssl_certificate = false;
  761. return false;
  762. } else {
  763. json debuggerSettings = j.at("DebuggerSettings");
  764. // convert json to string
  765. std::stringstream ss;
  766. ss << debuggerSettings;
  767. std::string cfg = ss.str();
  768. MS_LOG(INFO) << "Debugger Settings Json: " << cfg;
  769. if (!IsConfigExist(debuggerSettings)) {
  770. return false;
  771. }
  772. if (!IsConfigValid(debuggerSettings)) {
  773. return false;
  774. }
  775. }
  776. return true;
  777. }
  778. bool Debugger::IsConfigExist(const nlohmann::json &debuggerSettings) {
  779. if (debuggerSettings.find("ssl_certificate") == debuggerSettings.end() ||
  780. debuggerSettings.find("certificate_dir") == debuggerSettings.end() ||
  781. debuggerSettings.find("certificate_passphrase") == debuggerSettings.end()) {
  782. MS_LOG(ERROR) << "DebuggerSettings keys is not exist.";
  783. ssl_certificate = false;
  784. return false;
  785. }
  786. return true;
  787. }
  788. bool Debugger::IsConfigValid(const nlohmann::json &debuggerSettings) {
  789. auto enable_secure = debuggerSettings.at("ssl_certificate");
  790. auto certificate_dir_ = debuggerSettings.at("certificate_dir");
  791. auto certificate_passphrase_ = debuggerSettings.at("certificate_passphrase");
  792. if (!(enable_secure.is_boolean() && certificate_dir_.is_string() && certificate_passphrase_.is_string())) {
  793. MS_LOG(ERROR) << "Element's type in Debugger config json is invalid.";
  794. ssl_certificate = false;
  795. return false;
  796. }
  797. ssl_certificate = enable_secure;
  798. certificate_dir = certificate_dir_;
  799. certificate_passphrase = certificate_passphrase_;
  800. return true;
  801. }
  802. } // namespace mindspore