You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger.cc 46 kB

4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <cstdio>
  18. #include <fstream>
  19. #include <tuple>
  20. #include <vector>
  21. #include <algorithm>
  22. #include <iostream>
  23. #include <cstring>
  24. #include <utility>
  25. #include <map>
  26. #include <regex>
  27. #include "debug/debugger/debugger.h"
  28. #include "debug/data_dump/dump_json_parser.h"
  29. #include "pipeline/jit/pipeline.h"
  30. #include "backend/session/anf_runtime_algorithm.h"
  31. #include "runtime/device/kernel_runtime_manager.h"
  32. #include "runtime/device/kernel_runtime.h"
  33. #include "debug/data_dump/e2e_dump.h"
  34. #include "utils/config_manager.h"
  35. #include "debug/env_config_parser.h"
  36. #include "utils/comm_manager.h"
  37. #include "runtime/hardware/device_context_manager.h"
  38. #include "debug/anf_ir_dump.h"
  39. #include "debug/anf_ir_utils.h"
  40. #ifdef ENABLE_DEBUGGER
  41. #include "debug/debugger/proto_exporter.h"
  42. #else
  43. #include "debug/debugger/proto_exporter_stub.h"
  44. #endif
  45. using debugger::Chunk;
  46. using debugger::EventReply;
  47. using debugger::GraphProto;
  48. using debugger::ModelProto;
  49. using debugger::TensorProto;
  50. using debugger::WatchCondition;
  51. using debugger::WatchCondition_Condition_inf;
  52. using debugger::WatchCondition_Condition_nan;
  53. using debugger::WatchCondition_Parameter;
  54. using debugger::WatchNode;
  55. using debugger::WatchpointHit;
  56. namespace mindspore {
  57. static constexpr auto g_chunk_size = 1024 * 1024 * 3;
  58. DebuggerPtr Debugger::debugger_ = nullptr;
  59. std::mutex Debugger::instance_lock_;
  60. Debugger::Debugger()
  61. : grpc_client_(nullptr),
  62. debug_services_(nullptr),
  63. device_id_(0),
  64. device_target_(""),
  65. num_step_(0),
  66. debugger_enabled_(false),
  67. suspended_at_last_kernel_(false),
  68. run_level_(""),
  69. node_name_(""),
  70. cur_name_(""),
  71. training_done_(false),
  72. is_dataset_graph_(false),
  73. partial_memory_(false),
  74. initial_suspend_(true),
  75. not_dataset_graph_sum_(0),
  76. version_("") {
  77. CheckDebuggerEnabledParam();
  78. auto ms_context = MsContext::GetInstance();
  79. MS_EXCEPTION_IF_NULL(ms_context);
  80. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  81. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  82. if (device_target == kCPUDevice) {
  83. MS_LOG(WARNING) << "Not enabling debugger. Debugger does not support CPU.";
  84. } else if (CheckDebuggerEnabled()) {
  85. // configure partial memory reuse
  86. partial_memory_ = CheckDebuggerPartialMemoryEnabled();
  87. // switch memory reuse on or off
  88. EnvConfigParser::GetInstance().SetSysMemreuse(partial_memory_);
  89. // print some message about memory reuse to user
  90. if (partial_memory_) {
  91. MS_LOG(WARNING)
  92. << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
  93. "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
  94. } else {
  95. MS_LOG(WARNING)
  96. << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
  97. "usage for large models.";
  98. }
  99. }
  100. }
  101. void Debugger::Init(const uint32_t device_id, const std::string device_target) {
  102. // access lock for public method
  103. std::lock_guard<std::mutex> a_lock(access_lock_);
  104. // save device_id
  105. MS_LOG(INFO) << "Debugger got device_id: " << device_id;
  106. device_id_ = device_id;
  107. MS_LOG(INFO) << "Debugger got device_target: " << device_target;
  108. device_target_ = device_target;
  109. version_ = "1.3.0";
  110. }
  111. bool IsTypeDebuggerSupported(TypeId type) {
  112. if (type < TypeId::kNumberTypeEnd && type > TypeId::kNumberTypeBegin && type != kNumberTypeComplex64) {
  113. return true;
  114. } else {
  115. MS_LOG(INFO) << "Debugger does not support type: " << TypeIdLabel(type);
  116. return false;
  117. }
  118. }
  119. void Debugger::EnableDebugger() {
  120. // reset some of the class members
  121. num_step_ = 0;
  122. debugger_enabled_ = false;
  123. partial_memory_ = false;
  124. grpc_client_ = nullptr;
  125. debug_services_ = nullptr;
  126. // see if dump using debugger backend is enabled
  127. bool dump_enabled = CheckDebuggerDumpEnabled();
  128. MS_LOG(INFO) << "dump using debugger backend = " << dump_enabled;
  129. // check if debugger enabled
  130. debugger_enabled_ = CheckDebuggerEnabled();
  131. MS_LOG(INFO) << "debugger_enabled_ = " << debugger_enabled_;
  132. if (!debugger_enabled_ && !dump_enabled) {
  133. MS_LOG(INFO) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
  134. return;
  135. }
  136. if (debugger_enabled_) {
  137. std::string host = "localhost";
  138. // configure grpc port
  139. std::string env_port_str = common::GetEnv("MS_DEBUGGER_PORT");
  140. std::string port;
  141. if (!env_port_str.empty()) {
  142. if (CheckPort(env_port_str)) {
  143. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
  144. port = env_port_str;
  145. } else {
  146. debugger_enabled_ = false;
  147. MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to "
  148. "65535";
  149. }
  150. } else {
  151. port = "50051";
  152. if (!CheckPort(port)) {
  153. MS_EXCEPTION(ValueError) << "Default MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to 65535";
  154. }
  155. MS_LOG(INFO) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
  156. }
  157. // initialize grpc client
  158. grpc_client_ = std::make_unique<GrpcClient>(host, port);
  159. }
  160. debug_services_ = std::make_unique<DebugServices>();
  161. }
  162. void Debugger::CheckDatasetSinkMode() {
  163. if (CheckDebuggerDumpEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  164. MS_EXCEPTION(NotSupportError)
  165. << "e2e_dump not supported on GPU with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  166. }
  167. if (CheckDebuggerEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
  168. MS_EXCEPTION(NotSupportError)
  169. << "Debugger is not supported with dataset_sink_mode=True. Please set dataset_sink_mode=False";
  170. }
  171. }
  172. bool Debugger::CheckDebuggerDumpEnabled() const {
  173. // see if dump is enabled
  174. if (device_target_ == kGPUDevice) {
  175. return device::KernelRuntime::DumpDataEnabled();
  176. }
  177. return false;
  178. }
  179. bool Debugger::CheckDebuggerEnabled() const {
  180. // get env variables to configure debugger
  181. std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
  182. if (!env_enable_str.empty()) {
  183. (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
  184. if ((env_enable_str == "1" || env_enable_str == "true") && device_target_ != kCPUDevice) {
  185. return true;
  186. }
  187. }
  188. return false;
  189. }
  190. void Debugger::CheckDebuggerEnabledParam() const {
  191. // check the value of env variable ENABLE_MS_DEBUGGER
  192. std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
  193. if (!env_enable_str.empty()) {
  194. (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
  195. if (env_enable_str != "0" && env_enable_str != "1" && env_enable_str != "false" && env_enable_str != "true") {
  196. MS_LOG(WARNING) << "Env variable ENABLE_MS_DEBUGGER should be True/False/1/0 (case insensitive), but get: "
  197. << env_enable_str;
  198. }
  199. }
  200. }
  201. bool Debugger::CheckDebuggerPartialMemoryEnabled() const {
  202. std::string env_partial_mem_str = common::GetEnv("MS_DEBUGGER_PARTIAL_MEM");
  203. if (!env_partial_mem_str.empty()) {
  204. MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
  205. if (env_partial_mem_str == "1") {
  206. return true;
  207. }
  208. }
  209. return false;
  210. }
  211. bool Debugger::DebuggerBackendEnabled() const { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
  212. void Debugger::Reset() {
  213. // access lock for public method
  214. std::lock_guard<std::mutex> a_lock(access_lock_);
  215. // reset components
  216. device_id_ = 0;
  217. device_target_ = "";
  218. num_step_ = 0;
  219. debugger_enabled_ = false;
  220. is_dataset_graph_ = false;
  221. partial_memory_ = false;
  222. graph_ptr_ = nullptr;
  223. grpc_client_ = nullptr;
  224. debug_services_ = nullptr;
  225. graph_proto_list_.clear();
  226. graph_ptr_list_.clear();
  227. }
  228. void Debugger::PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs) {
  229. // Only GPU is supported for MindRTBackend
  230. if (device_target_ != kGPUDevice) {
  231. return;
  232. }
  233. for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
  234. const auto &graph = graphs[graph_index];
  235. if (debugger_) {
  236. debugger_->PreExecute(graph);
  237. }
  238. DumpSetup(graph);
  239. }
  240. }
  241. void Debugger::PreExecute(const KernelGraphPtr &graph_ptr) {
  242. // access lock for public method
  243. std::lock_guard<std::mutex> a_lock(access_lock_);
  244. CheckDatasetSinkMode();
  245. auto graph_id = graph_ptr->graph_id();
  246. // collect rungrap_ids to update step number in multigraph case
  247. if (!rungraph_id_list_.size()) {
  248. rungraph_id_list_.push_back(graph_id);
  249. } else {
  250. if (std::find(rungraph_id_list_.begin(), rungraph_id_list_.end(), graph_id) == rungraph_id_list_.end()) {
  251. rungraph_id_list_.push_back(graph_id);
  252. }
  253. }
  254. // multiple graphs
  255. if (graph_proto_list_.size() > 1) {
  256. // there are more than one graphs are not dataset_graph
  257. if (not_dataset_graph_sum_ > 0) {
  258. // only try to enable debugger if they are not all dataset graphs
  259. if (!debugger_enabled_) {
  260. EnableDebugger();
  261. }
  262. if (debugger_enabled_) {
  263. // only send compiled graphs once at the initial step.
  264. auto dbg_graph_ptr = graph_ptr_;
  265. // use current graph ptr to load parameters
  266. graph_ptr_ = graph_ptr;
  267. LoadParametersAndConst();
  268. // revert graph ptr to original value
  269. graph_ptr_ = dbg_graph_ptr;
  270. SendMultiGraphsAndSuspend(graph_proto_list_);
  271. graph_proto_list_.clear();
  272. }
  273. }
  274. } else if (graph_proto_list_.size() == 1) {
  275. // single graph, and not the initial step
  276. if (device_target_ == kGPUDevice && num_step_ != 0) {
  277. if (debugger_enabled_ && !(run_level_ == "node" && suspended_at_last_kernel_)) {
  278. CommandLoop();
  279. }
  280. debug_services_->ResetLoadedTensors();
  281. }
  282. // In single graph case, reset graph_ptr_ to be nullptr for the initial step
  283. if (num_step_ == 0) {
  284. graph_ptr_ = nullptr;
  285. CheckGraphPtr(graph_ptr);
  286. }
  287. } else if (debugger_enabled_ && graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
  288. // Multiple graph, and not the initial step,
  289. // stop only when receive the first sub run graph for each step
  290. // if we have stopped for the last kernel before, no need to stop again
  291. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  292. return;
  293. }
  294. if (!(run_level_ == "node" && suspended_at_last_kernel_)) {
  295. CommandLoop();
  296. }
  297. debug_services_->ResetLoadedTensors();
  298. }
  299. // resets for the new graph
  300. suspended_at_last_kernel_ = false;
  301. }
  302. bool Debugger::DumpDataEnabledIteration() const {
  303. auto &dump_json_parser = DumpJsonParser::GetInstance();
  304. if (!dump_json_parser.e2e_dump_enabled()) {
  305. return false;
  306. }
  307. auto cur_iter = dump_json_parser.cur_dump_iter();
  308. if (dump_json_parser.IsDumpIter(cur_iter)) {
  309. return true;
  310. }
  311. return false;
  312. }
  313. uint32_t Debugger::GetRankID() {
  314. auto ms_context = MsContext::GetInstance();
  315. MS_EXCEPTION_IF_NULL(ms_context);
  316. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  317. uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  318. const auto &device_context =
  319. device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
  320. uint32_t rank_id = device_context->GetRankID();
  321. return rank_id;
  322. }
  323. void Debugger::Dump(const KernelGraphPtr &kernel_graph) const {
  324. uint32_t rank_id = GetRankID();
  325. if (debugger_->DebuggerBackendEnabled()) {
  326. MS_EXCEPTION_IF_NULL(kernel_graph);
  327. (void)E2eDump::DumpParametersAndConstData(kernel_graph.get(), rank_id, debugger_.get());
  328. } else {
  329. DumpJsonParser::GetInstance().UpdateDumpIter();
  330. }
  331. }
  332. void Debugger::DumpSingleNode(const CNodePtr &node, uint32_t graph_id) {
  333. if (debugger_->DebuggerBackendEnabled()) {
  334. uint32_t rank_id = GetRankID();
  335. (void)E2eDump::DumpSingleNodeData(node, graph_id, rank_id, debugger_.get());
  336. }
  337. }
  338. void Debugger::DumpSetup(const KernelGraphPtr &kernel_graph) const {
  339. MS_LOG(INFO) << "Start!";
  340. uint32_t rank_id = GetRankID();
  341. MS_EXCEPTION_IF_NULL(kernel_graph);
  342. E2eDump::DumpSetup(kernel_graph.get(), rank_id);
  343. MS_LOG(INFO) << "Finish!";
  344. }
  345. void Debugger::DumpInGraphCompiler(const KernelGraphPtr &kernel_graph) {
  346. // This function will be called for new GPU runtime using MindRTBackend
  347. auto &json_parser = DumpJsonParser::GetInstance();
  348. if (json_parser.e2e_dump_enabled()) {
  349. uint32_t rank_id = GetRankID();
  350. kernel_graph->set_root_graph_id(kernel_graph->graph_id());
  351. std::string final_graph = "trace_code_graph_" + std::to_string(kernel_graph->graph_id());
  352. std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id);
  353. std::string target_dir = root_dir + "/graphs";
  354. std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
  355. DumpIRProtoWithSrcInfo(kernel_graph, final_graph, target_dir, kDebugWholeStack);
  356. DumpIR("trace_code_graph", kernel_graph, true, kWholeStack, ir_file_path);
  357. DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(kernel_graph->graph_id()) + ".csv", root_dir,
  358. kernel_graph->execution_order());
  359. }
  360. }
  361. void Debugger::PostExecuteGraphDebugger() {
  362. // On CPU, update dump iteration, Parameters and consts are not dumped here
  363. if (device_target_ == kCPUDevice) {
  364. DumpJsonParser::GetInstance().UpdateDumpIter();
  365. return;
  366. }
  367. // Only GPU is supported for MindRTBackend
  368. if (device_target_ != kGPUDevice) {
  369. return;
  370. }
  371. // LoadParametersAndConst for all the graphs
  372. for (auto graph : graph_ptr_list_) {
  373. debugger_->LoadParametersAndConst(graph);
  374. }
  375. // debug used for dump
  376. if (debugger_ && debugger_->CheckDebuggerDumpEnabled()) {
  377. // Dump Parameters and consts
  378. for (auto graph : graph_ptr_list_) {
  379. debugger_->Dump(graph);
  380. if (!debugger_->debugger_enabled()) {
  381. debugger_->ClearCurrentData();
  382. }
  383. }
  384. }
  385. if (debugger_) {
  386. debugger_->PostExecute();
  387. }
  388. }
  389. void Debugger::PostExecute() {
  390. // access lock for public method
  391. std::lock_guard<std::mutex> a_lock(access_lock_);
  392. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  393. return;
  394. }
  395. if (debugger_->DebuggerBackendEnabled()) {
  396. // analyze tensor data and send the watchpoints been hit
  397. if (debugger_enabled_ && !is_dataset_graph_) {
  398. if (device_target_ != kGPUDevice) {
  399. num_step_++;
  400. }
  401. SendWatchpoints(CheckWatchpoints());
  402. // no need to suspend at each graph for GPU, suspension happens in preExecute
  403. if (device_target_ != kGPUDevice) {
  404. CommandLoop();
  405. }
  406. }
  407. // Only keep parameters in the current map
  408. // GPU ResetLoadedTensors happens in preExecute
  409. if (device_target_ != kGPUDevice) {
  410. debug_services_->ResetLoadedTensors();
  411. }
  412. }
  413. }
  414. bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) const {
  415. if (debugger_enabled_ && !is_dataset_graph_) {
  416. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  417. // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
  418. if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
  419. return true;
  420. }
  421. }
  422. return false;
  423. }
  424. void Debugger::PostExecuteNode(const CNodePtr &kernel, bool last_kernel) {
  425. // access lock for public method
  426. std::lock_guard<std::mutex> a_lock(access_lock_);
  427. if (pipeline::ExecutorPy::GetDebugTerminate()) {
  428. return;
  429. }
  430. if (debugger_enabled_ && !is_dataset_graph_) {
  431. auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
  432. // if kernel is watchpoint,and get hit. suspend.
  433. bool hit_empty_flag = true;
  434. if (is_watchpoint) {
  435. auto hits = CheckWatchpoints(cur_name_, kernel);
  436. if (!hits.empty()) {
  437. SendWatchpoints(hits);
  438. CommandLoop();
  439. hit_empty_flag = false;
  440. }
  441. }
  442. if (hit_empty_flag && run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) {
  443. // if kernel is not watchpoint and is next_to or continue_to node, suspend
  444. // sets a bool to be checked in preExecute to avoid double stopping at last kernel in the last graph
  445. if (last_kernel) {
  446. suspended_at_last_kernel_ = true;
  447. }
  448. CommandLoop();
  449. }
  450. return;
  451. }
  452. }
  453. void Debugger::PostDebugOp() {
  454. // access lock for public method
  455. std::lock_guard<std::mutex> a_lock(access_lock_);
  456. // suspend if debugger is enabled
  457. if (debugger_enabled_ && !is_dataset_graph_) {
  458. MS_LOG(INFO) << "Debugger suspend at debug_op";
  459. CommandLoop();
  460. }
  461. }
  462. void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) {
  463. if (graph_ptr_ != graph_ptr) {
  464. MS_LOG(INFO) << "LoadGraphs Debugger got new graph: " << graph_ptr->graph_id();
  465. // save new graph_ptr
  466. graph_ptr_ = graph_ptr;
  467. CheckDatasetGraph();
  468. if (!is_dataset_graph_) {
  469. // get proto for new graph_ptr
  470. auto graph_proto = GetGraphProto(graph_ptr);
  471. // add new graph proto to graph_proto_list_
  472. graph_proto_list_.push_back(graph_proto);
  473. graph_ptr_list_.push_back(graph_ptr);
  474. not_dataset_graph_sum_++;
  475. }
  476. // reset is_dataset_graph to be false
  477. is_dataset_graph_ = false;
  478. }
  479. }
  480. // In single graph cases, check single graph ptr
  481. void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
  482. if (graph_ptr_ != graph_ptr) {
  483. MS_LOG(INFO) << "CheckGraphPtr Debugger got new graph: " << graph_ptr->graph_id();
  484. // save new graph_ptr
  485. graph_ptr_ = graph_ptr;
  486. if (!is_dataset_graph_) {
  487. // only try to enable debugger if it is not a dataset graph
  488. EnableDebugger();
  489. if (debugger_enabled_) {
  490. LoadParametersAndConst();
  491. // get graph proto and send to Mindinsight
  492. auto graph_proto = graph_proto_list_.front();
  493. SendGraphAndSuspend(graph_proto);
  494. }
  495. }
  496. }
  497. }
  498. void Debugger::CheckDatasetGraph() {
  499. // print parameter node names
  500. const auto &params = graph_ptr_->inputs();
  501. for (const auto &param : params) {
  502. MS_LOG(INFO) << "param: " << GetKernelNodeName(param);
  503. }
  504. // check if there is GetNext or InitDataSetQueue node
  505. const auto &nodes = graph_ptr_->execution_order();
  506. for (const auto &node : nodes) {
  507. auto node_name = AnfAlgo::GetCNodeName(node);
  508. MS_LOG(INFO) << "node: " << GetKernelNodeName(node);
  509. if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
  510. MS_LOG(INFO) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
  511. << node_name;
  512. is_dataset_graph_ = true;
  513. return;
  514. }
  515. }
  516. is_dataset_graph_ = false;
  517. }
  518. GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
  519. // convert kernel graph to debugger modelproto
  520. ModelProto model = GetDebuggerFuncGraphProto(graph_ptr);
  521. return model.graph();
  522. }
  523. void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
  524. if (SendMetadata(true)) {
  525. // send graph to Mindinsight server
  526. EventReply reply = grpc_client_->SendGraph(graph_proto);
  527. if (reply.status() != reply.OK) {
  528. MS_LOG(ERROR) << "Error: SendGraph failed";
  529. }
  530. // enter command loop, wait and process commands
  531. CommandLoop();
  532. }
  533. }
  534. bool Debugger::SendMetadata(bool version_check) {
  535. // prepare metadata
  536. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  537. Metadata metadata;
  538. metadata.set_device_name(device_name);
  539. metadata.set_cur_step(num_step_);
  540. metadata.set_backend(device_target_);
  541. metadata.set_cur_node(cur_name_);
  542. metadata.set_training_done(training_done_);
  543. metadata.set_ms_version(version_);
  544. MS_LOG(INFO) << "Is training done?" << training_done_;
  545. // set graph munber to not_dataset_graph_sum_
  546. metadata.set_graph_num(not_dataset_graph_sum_);
  547. EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
  548. bool ret = false;
  549. if (reply_metadata.status() == reply_metadata.OK) {
  550. if (version_check) {
  551. // get type of the command in meta data reply, it should be version matched
  552. DebuggerCommand cmd = GetCommand(reply_metadata);
  553. if (cmd != DebuggerCommand::kVersionMatchedCMD) {
  554. MS_LOG(ERROR) << "MindInsight version is too old, Mindspore version is " << version_;
  555. Exit();
  556. } else {
  557. if (GetMiVersionMatched(reply_metadata)) {
  558. MS_LOG(INFO) << "MindSpore version is " << version_ << " matches MindInsight version.";
  559. ret = true;
  560. } else {
  561. MS_LOG(ERROR) << "MindSpore version " << version_ << ", did not match MindInsight version.";
  562. CommandLoop();
  563. }
  564. }
  565. } else {
  566. // version check is done before so we can just return true here
  567. ret = true;
  568. }
  569. } else {
  570. MS_LOG(ERROR) << "Error: SendMetadata failed";
  571. }
  572. return ret;
  573. }
  574. void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list) {
  575. if (!SendMetadata(true)) {
  576. return;
  577. }
  578. // send multiple graphs to mindinght server
  579. // split graph into chunks if one graph is larger than chunk size
  580. std::list<Chunk> chunked_graph_proto_list;
  581. Chunk chunk;
  582. for (auto graph : graph_proto_list) {
  583. std::string str = graph.SerializeAsString();
  584. auto graph_size = graph.ByteSize();
  585. if (graph_size > g_chunk_size) {
  586. auto sub_graph_str = grpc_client_->ChunkString(str, graph_size);
  587. for (unsigned int i = 0; i < sub_graph_str.size(); i++) {
  588. chunk.set_buffer(sub_graph_str[i]);
  589. if (i < sub_graph_str.size() - 1) {
  590. chunk.set_finished(false);
  591. } else {
  592. chunk.set_finished(true);
  593. }
  594. chunked_graph_proto_list.push_back(chunk);
  595. }
  596. } else {
  597. chunk.set_buffer(str);
  598. chunk.set_finished(true);
  599. chunked_graph_proto_list.push_back(chunk);
  600. }
  601. }
  602. EventReply reply = grpc_client_->SendMultiGraphs(chunked_graph_proto_list);
  603. if (reply.status() != reply.OK) {
  604. MS_LOG(ERROR) << "Error: SendGraph failed";
  605. }
  606. // enter command loop, wait and process commands
  607. CommandLoop();
  608. }
  609. void Debugger::CommandLoop() {
  610. // prepare metadata
  611. std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
  612. Metadata metadata;
  613. metadata.set_device_name(device_name);
  614. metadata.set_cur_step(num_step_);
  615. metadata.set_backend(device_target_);
  616. metadata.set_cur_node(cur_name_);
  617. metadata.set_training_done(training_done_);
  618. // loop exit flag
  619. bool run = false;
  620. int num_wait_fail = 0;
  621. const int max_num_wait_fail = 5;
  622. while (!run) {
  623. // wait for command
  624. EventReply reply = grpc_client_->WaitForCommand(metadata);
  625. if (reply.status() != reply.OK) {
  626. MS_LOG(ERROR) << "Error: WaitForCommand failed";
  627. num_wait_fail++;
  628. if (num_wait_fail > max_num_wait_fail) {
  629. MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session.";
  630. MS_LOG(ERROR) << "Failed to connect to MindInsight debugger server. Please check the config "
  631. "of debugger host and port.";
  632. Exit();
  633. run = true;
  634. } else {
  635. MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
  636. << num_wait_fail << "s";
  637. std::this_thread::sleep_for(std::chrono::seconds(num_wait_fail));
  638. }
  639. continue;
  640. }
  641. // get type of the command in reply
  642. DebuggerCommand cmd = GetCommand(reply);
  643. if (cmd == DebuggerCommand::kUnknownCMD) {
  644. MS_LOG(DEBUG) << "Debug: debugger received unknown command";
  645. continue;
  646. }
  647. MS_LOG(INFO) << "received command: ";
  648. switch (cmd) {
  649. case DebuggerCommand::kUnknownCMD:
  650. MS_LOG(INFO) << "UnknownCMD";
  651. break;
  652. case DebuggerCommand::kExitCMD:
  653. MS_LOG(INFO) << "ExitCMD";
  654. Exit();
  655. // Used for debugger termination
  656. run = true;
  657. break;
  658. case DebuggerCommand::kRunCMD:
  659. ProcessRunCMD(reply);
  660. if (GetRunLevel(reply) != "recheck") {
  661. // exit loop
  662. run = true;
  663. }
  664. break;
  665. case DebuggerCommand::kSetCMD:
  666. ProcessKSetCMD(reply);
  667. break;
  668. case DebuggerCommand::kViewCMD:
  669. ProcessKViewCMD(reply);
  670. break;
  671. case DebuggerCommand::kVersionMatchedCMD:
  672. MS_LOG(ERROR) << "Received unexpected Version Matched CMD from Mindinsight.";
  673. Exit();
  674. break;
  675. default:
  676. MS_LOG(ERROR) << "Received unknown CMD from Mindinsight";
  677. Exit();
  678. break;
  679. }
  680. }
  681. }
  682. void Debugger::ProcessRunCMD(const EventReply &reply) {
  683. MS_LOG(INFO) << "RunCMD";
  684. if (GetRunLevel(reply) == "recheck") {
  685. MS_LOG(INFO) << "rechecking all watchpoints";
  686. SendWatchpoints(CheckWatchpoints("", nullptr, true));
  687. } else {
  688. // no longer the initial suspension.
  689. initial_suspend_ = false;
  690. // print run cmd content
  691. // get run_level and node_name
  692. run_level_ = GetRunLevel(reply);
  693. node_name_ = GetNodeName(reply);
  694. MS_LOG(INFO) << "run_level: " << run_level_;
  695. MS_LOG(INFO) << "node_name_: " << node_name_;
  696. }
  697. }
  698. void Debugger::ProcessKSetCMD(const EventReply &reply) {
  699. MS_LOG(INFO) << "SetCMD";
  700. MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
  701. MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
  702. if (GetWatchpointDelete(reply)) {
  703. MS_LOG(INFO) << "Deleting watchpoint";
  704. RemoveWatchpoint(GetWatchpointID(reply));
  705. } else {
  706. MS_LOG(INFO) << "Setting watchpoint";
  707. MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
  708. ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
  709. for (const auto &node : recieved_nodes) {
  710. MS_LOG(INFO) << "node name: " << node.node_name();
  711. MS_LOG(INFO) << "node type: " << node.node_type();
  712. }
  713. ProtoVector<WatchCondition_Parameter> parameters = GetParameters(reply);
  714. for (const auto &parameter : parameters) {
  715. MS_LOG(INFO) << "parameter name: " << parameter.name();
  716. MS_LOG(INFO) << "parameter is disabled: " << parameter.disabled();
  717. MS_LOG(INFO) << "parameter value: " << parameter.value();
  718. }
  719. SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply), GetParameters(reply));
  720. }
  721. }
  722. void Debugger::ProcessKViewCMD(const EventReply &reply) {
  723. MS_LOG(INFO) << "ViewCMD";
  724. // print view cmd content
  725. ProtoVector<TensorProto> received_tensors = GetTensors(reply);
  726. for (auto received_tensor : received_tensors) {
  727. MS_LOG(INFO) << "tensor node name: " << received_tensor.node_name();
  728. MS_LOG(INFO) << "tensor slot: " << received_tensor.slot();
  729. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << received_tensor.finished() << std::noboolalpha;
  730. MS_LOG(INFO) << "tensor iter: " << received_tensor.iter();
  731. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << received_tensor.truncate() << std::noboolalpha;
  732. }
  733. MS_LOG(INFO) << "Sending tensors";
  734. std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
  735. // print view cmd reply
  736. for (auto tensor : tensors) {
  737. MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
  738. MS_LOG(INFO) << "tensor slot: " << tensor.slot();
  739. MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
  740. MS_LOG(INFO) << "tensor iter: " << tensor.iter();
  741. MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
  742. MS_LOG(INFO) << "tensor dims: ";
  743. for (auto dim : tensor.dims()) {
  744. MS_LOG(INFO) << dim << ",";
  745. }
  746. MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
  747. }
  748. EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
  749. if (send_tensors_reply.status() != debugger::EventReply::OK) {
  750. MS_LOG(ERROR) << "Error: SendTensors failed";
  751. }
  752. }
  753. void AddTensorProtoInfo(TensorProto *tensor_item, const TensorProto &tensor) {
  754. tensor_item->set_node_name(tensor.node_name());
  755. tensor_item->set_slot(tensor.slot());
  756. tensor_item->set_iter(tensor.iter());
  757. tensor_item->set_truncate(tensor.truncate());
  758. tensor_item->clear_tensor_content();
  759. tensor_item->clear_data_type();
  760. tensor_item->clear_dims();
  761. }
  762. void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id,
  763. const ProtoVector<WatchCondition_Parameter> &parameters) {
  764. std::vector<std::tuple<std::string, bool>> check_node_list;
  765. std::vector<DebugServices::parameter_t> parameter_list;
  766. std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
  767. [](const WatchNode &node) -> std::tuple<std::string, bool> {
  768. return make_tuple(node.node_name(), node.node_type() == "scope");
  769. });
  770. std::transform(
  771. parameters.begin(), parameters.end(), std::back_inserter(parameter_list),
  772. [](const WatchCondition_Parameter &parameter) -> DebugServices::parameter_t {
  773. return DebugServices::parameter_t{parameter.name(), parameter.disabled(), parameter.value(), parameter.hit()};
  774. });
  775. debug_services_->AddWatchpoint(id, condition.condition(), condition.value(), check_node_list, parameter_list);
  776. }
  777. void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
  778. std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
  779. std::vector<std::string> name;
  780. std::vector<std::string> ret_name;
  781. std::vector<char *> data_ptr;
  782. std::vector<ssize_t> data_size;
  783. std::vector<unsigned int> dtype;
  784. std::vector<std::vector<int64_t>> shape;
  785. std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
  786. // ret_name will contain tensor names that are found in TensorLoader
  787. // items in ret_name will be in the same order with tensors if found
  788. debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
  789. std::list<TensorProto> tensor_list;
  790. size_t result_index = 0;
  791. for (auto tensor : tensors) {
  792. ssize_t size_iter = 0;
  793. if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
  794. TensorProto tensor_item;
  795. tensor_item.set_finished(true);
  796. AddTensorProtoInfo(&tensor_item, tensor);
  797. tensor_list.push_back(tensor_item);
  798. continue;
  799. }
  800. ssize_t tensor_size = data_size[result_index];
  801. while (size_iter < tensor_size) {
  802. ssize_t chunk_size = g_chunk_size;
  803. TensorProto tensor_item;
  804. tensor_item.set_finished(false);
  805. if (tensor_size - size_iter <= g_chunk_size) {
  806. chunk_size = tensor_size - size_iter;
  807. tensor_item.set_finished(true);
  808. }
  809. AddTensorProtoInfo(&tensor_item, tensor);
  810. // return empty tensor if didn't find the requested tensor
  811. tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
  812. tensor_item.set_data_type((debugger::DataType)dtype[result_index]);
  813. for (auto &elem : shape[result_index]) {
  814. tensor_item.add_dims(elem);
  815. }
  816. // add tensor to result list and increment result_index to check next item in ret_name
  817. tensor_list.push_back(tensor_item);
  818. if (size_iter > INT_MAX - g_chunk_size) {
  819. MS_EXCEPTION(ValueError) << size_iter << " + " << g_chunk_size << " would lead to integer overflow!";
  820. }
  821. size_iter += g_chunk_size;
  822. }
  823. result_index++;
  824. }
  825. return tensor_list;
  826. }
  827. void Debugger::Exit() {
  828. // clear resource before exit
  829. // debugger will notify main thread to exit because main thread can only exit at step boundary
  830. pipeline::ExecutorPy::DebugTerminate(true);
  831. }
  832. std::list<WatchpointHit> Debugger::CheckWatchpoints(const std::string &watchnode, const CNodePtr &kernel,
  833. bool recheck) {
  834. std::vector<std::string> name;
  835. std::vector<std::string> slot;
  836. std::vector<int> condition;
  837. std::vector<unsigned int> watchpoint_id;
  838. std::vector<std::string> overflow_ops;
  839. std::vector<std::vector<DebugServices::parameter_t>> parameters;
  840. std::vector<int32_t> error_codes;
  841. std::vector<std::shared_ptr<TensorData>> tensor_list;
  842. if (watchnode.empty()) {
  843. tensor_list = debug_services_->GetTensor();
  844. } else {
  845. tensor_list = debug_services_->GetNodeTensor(kernel);
  846. }
  847. std::vector<std::string> file_list;
  848. MS_LOG(INFO) << "checkwatchpoints call for step " << num_step_;
  849. debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, &parameters, &error_codes, overflow_ops,
  850. file_list, &tensor_list, initial_suspend_, watchnode.empty(), recheck);
  851. std::list<WatchpointHit> hits;
  852. for (unsigned int i = 0; i < name.size(); i++) {
  853. WatchpointHit hit;
  854. std::vector<DebugServices::parameter_t> &parameter = parameters[i];
  855. hit.set_id(watchpoint_id[i]);
  856. hit.set_error_code(error_codes[i]);
  857. // here TensorProto act as a tensor indicator, not sending tensor content
  858. TensorProto *tensor_item = hit.mutable_tensor();
  859. tensor_item->set_node_name(name[i]);
  860. tensor_item->set_slot(slot[i]);
  861. tensor_item->set_finished(true);
  862. WatchCondition *condition_item = hit.mutable_watch_condition();
  863. condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
  864. for (const auto &p : parameter) {
  865. auto x = condition_item->mutable_params()->Add();
  866. x->set_name(p.name);
  867. x->set_disabled(p.disabled);
  868. x->set_value(p.value);
  869. x->set_hit(p.hit);
  870. x->set_actual_value(p.actual_value);
  871. }
  872. hits.push_back(hit);
  873. }
  874. return hits;
  875. }
  876. void Debugger::SendWatchpoints(const std::list<WatchpointHit> &points) {
  877. // send info about watchpoint
  878. if (!points.empty()) {
  879. EventReply reply = grpc_client_->SendWatchpointHits(points);
  880. if (reply.status() != reply.OK) {
  881. MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
  882. }
  883. }
  884. }
  885. bool Debugger::DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath,
  886. const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type,
  887. TypeId device_type, const std::string &addr_format, size_t slot) const {
  888. return debug_services_.get()->DumpTensorToFile(tensor_name, trans_flag, filepath, host_fmt, host_shape, host_type,
  889. device_type, addr_format, slot);
  890. }
  891. bool Debugger::DebugServicesIsWatchPoint(const std::string &kernel_name, const CNodePtr &kernel) const {
  892. return debug_services_.get()->IsWatchPoint(kernel_name, kernel);
  893. }
  894. void Debugger::EmptyTensor() { debug_services_.get()->EmptyTensor(); }
  895. void Debugger::SetTensorLoaderIterNum(uint32_t iter_num) { debug_services_.get()->SetTensorLoaderIterNum(iter_num); }
  896. void Debugger::EmptyPrevTensor() { debug_services_.get()->EmptyPrevTensor(); }
  897. uint32_t Debugger::GetTensorLoaderIterNum() const { return debug_services_.get()->GetTensorLoaderIterNum(); }
  898. bool Debugger::LoadNewTensor(const std::shared_ptr<TensorData> &tensor, bool keep_prev) {
  899. return debug_services_.get()->LoadNewTensor(tensor, keep_prev);
  900. }
  901. bool Debugger::debugger_enabled() const { return debugger_enabled_; }
  902. DebuggerCommand GetCommand(const EventReply &reply) {
  903. DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
  904. switch (reply.cmd_case()) {
  905. case debugger::EventReply::CmdCase::kExit:
  906. cmd = DebuggerCommand::kExitCMD;
  907. break;
  908. case debugger::EventReply::CmdCase::kRunCmd:
  909. cmd = DebuggerCommand::kRunCMD;
  910. break;
  911. case debugger::EventReply::CmdCase::kSetCmd:
  912. cmd = DebuggerCommand::kSetCMD;
  913. break;
  914. case debugger::EventReply::CmdCase::kViewCmd:
  915. cmd = DebuggerCommand::kViewCMD;
  916. break;
  917. case debugger::EventReply::CmdCase::kVersionMatched:
  918. cmd = DebuggerCommand::kVersionMatchedCMD;
  919. break;
  920. default:
  921. MS_LOG(DEBUG) << "Debug: UnknownCMD";
  922. break;
  923. }
  924. return cmd;
  925. }
  926. ProtoVector<WatchCondition_Parameter> GetParameters(const EventReply &reply) {
  927. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  928. MS_LOG(ERROR) << "Error: Can not get Parameters from command. Returning default value: ProtoVector<Parameter>().";
  929. return ProtoVector<WatchCondition_Parameter>();
  930. }
  931. return reply.set_cmd().watch_condition().params();
  932. }
  933. ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
  934. if (!reply.has_set_cmd()) {
  935. MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
  936. return ProtoVector<WatchNode>();
  937. }
  938. return reply.set_cmd().watch_nodes();
  939. }
  940. std::string GetRunLevel(const EventReply &reply) {
  941. if (!reply.has_run_cmd()) {
  942. MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
  943. "";
  944. return "";
  945. }
  946. return reply.run_cmd().run_level();
  947. }
  948. std::string GetNodeName(const EventReply &reply) {
  949. if (!reply.has_run_cmd()) {
  950. MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
  951. "";
  952. return "";
  953. }
  954. return reply.run_cmd().node_name();
  955. }
  956. WatchCondition GetWatchcondition(const EventReply &reply) {
  957. if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
  958. MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
  959. return WatchCondition();
  960. }
  961. return reply.set_cmd().watch_condition();
  962. }
  963. int32_t GetWatchpointID(const EventReply &reply) {
  964. if (!reply.has_set_cmd()) {
  965. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
  966. return 0;
  967. }
  968. return reply.set_cmd().id();
  969. }
  970. bool GetWatchpointDelete(const EventReply &reply) {
  971. if (!reply.has_set_cmd()) {
  972. MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
  973. return false;
  974. }
  975. return reply.set_cmd().delete_();
  976. }
  977. ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
  978. if (!reply.has_view_cmd()) {
  979. MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
  980. return ProtoVector<TensorProto>();
  981. }
  982. return reply.view_cmd().tensors();
  983. }
  984. std::string GetTensorFullName(const TensorProto &tensor) {
  985. string node_name = tensor.node_name();
  986. if (tensor.truncate()) {
  987. // scopes in node name are separated by '/'
  988. // use the name without scope if truncate is true
  989. std::size_t found = node_name.find_last_of("/");
  990. node_name = node_name.substr(found + 1);
  991. }
  992. return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
  993. }
  994. bool GetMiVersionMatched(const EventReply &reply) { return reply.version_matched(); }
  995. bool Debugger::partial_memory() const { return partial_memory_; }
  996. void Debugger::SetCurNode(const std::string &cur_name) {
  997. // access lock for public method
  998. std::lock_guard<std::mutex> a_lock(access_lock_);
  999. cur_name_ = cur_name;
  1000. }
  1001. std::string Debugger::run_level() const { return run_level_; }
  1002. void Debugger::SetStepNum(int32_t cur_num_step) {
  1003. // access lock for public method
  1004. std::lock_guard<std::mutex> a_lock(access_lock_);
  1005. num_step_ = cur_num_step;
  1006. }
  1007. int32_t Debugger::step_num() const { return num_step_; }
  1008. void Debugger::SetTrainingDone(bool training_done) { training_done_ = training_done; }
  1009. bool Debugger::CheckPort(const std::string &port) const {
  1010. int num = 0;
  1011. const int min_port_num = 1;
  1012. const int max_port_num = 65535;
  1013. const int decimal = 10;
  1014. if (port[0] == '0' && port[1] != '\0') return false;
  1015. int i = 0;
  1016. while (port[i] != '\0') {
  1017. if (port[i] < '0' || port[i] > '9') return false;
  1018. num = num * decimal + (port[i] - '0');
  1019. if (num > max_port_num) return false;
  1020. i++;
  1021. }
  1022. if (num < min_port_num) return false;
  1023. return true;
  1024. }
  1025. uint32_t Debugger::GetFirstRunGraphId() const { return rungraph_id_list_.front(); }
  1026. void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
  1027. MS_EXCEPTION_IF_NULL(anf_node);
  1028. if (!anf_node->isa<Parameter>() && !anf_node->isa<ValueNode>()) {
  1029. return;
  1030. }
  1031. // When MindRT is used, only ValueNodes and ParameterWeights can be loaded from device to host
  1032. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) && (device_target_ == kGPUDevice)) {
  1033. if (!anf_node->isa<ValueNode>() &&
  1034. !(anf_node->isa<Parameter>() && AnfAlgo::IsParameterWeight(anf_node->cast<ParameterPtr>()))) {
  1035. return;
  1036. }
  1037. }
  1038. // for parameters and value nodes, set its execution order to be 0;
  1039. int exec_order = 0;
  1040. std::string node_name = GetKernelNodeName(anf_node);
  1041. GetFileKernelName(NOT_NULL(&node_name));
  1042. // check if output adde exists, if not, return;
  1043. if (!AnfAlgo::OutputAddrExist(anf_node, output_index)) {
  1044. return;
  1045. }
  1046. auto addr = AnfAlgo::GetOutputAddr(anf_node, output_index);
  1047. MS_EXCEPTION_IF_NULL(addr);
  1048. auto type = AnfAlgo::GetOutputInferDataType(anf_node, output_index);
  1049. if (!IsTypeDebuggerSupported(type)) {
  1050. return;
  1051. }
  1052. auto format = kOpFormat_DEFAULT;
  1053. string tensor_name = node_name + ':' + "0";
  1054. ShapeVector int_shapes = trans::GetRuntimePaddingShape(anf_node, output_index);
  1055. bool keep_prev;
  1056. if (anf_node->isa<Parameter>()) {
  1057. keep_prev = true;
  1058. debug_services_->MoveTensorCurrentToPrev(tensor_name);
  1059. } else {
  1060. keep_prev = false;
  1061. }
  1062. bool ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, keep_prev);
  1063. if (!ret) {
  1064. MS_LOG(ERROR) << "LoadMemToHost:"
  1065. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1066. }
  1067. }
  1068. void Debugger::LoadParametersAndConst() {
  1069. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  1070. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1071. // load parameters
  1072. MS_LOG(INFO) << "Start to load Parameters!";
  1073. const auto &parameters = graph_ptr_->inputs();
  1074. for (auto &item : parameters) {
  1075. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
  1076. }
  1077. // load value nodes
  1078. // get all constant avlues from the graph
  1079. MS_LOG(INFO) << "Start to load value nodes!";
  1080. const auto value_nodes = graph_ptr_->graph_value_nodes();
  1081. for (auto &item : value_nodes) {
  1082. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
  1083. }
  1084. }
  1085. void Debugger::LoadParametersAndConst(const KernelGraphPtr &graph) {
  1086. if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
  1087. MS_EXCEPTION_IF_NULL(graph);
  1088. // load parameters
  1089. MS_LOG(INFO) << "Start to load Parameters for graph " << graph->graph_id();
  1090. const auto &parameters = graph_ptr_->inputs();
  1091. for (auto &item : parameters) {
  1092. LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
  1093. }
  1094. // load value nodes
  1095. // get all constant avlues from the graph
  1096. MS_LOG(INFO) << "Start to load value nodes for graph " << graph->graph_id();
  1097. const auto value_nodes = graph_ptr_->graph_value_nodes();
  1098. for (auto &item : value_nodes) {
  1099. LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
  1100. }
  1101. }
  1102. void Debugger::LoadGraphOutputs() {
  1103. if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
  1104. MS_EXCEPTION_IF_NULL(graph_ptr_);
  1105. const auto &apply_kernels = graph_ptr_->execution_order();
  1106. // for kernels, execution order starts from 1
  1107. int exec_order = 1;
  1108. for (const auto &node : apply_kernels) {
  1109. MS_EXCEPTION_IF_NULL(node);
  1110. std::string kernel_name = GetKernelNodeName(node);
  1111. auto output_size = AnfAlgo::GetOutputTensorNum(node);
  1112. if (partial_memory_) {
  1113. if (!debug_services_->IsWatchPoint(kernel_name, node)) {
  1114. continue;
  1115. }
  1116. }
  1117. for (size_t j = 0; j < output_size; ++j) {
  1118. if (!AnfAlgo::OutputAddrExist(node, j)) {
  1119. MS_LOG(INFO) << "Cannot find output addr for slot " << j << " for " << kernel_name;
  1120. continue;
  1121. }
  1122. auto addr = AnfAlgo::GetOutputAddr(node, j);
  1123. MS_EXCEPTION_IF_NULL(addr);
  1124. auto type = AnfAlgo::GetOutputInferDataType(node, j);
  1125. if (!IsTypeDebuggerSupported(type)) {
  1126. continue;
  1127. }
  1128. auto format = kOpFormat_DEFAULT;
  1129. string tensor_name = kernel_name + ':' + std::to_string(j);
  1130. ShapeVector int_shapes = trans::GetRuntimePaddingShape(node, j);
  1131. auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false);
  1132. if (!ret) {
  1133. MS_LOG(ERROR) << "LoadMemToHost:"
  1134. << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
  1135. }
  1136. }
  1137. exec_order = exec_order + 1;
  1138. }
  1139. }
  1140. void Debugger::UpdateStepNum(const session::KernelGraph *graph) {
  1141. // update step number if we are processing the first graph (to support multigraph)
  1142. if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()) &&
  1143. (graph->graph_id() == debugger_->GetFirstRunGraphId())) {
  1144. // access lock for public method
  1145. std::lock_guard<std::mutex> a_lock(access_lock_);
  1146. ++num_step_;
  1147. }
  1148. }
  1149. void Debugger::UpdateStepNumGPU() {
  1150. // UpdateStepNum with DebugActor::DebugOnStepEnd
  1151. if (device_target_ == kGPUDevice && (debugger_enabled_ || DumpDataEnabledIteration())) {
  1152. // access lock for public method
  1153. std::lock_guard<std::mutex> a_lock(access_lock_);
  1154. ++num_step_;
  1155. }
  1156. }
  1157. void Debugger::ClearCurrentData() {
  1158. if ((device_target_ == kGPUDevice) && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration())) {
  1159. if (debug_services_) {
  1160. debug_services_->EmptyCurrentTensor();
  1161. } else {
  1162. MS_LOG(ERROR) << "debug_services_ is nullptr";
  1163. }
  1164. }
  1165. }
  1166. bool Debugger::TensorExistsInCurrent(const std::string &tensor_name) {
  1167. return debug_services_->TensorExistsInCurrent(tensor_name);
  1168. }
  1169. } // namespace mindspore