You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger.cc 66 kB

4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767
  1. /**
  2. * Copyright 2020-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <cstdio>
  18. #include <fstream>
  19. #include <tuple>
  20. #include <vector>
  21. #include <algorithm>
  22. #include <iostream>
  23. #include <cstring>
  24. #include <utility>
  25. #include <map>
  26. #include <regex>
  27. #include "debug/debugger/debugger.h"
  28. #include "debug/data_dump/dump_json_parser.h"
  29. #include "pipeline/jit/pipeline.h"
  30. #include "backend/session/anf_runtime_algorithm.h"
  31. #include "runtime/device/kernel_runtime_manager.h"
  32. #include "runtime/device/kernel_runtime.h"
  33. #include "debug/data_dump/e2e_dump.h"
  34. #include "utils/config_manager.h"
  35. #include "debug/env_config_parser.h"
  36. #include "utils/comm_manager.h"
  37. #include "runtime/hardware/device_context_manager.h"
  38. #include "debug/anf_ir_dump.h"
  39. #include "debug/anf_ir_utils.h"
  40. #ifdef ENABLE_DEBUGGER
  41. #include "debug/debugger/proto_exporter.h"
  42. #else
  43. #include "debug/debugger/proto_exporter_stub.h"
  44. #endif
  45. using debugger::Chunk;
  46. using debugger::EventReply;
  47. using debugger::GraphProto;
  48. using debugger::ModelProto;
  49. using debugger::Statistics;
  50. using debugger::TensorProto;
  51. using debugger::WatchCondition;
  52. using debugger::WatchCondition_Condition_inf;
  53. using debugger::WatchCondition_Condition_nan;
  54. using debugger::WatchCondition_Parameter;
  55. using debugger::WatchNode;
  56. using debugger::WatchpointHit;
  57. namespace mindspore {
  58. static constexpr auto g_chunk_size = 1024 * 1024 * 3;
  59. static constexpr int32_t heartbeat_period_second = 30;
  60. DebuggerPtr Debugger::debugger_ = nullptr;
  61. std::mutex Debugger::instance_lock_;
  62. Debugger::Debugger()
  63. : grpc_client_(nullptr),
  64. debug_services_(nullptr),
  65. heartbeat_thread_(nullptr),
  66. device_id_(0),
  67. device_target_(""),
  68. num_step_(0),
  69. debugger_enabled_(false),
  70. suspended_at_last_kernel_(false),
  71. run_level_(""),
  72. node_name_(""),
  73. cur_name_(""),
  74. training_done_(false),
  75. send_metadata_done_(false),
  76. received_new_graph_(false),
  77. is_dataset_graph_(false),
  78. partial_memory_(false),
  79. initial_suspend_(true),
  80. enable_heartbeat_(false),
  81. not_dataset_graph_sum_(0),
  82. ascend_kernel_by_kernel_(false),
  83. version_("") {
  84. CheckDebuggerEnabledParam();
  85. auto ms_context = MsContext::GetInstance();
  86. MS_EXCEPTION_IF_NULL(ms_context);
  87. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  88. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  89. if (device_target == kCPUDevice) {
  90. MS_LOG(WARNING) << "Not enabling debugger. Debugger does not support CPU.";
  91. } else if (CheckDebuggerEnabled()) {
  92. // configure partial memory reuse
  93. partial_memory_ = CheckDebuggerPartialMemoryEnabled();
  94. // switch memory reuse on or off
  95. EnvConfigParser::GetInstance().SetSysMemreuse(partial_memory_);
  96. // print some message about memory reuse to user
  97. if (partial_memory_) {
  98. MS_LOG(WARNING)
  99. << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
  100. "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
  101. } else {
  102. MS_LOG(WARNING)
  103. << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
  104. "usage for large models.";
  105. }
  106. }
  107. }
  108. void Debugger::Init(const uint32_t device_id, const std::string device_target) {
  109. // access lock for public method
  110. std::lock_guard<std::mutex> a_lock(access_lock_);
  111. // save device_id
  112. MS_LOG(INFO) << "Debugger got device_id: " << device_id;
  113. device_id_ = device_id;
  114. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  115. device_target_ = device_target;
  116. version_ = MSVERSION;
  117. }
  118. bool IsTypeDebuggerSupported(TypeId type) {
  119. if (type < TypeId::kNumberTypeEnd && type > TypeId::kNumberTypeBegin && type != kNumberTypeComplex64) {
  120. return true;
  121. }
  122. MS_LOG(INFO) << "Debugger does not support type: " << TypeIdLabel(type);
  123. return false;
  124. }
  125. void Debugger::EnableDebugger() {
  126. // reset some of the class members
  127. num_step_ = 0;
  128. debugger_enabled_ = false;
  129. enable_heartbeat_ = false;
  130. partial_memory_ = false;
  131. grpc_client_ = nullptr;
  132. debug_services_ = nullptr;
  133. heartbeat_thread_ = nullptr;
  134. // see if dump using debugger backend is enabled
  135. bool dump_enabled = CheckDebuggerDumpEnabled();
  136. MS_LOG(INFO) << "dump using debugger backend = " << dump_enabled;
  137. // check if debugger enabled
  138. debugger_enabled_ = CheckDebuggerEnabled();
  139. MS_LOG(INFO) << "debugger_enabled_ = " << debugger_enabled_;
  140. if (!debugger_enabled_ && !dump_enabled) {
  141. MS_LOG(INFO) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
  142. return;
  143. }
  144. if (debugger_enabled_) {
  145. // configure grpc host
  146. std::string env_host_str = common::GetEnv("MS_DEBUGGER_HOST");
  147. std::string host;
  148. if (!env_host_str.empty()) {
  149. if (CheckIp(env_host_str)) {
  150. MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
  151. host = env_host_str;
  152. } else {
  153. debugger_enabled_ = false;
  154. MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_HOST isn't a valid IP address. "
  155. "Please set environment variable MS_DEBUGGER_HOST=x.x.x.x to a valid IP";
  156. }
  157. } else {
  158. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
  159. host = "localhost";
  160. }
  161. // configure grpc port
  162. std::string env_port_str = common::GetEnv("MS_DEBUGGER_PORT");
  163. std::string port;
  164. if (!env_port_str.empty()) {
  165. if (CheckPort(env_port_str)) {
  166. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
  167. port = env_port_str;
  168. } else {
  169. debugger_enabled_ = false;
  170. MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to "
  171. "65535";
  172. }
  173. } else {
  174. port = "50051";
  175. if (!CheckPort(port)) {
  176. MS_EXCEPTION(ValueError) << "Default MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to 65535";
  177. }
  178. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
  179. }
  180. // initialize grpc client
  181. grpc_client_ = std::make_unique<GrpcClient>(host, port);
  182. // initialize sending heartbeat
  183. heartbeat_thread_ = std::make_unique<std::thread>([this]() { SendHeartbeat(heartbeat_period_second); });
  184. }
  185. debug_services_ = std::make_unique<DebugServices>();
  186. }
  187. void Debugger::CheckDatasetSinkMode(const KernelGraphPtr &graph_ptr) {
  188. bool sink_mode = ConfigManager::GetInstance().dataset_mode() || graph_ptr->IsDatasetGraph();
  189. if (CheckDebuggerDumpEnabled() && sink_mode && device_target_ == kGPUDevice) {
  190. MS_EXCEPTION(NotSupportError)
  191. << "e2e_dump is not supported on GPU with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  192. }
  193. if (CheckDebuggerEnabled() && sink_mode) {
  194. MS_EXCEPTION(NotSupportError)
  195. << "Debugger is not supported with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  196. }
  197. }
  198. bool Debugger::CheckDebuggerDumpEnabled() const {
  199. // see if dump is enabled
  200. auto &dump_json_parser = DumpJsonParser::GetInstance();
  201. if (device_target_ == kGPUDevice) {
  202. return dump_json_parser.e2e_dump_enabled();
  203. } else if (device_target_ == kAscendDevice) {
  204. return dump_json_parser.async_dump_enabled() || dump_json_parser.e2e_dump_enabled();
  205. }
  206. return false;
  207. }
  208. bool Debugger::CheckDebuggerEnabled() const {
  209. // get env variables to configure debugger
  210. std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
  211. if (!env_enable_str.empty()) {
  212. (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
  213. if ((env_enable_str == "1" || env_enable_str == "true") && device_target_ != kCPUDevice) {
  214. return true;
  215. }
  216. }
  217. return false;
  218. }
  219. void Debugger::CheckDebuggerEnabledParam() const {
  220. // check the value of env variable ENABLE_MS_DEBUGGER
  221. std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
  222. if (!env_enable_str.empty()) {
  223. (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
  224. if (env_enable_str != "0" && env_enable_str != "1" && env_enable_str != "false" && env_enable_str != "true") {
  225. MS_LOG(WARNING) << "Env variable ENABLE_MS_DEBUGGER should be True/False/1/0 (case insensitive), but get: "
  226. << env_enable_str;
  227. }
  228. }
  229. }
  230. bool Debugger::CheckDebuggerPartialMemoryEnabled() const {
  231. std::string env_partial_mem_str = common::GetEnv("MS_DEBUGGER_PARTIAL_MEM");
  232. if (!env_partial_mem_str.empty()) {
  233. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
  234. if (env_partial_mem_str == "1") {
  235. return true;
  236. }
  237. }
  238. return false;
  239. }
  240. /*
  241. * Feature group: Dump, Online debugger.
  242. * Target device group: Ascend, GPU.
  243. * Runtime category: Old runtime, MindRT
  244. * Description: Returns true if online debugger or dump is enabled.
  245. */
  246. bool Debugger::DebuggerBackendEnabled() const { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
  247. void Debugger::Reset() {
  248. // access lock for public method
  249. std::lock_guard<std::mutex> a_lock(access_lock_);
  250. // reset components
  251. if (heartbeat_thread_ && heartbeat_thread_->joinable()) {
  252. SetEnableHeartbeat(false);
  253. heartbeat_thread_->join();
  254. MS_LOG(INFO) << "Join Heartbeat thread.";
  255. }
  256. heartbeat_thread_ = nullptr;
  257. device_id_ = 0;
  258. device_target_ = "";
  259. num_step_ = 0;
  260. debugger_enabled_ = false;
  261. is_dataset_graph_ = false;
  262. partial_memory_ = false;
  263. graph_ptr_ = nullptr;
  264. grpc_client_ = nullptr;
  265. debug_services_ = nullptr;
  266. graph_proto_list_.clear();
  267. graph_ptr_list_.clear();
  268. graph_ptr_step_vec_.clear();
  269. MS_LOG(INFO) << "Release Debugger resource.";
  270. }
  271. /*
  272. * Feature group: Dump, Online debugger.
  273. * Target device group: Ascend, GPU.
  274. * Runtime category: MindRT.
  275. * Description: Sets root_graph_id for all the graphs in the compiled graph list. Sets cur_root_graph_id_ and
  276. * prev_root_graph_id_ and calls PreExecute function for all the graphs.
  277. */
  278. void Debugger::PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs) {
  279. // MindRTBackend for GPU and Ascend
  280. if (device_target_ == kCPUDevice) {
  281. return;
  282. }
  283. // Store graphs that are run in one step.
  284. graph_ptr_step_vec_ = graphs;
  285. prev_root_graph_id_ = cur_root_graph_id_;
  286. // set first run graph as the root graph
  287. cur_root_graph_id_ = graph_ptr_step_vec_[0]->graph_id();
  288. MS_LOG(DEBUG) << "Current root graph id: " << cur_root_graph_id_ << " prev_root_graph_id_: " << prev_root_graph_id_
  289. << " for step: " << num_step_ << ".";
  290. MS_LOG(DEBUG) << "Set root graph for all the subgraphs:";
  291. for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
  292. const auto &graph = graphs[graph_index];
  293. // set root graph id for GPU mindrt runtime.
  294. MS_LOG(DEBUG) << "Set root graph for graph: " << graph->graph_id() << " to: " << cur_root_graph_id_ << ".";
  295. graph->set_root_graph_id(cur_root_graph_id_);
  296. if (debugger_) {
  297. debugger_->PreExecute(graph);
  298. }
  299. }
  300. }
  301. /*
  302. * Feature group: Dump.
  303. * Target device group: Ascend.
  304. * Runtime category: Old runtime, MindRT.
  305. * Description: When async dump is enabled and dataset_sink_mode is true, graph_iter_num_map_ stores the number of
  306. * iterations per epoch for each running graph.
  307. */
  308. void Debugger::UpdateGraphIterMap(uint32_t graph_id, int32_t iter_num) {
  309. if (graph_iter_num_map_.find(graph_id) == graph_iter_num_map_.end()) {
  310. graph_iter_num_map_[graph_id] = iter_num;
  311. }
  312. }
  313. /*
  314. * Feature group: Dump, Online debugger.
  315. * Target device group: Ascend.
  316. * Runtime category: Old runtime.
  317. * Description: For Ascend old runtime, this function sets the current and previous root graph id.
  318. */
  319. void Debugger::SetCurrentAndPrevRootGraph(uint32_t root_graph_id) {
  320. // for GPU and ascend MindRT root graphs are set in PreExecuteGraphDebugger.
  321. if (device_target_ != kAscendDevice || MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
  322. return;
  323. }
  324. prev_root_graph_id_ = cur_root_graph_id_;
  325. cur_root_graph_id_ = root_graph_id;
  326. MS_LOG(DEBUG) << "Current root graph id: " << cur_root_graph_id_ << " prev_root_graph_id_: " << prev_root_graph_id_
  327. << " for step: " << num_step_ << ".";
  328. }
  329. /*
  330. * Feature group: Dump, Online debugger.
  331. * Target device group: GPU.
  332. * Runtime category: Old runtime.
  333. * Description: In the case of GPU old runtime and when we have multiple subgraphs, we use the first run graph id to
  334. * update the step number.
  335. */
  336. void Debugger::StoreRunGraphIdList(uint32_t graph_id) {
  337. // collect rungrap_ids to update step number in multigraph case for GPU old runtime
  338. if (!rungraph_id_list_.size()) {
  339. rungraph_id_list_.push_back(graph_id);
  340. } else {
  341. if (std::find(rungraph_id_list_.begin(), rungraph_id_list_.end(), graph_id) == rungraph_id_list_.end()) {
  342. rungraph_id_list_.push_back(graph_id);
  343. }
  344. }
  345. }
  346. /*
  347. * Feature group: Dump, Online debugger.
  348. * Target device group: Ascend, GPU.
  349. * Runtime category: Old runtime, MindRT.
  350. * Description: Sets previous and current root_graph_id for Ascend old runtime, sends graphs to online debugger when
  351. * debugger_enabled_ is true.
  352. */
  353. void Debugger::PreExecute(const KernelGraphPtr &graph_ptr) {
  354. MS_EXCEPTION_IF_NULL(graph_ptr);
  355. // access lock for public method
  356. std::lock_guard<std::mutex> a_lock(access_lock_);
  357. if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
  358. // Checking dataset_sink_mode for mindRT is done in debug_actor
  359. CheckDatasetSinkMode(graph_ptr);
  360. }
  361. auto graph_id = graph_ptr->graph_id();
  362. MS_LOG(DEBUG) << "PreExecute for graph: " << graph_id << " in step: " << num_step_ << ".";
  363. StoreRunGraphIdList(graph_id);
  364. SetCurrentAndPrevRootGraph(graph_ptr->root_graph_id());
  365. // multiple graphs
  366. if (graph_proto_list_.size() > 1) {
  367. // there are more than one graphs are not dataset_graph
  368. if (not_dataset_graph_sum_ > 0) {
  369. SendMultiGraphsAndClear(graph_ptr);
  370. }
  371. } else if (graph_proto_list_.size() == 1) {
  372. // single graph, and not the initial step
  373. if (device_target_ == kGPUDevice && !MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) &&
  374. num_step_ != 0) {
  375. if (debugger_enabled_ && !(run_level_ == "node" && suspended_at_last_kernel_)) {
  376. CommandLoop();
  377. }
  378. debug_services_->ResetLoadedTensors();
  379. }
  380. // In single graph case, reset graph_ptr_ to be nullptr when debugger receives a new graph
  381. if (received_new_graph_) {
  382. graph_ptr_ = nullptr;
  383. CheckGraphPtr(graph_ptr);
  384. }
  385. } else if (debugger_enabled_ && graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice &&
  386. !MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
  387. // Multiple graph, and not the initial step,
  388. // stop only when receive the first sub run graph for each step for old runtime
  389. // if we have stopped for the last kernel before, no need to stop again
  390. if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
  391. return;
  392. }
  393. if (!(run_level_ == "node" && suspended_at_last_kernel_)) {
  394. CommandLoop();
  395. }
  396. debug_services_->ResetLoadedTensors();
  397. }
  398. // resets for the new graph
  399. suspended_at_last_kernel_ = false;
  400. }
  401. /*
  402. * Feature group: Online debugger.
  403. * Target device group: Ascend, GPU.
  404. * Runtime category: Old runtime, MindRT.
  405. * Description: Sends all the subgraphs to online debugger when debugger_enabled_ is true.
  406. */
  407. void Debugger::SendMultiGraphsAndClear(const KernelGraphPtr &graph_ptr) {
  408. // only try to enable debugger if they are not all dataset graphs
  409. if (!debugger_enabled_) {
  410. EnableDebugger();
  411. }
  412. if (debugger_enabled_) {
  413. // only send compiled graphs once at the initial step.
  414. auto dbg_graph_ptr = graph_ptr_;
  415. // use current graph ptr to load parameters
  416. graph_ptr_ = graph_ptr;
  417. LoadParametersAndConst();
  418. // revert graph ptr to original value
  419. graph_ptr_ = dbg_graph_ptr;
  420. SendMultiGraphsAndSuspend(graph_proto_list_);
  421. graph_proto_list_.clear();
  422. received_new_graph_ = false;
  423. }
  424. }
  425. /*
  426. * Feature group: Dump.
  427. * Target device group: Ascend, GPU.
  428. * Runtime category: Old runtime, MindRT.
  429. * Description: Returns true for e2e dump if dump is enabled for the current iteration.
  430. */
  431. bool Debugger::DumpDataEnabledIteration() const {
  432. auto &dump_json_parser = DumpJsonParser::GetInstance();
  433. if (!dump_json_parser.e2e_dump_enabled()) {
  434. return false;
  435. }
  436. auto cur_iter = dump_json_parser.cur_dump_iter();
  437. if (dump_json_parser.IsDumpIter(cur_iter)) {
  438. return true;
  439. }
  440. return false;
  441. }
  442. /*
  443. * Feature group: Dump.
  444. * Target device group: Ascend, GPU.
  445. * Runtime category: MindRT.
  446. * Description: Returns the rank_id for GPU and Ascend kernel-bykernel mindRT.
  447. */
  448. uint32_t Debugger::GetRankID() {
  449. auto ms_context = MsContext::GetInstance();
  450. MS_EXCEPTION_IF_NULL(ms_context);
  451. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  452. uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  453. const auto &device_context =
  454. device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
  455. uint32_t rank_id = device_context->GetRankID();
  456. return rank_id;
  457. }
  458. /*
  459. * Feature group: Dump.
  460. * Target device group: Ascend, GPU.
  461. * Runtime category: MindRT.
  462. * Description: Dumps graph history and parameters for GPU and Ascend kernel-by-kernel MindRT. DumpConstantData for GPU.
  463. */
  464. void Debugger::Dump(const KernelGraphPtr &kernel_graph) const {
  465. if (!(ascend_kernel_by_kernel_ || device_target_ == kGPUDevice)) {
  466. return;
  467. }
  468. uint32_t rank_id = GetRankID();
  469. E2eDump::DumpRunIter(kernel_graph, rank_id);
  470. if (debugger_ && debugger_->DebuggerBackendEnabled()) {
  471. MS_EXCEPTION_IF_NULL(kernel_graph);
  472. (void)E2eDump::DumpParametersData(kernel_graph.get(), rank_id, debugger_.get());
  473. // Dump constant data for GPU mindRT.
  474. E2eDump::DumpConstantData(kernel_graph.get(), rank_id, debugger_.get());
  475. } else {
  476. DumpJsonParser::GetInstance().UpdateDumpIter();
  477. }
  478. }
  479. void Debugger::DumpConstantDataAscend(const KernelGraphPtr &graph) {
  480. if (device_target_ != kAscendDevice) {
  481. return;
  482. }
  483. auto &json_parser = DumpJsonParser::GetInstance();
  484. if (json_parser.e2e_dump_enabled() || json_parser.async_dump_enabled()) {
  485. // Dump constant data for ascend mindRT, for old runtime constant data is dumped in session_basic.
  486. uint32_t rank_id = GetRankID();
  487. std::string cst_file_dir = GenerateDumpPath(graph->root_graph_id(), rank_id, true);
  488. DumpConstantInfo(graph, cst_file_dir);
  489. }
  490. }
  491. /*
  492. * Feature group: Dump.
  493. * Target device group: Ascend, GPU.
  494. * Runtime category: MindRT.
  495. * Description: Dumps a single node for given graph_id.
  496. */
  497. void Debugger::DumpSingleNode(const CNodePtr &node, uint32_t graph_id) {
  498. if (debugger_ && debugger_->DebuggerBackendEnabled()) {
  499. uint32_t rank_id = GetRankID();
  500. (void)E2eDump::DumpSingleNodeData(node, graph_id, rank_id, debugger_.get());
  501. }
  502. }
  503. /*
  504. * Feature group: Dump.
  505. * Target device group: GPU.
  506. * Runtime category: MindRT.
  507. * Description: This function is used for new GPU runtime using MindRTBackend, on Ascend platform, graphs are saved in
  508. * session_basic.
  509. */
  510. void Debugger::DumpInGraphCompiler(const KernelGraphPtr &kernel_graph) {
  511. if (device_target_ == kAscendDevice) {
  512. return;
  513. }
  514. auto &json_parser = DumpJsonParser::GetInstance();
  515. if (json_parser.e2e_dump_enabled()) {
  516. uint32_t rank_id = GetRankID();
  517. kernel_graph->set_root_graph_id(kernel_graph->graph_id());
  518. std::string final_graph = "trace_code_graph_" + std::to_string(kernel_graph->graph_id());
  519. std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id);
  520. std::string target_dir = root_dir + "/graphs";
  521. std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
  522. DumpIRProtoWithSrcInfo(kernel_graph, final_graph, target_dir, kDebugWholeStack);
  523. DumpIR("trace_code_graph", kernel_graph, true, kWholeStack, ir_file_path);
  524. DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(kernel_graph->graph_id()) + ".csv", root_dir,
  525. kernel_graph->execution_order());
  526. }
  527. }
  528. /*
  529. * Feature group: Dump, Online debugger.
  530. * Target device group: Ascend, GPU and CPU.
  531. * Runtime category: MindRT.
  532. * Description: Load and dump parameters and constant data, call postExecute and update dump iter.
  533. */
  534. void Debugger::PostExecuteGraphDebugger() {
  535. // On CPU, update dump iteration, Parameters and consts are not dumped here
  536. if (device_target_ == kCPUDevice) {
  537. DumpJsonParser::GetInstance().UpdateDumpIter();
  538. return;
  539. }
  540. // LoadParametersAndConst for all the graphs that have been run in the current step
  541. if (debugger_ && device_target_ == kGPUDevice) {
  542. for (auto graph : graph_ptr_step_vec_) {
  543. debugger_->LoadParametersAndConst(graph);
  544. }
  545. }
  546. // debug used for dump
  547. if (debugger_ && debugger_->CheckDebuggerDumpEnabled()) {
  548. // Dump Parameters and consts
  549. for (auto graph : graph_ptr_step_vec_) {
  550. debugger_->Dump(graph);
  551. DumpConstantDataAscend(graph);
  552. if (!debugger_->debugger_enabled()) {
  553. debugger_->ClearCurrentData();
  554. }
  555. }
  556. }
  557. if (debugger_) {
  558. debugger_->PostExecute();
  559. }
  560. if (ascend_kernel_by_kernel_ || device_target_ == kGPUDevice) {
  561. E2eDump::UpdateIterMindRTDump();
  562. }
  563. }
  564. /*
  565. * Feature group: Online debugger.
  566. * Target device group: Ascend, GPU.
  567. * Runtime category: Old runtime, MindRT.
  568. * Description: Send hit watchpoints, update the step number and reset loaded tensors.
  569. */
  570. void Debugger::PostExecute() {
  571. // access lock for public method
  572. std::lock_guard<std::mutex> a_lock(access_lock_);
  573. if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
  574. return;
  575. }
  576. if (debugger_ && debugger_->DebuggerBackendEnabled()) {
  577. // analyze tensor data and send the watchpoints been hit
  578. if (debugger_enabled_ && !is_dataset_graph_) {
  579. SendWatchpoints(CheckWatchpoints());
  580. // no need to suspend at each graph for GPU old runtime, suspension happens in preExecute
  581. if (device_target_ == kAscendDevice) {
  582. CommandLoop();
  583. } else if (device_target_ == kGPUDevice && MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
  584. if (!(run_level_ == "node" && suspended_at_last_kernel_)) {
  585. CommandLoop();
  586. }
  587. }
  588. if (device_target_ != kGPUDevice) {
  589. num_step_++;
  590. }
  591. }
  592. // Only keep parameters in th current map
  593. // GPU ResetLoadedTensors for old runtime happens in preExecute
  594. if ((device_target_ == kGPUDevice && MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) ||
  595. device_target_ == kAscendDevice) {
  596. if (debug_services_ != nullptr) {
  597. debug_services_->ResetLoadedTensors();
  598. } else {
  599. MS_LOG(DEBUG) << "debug_services_ is nullptr";
  600. }
  601. }
  602. }
  603. }
  604. bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) const {
  605. if (debugger_enabled_ && !is_dataset_graph_) {
  606. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  607. // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
  608. if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
  609. return true;
  610. }
  611. }
  612. return false;
  613. }
  614. /*
  615. * Feature group: Online debugger.
  616. * Target device group: GPU.
  617. * Runtime category: Old runtime, MindRT.
  618. * Description: Check and send watchpoint hit for a single node, suspend if a watchpoint is hit or we are continuing
  619. * in node level.
  620. */
  621. void Debugger::PostExecuteNode(const CNodePtr &kernel, bool last_kernel) {
  622. // access lock for public method
  623. std::lock_guard<std::mutex> a_lock(access_lock_);
  624. if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
  625. return;
  626. }
  627. if (debugger_enabled_ && !is_dataset_graph_) {
  628. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  629. // if kernel is watchpoint,and get hit. suspend.
  630. bool hit_empty_flag = true;
  631. if (is_watchpoint) {
  632. auto hits = CheckWatchpoints(cur_name_, kernel);
  633. if (!hits.empty()) {
  634. SendWatchpoints(hits);
  635. CommandLoop();
  636. hit_empty_flag = false;
  637. }
  638. }
  639. if (hit_empty_flag && run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) {
  640. // if kernel is not watchpoint and is next_to or continue_to node, suspend
  641. // sets a bool to be checked in preExecute to avoid double stopping at last kernel in the last graph
  642. if (last_kernel) {
  643. suspended_at_last_kernel_ = true;
  644. }
  645. CommandLoop();
  646. }
  647. return;
  648. }
  649. }
  650. /*
  651. * Feature group: Dump, Online debugger.
  652. * Target device group: Ascend, GPU.
  653. * Runtime category: Old runtime, MindRT.
  654. * Description: Get graph proto and add it to graph proto list and add loaded graph pointers to a list.
  655. */
  656. void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) {
  657. MS_EXCEPTION_IF_NULL(graph_ptr);
  658. if (graph_ptr_ != graph_ptr) {
  659. MS_LOG(INFO) << "LoadGraphs Debugger got new graph: " << graph_ptr->graph_id();
  660. received_new_graph_ = true;
  661. // save new graph_ptr
  662. graph_ptr_ = graph_ptr;
  663. CheckDatasetGraph();
  664. if (!is_dataset_graph_) {
  665. // get proto for new graph_ptr
  666. auto graph_proto = GetGraphProto(graph_ptr);
  667. // add new graph proto to graph_proto_list_
  668. graph_proto_list_.push_back(graph_proto);
  669. graph_ptr_list_.push_back(graph_ptr);
  670. not_dataset_graph_sum_++;
  671. }
  672. // reset is_dataset_graph to be false
  673. is_dataset_graph_ = false;
  674. }
  675. }
  676. // In single graph cases, check single graph ptr
  677. void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
  678. MS_EXCEPTION_IF_NULL(graph_ptr);
  679. if (graph_ptr_ != graph_ptr) {
  680. MS_LOG(INFO) << "CheckGraphPtr Debugger got new graph: " << graph_ptr->graph_id();
  681. // save new graph_ptr
  682. graph_ptr_ = graph_ptr;
  683. if (!is_dataset_graph_) {
  684. // only try to enable debugger if it is not a dataset graph
  685. if (!debugger_enabled_) {
  686. EnableDebugger();
  687. }
  688. if (debugger_enabled_) {
  689. LoadParametersAndConst();
  690. // get graph proto and send to MindInsight
  691. auto graph_proto = graph_proto_list_.front();
  692. SendGraphAndSuspend(graph_proto);
  693. graph_proto_list_.clear();
  694. received_new_graph_ = false;
  695. }
  696. }
  697. }
  698. }
  699. void Debugger::CheckDatasetGraph() {
  700. // print parameter node names
  701. MS_EXCEPTION_IF_NULL(graph_ptr_);
  702. const auto &params = graph_ptr_->inputs();
  703. for (const auto &param : params) {
  704. MS_LOG(INFO) << "param: " << GetKernelNodeName(param);
  705. }
  706. // check if there is GetNext or InitDataSetQueue node
  707. const auto &nodes = graph_ptr_->execution_order();
  708. for (const auto &node : nodes) {
  709. auto node_name = AnfAlgo::GetCNodeName(node);
  710. MS_LOG(INFO) << "node: " << GetKernelNodeName(node);
  711. if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
  712. MS_LOG(INFO) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
  713. << node_name;
  714. is_dataset_graph_ = true;
  715. return;
  716. }
  717. }
  718. is_dataset_graph_ = false;
  719. }
  720. GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
  721. // convert kernel graph to debugger modelproto
  722. ModelProto model = GetDebuggerFuncGraphProto(graph_ptr);
  723. return model.graph();
  724. }
  725. /*
  726. * Feature group: Online debugger.
  727. * Target device group: Ascend, GPU.
  728. * Runtime category: Old runtime, MindRT.
  729. * Description: Send debugger backend heartbeat to online debugger every few seconds.
  730. */
  731. void Debugger::SendHeartbeat(int32_t period) {
  732. int num_heartbeat_fail = 0;
  733. const int max_num_heartbeat_fail = 5;
  734. const int retry_milliseconds = 500;
  735. Heartbeat heartbeat;
  736. heartbeat.set_message("Debugger is alive");
  737. heartbeat.set_period(heartbeat_period_second);
  738. SetEnableHeartbeat(CheckDebuggerEnabled());
  739. while (enable_heartbeat_) {
  740. MS_EXCEPTION_IF_NULL(grpc_client_);
  741. EventReply reply = grpc_client_->SendHeartbeat(heartbeat);
  742. if (reply.status() != reply.OK) {
  743. MS_LOG(ERROR) << "Error: SendHeartbeat failed";
  744. num_heartbeat_fail++;
  745. if (num_heartbeat_fail >= max_num_heartbeat_fail) {
  746. MS_LOG(ERROR) << "Maximum number of failure for SendHeartbeat reached : exiting training session.";
  747. SetEnableHeartbeat(false);
  748. break;
  749. } else {
  750. MS_LOG(ERROR) << "Number of consecutive SendHeartbeat fail:" << num_heartbeat_fail;
  751. std::this_thread::sleep_for(std::chrono::milliseconds(retry_milliseconds));
  752. }
  753. } else {
  754. int recheck_period_ms = 200;
  755. for (int i = 0; i < (period * 1000 / recheck_period_ms); i++) {
  756. if (enable_heartbeat_) {
  757. std::this_thread::sleep_for(std::chrono::milliseconds(recheck_period_ms));
  758. } else {
  759. break;
  760. }
  761. }
  762. }
  763. }
  764. }
  765. void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
  766. if (!CheckSendMetadata()) {
  767. return;
  768. }
  769. // send graph to MindInsight server
  770. MS_EXCEPTION_IF_NULL(grpc_client_);
  771. EventReply reply = grpc_client_->SendGraph(graph_proto);
  772. if (reply.status() != reply.OK) {
  773. MS_LOG(ERROR) << "Error: SendGraph failed";
  774. }
  775. // enter command loop, wait and process commands
  776. CommandLoop();
  777. }
  778. bool Debugger::SendMetadata(bool version_check) {
  779. // prepare metadata
  780. MS_EXCEPTION_IF_NULL(graph_ptr_);
  781. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  782. Metadata metadata;
  783. metadata.set_device_name(device_name);
  784. metadata.set_cur_step(num_step_);
  785. metadata.set_backend(device_target_);
  786. metadata.set_cur_node(cur_name_);
  787. metadata.set_training_done(training_done_);
  788. metadata.set_ms_version(version_);
  789. MS_LOG(INFO) << "Is training done?" << training_done_;
  790. // set graph number to not_dataset_graph_sum_
  791. metadata.set_graph_num(not_dataset_graph_sum_);
  792. MS_EXCEPTION_IF_NULL(grpc_client_);
  793. EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
  794. bool ret = false;
  795. if (reply_metadata.status() == reply_metadata.OK) {
  796. if (version_check) {
  797. // get type of the command in meta data reply, it should be version matched
  798. DebuggerCommand cmd = GetCommand(reply_metadata);
  799. if (cmd != DebuggerCommand::kVersionMatchedCMD) {
  800. MS_LOG(ERROR) << "MindInsight version is too old, Mindspore version is " << version_;
  801. Exit();
  802. } else {
  803. if (GetMiVersionMatched(reply_metadata)) {
  804. MS_LOG(INFO) << "MindSpore version is " << version_ << " matches MindInsight version.";
  805. ret = true;
  806. } else {
  807. MS_LOG(ERROR) << "MindSpore version " << version_ << ", did not match MindInsight version.";
  808. CommandLoop();
  809. }
  810. }
  811. } else {
  812. // version check is done before so we can just return true here
  813. ret = true;
  814. }
  815. } else {
  816. MS_LOG(ERROR) << "Error: SendMetadata failed";
  817. }
  818. return ret;
  819. }
  820. void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list) {
  821. if (!CheckSendMetadata()) {
  822. return;
  823. }
  824. MS_EXCEPTION_IF_NULL(grpc_client_);
  825. // send multiple graphs to mindinght server
  826. // split graph into chunks if one graph is larger than chunk size
  827. std::list<Chunk> chunked_graph_proto_list;
  828. Chunk chunk;
  829. for (auto graph : graph_proto_list) {
  830. std::string str = graph.SerializeAsString();
  831. auto graph_size = graph.ByteSize();
  832. if (graph_size > g_chunk_size) {
  833. auto sub_graph_str = grpc_client_->ChunkString(str, graph_size);
  834. for (unsigned int i = 0; i < sub_graph_str.size(); i++) {
  835. chunk.set_buffer(sub_graph_str[i]);
  836. if (i < sub_graph_str.size() - 1) {
  837. chunk.set_finished(false);
  838. } else {
  839. chunk.set_finished(true);
  840. }
  841. chunked_graph_proto_list.push_back(chunk);
  842. }
  843. } else {
  844. chunk.set_buffer(str);
  845. chunk.set_finished(true);
  846. chunked_graph_proto_list.push_back(chunk);
  847. }
  848. }
  849. EventReply reply = grpc_client_->SendMultiGraphs(chunked_graph_proto_list);
  850. if (reply.status() != reply.OK) {
  851. MS_LOG(ERROR) << "Error: SendGraph failed";
  852. }
  853. // enter command loop, wait and process commands
  854. CommandLoop();
  855. }
  856. bool Debugger::CheckSendMetadata() {
  857. if (!send_metadata_done_) {
  858. if (!SendMetadata(true)) {
  859. return false;
  860. }
  861. send_metadata_done_ = true;
  862. }
  863. return true;
  864. }
  865. void Debugger::CommandLoop() {
  866. // prepare metadata
  867. MS_EXCEPTION_IF_NULL(graph_ptr_);
  868. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(cur_root_graph_id_);
  869. Metadata metadata;
  870. metadata.set_device_name(device_name);
  871. metadata.set_cur_step(num_step_);
  872. metadata.set_backend(device_target_);
  873. metadata.set_cur_node(cur_name_);
  874. metadata.set_training_done(training_done_);
  875. // loop exit flag
  876. bool run = false;
  877. int num_wait_fail = 0;
  878. const int max_num_wait_fail = 5;
  879. while (!run) {
  880. // wait for command
  881. MS_EXCEPTION_IF_NULL(grpc_client_);
  882. EventReply reply = grpc_client_->WaitForCommand(metadata);
  883. if (reply.status() != reply.OK) {
  884. MS_LOG(ERROR) << "Error: WaitForCommand failed";
  885. num_wait_fail++;
  886. if (num_wait_fail > max_num_wait_fail) {
  887. MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session.";
  888. MS_LOG(ERROR) << "Failed to connect to MindInsight debugger server. Please check the config "
  889. "of debugger host and port.";
  890. Exit();
  891. run = true;
  892. } else {
  893. MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
  894. << num_wait_fail << "s";
  895. std::this_thread::sleep_for(std::chrono::seconds(num_wait_fail));
  896. }
  897. continue;
  898. }
  899. // get type of the command in reply
  900. DebuggerCommand cmd = GetCommand(reply);
  901. if (cmd == DebuggerCommand::kUnknownCMD) {
  902. MS_LOG(DEBUG) << "Debug: debugger received unknown command";
  903. continue;
  904. }
  905. MS_LOG(INFO) << "received command: ";
  906. switch (cmd) {
  907. case DebuggerCommand::kUnknownCMD:
  908. MS_LOG(INFO) << "UnknownCMD";
  909. break;
  910. case DebuggerCommand::kExitCMD:
  911. MS_LOG(INFO) << "ExitCMD";
  912. Exit(true);
  913. // Used for debugger termination
  914. run = true;
  915. break;
  916. case DebuggerCommand::kRunCMD:
  917. ProcessRunCMD(reply);
  918. if (GetRunLevel(reply) != "recheck") {
  919. // exit loop
  920. run = true;
  921. }
  922. break;
  923. case DebuggerCommand::kSetCMD:
  924. ProcessKSetCMD(reply);
  925. break;
  926. case DebuggerCommand::kViewCMD:
  927. ProcessKViewCMD(reply);
  928. break;
  929. case DebuggerCommand::kVersionMatchedCMD:
  930. MS_LOG(ERROR) << "Received unexpected Version Matched CMD from MindInsight.";
  931. Exit();
  932. break;
  933. default:
  934. MS_LOG(ERROR) << "Received unknown CMD from MindInsight";
  935. Exit();
  936. break;
  937. }
  938. }
  939. }
  940. void Debugger::ProcessRunCMD(const EventReply &reply) {
  941. MS_LOG(INFO) << "RunCMD";
  942. if (GetRunLevel(reply) == "recheck") {
  943. MS_LOG(INFO) << "rechecking all watchpoints";
  944. SendWatchpoints(CheckWatchpoints("", nullptr, true));
  945. } else {
  946. // no longer the initial suspension.
  947. initial_suspend_ = false;
  948. // print run cmd content
  949. // get run_level and node_name
  950. run_level_ = GetRunLevel(reply);
  951. node_name_ = GetNodeName(reply);
  952. MS_LOG(INFO) << "run_level: " << run_level_;
  953. MS_LOG(INFO) << "node_name_: " << node_name_;
  954. }
  955. }
  956. void Debugger::ProcessKSetCMD(const EventReply &reply) {
  957. MS_LOG(INFO) << "SetCMD";
  958. MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
  959. MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
  960. if (GetWatchpointDelete(reply)) {
  961. MS_LOG(INFO) << "Deleting watchpoint";
  962. RemoveWatchpoint(GetWatchpointID(reply));
  963. } else {
  964. MS_LOG(INFO) << "Setting watchpoint";
  965. MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
  966. ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
  967. for (const auto &node : recieved_nodes) {
  968. MS_LOG(INFO) << "node name: " << node.node_name();
  969. MS_LOG(INFO) << "node type: " << node.node_type();
  970. }
  971. ProtoVector<WatchCondition_Parameter> parameters = GetParameters(reply);
  972. for (const auto &parameter : parameters) {
  973. MS_LOG(INFO) << "parameter name: " << parameter.name();
  974. MS_LOG(INFO) << "parameter is disabled: " << parameter.disabled();
  975. MS_LOG(INFO) << "parameter value: " << parameter.value();
  976. }
  977. SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply), GetParameters(reply));
  978. }
  979. }
  980. void Debugger::ProcessKViewCMD(const EventReply &reply) {
  981. MS_LOG(INFO) << "ViewCMD";
  982. // print view cmd content
  983. ProtoVector<TensorProto> received_tensors = GetTensors(reply);
  984. for (auto received_tensor : received_tensors) {
  985. MS_LOG(INFO) << "tensor node name: " << received_tensor.node_name();
  986. MS_LOG(INFO) << "tensor slot: " << received_tensor.slot();
  987. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << received_tensor.finished() << std::noboolalpha;
  988. MS_LOG(INFO) << "tensor iter: " << received_tensor.iter();
  989. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << received_tensor.truncate() << std::noboolalpha;
  990. }
  991. switch (reply.view_cmd().level()) {
  992. case debugger::ViewCMD_Level::ViewCMD_Level_base:
  993. MS_LOG(INFO) << "Tensor base request.";
  994. ViewBaseLevel(reply);
  995. break;
  996. case debugger::ViewCMD_Level::ViewCMD_Level_statistics:
  997. MS_LOG(INFO) << "Tensor statistics request.";
  998. ViewStatLevel(reply);
  999. break;
  1000. case debugger::ViewCMD_Level::ViewCMD_Level_value:
  1001. MS_LOG(INFO) << "Tensor value request.";
  1002. ViewValueLevel(reply);
  1003. break;
  1004. default:
  1005. MS_LOG(DEBUG) << "Debug: Unknown tensor info level";
  1006. break;
  1007. }
  1008. }
  1009. void Debugger::ViewValueLevel(const EventReply &reply) {
  1010. MS_LOG(INFO) << "Sending tensors";
  1011. std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
  1012. // print view cmd reply
  1013. for (auto tensor : tensors) {
  1014. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  1015. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  1016. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  1017. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  1018. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  1019. MS_LOG(INFO) << "tensor dims: ";
  1020. for (auto dim : tensor.dims()) {
  1021. MS_LOG(INFO) << dim << ",";
  1022. }
  1023. MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
  1024. }
  1025. MS_EXCEPTION_IF_NULL(grpc_client_);
  1026. EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
  1027. if (send_tensors_reply.status() != debugger::EventReply::OK) {
  1028. MS_LOG(ERROR) << "Error: SendTensors failed";
  1029. }
  1030. }
  1031. void Debugger::ViewStatLevel(const EventReply &reply) {
  1032. std::list<TensorSummary> tensor_stats_list = LoadTensorsStat(GetTensors(reply));
  1033. EventReply send_tensors_stat_reply = grpc_client_->SendTensorStats(tensor_stats_list);
  1034. if (send_tensors_stat_reply.status() != debugger::EventReply::OK) {
  1035. MS_LOG(ERROR) << "Error: SendTensorsStats failed.";
  1036. }
  1037. }
  1038. void Debugger::ViewBaseLevel(const EventReply &reply) {
  1039. std::list<TensorBase> tensor_base_list = LoadTensorsBase(GetTensors(reply));
  1040. EventReply send_tensor_base_reply = grpc_client_->SendTensorBase(tensor_base_list);
  1041. if (send_tensor_base_reply.status() != debugger::EventReply::OK) {
  1042. MS_LOG(ERROR) << "Error: SendTensorsBase failed.";
  1043. }
  1044. }
  1045. void AddTensorProtoInfo(TensorProto *tensor_item, const TensorProto &tensor) {
  1046. tensor_item->set_node_name(tensor.node_name());
  1047. tensor_item->set_slot(tensor.slot());
  1048. tensor_item->set_iter(tensor.iter());
  1049. tensor_item->set_truncate(tensor.truncate());
  1050. tensor_item->clear_tensor_content();
  1051. tensor_item->clear_data_type();
  1052. tensor_item->clear_dims();
  1053. }
  1054. void AddTensorStatInfo(const DebugServices::TensorStat &tensor_stat,
  1055. std::list<TensorSummary> *const tensor_summary_list) {
  1056. if (tensor_summary_list == nullptr) {
  1057. MS_LOG(DEBUG) << "tensor_summary_list is nullptr.";
  1058. return;
  1059. }
  1060. TensorSummary tensor_summary_item;
  1061. TensorBase *tensor_base = tensor_summary_item.mutable_tensor_base();
  1062. tensor_base->set_data_type(tensor_stat.dtype);
  1063. tensor_base->set_data_size((int64_t)tensor_stat.data_size);
  1064. for (auto elem : tensor_stat.shape) {
  1065. tensor_base->add_shape(elem);
  1066. }
  1067. Statistics *tensor_statistics = tensor_summary_item.mutable_statistics();
  1068. tensor_statistics->set_is_bool(tensor_stat.is_bool);
  1069. tensor_statistics->set_max_value(static_cast<float>(tensor_stat.max_value));
  1070. tensor_statistics->set_min_value(static_cast<float>(tensor_stat.min_value));
  1071. tensor_statistics->set_avg_value(static_cast<float>(tensor_stat.avg_value));
  1072. tensor_statistics->set_count(tensor_stat.count);
  1073. tensor_statistics->set_neg_zero_count(tensor_stat.neg_zero_count);
  1074. tensor_statistics->set_pos_zero_count(tensor_stat.pos_zero_count);
  1075. tensor_statistics->set_nan_count(tensor_stat.nan_count);
  1076. tensor_statistics->set_neg_inf_count(tensor_stat.neg_inf_count);
  1077. tensor_statistics->set_pos_inf_count(tensor_stat.pos_inf_count);
  1078. tensor_statistics->set_zero_count(tensor_stat.zero_count);
  1079. tensor_summary_list->push_back(tensor_summary_item);
  1080. }
  1081. void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id,
  1082. const ProtoVector<WatchCondition_Parameter> &parameters) {
  1083. std::vector<std::tuple<std::string, bool>> check_node_list;
  1084. std::vector<DebugServices::parameter_t> parameter_list;
  1085. std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
  1086. [](const WatchNode &node) -> std::tuple<std::string, bool> {
  1087. return make_tuple(node.node_name(), node.node_type() == "scope");
  1088. });
  1089. std::transform(
  1090. parameters.begin(), parameters.end(), std::back_inserter(parameter_list),
  1091. [](const WatchCondition_Parameter &parameter) -> DebugServices::parameter_t {
  1092. return DebugServices::parameter_t{parameter.name(), parameter.disabled(), parameter.value(), parameter.hit()};
  1093. });
  1094. debug_services_->AddWatchpoint(id, condition.condition(), condition.value(), check_node_list, parameter_list);
  1095. }
  1096. void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
  1097. std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
  1098. std::vector<std::string> name;
  1099. std::vector<std::string> ret_name;
  1100. std::vector<const char *> data_ptr;
  1101. std::vector<ssize_t> data_size;
  1102. std::vector<unsigned int> dtype;
  1103. std::vector<std::vector<int64_t>> shape;
  1104. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  1105. // ret_name will contain tensor names that are found in TensorLoader
  1106. // items in ret_name will be in the same order with tensors if found
  1107. debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
  1108. std::list<TensorProto> tensor_list;
  1109. size_t result_index = 0;
  1110. for (auto tensor : tensors) {
  1111. ssize_t size_iter = 0;
  1112. if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
  1113. TensorProto tensor_item;
  1114. tensor_item.set_finished(true);
  1115. AddTensorProtoInfo(&tensor_item, tensor);
  1116. tensor_list.push_back(tensor_item);
  1117. continue;
  1118. }
  1119. ssize_t tensor_size = data_size[result_index];
  1120. while (size_iter < tensor_size) {
  1121. ssize_t chunk_size = g_chunk_size;
  1122. TensorProto tensor_item;
  1123. tensor_item.set_finished(false);
  1124. if (tensor_size - size_iter <= g_chunk_size) {
  1125. chunk_size = tensor_size - size_iter;
  1126. tensor_item.set_finished(true);
  1127. }
  1128. AddTensorProtoInfo(&tensor_item, tensor);
  1129. // return empty tensor if didn't find the requested tensor
  1130. tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
  1131. tensor_item.set_data_type((debugger::DataType)dtype[result_index]);
  1132. for (auto &elem : shape[result_index]) {
  1133. tensor_item.add_dims(elem);
  1134. }
  1135. // add tensor to result list and increment result_index to check next item in ret_name
  1136. tensor_list.push_back(tensor_item);
  1137. if (size_iter > INT_MAX - g_chunk_size) {
  1138. MS_EXCEPTION(ValueError) << size_iter << " + " << g_chunk_size << " would lead to integer overflow!";
  1139. }
  1140. size_iter += g_chunk_size;
  1141. }
  1142. result_index++;
  1143. }
  1144. return tensor_list;
  1145. }
  1146. std::list<TensorBase> Debugger::LoadTensorsBase(const ProtoVector<TensorProto> &tensors) const {
  1147. std::list<TensorBase> tensor_base_list;
  1148. std::vector<std::string> name;
  1149. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  1150. std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list;
  1151. debug_services_->SearchNodesTensors(name, &result_list);
  1152. for (auto result : result_list) {
  1153. auto tensor = std::get<1>(result);
  1154. if (!tensor || ((cur_root_graph_id_ != tensor->GetRootGraphId()) &&
  1155. MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT))) {
  1156. // tensor was not found or tensor's graph was not executed in the current step, creating empty tensor base.
  1157. TensorBase tensor_base_item;
  1158. tensor_base_item.set_data_size(0);
  1159. tensor_base_item.set_data_type(0);
  1160. tensor_base_item.add_shape(0);
  1161. tensor_base_list.push_back(tensor_base_item);
  1162. continue;
  1163. }
  1164. // tensor was found creating tensor base object.
  1165. TensorBase tensor_base_item;
  1166. tensor_base_item.set_data_size((int64_t)tensor->GetByteSize());
  1167. tensor_base_item.set_data_type((int32_t)tensor->GetType());
  1168. for (auto elem : tensor->GetShape()) {
  1169. tensor_base_item.add_shape(elem);
  1170. }
  1171. tensor_base_list.push_back(tensor_base_item);
  1172. }
  1173. return tensor_base_list;
  1174. }
  1175. std::list<TensorSummary> Debugger::LoadTensorsStat(const ProtoVector<TensorProto> &tensors) const {
  1176. std::list<TensorSummary> tensor_summary_list;
  1177. std::vector<std::string> name;
  1178. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  1179. std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list;
  1180. debug_services_->SearchNodesTensors(name, &result_list);
  1181. for (auto result : result_list) {
  1182. auto tensor = std::get<1>(result);
  1183. if (!tensor || ((cur_root_graph_id_ != tensor->GetRootGraphId()) &&
  1184. MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT))) {
  1185. // tensor was not found or tensor's graph was not executed in the current step, creating empty tensor summary.
  1186. DebugServices::TensorStat tensor_stat;
  1187. AddTensorStatInfo(tensor_stat, &tensor_summary_list);
  1188. continue;
  1189. }
  1190. // tensor was found creating tensor summary object.
  1191. DebugServices::TensorStat tensor_stat = DebugServices::GetTensorStatistics(tensor);
  1192. AddTensorStatInfo(tensor_stat, &tensor_summary_list);
  1193. }
  1194. return tensor_summary_list;
  1195. }
  1196. std::shared_ptr<TensorData> Debugger::GetTensor(const std::string &tensor_name) const {
  1197. return debug_services_->GetTensor(tensor_name);
  1198. }
  1199. void Debugger::Exit(bool exit_success) {
  1200. // debugger will notify main thread to exit because main thread can only exit at step boundary.
  1201. MS_LOG(INFO) << "Exit Debugger";
  1202. SetEnableHeartbeat(false);
  1203. pipeline::GraphExecutorPy::DebugTerminate(true, exit_success);
  1204. }
  1205. std::list<WatchpointHit> Debugger::CheckWatchpoints(const std::string &watchnode, const CNodePtr &kernel,
  1206. bool recheck) {
  1207. std::vector<std::string> name;
  1208. std::vector<std::string> slot;
  1209. std::vector<int> condition;
  1210. std::vector<unsigned int> watchpoint_id;
  1211. std::vector<std::string> overflow_ops;
  1212. std::vector<std::vector<DebugServices::parameter_t>> parameters;
  1213. std::vector<int32_t> error_codes;
  1214. std::vector<std::shared_ptr<TensorData>> tensor_list;
  1215. if (watchnode.empty()) {
  1216. tensor_list = debug_services_->GetTensor();
  1217. } else {
  1218. tensor_list = debug_services_->GetNodeTensor(kernel);
  1219. }
  1220. DebugServices::AsyncFilePool file_list;
  1221. MS_LOG(INFO) << "checkwatchpoints call for step " << num_step_;
  1222. debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, &parameters, &error_codes, overflow_ops,
  1223. file_list, &tensor_list, initial_suspend_, watchnode.empty(), recheck);
  1224. std::list<WatchpointHit> hits;
  1225. for (unsigned int i = 0; i < name.size(); i++) {
  1226. WatchpointHit hit;
  1227. std::vector<DebugServices::parameter_t> &parameter = parameters[i];
  1228. hit.set_id(watchpoint_id[i]);
  1229. hit.set_error_code(error_codes[i]);
  1230. // here TensorProto act as a tensor indicator, not sending tensor content
  1231. TensorProto *tensor_item = hit.mutable_tensor();
  1232. tensor_item->set_node_name(name[i]);
  1233. tensor_item->set_slot(slot[i]);
  1234. tensor_item->set_finished(true);
  1235. WatchCondition *condition_item = hit.mutable_watch_condition();
  1236. condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
  1237. for (const auto &p : parameter) {
  1238. auto x = condition_item->mutable_params()->Add();
  1239. x->set_name(p.name);
  1240. x->set_disabled(p.disabled);
  1241. x->set_value(p.value);
  1242. x->set_hit(p.hit);
  1243. x->set_actual_value(p.actual_value);
  1244. }
  1245. hits.push_back(hit);
  1246. }
  1247. return hits;
  1248. }
  1249. void Debugger::SendWatchpoints(const std::list<WatchpointHit> &points) {
  1250. // send info about watchpoint
  1251. if (!points.empty()) {
  1252. MS_EXCEPTION_IF_NULL(grpc_client_);
  1253. EventReply reply = grpc_client_->SendWatchpointHits(points);
  1254. if (reply.status() != reply.OK) {
  1255. MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
  1256. }
  1257. }
  1258. }
  1259. bool Debugger::DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath,
  1260. const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type,
  1261. TypeId device_type, const std::string &addr_format, size_t slot) const {
  1262. return debug_services_.get()->DumpTensorToFile(tensor_name, trans_flag, filepath, host_fmt, host_shape, host_type,
  1263. device_type, addr_format, slot);
  1264. }
  1265. bool Debugger::LoadNewTensor(const std::shared_ptr<TensorData> &tensor, bool keep_prev) {
  1266. return debug_services_.get()->LoadNewTensor(tensor, keep_prev);
  1267. }
  1268. bool Debugger::debugger_enabled() const { return debugger_enabled_; }
  1269. DebuggerCommand GetCommand(const EventReply &reply) {
  1270. DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
  1271. switch (reply.cmd_case()) {
  1272. case debugger::EventReply::CmdCase::kExit:
  1273. cmd = DebuggerCommand::kExitCMD;
  1274. break;
  1275. case debugger::EventReply::CmdCase::kRunCmd:
  1276. cmd = DebuggerCommand::kRunCMD;
  1277. break;
  1278. case debugger::EventReply::CmdCase::kSetCmd:
  1279. cmd = DebuggerCommand::kSetCMD;
  1280. break;
  1281. case debugger::EventReply::CmdCase::kViewCmd:
  1282. cmd = DebuggerCommand::kViewCMD;
  1283. break;
  1284. case debugger::EventReply::CmdCase::kVersionMatched:
  1285. cmd = DebuggerCommand::kVersionMatchedCMD;
  1286. break;
  1287. default:
  1288. MS_LOG(DEBUG) << "Debug: UnknownCMD";
  1289. break;
  1290. }
  1291. return cmd;
  1292. }
  1293. ProtoVector<WatchCondition_Parameter> GetParameters(const EventReply &reply) {
  1294. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  1295. MS_LOG(ERROR) << "Error: Can not get Parameters from command. Returning default value: ProtoVector<Parameter>().";
  1296. return ProtoVector<WatchCondition_Parameter>();
  1297. }
  1298. return reply.set_cmd().watch_condition().params();
  1299. }
  1300. ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
  1301. if (!reply.has_set_cmd()) {
  1302. MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
  1303. return ProtoVector<WatchNode>();
  1304. }
  1305. return reply.set_cmd().watch_nodes();
  1306. }
  1307. std::string GetRunLevel(const EventReply &reply) {
  1308. if (!reply.has_run_cmd()) {
  1309. MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
  1310. "";
  1311. return "";
  1312. }
  1313. return reply.run_cmd().run_level();
  1314. }
  1315. std::string GetNodeName(const EventReply &reply) {
  1316. if (!reply.has_run_cmd()) {
  1317. MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
  1318. "";
  1319. return "";
  1320. }
  1321. return reply.run_cmd().node_name();
  1322. }
  1323. WatchCondition GetWatchcondition(const EventReply &reply) {
  1324. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  1325. MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
  1326. return WatchCondition();
  1327. }
  1328. return reply.set_cmd().watch_condition();
  1329. }
  1330. int32_t GetWatchpointID(const EventReply &reply) {
  1331. if (!reply.has_set_cmd()) {
  1332. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
  1333. return 0;
  1334. }
  1335. return reply.set_cmd().id();
  1336. }
  1337. bool GetWatchpointDelete(const EventReply &reply) {
  1338. if (!reply.has_set_cmd()) {
  1339. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
  1340. return false;
  1341. }
  1342. return reply.set_cmd().delete_();
  1343. }
  1344. ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
  1345. if (!reply.has_view_cmd()) {
  1346. MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
  1347. return ProtoVector<TensorProto>();
  1348. }
  1349. return reply.view_cmd().tensors();
  1350. }
  1351. std::string GetTensorFullName(const TensorProto &tensor) {
  1352. string node_name = tensor.node_name();
  1353. if (tensor.truncate()) {
  1354. // scopes in node name are separated by '/'
  1355. // use the name without scope if truncate is true
  1356. std::size_t found = node_name.find_last_of("/");
  1357. node_name = node_name.substr(found + 1);
  1358. }
  1359. return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
  1360. }
  1361. bool GetMiVersionMatched(const EventReply &reply) { return reply.version_matched(); }
  1362. bool Debugger::partial_memory() const { return partial_memory_; }
  1363. void Debugger::SetEnableHeartbeat(bool enabled) { enable_heartbeat_ = enabled; }
  1364. void Debugger::SetCurNode(const std::string &cur_name) {
  1365. // access lock for public method
  1366. std::lock_guard<std::mutex> a_lock(access_lock_);
  1367. cur_name_ = cur_name;
  1368. }
  1369. std::string Debugger::run_level() const { return run_level_; }
  1370. void Debugger::SetTrainingDone(bool training_done) { training_done_ = training_done; }
  1371. bool Debugger::CheckPort(const std::string &port) const {
  1372. int num = 0;
  1373. const int min_port_num = 1;
  1374. const int max_port_num = 65535;
  1375. const int decimal = 10;
  1376. if (port[0] == '0' && port[1] != '\0') return false;
  1377. int i = 0;
  1378. while (port[i] != '\0') {
  1379. if (port[i] < '0' || port[i] > '9') return false;
  1380. num = num * decimal + (port[i] - '0');
  1381. if (num > max_port_num) return false;
  1382. i++;
  1383. }
  1384. if (num < min_port_num) return false;
  1385. return true;
  1386. }
  1387. bool Debugger::CheckIp(const std::string &host) const {
  1388. std::regex reg_ip(
  1389. "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
  1390. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  1391. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  1392. "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
  1393. std::smatch smat;
  1394. std::string host_str = host;
  1395. return std::regex_match(host_str, smat, reg_ip);
  1396. }
  1397. uint32_t Debugger::GetFirstRunGraphId() const { return rungraph_id_list_.front(); }
  1398. /*
  1399. * Feature group: Dump.
  1400. * Target device group: Ascend, GPU.
  1401. * Runtime category: Old runtime, MindRT.
  1402. * Description: Load a single parameter or value node.
  1403. */
  1404. void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index, uint32_t root_graph_id) {
  1405. MS_EXCEPTION_IF_NULL(anf_node);
  1406. if (!anf_node->isa<Parameter>() && !anf_node->isa<ValueNode>()) {
  1407. return;
  1408. }
  1409. // When MindRT is used, only ValueNodes and ParameterWeights can be loaded from device to host
  1410. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
  1411. if (!anf_node->isa<ValueNode>() &&
  1412. !(anf_node->isa<Parameter>() && AnfAlgo::IsParameterWeight(anf_node->cast<ParameterPtr>()))) {
  1413. return;
  1414. }
  1415. }
  1416. // for parameters and value nodes, set its execution order to be 0;
  1417. int exec_order = 0;
  1418. std::string node_name = GetKernelNodeName(anf_node);
  1419. GetFileKernelName(NOT_NULL(&node_name));
  1420. // check if output adde exists, if not, return;
  1421. if (!AnfAlgo::OutputAddrExist(anf_node, output_index)) {
  1422. return;
  1423. }
  1424. auto addr = AnfAlgo::GetOutputAddr(anf_node, output_index);
  1425. MS_EXCEPTION_IF_NULL(addr);
  1426. auto type = AnfAlgo::GetOutputInferDataType(anf_node, output_index);
  1427. if (!IsTypeDebuggerSupported(type)) {
  1428. return;
  1429. }
  1430. auto format = kOpFormat_DEFAULT;
  1431. string tensor_name = node_name + ':' + "0";
  1432. ShapeVector int_shapes = trans::GetRuntimePaddingShape(anf_node, output_index);
  1433. bool keep_prev;
  1434. if (anf_node->isa<Parameter>()) {
  1435. keep_prev = true;
  1436. debug_services_->MoveTensorCurrentToPrev(tensor_name);
  1437. } else {
  1438. keep_prev = false;
  1439. }
  1440. bool ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, keep_prev, root_graph_id);
  1441. if (!ret) {
  1442. MS_LOG(ERROR) << "LoadMemToHost:"
  1443. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1444. }
  1445. }
  1446. /*
  1447. * Feature group: Dump.
  1448. * Target device group: Ascend, GPU.
  1449. * Runtime category: Old runtime, MindRT.
  1450. * Description: Load all the parameters and value nodes for the last loaded graph.
  1451. */
  1452. void Debugger::LoadParametersAndConst() {
  1453. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  1454. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1455. // load parameters
  1456. MS_LOG(INFO) << "Start to load Parameters for graph " << graph_ptr_->graph_id() << ".";
  1457. auto root_graph_id = graph_ptr_->root_graph_id();
  1458. const auto &parameters = graph_ptr_->inputs();
  1459. for (auto &item : parameters) {
  1460. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX, root_graph_id);
  1461. }
  1462. // load value nodes
  1463. // get all constant values from the graph
  1464. MS_LOG(INFO) << "Start to load value nodes for graph " << graph_ptr_->graph_id() << ".";
  1465. const auto value_nodes = graph_ptr_->graph_value_nodes();
  1466. for (auto &item : value_nodes) {
  1467. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX, root_graph_id);
  1468. }
  1469. }
  1470. /*
  1471. * Feature group: Dump.
  1472. * Target device group: Ascend, GPU.
  1473. * Runtime category: Old runtime, MindRT.
  1474. * Description: Load all the parameters and value nodes for the given graph.
  1475. */
  1476. void Debugger::LoadParametersAndConst(const KernelGraphPtr &graph) {
  1477. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  1478. MS_EXCEPTION_IF_NULL(graph);
  1479. // load parameters
  1480. MS_LOG(INFO) << "Start to load Parameters for graph " << graph->graph_id() << ".";
  1481. auto root_graph_id = graph->root_graph_id();
  1482. const auto &parameters = graph->inputs();
  1483. for (auto &item : parameters) {
  1484. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX, root_graph_id);
  1485. }
  1486. // load value nodes
  1487. // get all constant values from the graph
  1488. MS_LOG(INFO) << "Start to load value nodes for graph " << graph->graph_id() << ".";
  1489. const auto value_nodes = graph->graph_value_nodes();
  1490. for (auto &item : value_nodes) {
  1491. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX, root_graph_id);
  1492. }
  1493. }
  1494. /*
  1495. * Feature group: Online debugger.
  1496. * Target device group: Ascend.
  1497. * Runtime category: Old runtime, MindRT.
  1498. * Description: Load all the kernels for the last loaded graph.
  1499. */
  1500. void Debugger::LoadGraphOutputs() {
  1501. if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
  1502. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1503. const auto &apply_kernels = graph_ptr_->execution_order();
  1504. auto root_graph_id = graph_ptr_->root_graph_id();
  1505. // for kernels, execution order starts from 1
  1506. int exec_order = 1;
  1507. for (const auto &node : apply_kernels) {
  1508. MS_EXCEPTION_IF_NULL(node);
  1509. std::string kernel_name = GetKernelNodeName(node);
  1510. auto output_size = AnfAlgo::GetOutputTensorNum(node);
  1511. if (partial_memory_) {
  1512. if (!debug_services_->IsWatchPoint(kernel_name, node)) {
  1513. continue;
  1514. }
  1515. }
  1516. for (size_t j = 0; j < output_size; ++j) {
  1517. if (!AnfAlgo::OutputAddrExist(node, j)) {
  1518. MS_LOG(INFO) << "Cannot find output addr for slot " << j << " for " << kernel_name;
  1519. continue;
  1520. }
  1521. auto addr = AnfAlgo::GetOutputAddr(node, j);
  1522. MS_EXCEPTION_IF_NULL(addr);
  1523. auto type = AnfAlgo::GetOutputInferDataType(node, j);
  1524. if (!IsTypeDebuggerSupported(type)) {
  1525. continue;
  1526. }
  1527. auto format = kOpFormat_DEFAULT;
  1528. string tensor_name = kernel_name + ':' + std::to_string(j);
  1529. ShapeVector int_shapes = trans::GetRuntimePaddingShape(node, j);
  1530. auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false, root_graph_id);
  1531. if (!ret) {
  1532. MS_LOG(ERROR) << "LoadMemToHost:"
  1533. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1534. }
  1535. }
  1536. exec_order = exec_order + 1;
  1537. }
  1538. }
  1539. /*
  1540. * Feature group: Dump.
  1541. * Target device group: Ascend.
  1542. * Runtime category: MindRT.
  1543. * Description: Load a single node for kernel-by-kernel ascend mindRT dump.
  1544. */
  1545. void Debugger::LoadNodeOutputs(const CNodePtr &node, uint32_t exec_order, uint32_t root_graph_id) {
  1546. if (device_target_ != kAscendDevice) {
  1547. return;
  1548. }
  1549. MS_EXCEPTION_IF_NULL(node);
  1550. std::string kernel_name = GetKernelNodeName(node);
  1551. auto output_size = AnfAlgo::GetOutputTensorNum(node);
  1552. if (partial_memory_) {
  1553. if (!debug_services_->IsWatchPoint(kernel_name, node)) {
  1554. return;
  1555. }
  1556. }
  1557. for (size_t j = 0; j < output_size; ++j) {
  1558. if (!AnfAlgo::OutputAddrExist(node, j)) {
  1559. MS_LOG(INFO) << "Cannot find output addr for slot " << j << " for " << kernel_name;
  1560. continue;
  1561. }
  1562. auto addr = AnfAlgo::GetOutputAddr(node, j);
  1563. MS_EXCEPTION_IF_NULL(addr);
  1564. auto type = AnfAlgo::GetOutputInferDataType(node, j);
  1565. if (!IsTypeDebuggerSupported(type)) {
  1566. return;
  1567. }
  1568. auto format = kOpFormat_DEFAULT;
  1569. string tensor_name = kernel_name + ':' + std::to_string(j);
  1570. ShapeVector int_shapes = trans::GetRuntimePaddingShape(node, j);
  1571. auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false, root_graph_id);
  1572. if (!ret) {
  1573. MS_LOG(ERROR) << "LoadMemToHost:"
  1574. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1575. }
  1576. }
  1577. }
  1578. /*
  1579. * Feature group: Online debugger.
  1580. * Target device group: GPU.
  1581. * Runtime category: Old runtime.
  1582. * Description: Update step number if we are processing the first graph (to support multigraph).
  1583. */
  1584. void Debugger::UpdateStepNum(const session::KernelGraph *graph) {
  1585. MS_EXCEPTION_IF_NULL(graph);
  1586. MS_EXCEPTION_IF_NULL(debugger_);
  1587. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()) &&
  1588. (graph->graph_id() == debugger_->GetFirstRunGraphId())) {
  1589. // access lock for public method
  1590. std::lock_guard<std::mutex> a_lock(access_lock_);
  1591. ++num_step_;
  1592. }
  1593. }
  1594. /*
  1595. * Feature group: Online debugger.
  1596. * Target device group: GPU.
  1597. * Runtime category: MindRT.
  1598. * Description: Update step number when DebugActor::DebugOnStepEnd is called at the end of each step.
  1599. */
  1600. void Debugger::UpdateStepNumGPU() {
  1601. if (device_target_ == kGPUDevice && (debugger_enabled_ || DumpDataEnabledIteration())) {
  1602. // access lock for public method
  1603. std::lock_guard<std::mutex> a_lock(access_lock_);
  1604. ++num_step_;
  1605. MS_LOG(DEBUG) << "Update step for GPU, current step: " << num_step_;
  1606. }
  1607. }
  1608. void Debugger::ClearCurrentData() {
  1609. if ((device_target_ == kGPUDevice) && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration())) {
  1610. if (debug_services_) {
  1611. debug_services_->EmptyCurrentTensor();
  1612. } else {
  1613. MS_LOG(ERROR) << "debug_services_ is nullptr";
  1614. }
  1615. }
  1616. }
  1617. bool Debugger::TensorExistsInCurrent(const std::string &tensor_name) {
  1618. return debug_services_->TensorExistsInCurrent(tensor_name);
  1619. }
  1620. #ifdef ENABLE_D
  1621. /*
  1622. * Feature group: Dump.
  1623. * Target device group: Ascend.
  1624. * Runtime category: Old runtime, MindRT.
  1625. * Description: Load DumpDataBuilder object from dump_data_construct_map_ for tracking data chunks of node_name. It's
  1626. * for Ascend a + m dump. If not found, create a new one for it and add to dump_data_construct_map_.
  1627. */
  1628. std::shared_ptr<DumpDataBuilder> Debugger::LoadDumpDataBuilder(const std::string &node_name) {
  1629. auto iter = dump_data_construct_map_.find(node_name);
  1630. if (iter == dump_data_construct_map_.end()) {
  1631. dump_data_construct_map_[node_name] = std::make_shared<DumpDataBuilder>();
  1632. }
  1633. return dump_data_construct_map_[node_name];
  1634. }
  1635. void Debugger::ClearDumpDataBuilder(const std::string &node_name) { dump_data_construct_map_.erase(node_name); }
  1636. #endif
  1637. } // namespace mindspore