You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger.cc 55 kB

4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <cstdio>
  18. #include <fstream>
  19. #include <tuple>
  20. #include <vector>
  21. #include <algorithm>
  22. #include <iostream>
  23. #include <cstring>
  24. #include <utility>
  25. #include <map>
  26. #include <regex>
  27. #include "debug/debugger/debugger.h"
  28. #include "debug/data_dump/dump_json_parser.h"
  29. #include "pipeline/jit/pipeline.h"
  30. #include "backend/session/anf_runtime_algorithm.h"
  31. #include "runtime/device/kernel_runtime_manager.h"
  32. #include "runtime/device/kernel_runtime.h"
  33. #include "debug/data_dump/e2e_dump.h"
  34. #include "utils/config_manager.h"
  35. #include "debug/env_config_parser.h"
  36. #include "utils/comm_manager.h"
  37. #include "runtime/hardware/device_context_manager.h"
  38. #include "debug/anf_ir_dump.h"
  39. #include "debug/anf_ir_utils.h"
  40. #ifdef ENABLE_DEBUGGER
  41. #include "debug/debugger/proto_exporter.h"
  42. #else
  43. #include "debug/debugger/proto_exporter_stub.h"
  44. #endif
  45. using debugger::Chunk;
  46. using debugger::EventReply;
  47. using debugger::GraphProto;
  48. using debugger::ModelProto;
  49. using debugger::Statistics;
  50. using debugger::TensorProto;
  51. using debugger::WatchCondition;
  52. using debugger::WatchCondition_Condition_inf;
  53. using debugger::WatchCondition_Condition_nan;
  54. using debugger::WatchCondition_Parameter;
  55. using debugger::WatchNode;
  56. using debugger::WatchpointHit;
  57. namespace mindspore {
  58. static constexpr auto g_chunk_size = 1024 * 1024 * 3;
  59. static constexpr int32_t heartbeat_period_second = 30;
  60. DebuggerPtr Debugger::debugger_ = nullptr;
  61. std::mutex Debugger::instance_lock_;
  62. Debugger::Debugger()
  63. : grpc_client_(nullptr),
  64. debug_services_(nullptr),
  65. heartbeat_thread_(nullptr),
  66. device_id_(0),
  67. device_target_(""),
  68. num_step_(0),
  69. debugger_enabled_(false),
  70. suspended_at_last_kernel_(false),
  71. run_level_(""),
  72. node_name_(""),
  73. cur_name_(""),
  74. training_done_(false),
  75. is_dataset_graph_(false),
  76. partial_memory_(false),
  77. initial_suspend_(true),
  78. enable_heartbeat_(false),
  79. not_dataset_graph_sum_(0),
  80. version_("") {
  81. CheckDebuggerEnabledParam();
  82. auto ms_context = MsContext::GetInstance();
  83. MS_EXCEPTION_IF_NULL(ms_context);
  84. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  85. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  86. if (device_target == kCPUDevice) {
  87. MS_LOG(WARNING) << "Not enabling debugger. Debugger does not support CPU.";
  88. } else if (CheckDebuggerEnabled()) {
  89. // configure partial memory reuse
  90. partial_memory_ = CheckDebuggerPartialMemoryEnabled();
  91. // switch memory reuse on or off
  92. EnvConfigParser::GetInstance().SetSysMemreuse(partial_memory_);
  93. // print some message about memory reuse to user
  94. if (partial_memory_) {
  95. MS_LOG(WARNING)
  96. << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
  97. "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
  98. } else {
  99. MS_LOG(WARNING)
  100. << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
  101. "usage for large models.";
  102. }
  103. }
  104. }
  105. void Debugger::Init(const uint32_t device_id, const std::string device_target) {
  106. // access lock for public method
  107. std::lock_guard<std::mutex> a_lock(access_lock_);
  108. // save device_id
  109. MS_LOG(INFO) << "Debugger got device_id: " << device_id;
  110. device_id_ = device_id;
  111. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  112. device_target_ = device_target;
  113. version_ = MSVERSION;
  114. }
  115. bool IsTypeDebuggerSupported(TypeId type) {
  116. if (type < TypeId::kNumberTypeEnd && type > TypeId::kNumberTypeBegin && type != kNumberTypeComplex64) {
  117. return true;
  118. }
  119. MS_LOG(INFO) << "Debugger does not support type: " << TypeIdLabel(type);
  120. return false;
  121. }
  122. void Debugger::EnableDebugger() {
  123. // reset some of the class members
  124. num_step_ = 0;
  125. debugger_enabled_ = false;
  126. enable_heartbeat_ = false;
  127. partial_memory_ = false;
  128. grpc_client_ = nullptr;
  129. debug_services_ = nullptr;
  130. heartbeat_thread_ = nullptr;
  131. // see if dump using debugger backend is enabled
  132. bool dump_enabled = CheckDebuggerDumpEnabled();
  133. MS_LOG(INFO) << "dump using debugger backend = " << dump_enabled;
  134. // check if debugger enabled
  135. debugger_enabled_ = CheckDebuggerEnabled();
  136. MS_LOG(INFO) << "debugger_enabled_ = " << debugger_enabled_;
  137. if (!debugger_enabled_ && !dump_enabled) {
  138. MS_LOG(INFO) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
  139. return;
  140. }
  141. if (debugger_enabled_) {
  142. // configure grpc host
  143. std::string env_host_str = common::GetEnv("MS_DEBUGGER_HOST");
  144. std::string host;
  145. if (!env_host_str.empty()) {
  146. if (CheckIp(env_host_str)) {
  147. MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
  148. host = env_host_str;
  149. } else {
  150. debugger_enabled_ = false;
  151. MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_HOST isn't a valid IP address. "
  152. "Please set environment variable MS_DEBUGGER_HOST=x.x.x.x to a valid IP";
  153. }
  154. } else {
  155. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
  156. host = "localhost";
  157. }
  158. // configure grpc port
  159. std::string env_port_str = common::GetEnv("MS_DEBUGGER_PORT");
  160. std::string port;
  161. if (!env_port_str.empty()) {
  162. if (CheckPort(env_port_str)) {
  163. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
  164. port = env_port_str;
  165. } else {
  166. debugger_enabled_ = false;
  167. MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to "
  168. "65535";
  169. }
  170. } else {
  171. port = "50051";
  172. if (!CheckPort(port)) {
  173. MS_EXCEPTION(ValueError) << "Default MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to 65535";
  174. }
  175. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
  176. }
  177. // initialize grpc client
  178. grpc_client_ = std::make_unique<GrpcClient>(host, port);
  179. // initialize sending heartbeat
  180. heartbeat_thread_ = std::make_unique<std::thread>([this]() { SendHeartbeat(heartbeat_period_second); });
  181. }
  182. debug_services_ = std::make_unique<DebugServices>();
  183. }
  184. void Debugger::CheckDatasetSinkMode() {
  185. if (CheckDebuggerDumpEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  186. MS_EXCEPTION(NotSupportError)
  187. << "e2e_dump not supported on GPU with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  188. }
  189. if (CheckDebuggerEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  190. MS_EXCEPTION(NotSupportError)
  191. << "Debugger is not supported with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  192. }
  193. }
  194. bool Debugger::CheckDebuggerDumpEnabled() const {
  195. // see if dump is enabled
  196. if (device_target_ == kGPUDevice) {
  197. return device::KernelRuntime::DumpDataEnabled();
  198. }
  199. return false;
  200. }
  201. bool Debugger::CheckDebuggerEnabled() const {
  202. // get env variables to configure debugger
  203. std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
  204. if (!env_enable_str.empty()) {
  205. (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
  206. if ((env_enable_str == "1" || env_enable_str == "true") && device_target_ != kCPUDevice) {
  207. return true;
  208. }
  209. }
  210. return false;
  211. }
  212. void Debugger::CheckDebuggerEnabledParam() const {
  213. // check the value of env variable ENABLE_MS_DEBUGGER
  214. std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
  215. if (!env_enable_str.empty()) {
  216. (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
  217. if (env_enable_str != "0" && env_enable_str != "1" && env_enable_str != "false" && env_enable_str != "true") {
  218. MS_LOG(WARNING) << "Env variable ENABLE_MS_DEBUGGER should be True/False/1/0 (case insensitive), but get: "
  219. << env_enable_str;
  220. }
  221. }
  222. }
  223. bool Debugger::CheckDebuggerPartialMemoryEnabled() const {
  224. std::string env_partial_mem_str = common::GetEnv("MS_DEBUGGER_PARTIAL_MEM");
  225. if (!env_partial_mem_str.empty()) {
  226. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
  227. if (env_partial_mem_str == "1") {
  228. return true;
  229. }
  230. }
  231. return false;
  232. }
  233. bool Debugger::DebuggerBackendEnabled() const { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
  234. void Debugger::Reset() {
  235. // access lock for public method
  236. std::lock_guard<std::mutex> a_lock(access_lock_);
  237. // reset components
  238. if (heartbeat_thread_ && heartbeat_thread_->joinable()) {
  239. SetEnableHeartbeat(false);
  240. heartbeat_thread_->join();
  241. MS_LOG(INFO) << "Join Heartbeat thread.";
  242. }
  243. heartbeat_thread_ = nullptr;
  244. device_id_ = 0;
  245. device_target_ = "";
  246. num_step_ = 0;
  247. debugger_enabled_ = false;
  248. is_dataset_graph_ = false;
  249. partial_memory_ = false;
  250. graph_ptr_ = nullptr;
  251. grpc_client_ = nullptr;
  252. debug_services_ = nullptr;
  253. graph_proto_list_.clear();
  254. graph_ptr_list_.clear();
  255. graph_ptr_step_vec_.clear();
  256. MS_LOG(INFO) << "Release Debugger resource.";
  257. }
  258. void Debugger::PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs) {
  259. // Only GPU is supported for MindRTBackend
  260. if (device_target_ != kGPUDevice) {
  261. return;
  262. }
  263. // Store graphs that are run in one step.
  264. graph_ptr_step_vec_ = graphs;
  265. for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
  266. const auto &graph = graphs[graph_index];
  267. if (debugger_) {
  268. debugger_->PreExecute(graph);
  269. }
  270. }
  271. }
  272. void Debugger::PreExecute(const KernelGraphPtr &graph_ptr) {
  273. MS_EXCEPTION_IF_NULL(graph_ptr);
  274. // access lock for public method
  275. std::lock_guard<std::mutex> a_lock(access_lock_);
  276. CheckDatasetSinkMode();
  277. auto graph_id = graph_ptr->graph_id();
  278. // collect rungrap_ids to update step number in multigraph case
  279. if (!rungraph_id_list_.size()) {
  280. rungraph_id_list_.push_back(graph_id);
  281. } else {
  282. if (std::find(rungraph_id_list_.begin(), rungraph_id_list_.end(), graph_id) == rungraph_id_list_.end()) {
  283. rungraph_id_list_.push_back(graph_id);
  284. }
  285. }
  286. // multiple graphs
  287. if (graph_proto_list_.size() > 1) {
  288. // there are more than one graphs are not dataset_graph
  289. if (not_dataset_graph_sum_ > 0) {
  290. SendMultiGraphsAndClear(graph_ptr);
  291. }
  292. } else if (graph_proto_list_.size() == 1) {
  293. // single graph, and not the initial step
  294. if (device_target_ == kGPUDevice && num_step_ != 0) {
  295. if (debugger_enabled_ && !(run_level_ == "node" && suspended_at_last_kernel_)) {
  296. CommandLoop();
  297. }
  298. debug_services_->ResetLoadedTensors();
  299. }
  300. // In single graph case, reset graph_ptr_ to be nullptr for the initial step
  301. if (num_step_ == 0) {
  302. graph_ptr_ = nullptr;
  303. CheckGraphPtr(graph_ptr);
  304. }
  305. } else if (debugger_enabled_ && graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
  306. // Multiple graph, and not the initial step,
  307. // stop only when receive the first sub run graph for each step
  308. // if we have stopped for the last kernel before, no need to stop again
  309. if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
  310. return;
  311. }
  312. if (!(run_level_ == "node" && suspended_at_last_kernel_)) {
  313. CommandLoop();
  314. }
  315. debug_services_->ResetLoadedTensors();
  316. }
  317. // resets for the new graph
  318. suspended_at_last_kernel_ = false;
  319. }
  320. void Debugger::SendMultiGraphsAndClear(const KernelGraphPtr &graph_ptr) {
  321. // only try to enable debugger if they are not all dataset graphs
  322. if (!debugger_enabled_) {
  323. EnableDebugger();
  324. }
  325. if (debugger_enabled_) {
  326. // only send compiled graphs once at the initial step.
  327. auto dbg_graph_ptr = graph_ptr_;
  328. // use current graph ptr to load parameters
  329. graph_ptr_ = graph_ptr;
  330. LoadParametersAndConst();
  331. // revert graph ptr to original value
  332. graph_ptr_ = dbg_graph_ptr;
  333. SendMultiGraphsAndSuspend(graph_proto_list_);
  334. graph_proto_list_.clear();
  335. }
  336. }
  337. bool Debugger::DumpDataEnabledIteration() const {
  338. auto &dump_json_parser = DumpJsonParser::GetInstance();
  339. if (!dump_json_parser.e2e_dump_enabled()) {
  340. return false;
  341. }
  342. auto cur_iter = dump_json_parser.cur_dump_iter();
  343. if (dump_json_parser.IsDumpIter(cur_iter)) {
  344. return true;
  345. }
  346. return false;
  347. }
  348. uint32_t Debugger::GetRankID() {
  349. auto ms_context = MsContext::GetInstance();
  350. MS_EXCEPTION_IF_NULL(ms_context);
  351. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  352. uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  353. const auto &device_context =
  354. device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
  355. uint32_t rank_id = device_context->GetRankID();
  356. return rank_id;
  357. }
  358. void Debugger::Dump(const KernelGraphPtr &kernel_graph) const {
  359. uint32_t rank_id = GetRankID();
  360. E2eDump::DumpRunIter(kernel_graph, rank_id);
  361. if (debugger_ && debugger_->DebuggerBackendEnabled()) {
  362. MS_EXCEPTION_IF_NULL(kernel_graph);
  363. (void)E2eDump::DumpParametersAndConstData(kernel_graph.get(), rank_id, debugger_.get());
  364. } else {
  365. DumpJsonParser::GetInstance().UpdateDumpIter();
  366. }
  367. }
  368. void Debugger::DumpSingleNode(const CNodePtr &node, uint32_t graph_id) {
  369. if (debugger_ && debugger_->DebuggerBackendEnabled()) {
  370. uint32_t rank_id = GetRankID();
  371. (void)E2eDump::DumpSingleNodeData(node, graph_id, rank_id, debugger_.get());
  372. }
  373. }
  374. void Debugger::DumpSetup(const KernelGraphPtr &kernel_graph) const {
  375. MS_LOG(INFO) << "Start!";
  376. MS_EXCEPTION_IF_NULL(kernel_graph);
  377. E2eDump::DumpSetup(kernel_graph.get());
  378. MS_LOG(INFO) << "Finish!";
  379. }
  380. void Debugger::DumpInGraphCompiler(const KernelGraphPtr &kernel_graph) {
  381. // This function will be called for new GPU runtime using MindRTBackend
  382. auto &json_parser = DumpJsonParser::GetInstance();
  383. if (json_parser.e2e_dump_enabled()) {
  384. uint32_t rank_id = GetRankID();
  385. kernel_graph->set_root_graph_id(kernel_graph->graph_id());
  386. std::string final_graph = "trace_code_graph_" + std::to_string(kernel_graph->graph_id());
  387. std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id);
  388. std::string target_dir = root_dir + "/graphs";
  389. std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
  390. DumpIRProtoWithSrcInfo(kernel_graph, final_graph, target_dir, kDebugWholeStack);
  391. DumpIR("trace_code_graph", kernel_graph, true, kWholeStack, ir_file_path);
  392. DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(kernel_graph->graph_id()) + ".csv", root_dir,
  393. kernel_graph->execution_order());
  394. }
  395. }
  396. void Debugger::PostExecuteGraphDebugger() {
  397. // On CPU, update dump iteration, Parameters and consts are not dumped here
  398. if (device_target_ == kCPUDevice) {
  399. DumpJsonParser::GetInstance().UpdateDumpIter();
  400. return;
  401. }
  402. // Only GPU is supported for MindRTBackend
  403. if (device_target_ != kGPUDevice) {
  404. return;
  405. }
  406. // LoadParametersAndConst for all the graphs that have been run in the current step
  407. if (debugger_) {
  408. for (auto graph : graph_ptr_step_vec_) {
  409. debugger_->LoadParametersAndConst(graph);
  410. }
  411. }
  412. // debug used for dump
  413. if (debugger_ && debugger_->CheckDebuggerDumpEnabled()) {
  414. // Dump Parameters and consts
  415. for (auto graph : graph_ptr_step_vec_) {
  416. debugger_->Dump(graph);
  417. if (!debugger_->debugger_enabled()) {
  418. debugger_->ClearCurrentData();
  419. }
  420. }
  421. }
  422. if (debugger_) {
  423. debugger_->PostExecute();
  424. }
  425. E2eDump::UpdateIterGPUDump();
  426. }
  427. void Debugger::PostExecute() {
  428. // access lock for public method
  429. std::lock_guard<std::mutex> a_lock(access_lock_);
  430. if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
  431. return;
  432. }
  433. if (debugger_ && debugger_->DebuggerBackendEnabled()) {
  434. // analyze tensor data and send the watchpoints been hit
  435. if (debugger_enabled_ && !is_dataset_graph_) {
  436. if (device_target_ != kGPUDevice) {
  437. num_step_++;
  438. }
  439. SendWatchpoints(CheckWatchpoints());
  440. // no need to suspend at each graph for GPU, suspension happens in preExecute
  441. if (device_target_ != kGPUDevice) {
  442. CommandLoop();
  443. }
  444. }
  445. // Only keep parameters in the current map
  446. // GPU ResetLoadedTensors happens in preExecute
  447. if (device_target_ != kGPUDevice) {
  448. debug_services_->ResetLoadedTensors();
  449. }
  450. }
  451. }
  452. bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) const {
  453. if (debugger_enabled_ && !is_dataset_graph_) {
  454. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  455. // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
  456. if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
  457. return true;
  458. }
  459. }
  460. return false;
  461. }
  462. void Debugger::PostExecuteNode(const CNodePtr &kernel, bool last_kernel) {
  463. // access lock for public method
  464. std::lock_guard<std::mutex> a_lock(access_lock_);
  465. if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
  466. return;
  467. }
  468. if (debugger_enabled_ && !is_dataset_graph_) {
  469. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  470. // if kernel is watchpoint,and get hit. suspend.
  471. bool hit_empty_flag = true;
  472. if (is_watchpoint) {
  473. auto hits = CheckWatchpoints(cur_name_, kernel);
  474. if (!hits.empty()) {
  475. SendWatchpoints(hits);
  476. CommandLoop();
  477. hit_empty_flag = false;
  478. }
  479. }
  480. if (hit_empty_flag && run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) {
  481. // if kernel is not watchpoint and is next_to or continue_to node, suspend
  482. // sets a bool to be checked in preExecute to avoid double stopping at last kernel in the last graph
  483. if (last_kernel) {
  484. suspended_at_last_kernel_ = true;
  485. }
  486. CommandLoop();
  487. }
  488. return;
  489. }
  490. }
  491. void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) {
  492. MS_EXCEPTION_IF_NULL(graph_ptr);
  493. if (graph_ptr_ != graph_ptr) {
  494. MS_LOG(INFO) << "LoadGraphs Debugger got new graph: " << graph_ptr->graph_id();
  495. // save new graph_ptr
  496. graph_ptr_ = graph_ptr;
  497. CheckDatasetGraph();
  498. if (!is_dataset_graph_) {
  499. // get proto for new graph_ptr
  500. auto graph_proto = GetGraphProto(graph_ptr);
  501. // add new graph proto to graph_proto_list_
  502. graph_proto_list_.push_back(graph_proto);
  503. graph_ptr_list_.push_back(graph_ptr);
  504. not_dataset_graph_sum_++;
  505. }
  506. // reset is_dataset_graph to be false
  507. is_dataset_graph_ = false;
  508. }
  509. }
  510. // In single graph cases, check single graph ptr
  511. void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
  512. MS_EXCEPTION_IF_NULL(graph_ptr);
  513. if (graph_ptr_ != graph_ptr) {
  514. MS_LOG(INFO) << "CheckGraphPtr Debugger got new graph: " << graph_ptr->graph_id();
  515. // save new graph_ptr
  516. graph_ptr_ = graph_ptr;
  517. if (!is_dataset_graph_) {
  518. // only try to enable debugger if it is not a dataset graph
  519. EnableDebugger();
  520. if (debugger_enabled_) {
  521. LoadParametersAndConst();
  522. // get graph proto and send to Mindinsight
  523. auto graph_proto = graph_proto_list_.front();
  524. SendGraphAndSuspend(graph_proto);
  525. }
  526. }
  527. }
  528. }
  529. void Debugger::CheckDatasetGraph() {
  530. // print parameter node names
  531. MS_EXCEPTION_IF_NULL(graph_ptr_);
  532. const auto &params = graph_ptr_->inputs();
  533. for (const auto &param : params) {
  534. MS_LOG(INFO) << "param: " << GetKernelNodeName(param);
  535. }
  536. // check if there is GetNext or InitDataSetQueue node
  537. const auto &nodes = graph_ptr_->execution_order();
  538. for (const auto &node : nodes) {
  539. auto node_name = AnfAlgo::GetCNodeName(node);
  540. MS_LOG(INFO) << "node: " << GetKernelNodeName(node);
  541. if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
  542. MS_LOG(INFO) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
  543. << node_name;
  544. is_dataset_graph_ = true;
  545. return;
  546. }
  547. }
  548. is_dataset_graph_ = false;
  549. }
  550. GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
  551. // convert kernel graph to debugger modelproto
  552. ModelProto model = GetDebuggerFuncGraphProto(graph_ptr);
  553. return model.graph();
  554. }
  555. void Debugger::SendHeartbeat(int32_t period) {
  556. int num_heartbeat_fail = 0;
  557. const int max_num_heartbeat_fail = 5;
  558. const int retry_milliseconds = 500;
  559. Heartbeat heartbeat;
  560. heartbeat.set_message("Debugger is alive");
  561. heartbeat.set_period(heartbeat_period_second);
  562. SetEnableHeartbeat(CheckDebuggerEnabled());
  563. while (enable_heartbeat_) {
  564. MS_EXCEPTION_IF_NULL(grpc_client_);
  565. EventReply reply = grpc_client_->SendHeartbeat(heartbeat);
  566. if (reply.status() != reply.OK) {
  567. MS_LOG(ERROR) << "Error: SendHeartbeat failed";
  568. num_heartbeat_fail++;
  569. if (num_heartbeat_fail >= max_num_heartbeat_fail) {
  570. MS_LOG(ERROR) << "Maximum number of failure for SendHeartbeat reached : exiting training session.";
  571. SetEnableHeartbeat(false);
  572. break;
  573. } else {
  574. MS_LOG(ERROR) << "Number of consecutive SendHeartbeat fail:" << num_heartbeat_fail;
  575. std::this_thread::sleep_for(std::chrono::milliseconds(retry_milliseconds));
  576. }
  577. } else {
  578. int recheck_period_ms = 200;
  579. for (int i = 0; i < (period * 1000 / recheck_period_ms); i++) {
  580. if (enable_heartbeat_) {
  581. std::this_thread::sleep_for(std::chrono::milliseconds(recheck_period_ms));
  582. } else {
  583. break;
  584. }
  585. }
  586. }
  587. }
  588. }
  589. void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
  590. if (SendMetadata(true)) {
  591. // send graph to Mindinsight server
  592. MS_EXCEPTION_IF_NULL(grpc_client_);
  593. EventReply reply = grpc_client_->SendGraph(graph_proto);
  594. if (reply.status() != reply.OK) {
  595. MS_LOG(ERROR) << "Error: SendGraph failed";
  596. }
  597. // enter command loop, wait and process commands
  598. CommandLoop();
  599. }
  600. }
  601. bool Debugger::SendMetadata(bool version_check) {
  602. // prepare metadata
  603. MS_EXCEPTION_IF_NULL(graph_ptr_);
  604. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  605. Metadata metadata;
  606. metadata.set_device_name(device_name);
  607. metadata.set_cur_step(num_step_);
  608. metadata.set_backend(device_target_);
  609. metadata.set_cur_node(cur_name_);
  610. metadata.set_training_done(training_done_);
  611. metadata.set_ms_version(version_);
  612. MS_LOG(INFO) << "Is training done?" << training_done_;
  613. // set graph number to not_dataset_graph_sum_
  614. metadata.set_graph_num(not_dataset_graph_sum_);
  615. MS_EXCEPTION_IF_NULL(grpc_client_);
  616. EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
  617. bool ret = false;
  618. if (reply_metadata.status() == reply_metadata.OK) {
  619. if (version_check) {
  620. // get type of the command in meta data reply, it should be version matched
  621. DebuggerCommand cmd = GetCommand(reply_metadata);
  622. if (cmd != DebuggerCommand::kVersionMatchedCMD) {
  623. MS_LOG(ERROR) << "MindInsight version is too old, Mindspore version is " << version_;
  624. Exit();
  625. } else {
  626. if (GetMiVersionMatched(reply_metadata)) {
  627. MS_LOG(INFO) << "MindSpore version is " << version_ << " matches MindInsight version.";
  628. ret = true;
  629. } else {
  630. MS_LOG(ERROR) << "MindSpore version " << version_ << ", did not match MindInsight version.";
  631. CommandLoop();
  632. }
  633. }
  634. } else {
  635. // version check is done before so we can just return true here
  636. ret = true;
  637. }
  638. } else {
  639. MS_LOG(ERROR) << "Error: SendMetadata failed";
  640. }
  641. return ret;
  642. }
  643. void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list) {
  644. if (!SendMetadata(true)) {
  645. return;
  646. }
  647. MS_EXCEPTION_IF_NULL(grpc_client_);
  648. // send multiple graphs to mindinght server
  649. // split graph into chunks if one graph is larger than chunk size
  650. std::list<Chunk> chunked_graph_proto_list;
  651. Chunk chunk;
  652. for (auto graph : graph_proto_list) {
  653. std::string str = graph.SerializeAsString();
  654. auto graph_size = graph.ByteSize();
  655. if (graph_size > g_chunk_size) {
  656. auto sub_graph_str = grpc_client_->ChunkString(str, graph_size);
  657. for (unsigned int i = 0; i < sub_graph_str.size(); i++) {
  658. chunk.set_buffer(sub_graph_str[i]);
  659. if (i < sub_graph_str.size() - 1) {
  660. chunk.set_finished(false);
  661. } else {
  662. chunk.set_finished(true);
  663. }
  664. chunked_graph_proto_list.push_back(chunk);
  665. }
  666. } else {
  667. chunk.set_buffer(str);
  668. chunk.set_finished(true);
  669. chunked_graph_proto_list.push_back(chunk);
  670. }
  671. }
  672. EventReply reply = grpc_client_->SendMultiGraphs(chunked_graph_proto_list);
  673. if (reply.status() != reply.OK) {
  674. MS_LOG(ERROR) << "Error: SendGraph failed";
  675. }
  676. // enter command loop, wait and process commands
  677. CommandLoop();
  678. }
  679. void Debugger::CommandLoop() {
  680. // prepare metadata
  681. MS_EXCEPTION_IF_NULL(graph_ptr_);
  682. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  683. Metadata metadata;
  684. metadata.set_device_name(device_name);
  685. metadata.set_cur_step(num_step_);
  686. metadata.set_backend(device_target_);
  687. metadata.set_cur_node(cur_name_);
  688. metadata.set_training_done(training_done_);
  689. // loop exit flag
  690. bool run = false;
  691. int num_wait_fail = 0;
  692. const int max_num_wait_fail = 5;
  693. while (!run) {
  694. // wait for command
  695. MS_EXCEPTION_IF_NULL(grpc_client_);
  696. EventReply reply = grpc_client_->WaitForCommand(metadata);
  697. if (reply.status() != reply.OK) {
  698. MS_LOG(ERROR) << "Error: WaitForCommand failed";
  699. num_wait_fail++;
  700. if (num_wait_fail > max_num_wait_fail) {
  701. MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session.";
  702. MS_LOG(ERROR) << "Failed to connect to MindInsight debugger server. Please check the config "
  703. "of debugger host and port.";
  704. Exit();
  705. run = true;
  706. } else {
  707. MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
  708. << num_wait_fail << "s";
  709. std::this_thread::sleep_for(std::chrono::seconds(num_wait_fail));
  710. }
  711. continue;
  712. }
  713. // get type of the command in reply
  714. DebuggerCommand cmd = GetCommand(reply);
  715. if (cmd == DebuggerCommand::kUnknownCMD) {
  716. MS_LOG(DEBUG) << "Debug: debugger received unknown command";
  717. continue;
  718. }
  719. MS_LOG(INFO) << "received command: ";
  720. switch (cmd) {
  721. case DebuggerCommand::kUnknownCMD:
  722. MS_LOG(INFO) << "UnknownCMD";
  723. break;
  724. case DebuggerCommand::kExitCMD:
  725. MS_LOG(INFO) << "ExitCMD";
  726. Exit(true);
  727. // Used for debugger termination
  728. run = true;
  729. break;
  730. case DebuggerCommand::kRunCMD:
  731. ProcessRunCMD(reply);
  732. if (GetRunLevel(reply) != "recheck") {
  733. // exit loop
  734. run = true;
  735. }
  736. break;
  737. case DebuggerCommand::kSetCMD:
  738. ProcessKSetCMD(reply);
  739. break;
  740. case DebuggerCommand::kViewCMD:
  741. ProcessKViewCMD(reply);
  742. break;
  743. case DebuggerCommand::kVersionMatchedCMD:
  744. MS_LOG(ERROR) << "Received unexpected Version Matched CMD from Mindinsight.";
  745. Exit();
  746. break;
  747. default:
  748. MS_LOG(ERROR) << "Received unknown CMD from Mindinsight";
  749. Exit();
  750. break;
  751. }
  752. }
  753. }
  754. void Debugger::ProcessRunCMD(const EventReply &reply) {
  755. MS_LOG(INFO) << "RunCMD";
  756. if (GetRunLevel(reply) == "recheck") {
  757. MS_LOG(INFO) << "rechecking all watchpoints";
  758. SendWatchpoints(CheckWatchpoints("", nullptr, true));
  759. } else {
  760. // no longer the initial suspension.
  761. initial_suspend_ = false;
  762. // print run cmd content
  763. // get run_level and node_name
  764. run_level_ = GetRunLevel(reply);
  765. node_name_ = GetNodeName(reply);
  766. MS_LOG(INFO) << "run_level: " << run_level_;
  767. MS_LOG(INFO) << "node_name_: " << node_name_;
  768. }
  769. }
  770. void Debugger::ProcessKSetCMD(const EventReply &reply) {
  771. MS_LOG(INFO) << "SetCMD";
  772. MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
  773. MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
  774. if (GetWatchpointDelete(reply)) {
  775. MS_LOG(INFO) << "Deleting watchpoint";
  776. RemoveWatchpoint(GetWatchpointID(reply));
  777. } else {
  778. MS_LOG(INFO) << "Setting watchpoint";
  779. MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
  780. ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
  781. for (const auto &node : recieved_nodes) {
  782. MS_LOG(INFO) << "node name: " << node.node_name();
  783. MS_LOG(INFO) << "node type: " << node.node_type();
  784. }
  785. ProtoVector<WatchCondition_Parameter> parameters = GetParameters(reply);
  786. for (const auto &parameter : parameters) {
  787. MS_LOG(INFO) << "parameter name: " << parameter.name();
  788. MS_LOG(INFO) << "parameter is disabled: " << parameter.disabled();
  789. MS_LOG(INFO) << "parameter value: " << parameter.value();
  790. }
  791. SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply), GetParameters(reply));
  792. }
  793. }
  794. void Debugger::ProcessKViewCMD(const EventReply &reply) {
  795. MS_LOG(INFO) << "ViewCMD";
  796. // print view cmd content
  797. ProtoVector<TensorProto> received_tensors = GetTensors(reply);
  798. for (auto received_tensor : received_tensors) {
  799. MS_LOG(INFO) << "tensor node name: " << received_tensor.node_name();
  800. MS_LOG(INFO) << "tensor slot: " << received_tensor.slot();
  801. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << received_tensor.finished() << std::noboolalpha;
  802. MS_LOG(INFO) << "tensor iter: " << received_tensor.iter();
  803. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << received_tensor.truncate() << std::noboolalpha;
  804. }
  805. switch (reply.view_cmd().level()) {
  806. case debugger::ViewCMD_Level::ViewCMD_Level_base:
  807. MS_LOG(INFO) << "Tensor base request.";
  808. ViewBaseLevel(reply);
  809. break;
  810. case debugger::ViewCMD_Level::ViewCMD_Level_statistics:
  811. MS_LOG(INFO) << "Tensor statistics request.";
  812. ViewStatLevel(reply);
  813. break;
  814. case debugger::ViewCMD_Level::ViewCMD_Level_value:
  815. MS_LOG(INFO) << "Tensor value request.";
  816. ViewValueLevel(reply);
  817. break;
  818. default:
  819. MS_LOG(DEBUG) << "Debug: Unknown tensor info level";
  820. break;
  821. }
  822. }
  823. void Debugger::ViewValueLevel(const EventReply &reply) {
  824. MS_LOG(INFO) << "Sending tensors";
  825. std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
  826. // print view cmd reply
  827. for (auto tensor : tensors) {
  828. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  829. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  830. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  831. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  832. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  833. MS_LOG(INFO) << "tensor dims: ";
  834. for (auto dim : tensor.dims()) {
  835. MS_LOG(INFO) << dim << ",";
  836. }
  837. MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
  838. }
  839. MS_EXCEPTION_IF_NULL(grpc_client_);
  840. EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
  841. if (send_tensors_reply.status() != debugger::EventReply::OK) {
  842. MS_LOG(ERROR) << "Error: SendTensors failed";
  843. }
  844. }
  845. void Debugger::ViewStatLevel(const EventReply &reply) {
  846. std::list<TensorSummary> tensor_stats_list = LoadTensorsStat(GetTensors(reply));
  847. EventReply send_tensors_stat_reply = grpc_client_->SendTensorStats(tensor_stats_list);
  848. if (send_tensors_stat_reply.status() != debugger::EventReply::OK) {
  849. MS_LOG(ERROR) << "Error: SendTensorsStats failed.";
  850. }
  851. }
  852. void Debugger::ViewBaseLevel(const EventReply &reply) {
  853. std::list<TensorBase> tensor_base_list = LoadTensorsBase(GetTensors(reply));
  854. EventReply send_tensor_base_reply = grpc_client_->SendTensorBase(tensor_base_list);
  855. if (send_tensor_base_reply.status() != debugger::EventReply::OK) {
  856. MS_LOG(ERROR) << "Error: SendTensorsBase failed.";
  857. }
  858. }
  859. void AddTensorProtoInfo(TensorProto *tensor_item, const TensorProto &tensor) {
  860. tensor_item->set_node_name(tensor.node_name());
  861. tensor_item->set_slot(tensor.slot());
  862. tensor_item->set_iter(tensor.iter());
  863. tensor_item->set_truncate(tensor.truncate());
  864. tensor_item->clear_tensor_content();
  865. tensor_item->clear_data_type();
  866. tensor_item->clear_dims();
  867. }
  868. void AddTensorStatInfo(const DebugServices::TensorStat &tensor_stat,
  869. std::list<TensorSummary> *const tensor_summary_list) {
  870. if (tensor_summary_list == nullptr) {
  871. MS_LOG(DEBUG) << "tensor_summary_list is nullptr.";
  872. return;
  873. }
  874. TensorSummary tensor_summary_item;
  875. TensorBase *tensor_base = tensor_summary_item.mutable_tensor_base();
  876. tensor_base->set_data_type(tensor_stat.dtype);
  877. tensor_base->set_data_size((int64_t)tensor_stat.data_size);
  878. for (auto elem : tensor_stat.shape) {
  879. tensor_base->add_shape(elem);
  880. }
  881. Statistics *tensor_statistics = tensor_summary_item.mutable_statistics();
  882. tensor_statistics->set_is_bool(tensor_stat.is_bool);
  883. tensor_statistics->set_max_value(static_cast<float>(tensor_stat.max_value));
  884. tensor_statistics->set_min_value(static_cast<float>(tensor_stat.min_value));
  885. tensor_statistics->set_avg_value(static_cast<float>(tensor_stat.avg_value));
  886. tensor_statistics->set_count(tensor_stat.count);
  887. tensor_statistics->set_neg_zero_count(tensor_stat.neg_zero_count);
  888. tensor_statistics->set_pos_zero_count(tensor_stat.pos_zero_count);
  889. tensor_statistics->set_nan_count(tensor_stat.nan_count);
  890. tensor_statistics->set_neg_inf_count(tensor_stat.neg_inf_count);
  891. tensor_statistics->set_pos_inf_count(tensor_stat.pos_inf_count);
  892. tensor_statistics->set_zero_count(tensor_stat.zero_count);
  893. tensor_summary_list->push_back(tensor_summary_item);
  894. }
  895. void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id,
  896. const ProtoVector<WatchCondition_Parameter> &parameters) {
  897. std::vector<std::tuple<std::string, bool>> check_node_list;
  898. std::vector<DebugServices::parameter_t> parameter_list;
  899. std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
  900. [](const WatchNode &node) -> std::tuple<std::string, bool> {
  901. return make_tuple(node.node_name(), node.node_type() == "scope");
  902. });
  903. std::transform(
  904. parameters.begin(), parameters.end(), std::back_inserter(parameter_list),
  905. [](const WatchCondition_Parameter &parameter) -> DebugServices::parameter_t {
  906. return DebugServices::parameter_t{parameter.name(), parameter.disabled(), parameter.value(), parameter.hit()};
  907. });
  908. debug_services_->AddWatchpoint(id, condition.condition(), condition.value(), check_node_list, parameter_list);
  909. }
  910. void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
  911. std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
  912. std::vector<std::string> name;
  913. std::vector<std::string> ret_name;
  914. std::vector<const char *> data_ptr;
  915. std::vector<ssize_t> data_size;
  916. std::vector<unsigned int> dtype;
  917. std::vector<std::vector<int64_t>> shape;
  918. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  919. // ret_name will contain tensor names that are found in TensorLoader
  920. // items in ret_name will be in the same order with tensors if found
  921. debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
  922. std::list<TensorProto> tensor_list;
  923. size_t result_index = 0;
  924. for (auto tensor : tensors) {
  925. ssize_t size_iter = 0;
  926. if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
  927. TensorProto tensor_item;
  928. tensor_item.set_finished(true);
  929. AddTensorProtoInfo(&tensor_item, tensor);
  930. tensor_list.push_back(tensor_item);
  931. continue;
  932. }
  933. ssize_t tensor_size = data_size[result_index];
  934. while (size_iter < tensor_size) {
  935. ssize_t chunk_size = g_chunk_size;
  936. TensorProto tensor_item;
  937. tensor_item.set_finished(false);
  938. if (tensor_size - size_iter <= g_chunk_size) {
  939. chunk_size = tensor_size - size_iter;
  940. tensor_item.set_finished(true);
  941. }
  942. AddTensorProtoInfo(&tensor_item, tensor);
  943. // return empty tensor if didn't find the requested tensor
  944. tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
  945. tensor_item.set_data_type((debugger::DataType)dtype[result_index]);
  946. for (auto &elem : shape[result_index]) {
  947. tensor_item.add_dims(elem);
  948. }
  949. // add tensor to result list and increment result_index to check next item in ret_name
  950. tensor_list.push_back(tensor_item);
  951. if (size_iter > INT_MAX - g_chunk_size) {
  952. MS_EXCEPTION(ValueError) << size_iter << " + " << g_chunk_size << " would lead to integer overflow!";
  953. }
  954. size_iter += g_chunk_size;
  955. }
  956. result_index++;
  957. }
  958. return tensor_list;
  959. }
  960. std::list<TensorBase> Debugger::LoadTensorsBase(const ProtoVector<TensorProto> &tensors) const {
  961. std::list<TensorBase> tensor_base_list;
  962. std::vector<std::string> name;
  963. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  964. std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list;
  965. debug_services_->SearchNodesTensors(name, &result_list);
  966. for (auto result : result_list) {
  967. auto tensor = std::get<1>(result);
  968. if (!tensor) {
  969. // tensor was not found, creating empty tensor base.
  970. TensorBase tensor_base_item;
  971. tensor_base_item.set_data_size(0);
  972. tensor_base_item.set_data_type(0);
  973. tensor_base_item.add_shape(0);
  974. tensor_base_list.push_back(tensor_base_item);
  975. continue;
  976. }
  977. // tensor was found creating tensor base object.
  978. TensorBase tensor_base_item;
  979. tensor_base_item.set_data_size((int64_t)tensor->GetByteSize());
  980. tensor_base_item.set_data_type((int32_t)tensor->GetType());
  981. for (auto elem : tensor->GetShape()) {
  982. tensor_base_item.add_shape(elem);
  983. }
  984. tensor_base_list.push_back(tensor_base_item);
  985. }
  986. return tensor_base_list;
  987. }
  988. std::list<TensorSummary> Debugger::LoadTensorsStat(const ProtoVector<TensorProto> &tensors) const {
  989. std::list<TensorSummary> tensor_summary_list;
  990. std::vector<std::string> name;
  991. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  992. std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list;
  993. debug_services_->SearchNodesTensors(name, &result_list);
  994. for (auto result : result_list) {
  995. auto tensor = std::get<1>(result);
  996. if (!tensor) {
  997. // tensor was not found, creating empty tensor summary.
  998. DebugServices::TensorStat tensor_stat;
  999. AddTensorStatInfo(tensor_stat, &tensor_summary_list);
  1000. continue;
  1001. }
  1002. // tensor was found creating tensor summary object.
  1003. DebugServices::TensorStat tensor_stat = debug_services_->GetTensorStatistics(tensor);
  1004. AddTensorStatInfo(tensor_stat, &tensor_summary_list);
  1005. }
  1006. return tensor_summary_list;
  1007. }
  1008. std::shared_ptr<TensorData> Debugger::GetTensor(const std::string &tensor_name) const {
  1009. return debug_services_->GetTensor(tensor_name);
  1010. }
  1011. DebugServices::TensorStat Debugger::GetTensorStatistics(std::shared_ptr<TensorData> tensor_data) const {
  1012. return DebugServices::GetTensorStatistics(tensor_data);
  1013. }
  1014. void Debugger::Exit(bool exit_success) {
  1015. // debugger will notify main thread to exit because main thread can only exit at step boundary.
  1016. MS_LOG(INFO) << "Exit Debugger";
  1017. SetEnableHeartbeat(false);
  1018. pipeline::GraphExecutorPy::DebugTerminate(true, exit_success);
  1019. }
  1020. std::list<WatchpointHit> Debugger::CheckWatchpoints(const std::string &watchnode, const CNodePtr &kernel,
  1021. bool recheck) {
  1022. std::vector<std::string> name;
  1023. std::vector<std::string> slot;
  1024. std::vector<int> condition;
  1025. std::vector<unsigned int> watchpoint_id;
  1026. std::vector<std::string> overflow_ops;
  1027. std::vector<std::vector<DebugServices::parameter_t>> parameters;
  1028. std::vector<int32_t> error_codes;
  1029. std::vector<std::shared_ptr<TensorData>> tensor_list;
  1030. if (watchnode.empty()) {
  1031. tensor_list = debug_services_->GetTensor();
  1032. } else {
  1033. tensor_list = debug_services_->GetNodeTensor(kernel);
  1034. }
  1035. std::vector<std::string> file_list;
  1036. MS_LOG(INFO) << "checkwatchpoints call for step " << num_step_;
  1037. debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, &parameters, &error_codes, overflow_ops,
  1038. file_list, &tensor_list, initial_suspend_, watchnode.empty(), recheck);
  1039. std::list<WatchpointHit> hits;
  1040. for (unsigned int i = 0; i < name.size(); i++) {
  1041. WatchpointHit hit;
  1042. std::vector<DebugServices::parameter_t> &parameter = parameters[i];
  1043. hit.set_id(watchpoint_id[i]);
  1044. hit.set_error_code(error_codes[i]);
  1045. // here TensorProto act as a tensor indicator, not sending tensor content
  1046. TensorProto *tensor_item = hit.mutable_tensor();
  1047. tensor_item->set_node_name(name[i]);
  1048. tensor_item->set_slot(slot[i]);
  1049. tensor_item->set_finished(true);
  1050. WatchCondition *condition_item = hit.mutable_watch_condition();
  1051. condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
  1052. for (const auto &p : parameter) {
  1053. auto x = condition_item->mutable_params()->Add();
  1054. x->set_name(p.name);
  1055. x->set_disabled(p.disabled);
  1056. x->set_value(p.value);
  1057. x->set_hit(p.hit);
  1058. x->set_actual_value(p.actual_value);
  1059. }
  1060. hits.push_back(hit);
  1061. }
  1062. return hits;
  1063. }
  1064. void Debugger::SendWatchpoints(const std::list<WatchpointHit> &points) {
  1065. // send info about watchpoint
  1066. if (!points.empty()) {
  1067. MS_EXCEPTION_IF_NULL(grpc_client_);
  1068. EventReply reply = grpc_client_->SendWatchpointHits(points);
  1069. if (reply.status() != reply.OK) {
  1070. MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
  1071. }
  1072. }
  1073. }
  1074. bool Debugger::DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath,
  1075. const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type,
  1076. TypeId device_type, const std::string &addr_format, size_t slot) const {
  1077. return debug_services_.get()->DumpTensorToFile(tensor_name, trans_flag, filepath, host_fmt, host_shape, host_type,
  1078. device_type, addr_format, slot);
  1079. }
  1080. bool Debugger::LoadNewTensor(const std::shared_ptr<TensorData> &tensor, bool keep_prev) {
  1081. return debug_services_.get()->LoadNewTensor(tensor, keep_prev);
  1082. }
  1083. bool Debugger::debugger_enabled() const { return debugger_enabled_; }
  1084. DebuggerCommand GetCommand(const EventReply &reply) {
  1085. DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
  1086. switch (reply.cmd_case()) {
  1087. case debugger::EventReply::CmdCase::kExit:
  1088. cmd = DebuggerCommand::kExitCMD;
  1089. break;
  1090. case debugger::EventReply::CmdCase::kRunCmd:
  1091. cmd = DebuggerCommand::kRunCMD;
  1092. break;
  1093. case debugger::EventReply::CmdCase::kSetCmd:
  1094. cmd = DebuggerCommand::kSetCMD;
  1095. break;
  1096. case debugger::EventReply::CmdCase::kViewCmd:
  1097. cmd = DebuggerCommand::kViewCMD;
  1098. break;
  1099. case debugger::EventReply::CmdCase::kVersionMatched:
  1100. cmd = DebuggerCommand::kVersionMatchedCMD;
  1101. break;
  1102. default:
  1103. MS_LOG(DEBUG) << "Debug: UnknownCMD";
  1104. break;
  1105. }
  1106. return cmd;
  1107. }
  1108. ProtoVector<WatchCondition_Parameter> GetParameters(const EventReply &reply) {
  1109. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  1110. MS_LOG(ERROR) << "Error: Can not get Parameters from command. Returning default value: ProtoVector<Parameter>().";
  1111. return ProtoVector<WatchCondition_Parameter>();
  1112. }
  1113. return reply.set_cmd().watch_condition().params();
  1114. }
  1115. ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
  1116. if (!reply.has_set_cmd()) {
  1117. MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
  1118. return ProtoVector<WatchNode>();
  1119. }
  1120. return reply.set_cmd().watch_nodes();
  1121. }
  1122. std::string GetRunLevel(const EventReply &reply) {
  1123. if (!reply.has_run_cmd()) {
  1124. MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
  1125. "";
  1126. return "";
  1127. }
  1128. return reply.run_cmd().run_level();
  1129. }
  1130. std::string GetNodeName(const EventReply &reply) {
  1131. if (!reply.has_run_cmd()) {
  1132. MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
  1133. "";
  1134. return "";
  1135. }
  1136. return reply.run_cmd().node_name();
  1137. }
  1138. WatchCondition GetWatchcondition(const EventReply &reply) {
  1139. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  1140. MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
  1141. return WatchCondition();
  1142. }
  1143. return reply.set_cmd().watch_condition();
  1144. }
  1145. int32_t GetWatchpointID(const EventReply &reply) {
  1146. if (!reply.has_set_cmd()) {
  1147. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
  1148. return 0;
  1149. }
  1150. return reply.set_cmd().id();
  1151. }
  1152. bool GetWatchpointDelete(const EventReply &reply) {
  1153. if (!reply.has_set_cmd()) {
  1154. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
  1155. return false;
  1156. }
  1157. return reply.set_cmd().delete_();
  1158. }
  1159. ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
  1160. if (!reply.has_view_cmd()) {
  1161. MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
  1162. return ProtoVector<TensorProto>();
  1163. }
  1164. return reply.view_cmd().tensors();
  1165. }
  1166. std::string GetTensorFullName(const TensorProto &tensor) {
  1167. string node_name = tensor.node_name();
  1168. if (tensor.truncate()) {
  1169. // scopes in node name are separated by '/'
  1170. // use the name without scope if truncate is true
  1171. std::size_t found = node_name.find_last_of("/");
  1172. node_name = node_name.substr(found + 1);
  1173. }
  1174. return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
  1175. }
  1176. bool GetMiVersionMatched(const EventReply &reply) { return reply.version_matched(); }
  1177. bool Debugger::partial_memory() const { return partial_memory_; }
  1178. void Debugger::SetEnableHeartbeat(bool enabled) { enable_heartbeat_ = enabled; }
  1179. void Debugger::SetCurNode(const std::string &cur_name) {
  1180. // access lock for public method
  1181. std::lock_guard<std::mutex> a_lock(access_lock_);
  1182. cur_name_ = cur_name;
  1183. }
  1184. std::string Debugger::run_level() const { return run_level_; }
  1185. void Debugger::SetTrainingDone(bool training_done) { training_done_ = training_done; }
  1186. bool Debugger::CheckPort(const std::string &port) const {
  1187. int num = 0;
  1188. const int min_port_num = 1;
  1189. const int max_port_num = 65535;
  1190. const int decimal = 10;
  1191. if (port[0] == '0' && port[1] != '\0') return false;
  1192. int i = 0;
  1193. while (port[i] != '\0') {
  1194. if (port[i] < '0' || port[i] > '9') return false;
  1195. num = num * decimal + (port[i] - '0');
  1196. if (num > max_port_num) return false;
  1197. i++;
  1198. }
  1199. if (num < min_port_num) return false;
  1200. return true;
  1201. }
  1202. bool Debugger::CheckIp(const std::string &host) const {
  1203. std::regex reg_ip(
  1204. "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
  1205. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  1206. "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
  1207. "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
  1208. std::smatch smat;
  1209. std::string host_str = host;
  1210. return std::regex_match(host_str, smat, reg_ip);
  1211. }
  1212. uint32_t Debugger::GetFirstRunGraphId() const { return rungraph_id_list_.front(); }
  1213. void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
  1214. MS_EXCEPTION_IF_NULL(anf_node);
  1215. if (!anf_node->isa<Parameter>() && !anf_node->isa<ValueNode>()) {
  1216. return;
  1217. }
  1218. // When MindRT is used, only ValueNodes and ParameterWeights can be loaded from device to host
  1219. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) && (device_target_ == kGPUDevice)) {
  1220. if (!anf_node->isa<ValueNode>() &&
  1221. !(anf_node->isa<Parameter>() && AnfAlgo::IsParameterWeight(anf_node->cast<ParameterPtr>()))) {
  1222. return;
  1223. }
  1224. }
  1225. // for parameters and value nodes, set its execution order to be 0;
  1226. int exec_order = 0;
  1227. std::string node_name = GetKernelNodeName(anf_node);
  1228. GetFileKernelName(NOT_NULL(&node_name));
  1229. // check if output adde exists, if not, return;
  1230. if (!AnfAlgo::OutputAddrExist(anf_node, output_index)) {
  1231. return;
  1232. }
  1233. auto addr = AnfAlgo::GetOutputAddr(anf_node, output_index);
  1234. MS_EXCEPTION_IF_NULL(addr);
  1235. auto type = AnfAlgo::GetOutputInferDataType(anf_node, output_index);
  1236. if (!IsTypeDebuggerSupported(type)) {
  1237. return;
  1238. }
  1239. auto format = kOpFormat_DEFAULT;
  1240. string tensor_name = node_name + ':' + "0";
  1241. ShapeVector int_shapes = trans::GetRuntimePaddingShape(anf_node, output_index);
  1242. bool keep_prev;
  1243. if (anf_node->isa<Parameter>()) {
  1244. keep_prev = true;
  1245. debug_services_->MoveTensorCurrentToPrev(tensor_name);
  1246. } else {
  1247. keep_prev = false;
  1248. }
  1249. bool ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, keep_prev);
  1250. if (!ret) {
  1251. MS_LOG(ERROR) << "LoadMemToHost:"
  1252. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1253. }
  1254. }
  1255. void Debugger::LoadParametersAndConst() {
  1256. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  1257. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1258. // load parameters
  1259. MS_LOG(INFO) << "Start to load Parameters for graph " << graph_ptr_->graph_id() << ".";
  1260. const auto &parameters = graph_ptr_->inputs();
  1261. for (auto &item : parameters) {
  1262. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
  1263. }
  1264. // load value nodes
  1265. // get all constant values from the graph
  1266. MS_LOG(INFO) << "Start to load value nodes for graph " << graph_ptr_->graph_id() << ".";
  1267. const auto value_nodes = graph_ptr_->graph_value_nodes();
  1268. for (auto &item : value_nodes) {
  1269. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
  1270. }
  1271. }
  1272. void Debugger::LoadParametersAndConst(const KernelGraphPtr &graph) {
  1273. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  1274. MS_EXCEPTION_IF_NULL(graph);
  1275. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1276. // load parameters
  1277. MS_LOG(INFO) << "Start to load Parameters for graph " << graph->graph_id() << ".";
  1278. const auto &parameters = graph_ptr_->inputs();
  1279. for (auto &item : parameters) {
  1280. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
  1281. }
  1282. // load value nodes
  1283. // get all constant values from the graph
  1284. MS_LOG(INFO) << "Start to load value nodes for graph " << graph->graph_id() << ".";
  1285. const auto value_nodes = graph_ptr_->graph_value_nodes();
  1286. for (auto &item : value_nodes) {
  1287. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
  1288. }
  1289. }
  1290. void Debugger::LoadGraphOutputs() {
  1291. if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
  1292. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1293. const auto &apply_kernels = graph_ptr_->execution_order();
  1294. // for kernels, execution order starts from 1
  1295. int exec_order = 1;
  1296. for (const auto &node : apply_kernels) {
  1297. MS_EXCEPTION_IF_NULL(node);
  1298. std::string kernel_name = GetKernelNodeName(node);
  1299. auto output_size = AnfAlgo::GetOutputTensorNum(node);
  1300. if (partial_memory_) {
  1301. if (!debug_services_->IsWatchPoint(kernel_name, node)) {
  1302. continue;
  1303. }
  1304. }
  1305. for (size_t j = 0; j < output_size; ++j) {
  1306. if (!AnfAlgo::OutputAddrExist(node, j)) {
  1307. MS_LOG(INFO) << "Cannot find output addr for slot " << j << " for " << kernel_name;
  1308. continue;
  1309. }
  1310. auto addr = AnfAlgo::GetOutputAddr(node, j);
  1311. MS_EXCEPTION_IF_NULL(addr);
  1312. auto type = AnfAlgo::GetOutputInferDataType(node, j);
  1313. if (!IsTypeDebuggerSupported(type)) {
  1314. continue;
  1315. }
  1316. auto format = kOpFormat_DEFAULT;
  1317. string tensor_name = kernel_name + ':' + std::to_string(j);
  1318. ShapeVector int_shapes = trans::GetRuntimePaddingShape(node, j);
  1319. auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false);
  1320. if (!ret) {
  1321. MS_LOG(ERROR) << "LoadMemToHost:"
  1322. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1323. }
  1324. }
  1325. exec_order = exec_order + 1;
  1326. }
  1327. }
  1328. void Debugger::UpdateStepNum(const session::KernelGraph *graph) {
  1329. MS_EXCEPTION_IF_NULL(graph);
  1330. MS_EXCEPTION_IF_NULL(debugger_);
  1331. // update step number if we are processing the first graph (to support multigraph)
  1332. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()) &&
  1333. (graph->graph_id() == debugger_->GetFirstRunGraphId())) {
  1334. // access lock for public method
  1335. std::lock_guard<std::mutex> a_lock(access_lock_);
  1336. ++num_step_;
  1337. }
  1338. }
  1339. void Debugger::UpdateStepNumGPU() {
  1340. // UpdateStepNum with DebugActor::DebugOnStepEnd
  1341. if (device_target_ == kGPUDevice && (debugger_enabled_ || DumpDataEnabledIteration())) {
  1342. // access lock for public method
  1343. std::lock_guard<std::mutex> a_lock(access_lock_);
  1344. ++num_step_;
  1345. }
  1346. }
  1347. void Debugger::ClearCurrentData() {
  1348. if ((device_target_ == kGPUDevice) && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration())) {
  1349. if (debug_services_) {
  1350. debug_services_->EmptyCurrentTensor();
  1351. } else {
  1352. MS_LOG(ERROR) << "debug_services_ is nullptr";
  1353. }
  1354. }
  1355. }
  1356. bool Debugger::TensorExistsInCurrent(const std::string &tensor_name) {
  1357. return debug_services_->TensorExistsInCurrent(tensor_name);
  1358. }
  1359. } // namespace mindspore