You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger.cc 53 kB

4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <cstdio>
  18. #include <fstream>
  19. #include <tuple>
  20. #include <vector>
  21. #include <algorithm>
  22. #include <iostream>
  23. #include <cstring>
  24. #include <utility>
  25. #include <map>
  26. #include <regex>
  27. #include "debug/debugger/debugger.h"
  28. #include "debug/data_dump/dump_json_parser.h"
  29. #include "pipeline/jit/pipeline.h"
  30. #include "backend/session/anf_runtime_algorithm.h"
  31. #include "runtime/device/kernel_runtime_manager.h"
  32. #include "runtime/device/kernel_runtime.h"
  33. #include "debug/data_dump/e2e_dump.h"
  34. #include "utils/config_manager.h"
  35. #include "debug/env_config_parser.h"
  36. #include "utils/comm_manager.h"
  37. #include "runtime/hardware/device_context_manager.h"
  38. #include "debug/anf_ir_dump.h"
  39. #include "debug/anf_ir_utils.h"
  40. #ifdef ENABLE_DEBUGGER
  41. #include "debug/debugger/proto_exporter.h"
  42. #else
  43. #include "debug/debugger/proto_exporter_stub.h"
  44. #endif
  45. using debugger::Chunk;
  46. using debugger::EventReply;
  47. using debugger::GraphProto;
  48. using debugger::ModelProto;
  49. using debugger::TensorProto;
  50. using debugger::WatchCondition;
  51. using debugger::WatchCondition_Condition_inf;
  52. using debugger::WatchCondition_Condition_nan;
  53. using debugger::WatchCondition_Parameter;
  54. using debugger::WatchNode;
  55. using debugger::WatchpointHit;
  56. namespace mindspore {
  57. static constexpr auto g_chunk_size = 1024 * 1024 * 3;
  58. DebuggerPtr Debugger::debugger_ = nullptr;
  59. std::mutex Debugger::instance_lock_;
  60. Debugger::Debugger()
  61. : grpc_client_(nullptr),
  62. debug_services_(nullptr),
  63. device_id_(0),
  64. device_target_(""),
  65. num_step_(0),
  66. debugger_enabled_(false),
  67. run_level_(""),
  68. node_name_(""),
  69. cur_name_(""),
  70. training_done_(false),
  71. is_dataset_graph_(false),
  72. partial_memory_(false),
  73. last_overflow_bin_(0),
  74. initial_suspend_(true),
  75. not_dataset_graph_sum_(0),
  76. version_("") {
  77. CheckDebuggerEnabledParam();
  78. auto ms_context = MsContext::GetInstance();
  79. MS_EXCEPTION_IF_NULL(ms_context);
  80. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  81. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  82. if (device_target == kCPUDevice) {
  83. MS_LOG(WARNING) << "Not enabling debugger. Debugger does not support CPU.";
  84. } else if (CheckDebuggerEnabled()) {
  85. // configure partial memory reuse
  86. partial_memory_ = CheckDebuggerPartialMemoryEnabled();
  87. // switch memory reuse on or off
  88. EnvConfigParser::GetInstance().SetSysMemreuse(partial_memory_);
  89. // print some message about memory reuse to user
  90. if (partial_memory_) {
  91. MS_LOG(WARNING)
  92. << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
  93. "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
  94. } else {
  95. MS_LOG(WARNING)
  96. << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
  97. "usage for large models.";
  98. }
  99. }
  100. }
  101. void Debugger::Init(const uint32_t device_id, const std::string device_target) {
  102. // access lock for public method
  103. std::lock_guard<std::mutex> a_lock(access_lock_);
  104. // save device_id
  105. MS_LOG(INFO) << "Debugger got device_id: " << device_id;
  106. device_id_ = device_id;
  107. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  108. device_target_ = device_target;
  109. version_ = "1.3.0";
  110. }
  111. bool IsTypeDebuggerSupported(TypeId type) {
  112. if (type < TypeId::kNumberTypeEnd && type > TypeId::kNumberTypeBegin && type != kNumberTypeComplex64) {
  113. return true;
  114. } else {
  115. MS_LOG(INFO) << "Debugger does not support type: " << TypeIdLabel(type);
  116. return false;
  117. }
  118. }
  119. void Debugger::EnableDebugger() {
  120. // reset some of the class members
  121. num_step_ = 0;
  122. debugger_enabled_ = false;
  123. partial_memory_ = false;
  124. grpc_client_ = nullptr;
  125. debug_services_ = nullptr;
  126. // see if dump using debugger backend is enabled
  127. bool dump_enabled = CheckDebuggerDumpEnabled();
  128. MS_LOG(INFO) << "dump using debugger backend = " << dump_enabled;
  129. // check if debugger enabled
  130. debugger_enabled_ = CheckDebuggerEnabled();
  131. MS_LOG(INFO) << "debugger_enabled_ = " << debugger_enabled_;
  132. if (!debugger_enabled_ && !dump_enabled) {
  133. MS_LOG(INFO) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
  134. return;
  135. }
  136. if (debugger_enabled_) {
  137. // configure grpc host
  138. std::string env_host_str = common::GetEnv("MS_DEBUGGER_HOST");
  139. std::string host;
  140. if (!env_host_str.empty()) {
  141. if (CheckIp(env_host_str)) {
  142. MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
  143. host = env_host_str;
  144. } else {
  145. debugger_enabled_ = false;
  146. MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_HOST isn't a valid IP address. "
  147. "Please set environment variable MS_DEBUGGER_HOST=x.x.x.x to a valid IP";
  148. }
  149. } else {
  150. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
  151. host = "localhost";
  152. }
  153. // configure grpc port
  154. std::string env_port_str = common::GetEnv("MS_DEBUGGER_PORT");
  155. std::string port;
  156. if (!env_port_str.empty()) {
  157. if (CheckPort(env_port_str)) {
  158. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
  159. port = env_port_str;
  160. } else {
  161. debugger_enabled_ = false;
  162. MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to "
  163. "65535";
  164. }
  165. } else {
  166. port = "50051";
  167. if (!CheckPort(port)) {
  168. MS_EXCEPTION(ValueError) << "Default MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to 65535";
  169. }
  170. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
  171. }
  172. // initialize grpc client
  173. grpc_client_ = std::make_unique<GrpcClient>(host, port);
  174. }
  175. debug_services_ = std::make_unique<DebugServices>();
  176. }
  177. void Debugger::SetOpOverflowBinPath(uint32_t graph_id) {
  178. #ifdef ENABLE_D
  179. // set operation overflow info
  180. overflow_bin_path_.insert(std::pair<uint32_t, std::string>(
  181. graph_id, DumpJsonParser::GetInstance().GetOpOverflowBinPath(graph_id, device_id_)));
  182. // new overflow dump files will have a timestamp greater than last_overflow_bin_
  183. auto overflow_bin_path = overflow_bin_path_.find(graph_id)->second;
  184. MS_LOG(INFO) << "overflow_bin_path = " << overflow_bin_path;
  185. DIR *d = opendir(overflow_bin_path.c_str());
  186. if (d != nullptr) {
  187. struct dirent *dir;
  188. while ((dir = readdir(d)) != nullptr) {
  189. if (dir->d_type == DT_REG) {
  190. std::string file_path = overflow_bin_path;
  191. file_path.append(dir->d_name);
  192. std::size_t found = file_path.find_last_of(".");
  193. if (found == std::string::npos) {
  194. continue;
  195. }
  196. std::string overflow_time = file_path.substr(found + 1);
  197. if (stod(overflow_time) <= last_overflow_bin_) {
  198. MS_LOG(INFO) << "Old op overflow bin folder" << file_path;
  199. continue;
  200. }
  201. last_overflow_bin_ = stod(overflow_time);
  202. }
  203. }
  204. MS_LOG(INFO) << "last op overflow bin folder" << last_overflow_bin_;
  205. closedir(d);
  206. }
  207. #endif
  208. }
  209. void Debugger::CheckDatasetSinkMode() {
  210. if (CheckDebuggerDumpEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  211. MS_EXCEPTION(NotSupportError)
  212. << "e2e_dump not supported on GPU with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  213. }
  214. if (CheckDebuggerEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  215. MS_EXCEPTION(NotSupportError)
  216. << "Debugger is not supported with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  217. }
  218. }
  219. bool Debugger::CheckDebuggerDumpEnabled() const {
  220. // see if dump is enabled
  221. if (device_target_ == kGPUDevice) {
  222. return device::KernelRuntime::DumpDataEnabled();
  223. }
  224. return false;
  225. }
  226. bool Debugger::CheckDebuggerEnabled() const {
  227. // get env variables to configure debugger
  228. std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
  229. if (!env_enable_str.empty()) {
  230. (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
  231. if ((env_enable_str == "1" || env_enable_str == "true") && device_target_ != kCPUDevice) {
  232. return true;
  233. }
  234. }
  235. return false;
  236. }
  237. void Debugger::CheckDebuggerEnabledParam() const {
  238. // check the value of env variable ENABLE_MS_DEBUGGER
  239. std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
  240. if (!env_enable_str.empty()) {
  241. (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
  242. if (env_enable_str != "0" && env_enable_str != "1" && env_enable_str != "false" && env_enable_str != "true") {
  243. MS_LOG(WARNING) << "Env variable ENABLE_MS_DEBUGGER should be True/False/1/0 (case insensitive), but get: "
  244. << env_enable_str;
  245. }
  246. }
  247. }
  248. bool Debugger::CheckDebuggerPartialMemoryEnabled() const {
  249. std::string env_partial_mem_str = common::GetEnv("MS_DEBUGGER_PARTIAL_MEM");
  250. if (!env_partial_mem_str.empty()) {
  251. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
  252. if (env_partial_mem_str == "1") {
  253. return true;
  254. }
  255. }
  256. return false;
  257. }
  258. bool Debugger::DebuggerBackendEnabled() const { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
  259. void Debugger::Reset() {
  260. // access lock for public method
  261. std::lock_guard<std::mutex> a_lock(access_lock_);
  262. // reset components
  263. device_id_ = 0;
  264. device_target_ = "";
  265. num_step_ = 0;
  266. debugger_enabled_ = false;
  267. is_dataset_graph_ = false;
  268. partial_memory_ = false;
  269. graph_ptr_ = nullptr;
  270. grpc_client_ = nullptr;
  271. debug_services_ = nullptr;
  272. last_overflow_bin_ = 0;
  273. overflow_bin_path_.clear();
  274. stream_task_to_opname_.clear();
  275. graph_proto_list_.clear();
  276. graph_ptr_list_.clear();
  277. }
  278. void Debugger::PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs) {
  279. // Only GPU is supported for MindRTBackend
  280. if (device_target_ != kGPUDevice) {
  281. return;
  282. }
  283. uint32_t graph_sum = graphs.size();
  284. for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
  285. const auto &graph = graphs[graph_index];
  286. if (debugger_) {
  287. debugger_->PreExecute(graph, graph_sum);
  288. }
  289. DumpSetup(graph);
  290. }
  291. }
  292. void Debugger::PreExecute(const KernelGraphPtr &graph_ptr, uint32_t graph_sum) {
  293. // access lock for public method
  294. std::lock_guard<std::mutex> a_lock(access_lock_);
  295. CheckDatasetSinkMode();
  296. auto graph_id = graph_ptr->graph_id();
  297. // collect rungrap_ids to update step number in multigraph case
  298. if (!rungraph_id_list_.size()) {
  299. rungraph_id_list_.push_back(graph_id);
  300. } else {
  301. if (std::find(rungraph_id_list_.begin(), rungraph_id_list_.end(), graph_id) == rungraph_id_list_.end()) {
  302. rungraph_id_list_.push_back(graph_id);
  303. }
  304. }
  305. // check and save graph_ptr, suspend if graph is new
  306. MS_LOG(INFO) << "total number graph: " << graph_sum;
  307. // multiple graphs
  308. if (graph_sum > 1) {
  309. // there are more than one graphs are not dataset_graph
  310. if (not_dataset_graph_sum_ > 0) {
  311. // only try to enable debugger if they are not all dataset graphs
  312. if (!debugger_enabled_) {
  313. EnableDebugger();
  314. }
  315. if (debugger_enabled_) {
  316. if (graph_proto_list_.size()) {
  317. // only send compiled graphs once.
  318. auto dbg_graph_ptr = graph_ptr_;
  319. // use current graph ptr to load parameters
  320. graph_ptr_ = graph_ptr;
  321. LoadParametersAndConst();
  322. // revert graph ptr to original value
  323. graph_ptr_ = dbg_graph_ptr;
  324. SendMultiGraphsAndSuspend(graph_proto_list_);
  325. graph_proto_list_.clear();
  326. } else if (graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
  327. // stop only when receive the first sub run graph for each step
  328. // if we have stopped for the last kernel before, no need to stop again
  329. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  330. return;
  331. }
  332. if (!(run_level_ == "node" && suspended_at_last_kernel_)) {
  333. CommandLoop();
  334. }
  335. debug_services_->ResetLoadedTensors();
  336. }
  337. }
  338. }
  339. } else if (graph_proto_list_.size() == 1) {
  340. if (device_target_ == kGPUDevice && num_step_ != 0) {
  341. if (debugger_enabled_ && !(run_level_ == "node" && suspended_at_last_kernel_)) {
  342. CommandLoop();
  343. }
  344. debug_services_->ResetLoadedTensors();
  345. }
  346. // In single graph case, reset graph_ptr_ to be nullptr for the initial step
  347. if (num_step_ == 0) {
  348. graph_ptr_ = nullptr;
  349. CheckGraphPtr(graph_ptr);
  350. }
  351. }
  352. // resets for the new graph
  353. suspended_at_last_kernel_ = 0;
  354. }
  355. bool Debugger::DumpDataEnabledIteration() const {
  356. auto &dump_json_parser = DumpJsonParser::GetInstance();
  357. if (!dump_json_parser.e2e_dump_enabled()) {
  358. return false;
  359. }
  360. auto cur_iter = dump_json_parser.cur_dump_iter();
  361. if (dump_json_parser.IsDumpIter(cur_iter)) {
  362. return true;
  363. }
  364. return false;
  365. }
  366. uint32_t Debugger::GetRankID() {
  367. auto ms_context = MsContext::GetInstance();
  368. MS_EXCEPTION_IF_NULL(ms_context);
  369. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  370. uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  371. const auto &device_context =
  372. device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
  373. uint32_t rank_id = device_context->GetRankID();
  374. return rank_id;
  375. }
  376. void Debugger::Dump(const KernelGraphPtr &kernel_graph) const {
  377. uint32_t rank_id = GetRankID();
  378. if (debugger_->DebuggerBackendEnabled()) {
  379. MS_EXCEPTION_IF_NULL(kernel_graph);
  380. E2eDump::DumpParametersAndConstData(kernel_graph.get(), rank_id, debugger_.get());
  381. } else {
  382. DumpJsonParser::GetInstance().UpdateDumpIter();
  383. }
  384. }
  385. void Debugger::DumpSingleNode(const CNodePtr &node, uint32_t graph_id) {
  386. if (debugger_->DebuggerBackendEnabled()) {
  387. uint32_t rank_id = GetRankID();
  388. E2eDump::DumpSingleNodeData(node, graph_id, rank_id, debugger_.get());
  389. }
  390. }
  391. void Debugger::DumpSetup(const KernelGraphPtr &kernel_graph) const {
  392. MS_LOG(INFO) << "Start!";
  393. uint32_t rank_id = GetRankID();
  394. MS_EXCEPTION_IF_NULL(kernel_graph);
  395. E2eDump::DumpSetup(kernel_graph.get(), rank_id);
  396. MS_LOG(INFO) << "Finish!";
  397. }
  398. void Debugger::DumpInGraphCompiler(const KernelGraphPtr &kernel_graph) {
  399. // This function will be called for new GPU runtime using MindRTBackend
  400. auto &json_parser = DumpJsonParser::GetInstance();
  401. if (json_parser.e2e_dump_enabled()) {
  402. uint32_t rank_id = GetRankID();
  403. kernel_graph->set_root_graph_id(kernel_graph->graph_id());
  404. std::string final_graph = "trace_code_graph_" + std::to_string(kernel_graph->graph_id());
  405. std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id);
  406. std::string target_dir = root_dir + "/graphs";
  407. std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
  408. DumpIRProtoWithSrcInfo(kernel_graph, final_graph, target_dir, kDebugWholeStack);
  409. DumpIR("trace_code_graph", kernel_graph, true, kWholeStack, ir_file_path);
  410. DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(kernel_graph->graph_id()) + ".csv", root_dir,
  411. kernel_graph->execution_order());
  412. }
  413. }
  414. void Debugger::PostExecuteGraphDebugger() {
  415. // Only GPU is supported for MindRTBackend
  416. if (device_target_ != kGPUDevice) {
  417. return;
  418. }
  419. // LoadParametersAndConst for all the graphs
  420. for (auto graph : graph_ptr_list_) {
  421. debugger_->LoadParametersAndConst(graph);
  422. }
  423. // debug used for dump
  424. if (debugger_ && debugger_->CheckDebuggerDumpEnabled()) {
  425. // Dump Parameters and consts
  426. for (auto graph : graph_ptr_list_) {
  427. debugger_->Dump(graph);
  428. if (!debugger_->debugger_enabled()) {
  429. debugger_->ClearCurrentData();
  430. }
  431. }
  432. }
  433. if (debugger_) {
  434. debugger_->PostExecute();
  435. }
  436. }
  437. void Debugger::PostExecute() {
  438. // access lock for public method
  439. std::lock_guard<std::mutex> a_lock(access_lock_);
  440. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  441. return;
  442. }
  443. if (debugger_->DebuggerBackendEnabled()) {
  444. // analyze tensor data and send the watchpoints been hit
  445. if (debugger_enabled_ && !is_dataset_graph_) {
  446. if (device_target_ != kGPUDevice) {
  447. num_step_++;
  448. }
  449. SendWatchpoints(CheckWatchpoints());
  450. // no need to suspend at each graph for GPU, suspension happens in preExecute
  451. if (device_target_ != kGPUDevice) {
  452. CommandLoop();
  453. }
  454. }
  455. // Only keep parameters in the current map
  456. // GPU ResetLoadedTensors happens in preExecute
  457. if (device_target_ != kGPUDevice) {
  458. debug_services_->ResetLoadedTensors();
  459. }
  460. }
  461. }
  462. bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) const {
  463. if (debugger_enabled_ && !is_dataset_graph_) {
  464. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  465. // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
  466. if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
  467. return true;
  468. }
  469. }
  470. return false;
  471. }
  472. void Debugger::PostExecuteNode(const CNodePtr &kernel, bool last_kernel) {
  473. // access lock for public method
  474. std::lock_guard<std::mutex> a_lock(access_lock_);
  475. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  476. return;
  477. }
  478. if (debugger_enabled_ && !is_dataset_graph_) {
  479. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  480. // if kernel is watchpoint,and get hit. suspend.
  481. bool hit_empty_flag = true;
  482. if (is_watchpoint) {
  483. auto hits = CheckWatchpoints(cur_name_, kernel);
  484. if (!hits.empty()) {
  485. SendWatchpoints(hits);
  486. CommandLoop();
  487. hit_empty_flag = false;
  488. }
  489. }
  490. if (hit_empty_flag && run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) {
  491. // if kernel is not watchpoint and is next_to or continue_to node, suspend
  492. // sets a bool to be checked in preExecute to avoid double stopping at last kernel in the last graph
  493. if (last_kernel) {
  494. suspended_at_last_kernel_ = 1;
  495. }
  496. CommandLoop();
  497. }
  498. return;
  499. }
  500. }
  501. void Debugger::PostDebugOp() {
  502. // access lock for public method
  503. std::lock_guard<std::mutex> a_lock(access_lock_);
  504. // suspend if debugger is enabled
  505. if (debugger_enabled_ && !is_dataset_graph_) {
  506. MS_LOG(INFO) << "Debugger suspend at debug_op";
  507. CommandLoop();
  508. }
  509. }
  510. void Debugger::SetStreamTaskToOpnameMap(const std::map<std::pair<uint32_t, uint32_t>, std::string> &mapping) {
  511. MS_LOG(INFO) << "SetStreamTaskToOpnameMap start";
  512. for (auto const &item : mapping) {
  513. MS_LOG(INFO) << "stream = " << item.first.first << ", task = " << item.first.second
  514. << ", op_name = " << item.second << std::endl;
  515. }
  516. MS_LOG(INFO) << "SetStreamTaskToOpnameMap end";
  517. stream_task_to_opname_ = mapping;
  518. }
  519. void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) {
  520. if (graph_ptr_ != graph_ptr) {
  521. MS_LOG(INFO) << "LoadGraphs Debugger got new graph: " << graph_ptr->graph_id();
  522. // save new graph_ptr
  523. graph_ptr_ = graph_ptr;
  524. CheckDatasetGraph();
  525. if (!is_dataset_graph_) {
  526. // get proto for new graph_ptr
  527. auto graph_proto = GetGraphProto(graph_ptr);
  528. // add new graph proto to graph_proto_list_
  529. graph_proto_list_.push_back(graph_proto);
  530. graph_ptr_list_.push_back(graph_ptr);
  531. #ifdef ENABLE_D
  532. SetOpOverflowBinPath(graph_ptr->graph_id());
  533. #endif
  534. not_dataset_graph_sum_++;
  535. }
  536. // reset is_dataset_graph to be false
  537. is_dataset_graph_ = false;
  538. }
  539. }
  540. // In single graph cases, check single graph ptr
  541. void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
  542. if (graph_ptr_ != graph_ptr) {
  543. MS_LOG(INFO) << "CheckGraphPtr Debugger got new graph: " << graph_ptr->graph_id();
  544. // save new graph_ptr
  545. graph_ptr_ = graph_ptr;
  546. if (!is_dataset_graph_) {
  547. // only try to enable debugger if it is not a dataset graph
  548. EnableDebugger();
  549. if (debugger_enabled_) {
  550. LoadParametersAndConst();
  551. // get graph proto and send to Mindinsight
  552. auto graph_proto = graph_proto_list_.front();
  553. SendGraphAndSuspend(graph_proto);
  554. }
  555. }
  556. }
  557. }
  558. void Debugger::CheckDatasetGraph() {
  559. // print parameter node names
  560. const auto &params = graph_ptr_->inputs();
  561. for (const auto &param : params) {
  562. MS_LOG(INFO) << "param: " << GetKernelNodeName(param);
  563. }
  564. // check if there is GetNext or InitDataSetQueue node
  565. const auto &nodes = graph_ptr_->execution_order();
  566. for (const auto &node : nodes) {
  567. auto node_name = AnfAlgo::GetCNodeName(node);
  568. MS_LOG(INFO) << "node: " << GetKernelNodeName(node);
  569. if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
  570. MS_LOG(INFO) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
  571. << node_name;
  572. is_dataset_graph_ = true;
  573. return;
  574. }
  575. }
  576. is_dataset_graph_ = false;
  577. }
  578. GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
  579. // convert kernel graph to debugger modelproto
  580. ModelProto model = GetDebuggerFuncGraphProto(graph_ptr);
  581. return model.graph();
  582. }
  583. void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
  584. if (SendMetadata(true)) {
  585. // send graph to Mindinsight server
  586. EventReply reply = grpc_client_->SendGraph(graph_proto);
  587. if (reply.status() != reply.OK) {
  588. MS_LOG(ERROR) << "Error: SendGraph failed";
  589. }
  590. // enter command loop, wait and process commands
  591. CommandLoop();
  592. }
  593. }
  594. bool Debugger::SendMetadata(bool version_check) {
  595. // prepare metadata
  596. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  597. Metadata metadata;
  598. metadata.set_device_name(device_name);
  599. metadata.set_cur_step(num_step_);
  600. metadata.set_backend(device_target_);
  601. metadata.set_cur_node(cur_name_);
  602. metadata.set_training_done(training_done_);
  603. metadata.set_ms_version(version_);
  604. MS_LOG(INFO) << "Is training done?" << training_done_;
  605. // set graph munber to not_dataset_graph_sum_
  606. metadata.set_graph_num(not_dataset_graph_sum_);
  607. EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
  608. bool ret = false;
  609. if (reply_metadata.status() == reply_metadata.OK) {
  610. if (version_check) {
  611. // get type of the command in meta data reply, it should be version matched
  612. DebuggerCommand cmd = GetCommand(reply_metadata);
  613. if (cmd != DebuggerCommand::kVersionMatchedCMD) {
  614. MS_LOG(ERROR) << "MindInsight version is too old, Mindspore version is " << version_;
  615. Exit();
  616. } else {
  617. if (GetMiVersionMatched(reply_metadata)) {
  618. MS_LOG(INFO) << "MindSpore version is " << version_ << " matches MindInsight version.";
  619. ret = true;
  620. } else {
  621. MS_LOG(ERROR) << "MindSpore version " << version_ << ", did not match MindInsight version.";
  622. CommandLoop();
  623. }
  624. }
  625. } else {
  626. // version check is done before so we can just return true here
  627. ret = true;
  628. }
  629. } else {
  630. MS_LOG(ERROR) << "Error: SendMetadata failed";
  631. }
  632. return ret;
  633. }
  634. void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list) {
  635. if (!SendMetadata(true)) {
  636. return;
  637. }
  638. // send multiple graphs to mindinght server
  639. // split graph into chunks if one graph is larger than chunk size
  640. std::list<Chunk> chunked_graph_proto_list;
  641. Chunk chunk;
  642. for (auto graph : graph_proto_list) {
  643. std::string str = graph.SerializeAsString();
  644. auto graph_size = graph.ByteSize();
  645. if (graph_size > g_chunk_size) {
  646. auto sub_graph_str = grpc_client_->ChunkString(str, graph_size);
  647. for (unsigned int i = 0; i < sub_graph_str.size(); i++) {
  648. chunk.set_buffer(sub_graph_str[i]);
  649. if (i < sub_graph_str.size() - 1) {
  650. chunk.set_finished(false);
  651. } else {
  652. chunk.set_finished(true);
  653. }
  654. chunked_graph_proto_list.push_back(chunk);
  655. }
  656. } else {
  657. chunk.set_buffer(str);
  658. chunk.set_finished(true);
  659. chunked_graph_proto_list.push_back(chunk);
  660. }
  661. }
  662. EventReply reply = grpc_client_->SendMultiGraphs(chunked_graph_proto_list);
  663. if (reply.status() != reply.OK) {
  664. MS_LOG(ERROR) << "Error: SendGraph failed";
  665. }
  666. // enter command loop, wait and process commands
  667. CommandLoop();
  668. }
  669. void Debugger::CommandLoop() {
  670. // prepare metadata
  671. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  672. Metadata metadata;
  673. metadata.set_device_name(device_name);
  674. metadata.set_cur_step(num_step_);
  675. metadata.set_backend(device_target_);
  676. metadata.set_cur_node(cur_name_);
  677. metadata.set_training_done(training_done_);
  678. // loop exit flag
  679. bool run = false;
  680. int num_wait_fail = 0;
  681. const int max_num_wait_fail = 5;
  682. while (!run) {
  683. // wait for command
  684. EventReply reply = grpc_client_->WaitForCommand(metadata);
  685. if (reply.status() != reply.OK) {
  686. MS_LOG(ERROR) << "Error: WaitForCommand failed";
  687. num_wait_fail++;
  688. if (num_wait_fail > max_num_wait_fail) {
  689. MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session.";
  690. MS_LOG(ERROR) << "Failed to connect to MindInsight debugger server. Please check the config "
  691. "of debugger host and port.";
  692. Exit();
  693. run = true;
  694. } else {
  695. MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
  696. << num_wait_fail << "s";
  697. std::this_thread::sleep_for(std::chrono::seconds(num_wait_fail));
  698. }
  699. continue;
  700. }
  701. // get type of the command in reply
  702. DebuggerCommand cmd = GetCommand(reply);
  703. if (cmd == DebuggerCommand::kUnknownCMD) {
  704. MS_LOG(DEBUG) << "Debug: debugger received unknown command";
  705. continue;
  706. }
  707. MS_LOG(INFO) << "received command: ";
  708. switch (cmd) {
  709. case DebuggerCommand::kUnknownCMD:
  710. MS_LOG(INFO) << "UnknownCMD";
  711. break;
  712. case DebuggerCommand::kExitCMD:
  713. MS_LOG(INFO) << "ExitCMD";
  714. Exit();
  715. // Used for debugger termination
  716. run = true;
  717. break;
  718. case DebuggerCommand::kRunCMD:
  719. ProcessRunCMD(reply);
  720. if (GetRunLevel(reply) != "recheck") {
  721. // exit loop
  722. run = true;
  723. }
  724. break;
  725. case DebuggerCommand::kSetCMD:
  726. ProcessKSetCMD(reply);
  727. break;
  728. case DebuggerCommand::kViewCMD:
  729. ProcessKViewCMD(reply);
  730. break;
  731. case DebuggerCommand::kVersionMatchedCMD:
  732. MS_LOG(ERROR) << "Received unexpected Version Matched CMD from Mindinsight.";
  733. Exit();
  734. break;
  735. default:
  736. MS_LOG(ERROR) << "Received unknown CMD from Mindinsight";
  737. Exit();
  738. break;
  739. }
  740. }
  741. }
  742. void Debugger::ProcessRunCMD(const EventReply &reply) {
  743. MS_LOG(INFO) << "RunCMD";
  744. if (GetRunLevel(reply) == "recheck") {
  745. MS_LOG(INFO) << "rechecking all watchpoints";
  746. SendWatchpoints(CheckWatchpoints("", nullptr, true));
  747. } else {
  748. // no longer the initial suspension.
  749. initial_suspend_ = false;
  750. // print run cmd content
  751. // get run_level and node_name
  752. run_level_ = GetRunLevel(reply);
  753. node_name_ = GetNodeName(reply);
  754. MS_LOG(INFO) << "run_level: " << run_level_;
  755. MS_LOG(INFO) << "node_name_: " << node_name_;
  756. }
  757. }
  758. void Debugger::ProcessKSetCMD(const EventReply &reply) {
  759. MS_LOG(INFO) << "SetCMD";
  760. MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
  761. MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
  762. if (GetWatchpointDelete(reply)) {
  763. MS_LOG(INFO) << "Deleting watchpoint";
  764. RemoveWatchpoint(GetWatchpointID(reply));
  765. } else {
  766. MS_LOG(INFO) << "Setting watchpoint";
  767. MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
  768. ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
  769. for (const auto &node : recieved_nodes) {
  770. MS_LOG(INFO) << "node name: " << node.node_name();
  771. MS_LOG(INFO) << "node type: " << node.node_type();
  772. }
  773. ProtoVector<WatchCondition_Parameter> parameters = GetParameters(reply);
  774. for (const auto &parameter : parameters) {
  775. MS_LOG(INFO) << "parameter name: " << parameter.name();
  776. MS_LOG(INFO) << "parameter is disabled: " << parameter.disabled();
  777. MS_LOG(INFO) << "parameter value: " << parameter.value();
  778. }
  779. SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply), GetParameters(reply));
  780. }
  781. }
  782. void Debugger::ProcessKViewCMD(const EventReply &reply) {
  783. MS_LOG(INFO) << "ViewCMD";
  784. // print view cmd content
  785. ProtoVector<TensorProto> received_tensors = GetTensors(reply);
  786. for (auto received_tensor : received_tensors) {
  787. MS_LOG(INFO) << "tensor node name: " << received_tensor.node_name();
  788. MS_LOG(INFO) << "tensor slot: " << received_tensor.slot();
  789. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << received_tensor.finished() << std::noboolalpha;
  790. MS_LOG(INFO) << "tensor iter: " << received_tensor.iter();
  791. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << received_tensor.truncate() << std::noboolalpha;
  792. }
  793. MS_LOG(INFO) << "Sending tensors";
  794. std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
  795. // print view cmd reply
  796. for (auto tensor : tensors) {
  797. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  798. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  799. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  800. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  801. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  802. MS_LOG(INFO) << "tensor dims: ";
  803. for (auto dim : tensor.dims()) {
  804. MS_LOG(INFO) << dim << ",";
  805. }
  806. MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
  807. }
  808. EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
  809. if (send_tensors_reply.status() != debugger::EventReply::OK) {
  810. MS_LOG(ERROR) << "Error: SendTensors failed";
  811. }
  812. }
  813. void AddTensorProtoInfo(TensorProto *tensor_item, const TensorProto &tensor) {
  814. tensor_item->set_node_name(tensor.node_name());
  815. tensor_item->set_slot(tensor.slot());
  816. tensor_item->set_iter(tensor.iter());
  817. tensor_item->set_truncate(tensor.truncate());
  818. tensor_item->clear_tensor_content();
  819. tensor_item->clear_data_type();
  820. tensor_item->clear_dims();
  821. }
  822. void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id,
  823. const ProtoVector<WatchCondition_Parameter> &parameters) {
  824. std::vector<std::tuple<std::string, bool>> check_node_list;
  825. std::vector<DebugServices::parameter_t> parameter_list;
  826. std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
  827. [](const WatchNode &node) -> std::tuple<std::string, bool> {
  828. return make_tuple(node.node_name(), node.node_type() == "scope");
  829. });
  830. std::transform(
  831. parameters.begin(), parameters.end(), std::back_inserter(parameter_list),
  832. [](const WatchCondition_Parameter &parameter) -> DebugServices::parameter_t {
  833. return DebugServices::parameter_t{parameter.name(), parameter.disabled(), parameter.value(), parameter.hit()};
  834. });
  835. debug_services_->AddWatchpoint(id, condition.condition(), condition.value(), check_node_list, parameter_list);
  836. }
  837. void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
  838. std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
  839. std::vector<std::string> name;
  840. std::vector<std::string> ret_name;
  841. std::vector<char *> data_ptr;
  842. std::vector<ssize_t> data_size;
  843. std::vector<unsigned int> dtype;
  844. std::vector<std::vector<int64_t>> shape;
  845. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  846. // ret_name will contain tensor names that are found in TensorLoader
  847. // items in ret_name will be in the same order with tensors if found
  848. debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
  849. std::list<TensorProto> tensor_list;
  850. unsigned int result_index = 0;
  851. for (auto tensor : tensors) {
  852. ssize_t size_iter = 0;
  853. if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
  854. TensorProto tensor_item;
  855. tensor_item.set_finished(true);
  856. AddTensorProtoInfo(&tensor_item, tensor);
  857. tensor_list.push_back(tensor_item);
  858. continue;
  859. }
  860. ssize_t tensor_size = data_size[result_index];
  861. while (size_iter < tensor_size) {
  862. ssize_t chunk_size = g_chunk_size;
  863. TensorProto tensor_item;
  864. tensor_item.set_finished(false);
  865. if (tensor_size - size_iter <= g_chunk_size) {
  866. chunk_size = tensor_size - size_iter;
  867. tensor_item.set_finished(true);
  868. }
  869. AddTensorProtoInfo(&tensor_item, tensor);
  870. // return empty tensor if didn't find the requested tensor
  871. tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
  872. tensor_item.set_data_type((debugger::DataType)dtype[result_index]);
  873. for (auto &elem : shape[result_index]) {
  874. tensor_item.add_dims(elem);
  875. }
  876. // add tensor to result list and increment result_index to check next item in ret_name
  877. tensor_list.push_back(tensor_item);
  878. size_iter += g_chunk_size;
  879. }
  880. result_index++;
  881. }
  882. return tensor_list;
  883. }
  884. void Debugger::Exit() {
  885. // clear resource before exit
  886. // debugger will notify main thread to exit because main thread can only exit at step boundary
  887. pipeline::ExecutorPy::DebugTerminate(true);
  888. }
  889. std::list<WatchpointHit> Debugger::CheckWatchpoints(const std::string &watchnode, const CNodePtr &kernel,
  890. bool recheck) {
  891. std::vector<std::string> name;
  892. std::vector<std::string> slot;
  893. std::vector<int> condition;
  894. std::vector<unsigned int> watchpoint_id;
  895. std::vector<std::string> overflow_ops;
  896. std::vector<std::vector<DebugServices::parameter_t>> parameters;
  897. std::vector<int32_t> error_codes;
  898. #ifdef ENABLE_D
  899. overflow_ops = CheckOpOverflow();
  900. for (auto const &item : overflow_ops) {
  901. MS_LOG(DEBUG) << "overflow_ops item = " << item << std::endl;
  902. }
  903. #endif
  904. std::vector<std::shared_ptr<TensorData>> tensor_list;
  905. if (watchnode.empty()) {
  906. tensor_list = debug_services_->GetTensor();
  907. } else {
  908. tensor_list = debug_services_->GetNodeTensor(kernel);
  909. }
  910. std::vector<std::string> file_list;
  911. MS_LOG(INFO) << "checkwatchpoints call for step " << num_step_;
  912. debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, &parameters, &error_codes, overflow_ops,
  913. file_list, &tensor_list, initial_suspend_, watchnode.empty(), recheck);
  914. std::list<WatchpointHit> hits;
  915. for (unsigned int i = 0; i < name.size(); i++) {
  916. WatchpointHit hit;
  917. std::vector<DebugServices::parameter_t> &parameter = parameters[i];
  918. hit.set_id(watchpoint_id[i]);
  919. hit.set_error_code(error_codes[i]);
  920. // here TensorProto act as a tensor indicator, not sending tensor content
  921. TensorProto *tensor_item = hit.mutable_tensor();
  922. tensor_item->set_node_name(name[i]);
  923. tensor_item->set_slot(slot[i]);
  924. tensor_item->set_finished(true);
  925. WatchCondition *condition_item = hit.mutable_watch_condition();
  926. condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
  927. for (const auto &p : parameter) {
  928. auto x = condition_item->mutable_params()->Add();
  929. x->set_name(p.name);
  930. x->set_disabled(p.disabled);
  931. x->set_value(p.value);
  932. x->set_hit(p.hit);
  933. x->set_actual_value(p.actual_value);
  934. }
  935. hits.push_back(hit);
  936. }
  937. return hits;
  938. }
  939. void Debugger::SendWatchpoints(const std::list<WatchpointHit> &points) {
  940. // send info about watchpoint
  941. if (!points.empty()) {
  942. EventReply reply = grpc_client_->SendWatchpointHits(points);
  943. if (reply.status() != reply.OK) {
  944. MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
  945. }
  946. }
  947. }
  948. bool Debugger::DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath,
  949. const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type,
  950. TypeId device_type, const std::string &addr_format, size_t slot) const {
  951. return debug_services_.get()->DumpTensorToFile(tensor_name, trans_flag, filepath, host_fmt, host_shape, host_type,
  952. device_type, addr_format, slot);
  953. }
  954. bool Debugger::DebugServicesIsWatchPoint(const std::string &kernel_name, const CNodePtr &kernel) const {
  955. return debug_services_.get()->IsWatchPoint(kernel_name, kernel);
  956. }
  957. void Debugger::EmptyTensor() { debug_services_.get()->EmptyTensor(); }
  958. void Debugger::SetTensorLoaderIterNum(uint32_t iter_num) { debug_services_.get()->SetTensorLoaderIterNum(iter_num); }
  959. void Debugger::EmptyPrevTensor() { debug_services_.get()->EmptyPrevTensor(); }
  960. uint32_t Debugger::GetTensorLoaderIterNum() const { return debug_services_.get()->GetTensorLoaderIterNum(); }
  961. bool Debugger::LoadNewTensor(const std::shared_ptr<TensorData> &tensor, bool keep_prev) {
  962. return debug_services_.get()->LoadNewTensor(tensor, keep_prev);
  963. }
  964. bool Debugger::debugger_enabled() const { return debugger_enabled_; }
  965. DebuggerCommand GetCommand(const EventReply &reply) {
  966. DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
  967. switch (reply.cmd_case()) {
  968. case debugger::EventReply::CmdCase::kExit:
  969. cmd = DebuggerCommand::kExitCMD;
  970. break;
  971. case debugger::EventReply::CmdCase::kRunCmd:
  972. cmd = DebuggerCommand::kRunCMD;
  973. break;
  974. case debugger::EventReply::CmdCase::kSetCmd:
  975. cmd = DebuggerCommand::kSetCMD;
  976. break;
  977. case debugger::EventReply::CmdCase::kViewCmd:
  978. cmd = DebuggerCommand::kViewCMD;
  979. break;
  980. case debugger::EventReply::CmdCase::kVersionMatched:
  981. cmd = DebuggerCommand::kVersionMatchedCMD;
  982. break;
  983. default:
  984. MS_LOG(DEBUG) << "Debug: UnknownCMD";
  985. break;
  986. }
  987. return cmd;
  988. }
  989. ProtoVector<WatchCondition_Parameter> GetParameters(const EventReply &reply) {
  990. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  991. MS_LOG(ERROR) << "Error: Can not get Parameters from command. Returning default value: ProtoVector<Parameter>().";
  992. return ProtoVector<WatchCondition_Parameter>();
  993. }
  994. return reply.set_cmd().watch_condition().params();
  995. }
  996. ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
  997. if (!reply.has_set_cmd()) {
  998. MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
  999. return ProtoVector<WatchNode>();
  1000. }
  1001. return reply.set_cmd().watch_nodes();
  1002. }
  1003. std::string GetRunLevel(const EventReply &reply) {
  1004. if (!reply.has_run_cmd()) {
  1005. MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
  1006. "";
  1007. return "";
  1008. }
  1009. return reply.run_cmd().run_level();
  1010. }
  1011. std::string GetNodeName(const EventReply &reply) {
  1012. if (!reply.has_run_cmd()) {
  1013. MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
  1014. "";
  1015. return "";
  1016. }
  1017. return reply.run_cmd().node_name();
  1018. }
  1019. WatchCondition GetWatchcondition(const EventReply &reply) {
  1020. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  1021. MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
  1022. return WatchCondition();
  1023. }
  1024. return reply.set_cmd().watch_condition();
  1025. }
  1026. int32_t GetWatchpointID(const EventReply &reply) {
  1027. if (!reply.has_set_cmd()) {
  1028. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
  1029. return 0;
  1030. }
  1031. return reply.set_cmd().id();
  1032. }
  1033. bool GetWatchpointDelete(const EventReply &reply) {
  1034. if (!reply.has_set_cmd()) {
  1035. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
  1036. return false;
  1037. }
  1038. return reply.set_cmd().delete_();
  1039. }
  1040. ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
  1041. if (!reply.has_view_cmd()) {
  1042. MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
  1043. return ProtoVector<TensorProto>();
  1044. }
  1045. return reply.view_cmd().tensors();
  1046. }
  1047. std::string GetTensorFullName(const TensorProto &tensor) {
  1048. string node_name = tensor.node_name();
  1049. if (tensor.truncate()) {
  1050. // scopes in node name are separated by '/'
  1051. // use the name without scope if truncate is true
  1052. std::size_t found = node_name.find_last_of("/");
  1053. node_name = node_name.substr(found + 1);
  1054. }
  1055. return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
  1056. }
  1057. bool GetMiVersionMatched(const EventReply &reply) { return reply.version_matched(); }
  1058. bool Debugger::partial_memory() const { return partial_memory_; }
  1059. void Debugger::SetCurNode(const std::string &cur_name) {
  1060. // access lock for public method
  1061. std::lock_guard<std::mutex> a_lock(access_lock_);
  1062. cur_name_ = cur_name;
  1063. }
  1064. std::string Debugger::run_level() const { return run_level_; }
  1065. void Debugger::SetStepNum(int32_t cur_num_step) {
  1066. // access lock for public method
  1067. std::lock_guard<std::mutex> a_lock(access_lock_);
  1068. num_step_ = cur_num_step;
  1069. }
  1070. int32_t Debugger::step_num() const { return num_step_; }
  1071. uint64_t BytestoUInt64(const std::vector<char> &buffer) {
  1072. return le64toh(*reinterpret_cast<const uint64_t *>(buffer.data()));
  1073. }
  1074. std::vector<std::string> Debugger::CheckOpOverflow() {
  1075. std::vector<double> bin_list;
  1076. std::vector<std::string> op_names;
  1077. for (const auto &[graph_id, overflow_bin_path] : overflow_bin_path_) {
  1078. DIR *d = opendir(overflow_bin_path.c_str());
  1079. MS_LOG(INFO) << "processing bin file path " << overflow_bin_path << ", graph id " << graph_id;
  1080. if (d != nullptr) {
  1081. struct dirent *dir = nullptr;
  1082. while ((dir = readdir(d)) != nullptr) {
  1083. if (dir->d_type == DT_REG) {
  1084. std::string file_path = overflow_bin_path;
  1085. file_path.append(dir->d_name);
  1086. std::string file_name = dir->d_name;
  1087. std::size_t found = file_name.find_last_of(".");
  1088. if (found == std::string::npos) {
  1089. continue;
  1090. }
  1091. std::string overflow_time = file_name.substr(found + 1);
  1092. if (stod(overflow_time) <= last_overflow_bin_) {
  1093. MS_LOG(INFO) << "File already processed " << file_name;
  1094. continue;
  1095. }
  1096. bin_list.push_back(stod(overflow_time));
  1097. std::fstream infile;
  1098. infile.open(file_path.c_str(), std::ios::binary | std::ios::in);
  1099. if (!infile.is_open()) {
  1100. MS_LOG(ERROR) << "Failed to open overflow bin file " << file_name;
  1101. continue;
  1102. }
  1103. MS_LOG(INFO) << "Open overflow bin file " << file_name;
  1104. const uint32_t offset = 321;
  1105. (void)infile.seekg(offset, std::ios::beg);
  1106. std::vector<char> buffer;
  1107. const size_t buf_size = 256;
  1108. buffer.resize(buf_size);
  1109. (void)infile.read(buffer.data(), buf_size);
  1110. const uint8_t stream_id_offset = 16;
  1111. const uint8_t task_id_offset = 24;
  1112. // The stream_id and task_id in the dump file are 8 byte fields for extensibility purpose, but only hold 4
  1113. // byte values currently.
  1114. uint64_t stream_id = BytestoUInt64(std::vector<char>(buffer.begin() + stream_id_offset, buffer.end()));
  1115. uint64_t task_id = BytestoUInt64(std::vector<char>(buffer.begin() + task_id_offset, buffer.end()));
  1116. MS_LOG(INFO) << "Overflow stream_id " << stream_id << ", task_id " << task_id << ".";
  1117. auto op = debugger_->stream_task_to_opname_.find(std::make_pair(stream_id, task_id));
  1118. if (op != debugger_->stream_task_to_opname_.end()) {
  1119. MS_LOG(ERROR) << "Overflow detected on node " << op->second << std::endl;
  1120. op_names.push_back(op->second);
  1121. } else {
  1122. MS_LOG(INFO) << "No overflow is detected " << std::endl;
  1123. }
  1124. infile.close();
  1125. }
  1126. }
  1127. } else {
  1128. MS_LOG(INFO) << "OverFlow bin directory does not exist!";
  1129. }
  1130. closedir(d);
  1131. }
  1132. if (!op_names.empty()) {
  1133. MS_LOG(ERROR) << "These operation overflows are detected " << op_names;
  1134. }
  1135. for (auto &i : bin_list) {
  1136. if (i > last_overflow_bin_) {
  1137. last_overflow_bin_ = i;
  1138. }
  1139. }
  1140. auto iter_op_names = overflow_ops_.find(num_step_);
  1141. if (iter_op_names == overflow_ops_.end()) {
  1142. overflow_ops_.insert(std::pair<uint32_t, std::vector<std::string>>(num_step_, op_names));
  1143. return op_names;
  1144. }
  1145. iter_op_names->second.insert(std::end(iter_op_names->second), std::begin(op_names), std::end(op_names));
  1146. return iter_op_names->second;
  1147. }
  1148. void Debugger::SetTrainingDone(bool training_done) { training_done_ = training_done; }
  1149. bool Debugger::CheckPort(const std::string &port) const {
  1150. int num = 0;
  1151. const int min_port_num = 1;
  1152. const int max_port_num = 65535;
  1153. const int decimal = 10;
  1154. if (port[0] == '0' && port[1] != '\0') return false;
  1155. int i = 0;
  1156. while (port[i] != '\0') {
  1157. if (port[i] < '0' || port[i] > '9') return false;
  1158. num = num * decimal + (port[i] - '0');
  1159. if (num > max_port_num) return false;
  1160. i++;
  1161. }
  1162. if (num < min_port_num) return false;
  1163. return true;
  1164. }
  1165. bool Debugger::CheckIp(const std::string &host) const {
  1166. std::regex reg_ip(
  1167. "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
  1168. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  1169. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  1170. "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
  1171. std::smatch smat;
  1172. std::string host_str = host;
  1173. return std::regex_match(host_str, smat, reg_ip);
  1174. }
  1175. uint32_t Debugger::GetFirstRunGraphId() const { return rungraph_id_list_.front(); }
  1176. void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
  1177. MS_EXCEPTION_IF_NULL(anf_node);
  1178. if (!anf_node->isa<Parameter>() && !anf_node->isa<ValueNode>()) {
  1179. return;
  1180. }
  1181. // When MindRT is used, only ValueNodes and ParameterWeights can be loaded from device to host
  1182. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) && (device_target_ == kGPUDevice)) {
  1183. if (!anf_node->isa<ValueNode>() &&
  1184. !(anf_node->isa<Parameter>() && AnfAlgo::IsParameterWeight(anf_node->cast<ParameterPtr>()))) {
  1185. return;
  1186. }
  1187. }
  1188. // for parameters and value nodes, set its execution order to be 0;
  1189. int exec_order = 0;
  1190. std::string node_name = GetKernelNodeName(anf_node);
  1191. GetFileKernelName(NOT_NULL(&node_name));
  1192. // check if output adde exists, if not, return;
  1193. if (!AnfAlgo::OutputAddrExist(anf_node, output_index)) {
  1194. return;
  1195. }
  1196. auto addr = AnfAlgo::GetOutputAddr(anf_node, output_index);
  1197. MS_EXCEPTION_IF_NULL(addr);
  1198. auto type = AnfAlgo::GetOutputInferDataType(anf_node, output_index);
  1199. if (!IsTypeDebuggerSupported(type)) {
  1200. return;
  1201. }
  1202. auto format = kOpFormat_DEFAULT;
  1203. string tensor_name = node_name + ':' + "0";
  1204. ShapeVector int_shapes = trans::GetRuntimePaddingShape(anf_node, output_index);
  1205. bool keep_prev;
  1206. if (anf_node->isa<Parameter>()) {
  1207. keep_prev = true;
  1208. debug_services_->MoveTensorCurrentToPrev(tensor_name);
  1209. } else {
  1210. keep_prev = false;
  1211. }
  1212. bool ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, keep_prev);
  1213. if (!ret) {
  1214. MS_LOG(ERROR) << "LoadMemToHost:"
  1215. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1216. }
  1217. }
  1218. void Debugger::LoadParametersAndConst() {
  1219. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  1220. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1221. // load parameters
  1222. MS_LOG(INFO) << "Start to load Parameters!";
  1223. const auto &parameters = graph_ptr_->inputs();
  1224. for (auto &item : parameters) {
  1225. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
  1226. }
  1227. // load value nodes
  1228. // get all constant avlues from the graph
  1229. MS_LOG(INFO) << "Start to load value nodes!";
  1230. const auto value_nodes = graph_ptr_->graph_value_nodes();
  1231. for (auto &item : value_nodes) {
  1232. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
  1233. }
  1234. }
  1235. void Debugger::LoadParametersAndConst(const KernelGraphPtr &graph) {
  1236. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  1237. MS_EXCEPTION_IF_NULL(graph);
  1238. // load parameters
  1239. MS_LOG(INFO) << "Start to load Parameters for graph " << graph->graph_id();
  1240. const auto &parameters = graph_ptr_->inputs();
  1241. for (auto &item : parameters) {
  1242. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
  1243. }
  1244. // load value nodes
  1245. // get all constant avlues from the graph
  1246. MS_LOG(INFO) << "Start to load value nodes for graph " << graph->graph_id();
  1247. const auto value_nodes = graph_ptr_->graph_value_nodes();
  1248. for (auto &item : value_nodes) {
  1249. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
  1250. }
  1251. }
  1252. void Debugger::LoadGraphOutputs() {
  1253. if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
  1254. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1255. const auto &apply_kernels = graph_ptr_->execution_order();
  1256. // for kernels, execution order starts from 1
  1257. int exec_order = 1;
  1258. for (const auto &node : apply_kernels) {
  1259. MS_EXCEPTION_IF_NULL(node);
  1260. std::string kernel_name = GetKernelNodeName(node);
  1261. auto output_size = AnfAlgo::GetOutputTensorNum(node);
  1262. if (partial_memory_) {
  1263. if (!debug_services_->IsWatchPoint(kernel_name, node)) {
  1264. continue;
  1265. }
  1266. }
  1267. for (size_t j = 0; j < output_size; ++j) {
  1268. if (!AnfAlgo::OutputAddrExist(node, j)) {
  1269. MS_LOG(INFO) << "Cannot find output addr for slot " << j << " for " << kernel_name;
  1270. continue;
  1271. }
  1272. auto addr = AnfAlgo::GetOutputAddr(node, j);
  1273. MS_EXCEPTION_IF_NULL(addr);
  1274. auto type = AnfAlgo::GetOutputInferDataType(node, j);
  1275. if (!IsTypeDebuggerSupported(type)) {
  1276. continue;
  1277. }
  1278. auto format = kOpFormat_DEFAULT;
  1279. string tensor_name = kernel_name + ':' + std::to_string(j);
  1280. ShapeVector int_shapes = trans::GetRuntimePaddingShape(node, j);
  1281. auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false);
  1282. if (!ret) {
  1283. MS_LOG(ERROR) << "LoadMemToHost:"
  1284. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1285. }
  1286. }
  1287. exec_order = exec_order + 1;
  1288. }
  1289. }
  1290. void Debugger::UpdateStepNum(const session::KernelGraph *graph) {
  1291. // update step number if we are processing the first graph (to support multigraph)
  1292. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()) &&
  1293. (graph->graph_id() == debugger_->GetFirstRunGraphId())) {
  1294. // access lock for public method
  1295. std::lock_guard<std::mutex> a_lock(access_lock_);
  1296. ++num_step_;
  1297. }
  1298. }
  1299. void Debugger::UpdateStepNumGPU() {
  1300. // UpdateStepNum with DebugActor::DebugOnStepEnd
  1301. if (device_target_ == kGPUDevice && (debugger_enabled_ || DumpDataEnabledIteration())) {
  1302. // access lock for public method
  1303. std::lock_guard<std::mutex> a_lock(access_lock_);
  1304. ++num_step_;
  1305. }
  1306. }
  1307. void Debugger::ClearCurrentData() {
  1308. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()))
  1309. debug_services_->EmptyCurrentTensor();
  1310. }
  1311. bool Debugger::TensorExistsInCurrent(const std::string &tensor_name) {
  1312. return debug_services_->TensorExistsInCurrent(tensor_name);
  1313. }
  1314. } // namespace mindspore